为了账号安全,请及时绑定邮箱和手机立即绑定

TensorFlow:如何将图像解码器节点添加到我的图形中?

TensorFlow:如何将图像解码器节点添加到我的图形中?

守着一只汪 2021-12-17 16:30:28
我有一个张量流模型作为冻结图,它接受一个图像张量作为输入。但是,我想向该图中添加一个新的输入图像解码器节点,以便模型也接受 jpg 图像的编码字节字符串,并最终自行解码图像。到目前为止,我已经尝试过这种方法:model = './frozen_graph.pb'with tf.gfile.FastGFile(model, 'rb') as f:    # read graph    graph_def = tf.GraphDef()    graph_def.ParseFromString(f.read())    tf.import_graph_def(graph_def, name="")    g = tf.get_default_graph()    # fetch old input    old_input = g.get_tensor_by_name('image_tensor:0')    # define new input    new_input = graph_def.node.add()    new_input.name = 'encoded_image_string_tensor'    new_input.op = 'Substr'    # add new input attr    image = tf.image.decode_image(new_input, channels=3)    # link new input to old input    old_input.input = 'encoded_image_string_tensor'  #  must match with the name above上面的代码返回这个异常:Expected string passed to parameter 'input' of op 'Substr', got name: "encoded_image_string_tensor" op: "Substr"  of type 'NodeDef' instead.我不太确定我是否可以tf.image.decode_image在图表中使用 ,所以也许有另一种方法可以解决这个问题。有人有提示吗?
查看完整描述

1 回答

?
守着星空守着你

TA贡献1799条经验 获得超8个赞

使用该input_map参数,我成功地将一个仅解码 jpg 图像的新图形映射到我的原始图形的输入(此处:)node.name='image_tensor:0'。只需确保重命名name_scope解码器图的 的(此处:)decoder。之后,您可以使用 tensorflow SavedModelBuilder 保存新的连接图。


这是一个物体检测网络的例子:


import tensorflow as tf

from tensorflow.python.saved_model import signature_constants

from tensorflow.python.saved_model import tag_constants



# The export path contains the name and the version of the model

model = 'path/to/model.pb'

export_path = './output/dir/'


sigs = {}


with tf.gfile.FastGFile(model, 'rb') as f:

        with tf.name_scope('decoder'):

                image_str_tensor = tf.placeholder(tf.string, shape=[None], name= 'encoded_image_string_tensor')

                # The CloudML Prediction API always "feeds" the Tensorflow graph with

                # dynamic batch sizes e.g. (?,).  decode_jpeg only processes scalar

                # strings because it cannot guarantee a batch of images would have

                # the same output size.  We use tf.map_fn to give decode_jpeg a scalar

                # string from dynamic batches.

                def decode_and_resize(image_str_tensor):

                        """Decodes jpeg string, resizes it and returns a uint8 tensor."""

                        image = tf.image.decode_jpeg(image_str_tensor, channels=3)


                        # do additional image manipulation here (like resize etc...)


                        image = tf.cast(image, dtype=tf.uint8)

                        return image


                image = tf.map_fn(decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8)


        with tf.name_scope('net'):

                # load .pb file

                graph_def = tf.GraphDef()

                graph_def.ParseFromString(f.read())


                # concatenate decoder graph and original graph

                tf.import_graph_def(graph_def, name="", input_map={'image_tensor:0':image})

                g = tf.get_default_graph()


with tf.Session() as sess:

        # load graph into session and save to new .pb file


        # define model input

        inp = g.get_tensor_by_name('decoder/encoded_image_string_tensor:0')


        # define model outputs

        num_detections = g.get_tensor_by_name('num_detections:0')

        detection_scores = g.get_tensor_by_name('detection_scores:0')

        detection_boxes = g.get_tensor_by_name('detection_boxes:0')

        out = {'num_detections': num_detections, 'detection_scores': detection_scores, 'detection_boxes': detection_boxes}



        builder = tf.saved_model.builder.SavedModelBuilder(export_path)


        tensor_info_inputs = {

                'inputs': tf.saved_model.utils.build_tensor_info(inp)}

        tensor_info_outputs = {}

        for k, v in out.items():

                tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)


        # assign detection signature for tensorflow serving

        detection_signature = (

        tf.saved_model.signature_def_utils.build_signature_def(

                inputs=tensor_info_inputs,

                outputs=tensor_info_outputs,

                method_name=signature_constants.PREDICT_METHOD_NAME))


        # "build" graph

        builder.add_meta_graph_and_variables(

                sess, [tf.saved_model.tag_constants.SERVING],

                signature_def_map={

                'detection_signature':

                        detection_signature,

                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:

                        detection_signature,

                },

                main_op=tf.tables_initializer()

        )

        # save graph

        builder.save()

另外:如果您难以找到正确的输入和输出节点,您可以运行它来显示图形:


graph_op = g.get_operations()

for i in graph_op:

    print(i.node_def)


查看完整回答
反对 回复 2021-12-17
  • 1 回答
  • 0 关注
  • 157 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号