全部开发者教程

TensorFlow 入门教程

Python / 使用 tf.function 提升效率

使用 tf.function 提升效率

在之前的入门介绍之中,我们曾经介绍过 TensorFlow1.x 采用的并不是 Eager execution 执行模型;而 TensorFlow2.x 默认采用的是 Eager execution 模式。

这种改变使得我们可以更加容易地学习,但是也会造成性能的损失,因此, TensorFlow 在 2.0 版本之后引入了 tf.function 。

1. 什么是 tf.funtion

在 TensorFlow1.x 之中,如果我们想要运行一个学习任务,那么我们需要首先创建一个 tf.Sesstion (),然后再调用 Session.run () 进行运行。

其实在 TensorFlow1.x 内部,当我们在 TensorFlow 之中进行工作的时候, TensorFlow 会帮助我们创建一个计算图 tf.graph ,然后通过 tf.Session 对计算图进行计算

而在 TensorFlow2.x 之中,其默认采用的是 Eager execution 执行方式,在该执行方式之中,我们不再需要定义一个计算图来进行

这样就产生了一些问题:

  • 使用 tf.Sesstion () 的运行效率非常高,但是代码很难懂
  • 使用 Eager execution 方式的代码很简单,但是执行效率比较低

有什么方法能够兼顾两者吗?

那就是 tf.function 。

2. tf.funtion 的用法

tf.function 是一个函数标注修饰,也就是如下的形式:

@tf.function
def my_function():
    ...

其实如你所见,这就是 tf.function 的全部用法。

我们只需要在我们要修饰的函数之前加上 tf.function 标注既可

采用 tf.function,TensorFlow 会将该函数转变为计算图 tf.graph 的形式来进行运算,这会使得该函数在进行大量运算的时候会加速非常多。

是不是所有的函数都适合 tf.function 进行修饰呢?

答案是否定的,以下两种情况不适合使用 tf.function 进行修饰:

  • 函数本身的计算非常简单,那么构建计算图本身的时间就会相对非常浪费;
  • 当我们需要在函数之中定义 tf.Variable 的时候,因为 tf.function 可能会被调用多次,因此定义 tf.Variable 会产生重复定义的情况。

3. tf.function 的性能

既然了解了 tf.function 的用法,那么我们便来测试一下 tf.function 的性能,我们采用一个简单的卷积神经网络来进行测试:

import tensorflow as tf
import timeit


def f1(layer, image):
    y = layer(image)
    return y

@tf.function
def f2(layer, image):
    y = layer(image)
    return y

layer = tf.keras.layers.Conv2D(300, 3)
image = tf.zeros([64, 32, 32, 3])

model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),  
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')])

print(timeit.timeit(lambda: f1(model, image), number=500))
print(timeit.timeit(lambda: f2(model, image), number=500))

在这里,我们定义了两个相同的函数,其中一个使用了 tf.function 进行修饰,而另外一个没有。

在这里我们使用 lambda 函数来让函数重复执行 500 次,并且使用 timeit 来进行时间的统计,得到两个函数的执行时间,从而进行比较。

最终,我们可以得到结果:

17.20403664399987
12.07886406200032

由此可以看出,我们的 tf.function 已经提升了一定的速度,但是提升的速度有限,目前大概提升了 25 % 的速度。这是因为我们的计算仍然还是太简单了,当我们计算非常大的时候,性能会有很大的提升。

4. 小结

在这节课之中,我们学习到了什么是 tf.function ,以及 tf.function 的基本原理,然后我们了解了 tf.function 的使用方法;最后我们通过一个简单的神经网络来进行了性能的测试,最终我们发现我们的 tf.function 确实能给我们性能带来很大的提升。