1 回答
TA贡献1802条经验 获得超6个赞
在不了解模型详细信息的情况下,以下代码段可能会有所帮助:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
# Train your initial model
def get_initial_model():
...
return model
model = get_initial_model()
model.fit(...)
model.save_weights('initial_model_weights.h5')
# Use Model API to create another model, built on your initial model
initial_model = get_initial_model()
initial_model.load_weights('initial_model_weights.h5')
nn_input = Input(...)
x = initial_model(nn_input)
x = Dense(...)(x) # This is the additional layer, connected to your initial model
nn_output = Dense(...)(x)
# Combine your model
full_model = Model(inputs=nn_input, outputs=nn_output)
# Compile and train as usual
full_model.compile(...)
full_model.fit(...)
基本上,你训练你的初始模型,保存它。然后再次重新加载它,并使用 API 将其与其他层包装在一起。如果您不熟悉API,可以在此处查看Keras文档(afaik API对于Tensorflow.Keras 2.0保持不变)。ModelModel
请注意,您需要检查初始模型的最终层的输出形状是否与其他层兼容(例如,如果您只是执行特征提取,则可能需要从初始模型中删除最终的密集层)。
添加回答
举报
