为了账号安全,请及时绑定邮箱和手机立即绑定

图像生成器缺少 unet keras 的位置参数

图像生成器缺少 unet keras 的位置参数

当年话下 2021-08-24 15:24:33
我不断收到以下错误下面的代码,当我尝试训练模型:TypeError: fit_generator() missing 1 required positional argument: 'generator'。对于我的生活,我无法弄清楚是什么导致了这个错误。x_train 是一个形状为 (400, 256, 256, 3) 的 rgb 图像,对于 y_train,我有 10 个输出类使其具有形状 (400, 256, 256, 10)。这里出了什么问题?如有必要,可通过以下链接下载数据:https : //www49.zippyshare.com/v/5pR3GPv3/file.htmlimport skimagefrom skimage.io import imread, imshow, imread_collection, concatenate_imagesfrom skimage.transform import resizefrom skimage.morphology import labelimport numpy as npimport matplotlib.pyplot as pltfrom keras.models import Modelfrom keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropoutfrom keras.optimizers import Adamfrom keras.callbacks import ModelCheckpoint, LearningRateSchedulerfrom keras import backend as Kfrom sklearn.metrics import jaccard_similarity_scorefrom shapely.geometry import MultiPolygon, Polygonimport shapely.wktimport shapely.affinityfrom collections import defaultdictfrom keras.preprocessing.image import ImageDataGeneratorfrom keras.utils.np_utils import to_categoricalfrom keras import utils as np_utilsimport osfrom keras.preprocessing.image import ImageDataGeneratorgen = ImageDataGenerator()#Importing image and labelslabels = skimage.io.imread("ede_subset_293_wegen.tif")images = skimage.io.imread("ede_subset_293_20180502_planetscope.tif")[...,:-1]#scaling imageimg_scaled = images / images.max()#Make non-roads 0labels[labels == 15] = 0#Resizing image and mask and labelsimg_scaled_resized = img_scaled[:6400, :6400 ]print(img_scaled_resized.shape)labels_resized = labels[:6400, :6400]print(labels_resized.shape)#splitting imagessplit_img = [    np.split(array, 25, axis=0)     for array in np.split(img_scaled_resized, 25, axis=1)]split_img[-1][-1].shape#splitting labelssplit_labels = [    np.split(array, 25, axis=0)     for array in np.split(labels_resized, 25, axis=1)]
查看完整描述

1 回答

?
宝慕林4294392

TA贡献2021条经验 获得超8个赞

您的代码中有一些错误,但考虑到您的错误:

类型错误:fit_generator() 缺少 1 个必需的位置参数:'generator'

这是因为 fit_generator 调用 XYaugmentGenerator 但内部没有调用增强生成器。

gen.flow(...

不会工作,因为未声明 gen。您应该将 image_datagen 重命名为 gen 为:

gen = ImageDataGenerator(**data_gen_args)

或者,用 image_datagen 替换 gen

genX1 = image_datagen.flow(X1, y, batch_size=batch_size, seed=seed)
genX2 = image_datagen.flow(y, X1, batch_size=batch_size, seed=seed)


查看完整回答
反对 回复 2021-08-24
  • 1 回答
  • 0 关注
  • 172 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信