1 回答

TA贡献1799条经验 获得超8个赞
使用该input_map参数,我成功地将一个仅解码 jpg 图像的新图形映射到我的原始图形的输入(此处:)node.name='image_tensor:0'。只需确保重命名name_scope解码器图的 的(此处:)decoder。之后,您可以使用 tensorflow SavedModelBuilder 保存新的连接图。
这是一个物体检测网络的例子:
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
# The export path contains the name and the version of the model
model = 'path/to/model.pb'
export_path = './output/dir/'
sigs = {}
with tf.gfile.FastGFile(model, 'rb') as f:
with tf.name_scope('decoder'):
image_str_tensor = tf.placeholder(tf.string, shape=[None], name= 'encoded_image_string_tensor')
# The CloudML Prediction API always "feeds" the Tensorflow graph with
# dynamic batch sizes e.g. (?,). decode_jpeg only processes scalar
# strings because it cannot guarantee a batch of images would have
# the same output size. We use tf.map_fn to give decode_jpeg a scalar
# string from dynamic batches.
def decode_and_resize(image_str_tensor):
"""Decodes jpeg string, resizes it and returns a uint8 tensor."""
image = tf.image.decode_jpeg(image_str_tensor, channels=3)
# do additional image manipulation here (like resize etc...)
image = tf.cast(image, dtype=tf.uint8)
return image
image = tf.map_fn(decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)
with tf.name_scope('net'):
# load .pb file
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# concatenate decoder graph and original graph
tf.import_graph_def(graph_def, name="", input_map={'image_tensor:0':image})
g = tf.get_default_graph()
with tf.Session() as sess:
# load graph into session and save to new .pb file
# define model input
inp = g.get_tensor_by_name('decoder/encoded_image_string_tensor:0')
# define model outputs
num_detections = g.get_tensor_by_name('num_detections:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_boxes = g.get_tensor_by_name('detection_boxes:0')
out = {'num_detections': num_detections, 'detection_scores': detection_scores, 'detection_boxes': detection_boxes}
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inp)}
tensor_info_outputs = {}
for k, v in out.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)
# assign detection signature for tensorflow serving
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
# "build" graph
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'detection_signature':
detection_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
main_op=tf.tables_initializer()
)
# save graph
builder.save()
另外:如果您难以找到正确的输入和输出节点,您可以运行它来显示图形:
graph_op = g.get_operations()
for i in graph_op:
print(i.node_def)
添加回答
举报