为了账号安全,请及时绑定邮箱和手机立即绑定

如何将训练数据分成更小的批次来解决内存错误

如何将训练数据分成更小的批次来解决内存错误

白衣非少年 2023-07-27 10:24:52
我有一个包含两个多维数组的训练数据 [prev_sentences, current_sentences],当我使用简单的 model.fit 方法时,它给了我内存错误。我现在想使用 fit_generator,但我不知道如何将训练数据分成批次以输入 model.fit_generator。训练数据的形状为(111356,126,1024)和(111356,126,1024),y_train形状为(111356,19)。这是简单拟合方法的代码行。history=model.fit([previous_sentences, current_sentences], y_train,                   epochs=15,batch_size=256,                   shuffle = False, verbose = 1,                   validation_split=0.2,                   class_weight=custom_weight_dict,                   callbacks=[early_stopping_cb])我从未使用过 fit_generator 和数据生成器,所以我不知道如何拆分这些训练数据以使用 fit_generator。任何人都可以帮助我使用 fit_generator 创建批次吗?
查看完整描述

2 回答

?
慕村9548890

TA贡献1884条经验 获得超4个赞

您只需要拨打:

model.fit_generator(generator, steps_per_epoch)

其中steps_per_epoch是通常ceil(num_samples / batch_size)并且generator是一个 python 生成器,它迭代数据并批量生成数据。每次调用生成器都应该产生batch_size许多元素。生成器的示例:

def generate_data(directory, batch_size):

    """Replaces Keras' native ImageDataGenerator."""

    i = 0

    file_list = os.listdir(directory)

    while True:

        image_batch = []

        for b in range(batch_size):

            if i == len(file_list):

                i = 0

                random.shuffle(file_list)

            sample = file_list[i]

            i += 1

            image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)

            image_batch.append((image.astype(float) - 128) / 128)


        yield np.array(image_batch)

由于这绝对是特定于问题的,因此您必须编写自己的生成器,尽管使用此模板应该很简单。


查看完整回答
反对 回复 2023-07-27
?
慕虎7371278

TA贡献1802条经验 获得超4个赞

这是将训练数据分割成小批量的数据生成器:


def generate_data(X1,X2,Y,batch_size):

  p_input=[]

  c_input=[]

  target=[]

  batch_count=0

  for i in range(len(X1)):

    p_input.append(X1[i])

    c_input.append(X2[i])

    target.append(Y[i])

    batch_count+=1

    if batch_count>batch_size:

      prev_X=np.array(p_input,dtype=np.int64)

      cur_X=np.array(c_input,dtype=np.int64)

      cur_y=np.array(target,dtype=np.int32)

      print(len(prev_X),len(cur_X))

      yield ([prev_X,cur_X],cur_y ) 

      p_input=[]

      c_input=[]

      target=[]

      batch_count=0

  return

这里是 fit_generator 函数调用而不是 model.fit 方法:


batch_size=256

epoch_steps=math.ceil(len(previous_sentences)/ batch_size)

hist = model.fit_generator(generate_data(previous_sentences,current_sentences, y_train, batch_size),

                steps_per_epoch=epoch_steps,

                callbacks = [early_stopping_cb],

                validation_data=generate_data(val_prev, val_curr,y_val,batch_size),

                validation_steps=val_steps,  class_weight=custom_weight_dict,

                 verbose=1)


查看完整回答
反对 回复 2023-07-27
  • 2 回答
  • 0 关注
  • 73 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信