我有以下数据生成器。它工作并返回预期数据。除了我将 epochs 或 batchsize 设置为等于什么之外,它只执行 12 次迭代然后给出错误(见下文)我曾尝试更改纪元数和批量大小。# initialize the number of epochs to train for and batch sizeNUM_EPOCHS = 10 #100BS = 32 #64 #32NUM_TRAIN_IMAGES = len(train_uxo_scrap)NUM_TEST_IMAGES = len(test_uxo_scrap)def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'): cnt=0 while True: images = [] labels = [] #cnt=0 while len(images) < batchsize and cnt < len(imgfns): images.append(imgfns[cnt]) labels.append(imglabels[cnt]) cnt=cnt+1 print(images) print(labels) print('********** cnt = ', cnt) yield images, labelstrain_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS, class_mode='binary')valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS, class_mode='binary')# train the networkH = model.fit_generator( train_gen, steps_per_epoch=NUM_TRAIN_IMAGES // BS, validation_data=valid_gen, validation_steps=NUM_TEST_IMAGES // BS, epochs=NUM_EPOCHS)我希望代码在每次迭代中通过 32 个样本经历 10 个时期。我每次迭代得到 32 个样本,但在第一个时期我只得到 12 个迭代,然后我得到以下错误。无论设置什么批次大小或纪元,都会发生这种情况。---------------------------------------------------------------------------IndexError Traceback (most recent call last)<ipython-input-83-26f81894773d> in <module>() 5 validation_data=valid_gen, 6 validation_steps=NUM_TEST_IMAGES // BS,----> 7 epochs=NUM_EPOCHS)~\AppData\Local\Continuum\anaconda3\envs\dltf1\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 1424 use_multiprocessing=use_multiprocessing, 1425 shuffle=shuffle,-> 1426 initial_epoch=initial_epoch) 1427 1428 def evaluate_generator(self,
1 回答
qq_笑_17
TA贡献1818条经验 获得超7个赞
看看这是否有效:
def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):
while True:
start = 0
end = batchsize
while start < len(imgfns):
x = imgfns[start:end]
y = imglabels[start:end]
yield x, y
start += batchsize
end += batchsize
假设imgfns, imglabels是 numpy 数组。
添加回答
举报
0/150
提交
取消
