全部开发者教程

TensorFlow 入门教程

Python / 在 TensorBoard 之中查看模型结构图

在 TensorBoard 之中查看模型结构图

在之前的学习过程之中,我们学习了如何自定义查看训练过程之中的各项指标。在实际的应用过程之中,为了保证模型构建的准确性,我们也会经常查看网络的模型结构图。那么这节课我们就来看一下如何在 TensorBoard 之中查看模型图。

1. 如何在 TensorBoard 之中生成 Keras 模型结构图

倘若我们通过 tf.keras API 来自定义了一个网络模型,那么我们在 TensorBoard 来查看模型图是非常简单的一件事情。

当我们使用 tf.keras 的模型的 fit() 方法的时候,框架会自动帮我们绘制模型结构图

如下代码所示:

首先我们定义模型、数据与相应的参数。

import tensorflow as tf

(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=[])

然后我们定义相应的 TensorBoard 日志目录,同时对模型使用 fit() 进行训练:

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')

model.fit(x=x_train, y=y_train, 
          epochs=3, 
          validation_data=(x_test, y_test), 
          callbacks=[tensorboard_callback])

最后我们就可以打开 TensorBoard 并在浏览器查看:

tensorboard --logdir logs

我们就可以在浏览器的 Graph 标签页之中看到模型图了:

图片描述

2. 如何在 TensorBoard 之中生成使用 tf.function 函数定义的图

在实际的应用过程之中,有很多的情况下,我们需要使用 tf.function 来加速模型的速度并自定义训练过程。那么这个时候我们要如何才能查看网络的模型结构图呢?

其实也很简单,我们只需经过如下几个步骤:

  • 确保 tf.function 函数修饰了我们需要进行可视化的操作,这边就是模型的过程;
  • 创建一个 TensorBoard 的日志写入器 tf.summary.create_file_writer() ;
  • 通过 tf.summary.trace_on() API 进行变量路径的追踪
  • 执行我们需要可视化的操作;
  • 使用 tf.summary.trace_export() API 将图写入日志

在这里,我们可以使用一个很简单的例子来查看操作的结构:


# 定义网络的操作
@tf.function
def test_func(x, y):
  z = tf.matmul(x, y)
  z = z * 5.0
  z = tf.nn.relu(z)
  return z

# 创建写入器
writer = tf.summary.create_file_writer('./logs/3')

# 创建初试数据
x = tf.random.uniform((5, 5))
y = tf.random.uniform((5, 5))

# 开启变量追踪
tf.summary.trace_on(graph=True, profiler=True)

# 运行程序
z = test_func(x, y)

# 将日志输出
with writer.as_default():
  tf.summary.trace_export(
      name="test_func_graph",
      step=1,
      profiler_outdir='./logs/3')

在这里,我们首先定义了一个基本的模型操作,该模型操作由一个矩阵乘法、一个常量乘法、外加一个 Relu 激活层组成。
在运行完操作之后,我们便使用 tf.summary.trace_export() API 来将模型图输入道日志之中。

然后我们便可以在浏览器之中查看到相应的模型图:

图片描述

可以看到,该模型图完整的反映了我们的操作。

3. TensorBoard 之中基本、基本的操作

既然了解了如何将模型图输出到日志,那么接下来我们就应该查看在 TensorBoard 之中对模型图的基本操作。

3.1 平移、缩放以及详细信息的查看

在 TensorBoard 之中,使用鼠标滚轮即可实现模型图的缩放,当我们一直放大,会看到操作内部的细节。

并且按住鼠标左键,移动鼠标,即可实现模型图的移动操作。

双击网络节点,即可展开网络节点,从而查看到网络内部的细节操作

3.2 模型的节点的搜索

图片描述

在左侧的最上方,可以搜索自己想要查看的节点,这里是支持正则表达式的。

3.3 模型的下载

点击左侧的 Download PNG 即可下载带有透明度的、网络模型的图片。

3.4 切换网络模型

点击左侧的 Run 按钮,即可选择不同的网络模型进行查看,前提是我们已经将网络模型输入到日志之中去。

3.5 切换查看方式

点击左侧的 Tag 选项,即可查看网络的查看方式。

默认是查看中粒度的网络模型,如果我们的模型是使用 Keras 定义的,那么我们可以选择查看 Keras 结构,这是一个总体的概览,可以帮助我们掌握大体的网络结构

3.6 模型图的图例

当我们遇到一些不理解的图标的时候,我们可以通过左下角的图例进行查询:

图片描述

4. 小结

在这节课之中,我们学习率如何在 TensorBoard 之中查看 Keras 模型,同时也了解了如何产看自定义的操作过程,最后我们了解了 TensorBoard 的一些基本操作。 TensorBoard 也在持续更新,未来一定会有更多新的功能。