3 回答

TA贡献1818条经验 获得超3个赞
您可以通过在调用之前将标签转换为数组来避免此错误model.fit():
train_x = np.asarray(train_x)
train_y = np.asarray(train_y)
validation_x = np.asarray(validation_x)
validation_y = np.asarray(validation_y)

TA贡献1829条经验 获得超7个赞
如果您在处理从该类继承的自定义生成器时遇到此问题keras.utils.Sequence,您可能必须确保不要混合使用 aKeras或tensorflow - Keras-import。
当您必须切换到以前的tensorflow版本以实现兼容性时(例如 with cuDNN),这种情况尤其可能发生。
例如,如果您将其与tensorflow-version > 2 一起使用...
from keras.utils import Sequence
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
...
def __len__(self):
...
def __getitem__(self, idx):
return ...
...但是您实际上尝试将此生成器安装在tensorflow-version < 2 中,您必须确保Sequence从该版本导入 -class,例如:
keras = tf.compat.v1.keras
Sequence = keras.utils.Sequence
class generatorClass(Sequence):
...
添加回答
举报