为了账号安全,请及时绑定邮箱和手机立即绑定

如何让数据生成器更有效率?

如何让数据生成器更有效率?

墨色风雨 2022-12-14 17:38:07
为了训练神经网络,我修改了在 YouTube 上找到的一段代码。它看起来如下:def data_generator(samples, batch_size, shuffle_data = True, resize=224):  num_samples = len(samples)  while True:    random.shuffle(samples)    for offset in range(0, num_samples, batch_size):      batch_samples = samples[offset: offset + batch_size]      X_train = []      y_train = []      for batch_sample in batch_samples:        img_name = batch_sample[0]        label = batch_sample[1]        img = cv2.imread(os.path.join(root_dir, img_name))        #img, label = preprocessing(img, label, new_height=224, new_width=224, num_classes=37)        img = preprocessing(img, new_height=224, new_width=224)        label = my_onehot_encoded(label)        X_train.append(img)        y_train.append(label)      X_train = np.array(X_train)      y_train = np.array(y_train)      yield X_train, y_train现在,我尝试使用此代码训练神经网络,训练样本大小为 105.000(图像文件包含 37 种可能性中的 8 个字符、AZ、0-9 和空格)。我使用了相对较小的批次大小(32,我认为这已经太小了)来提高效率,但是训练第一个时期的四分之一却花了很长时间(我每个时期有 826 步,花了 90 分钟199 步... steps_per_epoch = num_train_samples // batch_size)。数据生成器中包含以下功能:def shuffle_data(data):  data=random.shuffle(data)  return data我不认为我们可以使这个函数更有效或将它从生成器中排除。def preprocessing(img, new_height, new_width):  img = cv2.resize(img,(new_height, new_width))  img = img/255  return img为了预处理/调整数据大小,我使用此代码将图像设置为唯一大小,例如 (224, 224, 3)。我认为,生成器的这一部分花费的时间最多,但我看不到将其从生成器中排除的可能性(因为如果我们在批次之外调整图像的大小,我的内存将满)。#One Hot Encoding of the Labelsfrom numpy import argmax# define input string我认为,在这一部分中,可能有一种方法可以提高效率。我正在考虑从生成器中排除此代码并在生成器外部生成数组 y_train,这样生成器就不必每次都对标签进行热编码。你怎么看?还是我应该采用完全不同的方法?
查看完整描述

1 回答

?
LEATH

TA贡献1936条经验 获得超6个赞

我发现你的问题非常有趣,因为你只提供了线索。所以这是我的调查。

使用您的代码片段,我在 YouTube 上找到了GitHub 存储库和 3 部分视频教程,主要关注在 Python 中使用生成器函数的好处。数据基于这个 kaggle(我建议检查该问题的不同内核,以将您已经尝试过的方法与另一个 CNN 网络进行比较,并查看正在使用的 API)。

您不需要从头开始编写数据生成器,虽然这并不难,但是发明轮子效率不高。

尽管如此,为了解决 kaggle 的任务,该模型只需要感知单个图像,因此该模型是一个简单的深度 CNN。但据我了解,您将 8 个随机字符(类别)组合到一张图像中以一次识别多个类别。对于该任务,您需要 R-CNN 或 YOLO 作为模型。我最近刚刚为自己打开了YOLO v4,可以让它非常快速地用于特定任务。

关于您的设计和代码的一般建议。

  • 确保库使用 GPU。它节省了很多时间。(尽管我在 CPU 上非常快地重复了存储库中的花实验——大约 10 分钟,但结果预测并不比随机猜测好多少。所以完整的训练需要大量的 CPU 时间。)

  • 比较不同的版本以找到瓶颈。尝试包含 48 张图像(每个类 1 张)的数据集,增加每个类的图像数量,然后进行比较。缩小图像尺寸,改变模型结构等。

  • 在小的人工数据上测试全新模型以证明想法或使用迭代过程,从可以转换为您的任务(手写识别?)的项目开始。


查看完整回答
反对 回复 2022-12-14
  • 1 回答
  • 0 关注
  • 55 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信