为了账号安全,请及时绑定邮箱和手机立即绑定

如何正确组合TensorFlow的数据集API和Keras?

如何正确组合TensorFlow的数据集API和Keras?

Keras的fit_generator()模型方法期望生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组。该文档似乎暗示着,如果我将Dataset迭代器简单地包装在生成器中,并确保将Tensors转换为NumPy数组,那我应该很好。这段代码给我一个错误:import numpy as npimport osimport keras.backend as Kfrom keras.layers import Dense, Inputfrom keras.models import Modelimport tensorflow as tffrom tensorflow.contrib.data import Datasetos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'with tf.Session() as sess:    def create_data_generator():        dat1 = np.arange(4).reshape(-1, 1)        ds1 = Dataset.from_tensor_slices(dat1).repeat()        dat2 = np.arange(5, 9).reshape(-1, 1)        ds2 = Dataset.from_tensor_slices(dat2).repeat()        ds = Dataset.zip((ds1, ds2)).batch(4)        iterator = ds.make_one_shot_iterator()        while True:            next_val = iterator.get_next()            yield sess.run(next_val)datagen = create_data_generator()input_vals = Input(shape=(1,))output = Dense(1, activation='relu')(input_vals)model = Model(inputs=input_vals, outputs=output)model.compile('rmsprop', 'mean_squared_error')model.fit_generator(datagen, steps_per_epoch=1, epochs=5,                    verbose=2, max_queue_size=2)这是我得到的错误:Using TensorFlow backend.Epoch 1/5Exception in thread Thread-1:Traceback (most recent call last):  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__    fetch, allow_tensor=True, allow_operation=True))  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)奇怪的是,next(datagen)在我初始化的位置之后直接添加包含一行datagen的代码会使代码运行正常,没有错误。为什么我的原始代码不起作用?将行添加到代码中后,为什么它开始起作用?是否有一种更有效的方式将TensorFlow的Dataset API与Keras结合使用,而无需将Tensors转换为NumPy数组然后再次返回?
查看完整描述

3 回答

  • 3 回答
  • 0 关注
  • 957 浏览

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信