为了账号安全,请及时绑定邮箱和手机立即绑定

tf.GradientTape() 的 __exit__ 函数的参数是什么?

tf.GradientTape() 的 __exit__ 函数的参数是什么?

一只萌萌小番薯 2023-01-04 16:34:39
根据 的文档,tf.GradientTape其__exit__()方法采用三个位置参数:typ, value, traceback.这些参数究竟是什么?该语句如何with推断它们?我应该在下面的代码中给它们什么值(我没有使用with语句的地方):x = tf.Variable(5)gt = tf.GradientTape()gt.__enter__()y = x ** 2gt.__exit__(typ = __, value = __, traceback = __)
查看完整描述

1 回答

?
大话西游666

TA贡献1817条经验 获得超14个赞

sys.exc_info()返回具有三个值的元组(type, value, traceback)

  1. 这里type获取正在处理的Exception的异常类型

  2. value是传递给异常类的构造函数的参数。

  3. traceback包含堆栈信息,如发生异常的位置等。

在 GradientTape 上下文中,当异常发生时,sys.exc_info()详细信息将传递给exit () 函数,后者将Exits the recording context, no further operations are traced

下面是说明相同的示例。

让我们考虑一个简单的函数。

def f(w1, w2):
    return 3 * w1 ** 2 + 2 * w1 * w2

通过不使用with语句:

w1, w2 = tf.Variable(5.), tf.Variable(3.)


tape = tf.GradientTape()

z = f(w1, w2)

tape.__enter__()

dz_dw1 = tape.gradient(z, w1)

try:

    dz_dw2 = tape.gradient(z, w2)

except Exception as ex:

    print(ex)

    exec_tup = sys.exc_info()

    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

印刷:


GradientTape.gradient 只能在非持久性磁带上调用一次。


即使你没有通过传递值显式退出,程序也会传递这些值来退出GradientTaoe记录,下面是示例。


w1, w2 = tf.Variable(5.), tf.Variable(3.)


tape = tf.GradientTape()

z = f(w1, w2)

tape.__enter__()

dz_dw1 = tape.gradient(z, w1)

try:

    dz_dw2 = tape.gradient(z, w2)

except Exception as ex:

    print(ex)

打印相同的异常消息。


通过使用with语句。


with tf.GradientTape() as tape:

    z = f(w1, w2)


dz_dw1 = tape.gradient(z, w1)

try:

    dz_dw2 = tape.gradient(z, w2)

except Exception as ex:

    print(ex)

    exec_tup = sys.exc_info()

    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

以下是sys.exc_info()对上述异常的响应。


(RuntimeError,

 RuntimeError('GradientTape.gradient can only be called once on non-persistent tapes.'),

 <traceback at 0x7fcd42dd4208>)

编辑 1:


如user2357112 supports Monica评论中所述。为非异常情况提供解决方案。


在非异常情况下,规范要求传递给的值都__exit__应该是None.


示例 1:


x = tf.constant(3.0)

g = tf.GradientTape()

g.__enter__()

g.watch(x)

y = x * x

g.__exit__(None,None,None)

z  = x*x

dy_dx = g.gradient(y, x) 

# dz_dx = g.gradient(z, x) 

print(dy_dx)

# print(dz_dx)

印刷:


tf.Tensor(6.0, shape=(), dtype=float32) 

由于在它返回梯度值 y之前已经被捕获。__exit__


示例 2:


x = tf.constant(3.0)

g = tf.GradientTape()

g.__enter__()

g.watch(x)

y = x * x

g.__exit__(None,None,None)

z  = x*x

# dy_dx = g.gradient(y, x) 

dz_dx = g.gradient(z, x) 

# print(dy_dx)

print(dz_dx)

印刷:


None 

这是因为在梯度停止记录z之后被捕获。__exit__


查看完整回答
反对 回复 2023-01-04
  • 1 回答
  • 0 关注
  • 108 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信