为了账号安全,请及时绑定邮箱和手机立即绑定

无法获得 TensorFlow 自定义层的正确形状

无法获得 TensorFlow 自定义层的正确形状

潇潇雨雨 2023-10-31 14:31:38
我正在尝试使用自定义层在 TensorFlow 中训练模型。我在最后一层遇到问题,我正在尝试构建一个层来获取一批图像 [None,100,100,1] 并返回 10 个不同方形区域的总和,因此输出应该是 [None] 的形状,10]。我尝试了一些不同的方法但没有成功。我试过了:        output = tf.concat([tf.reshape(tf.math.reduce_sum(inputs[34:42, 28:40,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[34:42, 44:56,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[34:42, 60:72,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[46:54, 20:32,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[46:54, 36:48,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[46:54, 52:64,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[46:54, 68:80,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[58:66, 28:40,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[58:66, 44:56,0]), [1,]),                            tf.reshape(tf.math.reduce_sum(inputs[58:66, 60:72,0]), [1,])], axis= 0)和类似的求和函数,但无法将形状的第一维设置为“无”。我尝试过“作弊”,将输入重塑为正确的形状,然后乘以 0 并添加大小为 [10] 的张量。这得到了正确的形状,但模型没有训练。有没有正确的方法来做到这一点?我在这个问题上被困了好几个星期,但没有运气。如果我放置一个不做我想要的事情的不同层,模型训练得很好,但它具有正确的输出形状:class output_layer(tf.keras.Model):    def __init__(self, shape_input):        self.shape_input = shape_input        super(output_layer, self).__init__()    def call(self, inputs):        inputs = tf.math.multiply(inputs,tf.math.conj(inputs))        temp = tf.math.reduce_sum(inputs, axis=2)        temp = tf.reshape(temp, [-1,10,10])        temp = tf.math.reduce_sum(temp, axis=2)                output = tf.cast(temp, tf.float32)        output = tf.keras.activations.softmax(output, axis=-1)        return output如果有人能帮助我,我将非常感激!
查看完整描述

1 回答

?
幕布斯7119047

TA贡献1794条经验 获得超8个赞

我修改了你的代码并提出以下内容:


output = tf.concat(

                  [tf.math.reduce_sum(inputs[:, 34:42, 28:40,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 34:42, 44:56,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 34:42, 60:72,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 46:54, 20:32,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 46:54, 36:48,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 46:54, 52:64,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 46:54, 68:80,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 58:66, 28:40,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 58:66, 44:56,:], axis=[1,2]),

                   tf.math.reduce_sum(inputs[:, 58:66, 60:72,:], axis=[1,2])], axis=-1)

请注意,我更改inputs[34:42, 28:40, 0]为inputs[:, 34:42, 28:40,:]. 您可以用于:想要保持相同的尺寸。我还指定了应减少哪个轴,因此,仅保留没有要减少的规格的尺寸 - 在本例中,它是第一个也是最后一个尺寸。在你的情况下,tf.math.reduce_sum将产生形状[无,1]。与此同时,我将 的轴更改tf.concat为 -1,这是最后一层,因此它产生形状 [None, 10]。


为了完整起见,您可以创建自己的图层。为此,您必须继承 tf.keras.layers.Layer。


然后,您可以将其用作任何其他层。


class ReduceZones(tf.keras.layers.Layer):

    def __init__(self):

        super(ReduceZones, self).__init__()

      

    def build(self, input_shapes):

        return

      

    def call(self, inputs):

        output = tf.concat(

                [tf.math.reduce_sum(inputs[:, 34:42, 28:40,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 34:42, 44:56,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 34:42, 60:72,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 46:54, 20:32,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 46:54, 36:48,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 46:54, 52:64,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 46:54, 68:80,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 58:66, 28:40,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 58:66, 44:56,:], axis=[1,2]),

                 tf.math.reduce_sum(inputs[:, 58:66, 60:72,:], axis=[1,2])], axis=-1)

        return output


查看完整回答
反对 回复 2023-10-31
  • 1 回答
  • 0 关注
  • 75 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信