为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow模型保存与加载

标签:
深度学习

一.保存模型

所谓的模型保存,也就是冻结(freeze)模型,将该模型的图结构和该模型的权重固化到一起。

二.加载模型

在恢复模型的时候,通过get_tensor_by_name获得模型中的变量,然后对变量进行赋值。

三.代码实例

Demo1:单纯地训练,不生成模型

#-*- coding:utf-8 -*-
import tensorflow as tf 
import numpy as np 

with tf.variable_scope('Placeholder'):
	inputs_placeholder = tf.placeholder(tf.float32,name = 'inputs_placeholder',shape = [None,10])
	labels_placeholder = tf.placeholder(tf.float32,name = 'labels_placeholder',shape = [None,1])

with tf.variable_scope('NN'):
	W1 = tf.get_variable('W1',shape = [10,1],initializer = tf.random_normal_initializer(stddev = 1e-1))
	b1 = tf.get_variable('b1',shape = [1],initializer = tf.constant_initializer(0.1))
	W2 = tf.get_variable('W2',shape = [10,1],initializer = tf.random_normal_initializer(stddev = 1e-1))
	b2 = tf.get_variable('b2',shape = [1],initializer = tf.constant_initializer(0.1))

	a1 = tf.nn.relu(tf.matmul(inputs_placeholder,W1) + b1)
	a2 = tf.nn.relu(tf.matmul(inputs_placeholder,W2) + b2)

	y  = tf.div(tf.add(a1,a2),2)

with tf.variable_scope('Loss'):
	loss = tf.reduce_sum(tf.square(y - labels_placeholder) / 2)

with tf.variable_scope('Accuracy'):
	predictions = tf.greater(y,0.5,name = 'predictions')
	correct_predictions = tf.equal(predictions,tf.cast(labels_placeholder,tf.bool),name = "correct_predictions")
	accuracy = tf.reduce_mean(tf.cast(correct_predictions,tf.float32))

adam = tf.train.AdamOptimizer(learning_rate = 1e-3)
train_op = adam.minimize(loss)

#generate_data  
inputs = np.random.choice(10,size = [10000,10])
labels = (np.sum(inputs,axis = 1) > 45).reshape(-1,1).astype(np.float32)
print('inputs.shape:',inputs.shape)
print('labels.shape:',labels.shape)

test_inputs = np.random.choice(10,size = [100,10])
test_labels = (np.sum(test_inputs,axis = 1) > 45).reshape(-1,1).astype(np.float32)
print('test_inputs.shape:',test_inputs.shape)
print('test_labels.shape:',test_labels.shape)

batch_size = 32
epochs     = 10

batches = []
for i in range(len(inputs) // batch_size):
	batch = [ inputs[batch_size * i:batch_size * i + batch_size],labels[batch_size * i:batch_size * i + batch_size]]
	batches.append(list(batch))
if (i + 1) * batch_size < len(inputs):
	batch = [inputs[batch_size*(i + 1):],labels[batch_size*(i + 1):]]
	batches.append(list(batch))

print("Number of batches: %d" % len(batches))
print("Size of full batch: %d" % len(batches[0]))
print("Size of final batch: %d" % len(batches[-1]))

global_count = 0 

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	for i in range(epochs):
		for batch in batches:
			train_loss, _ = sess.run(
				[loss,train_op],
				feed_dict = {
					inputs_placeholder:batch[0],
					labels_placeholder:batch[1]
				})
			if global_count % 100 == 0:
				acc = sess.run(accuracy, feed_dict = {
						inputs_placeholder: test_inputs,
						labels_placeholder: test_labels
					})
				print('accuracy: %f' % acc)
			global_count += 1

	acc = sess.run(accuracy,feed_dict = {
		inputs_placeholder:test_inputs,
		labels_placeholder:test_labels
	  })
	print("final accuracy: %f" % acc)

运行结果:
图片描述
图片描述

Demo2:基本的保存与加载模型
保存模型:

import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print (sess.run(w4,feed_dict))
#Prints 24 which is sum of (w1+w2)*b1 
#Now, save the graph
saver.save(sess, './model/tiny_model',global_step=1000)

运行结果:
图片描述

这里4,5,6,11行中的name=’w1′, name=’w2′, name=’bias’, name=’op_to_restore’ 千万不能省略,这是恢复还原模型的关键。

加载模型:

import tensorflow as tf
sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./model/tiny_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print (sess.run(op_to_restore,feed_dict))
#This will print 60 which is calculated

运行结果:
图片描述

ckpt文件不方便模型迁移,比如在windows上训练好的模型放在Linux环境可能加载不了,原因是里面的checkpoints中的路径参数会改变,为了更好的部署和上线,应该考虑将模型保存为pb文件,本文的方式只适合入门学习。
另外:
“Note that when the network is saved, values of the placeholders are not saved.”

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消