为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow:具有任意维度张量的批量 TFRecord 数据集

Tensorflow:具有任意维度张量的批量 TFRecord 数据集

慕的地10843 2021-07-20 17:15:09
如何使用 TFRecordsDataset 批量处理任意形状的张量?我目前正在研究对象检测网络的输入管道,并且正在努力处理我的标签批处理。标签由边界框坐标和图像中对象的类组成。由于图像中可能有多个对象,因此标签尺寸是任意的当与工作tf.train.batch有设置的可能性dynamic_padding=True,以适应形状相同的尺寸。但是data.TFRecordDataset.batch().我想要批处理的所需形状将[batch_size, arbitrary , 4]用于我的盒子和[batch_size, arbitrary, 1]课程。def decode(serialized_example):"""Decodes the information of the TFRecords to image, label_coord, label_classesLater on will also contain the Image Sequence!:param serialized_example: Serialized Example read from the TFRecords:return: image, label_coordinates list, label_classes list"""features = {'image/shape': tf.FixedLenFeature([], tf.string),            'train/image': tf.FixedLenFeature([], tf.string),            'label/coordinates': tf.VarLenFeature(tf.float32),            'label/classes': tf.VarLenFeature(tf.string)}features = tf.parse_single_example(serialized_example, features=features)image_shape = tf.decode_raw(features['image/shape'], tf.int64)image = tf.decode_raw(features['train/image'], tf.float32)image = tf.reshape(image, image_shape)# Contains the Bounding Box coordinates in a flattened tensorlabel_coord = features['label/coordinates']label_coord = label_coord.valueslabel_coord = tf.reshape(label_coord, [1, -1, 4])# Contains the Classes of the BBox in a flattened Tensorlabel_classes = features['label/classes']label_classes = label_classes.valueslabel_classes = tf.reshape(label_classes, [1, -1, 1])return image, label_coord, label_classes    dataset = tf.data.TFRecordDataset(filename)    dataset = dataset.map(decode)    dataset = dataset.map(augment)    dataset = dataset.map(normalize)    dataset = dataset.repeat(num_epochs)    dataset = dataset.batch(batch_size)    dataset = dataset.batch(batch_size)抛出的错误是 Cannot batch tensors with different shapes in component 1. First element had shape [1,1,4] and element 1 had shape [1,7,4].目前augment和normalize函数也只是占位符。
查看完整描述

1 回答

?
小唯快跑啊

TA贡献1863条经验 获得超2个赞

事实证明,tf.data.TFRecordDataset有一个其他函数被调用padded_batch,它基本上是在做这件事tf.train.batch(dynamic_pad=True)。这很容易解决问题......


dataset = tf.data.TFRecordDataset(filename)


dataset = dataset.map(decode)

dataset = dataset.map(augment)

dataset = dataset.map(normalize)


dataset = dataset.shuffle(1000+3*batch_size)

dataset = dataset.repeat(num_epochs)

dataset = dataset.padded_batch(batch_size,

                               drop_remainder=False,

                               padded_shapes=([None, None, None],

                                              [None, 4],

                                              [None, 1])

                              )


查看完整回答
反对 回复 2021-07-28
  • 1 回答
  • 0 关注
  • 170 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信