我正在尝试创建一个在TensorFlow数据集上包含for循环的函数,该函数在每次迭代中为TensorFlow变量分配一个新值。变量也应作为函数的输出返回。启用了急切执行后,就没有问题了,但是,在图形模式下,似乎发生了一些意想不到的事情。考虑以下简单的伪代码:
import tensorflow as tf
class Test(object):
def __init__(self):
self.var = tf.Variable(0, trainable=False, dtype=tf.float32)
self.increment = tf.constant(1, dtype=tf.float32)
self.dataset = tf.data.Dataset.from_tensor_slices([0, 1, 2])
@tf.function
def fn1(self):
self.var.assign(0)
for _ in tf.range(3):
self.var.assign(self.var+self.increment)
tf.print(self.var)
tf.print(self.var)
return self.var
@tf.function
def fn2(self):
self.var.assign(0)
for _ in self.dataset:
self.var.assign(self.var+self.increment)
tf.print(self.var)
tf.print(self.var)
return self.var
@tf.function
def fn3(self):
self.var.assign(0)
y = self.var
for _ in self.dataset:
self.var.assign(self.var+self.increment)
y = self.var
tf.print(y)
tf.print(y)
return y
@tf.function
def fn4(self):
var = 0.0
for _ in self.dataset:
var += 1.0
tf.print(var)
tf.print(var)
return var
test.fn1()
,test.fn3()
并且test.fn4()
全部返回以下(所需)输出:
1
2
3
3
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
但是,test.fn2()
其行为有所不同:
1
2
3
0
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
有趣的是,执行之后test.fn2
,test.var
似乎确实包含正确的值:
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>
我不确定为什么会test.fn2
失败。显然,它在做某些事情是正确的(因为test.var
函数执行后包含正确的值),但是它不能提供正确的结果。您能帮助我了解导致此代码失败的原因吗?
当在CentOS 7上将TensorFlow 2.1.0用于Python 3.6时,会发生上述行为。
在TensorFlow 2.1.0上运行此脚本可重现您的方案。
它打印1 2 3 0
了test.fn2()
,但你也应该考虑到,当您打印self.var
在test.fn3()
它还会告诉你self.var = 0
在函数调用期间。
修改的 fn3():
@tf.function
def fn3(self):
self.var.assign(0)
y = self.var
for _ in self.dataset:
self.var.assign(self.var+self.increment)
y = self.var
tf.print(y)
tf.print(self.var) # Inspect self.var value
tf.print(y)
return y
输出:
# Executed in Tensorflow 2.1.0
# test.fn3()
1
2
3
0 << self.var
3
如果您在Tensorflow 2.2.0-rc2中执行此操作,则此问题已修复。
即使在图形执行期间打印输出,也将是您想要的结果。
要快速模拟这个你可以使用谷歌Colab和使用%tensorflow_version 2.x
,以获得最新版本的Tensorflow。
输出:
# Executed in Tensorflow 2.2.0-rc2
Function 1
1
2
3
3
Function 2
1
2
3
3
Function 3
1
2
3
3 << Value of self.var in test.fn3()
3
Function 4
1
2
3
3
您可以在此链接中查看有关最新Tensorflow更新中更改的更多信息。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句