为了账号安全,请及时绑定邮箱和手机立即绑定

Generator 只进行 12 次迭代 - 无论批量大小

Generator 只进行 12 次迭代 - 无论批量大小

PIPIONE 2022-01-05 11:29:38
我有以下数据生成器。它工作并返回预期数据。除了我将 epochs 或 batchsize 设置为等于什么之外,它只执行 12 次迭代然后给出错误(见下文)我曾尝试更改纪元数和批量大小。# initialize the number of epochs to train for and batch sizeNUM_EPOCHS = 10 #100BS = 32 #64 #32NUM_TRAIN_IMAGES = len(train_uxo_scrap)NUM_TEST_IMAGES = len(test_uxo_scrap)def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):    cnt=0    while True:        images = []        labels = []        #cnt=0        while len(images) < batchsize and cnt < len(imgfns):            images.append(imgfns[cnt])            labels.append(imglabels[cnt])            cnt=cnt+1        print(images)        print(labels)        print('********** cnt = ', cnt)        yield images, labelstrain_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS, class_mode='binary')valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS, class_mode='binary')# train the networkH = model.fit_generator(    train_gen,    steps_per_epoch=NUM_TRAIN_IMAGES // BS,    validation_data=valid_gen,    validation_steps=NUM_TEST_IMAGES // BS,    epochs=NUM_EPOCHS)我希望代码在每次迭代中通过 32 个样本经历 10 个时期。我每次迭代得到 32 个样本,但在第一个时期我只得到 12 个迭代,然后我得到以下错误。无论设置什么批次大小或纪元,都会发生这种情况。---------------------------------------------------------------------------IndexError                                Traceback (most recent call last)<ipython-input-83-26f81894773d> in <module>()      5     validation_data=valid_gen,      6     validation_steps=NUM_TEST_IMAGES // BS,----> 7     epochs=NUM_EPOCHS)~\AppData\Local\Continuum\anaconda3\envs\dltf1\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)   1424         use_multiprocessing=use_multiprocessing,   1425         shuffle=shuffle,-> 1426         initial_epoch=initial_epoch)   1427    1428   def evaluate_generator(self,
查看完整描述

1 回答

?
qq_笑_17

TA贡献1818条经验 获得超7个赞

看看这是否有效:


def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):

    while True:

        start = 0

        end = batchsize


        while start  < len(imgfns): 

            x = imgfns[start:end]

            y = imglabels[start:end]

            yield x, y


            start += batchsize

            end += batchsize

假设imgfns, imglabels是 numpy 数组。


查看完整回答
反对 回复 2022-01-05
  • 1 回答
  • 0 关注
  • 176 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号