如何在使用@ tf.function装饰的函数中的tf.data.Dataset上使用for循环来操纵和返回tf.Variable?

o

我正在尝试创建一个在TensorFlow数据集上包含for循环的函数,该函数在每次迭代中为TensorFlow变量分配一个新值。变量也应作为函数的输出返回。启用了急切执行后,就没有问题了,但是,在图形模式下,似乎发生了一些意想不到的事情。考虑以下简单的伪代码:

import tensorflow as tf


class Test(object):
    def __init__(self):
        self.var = tf.Variable(0, trainable=False, dtype=tf.float32)
        self.increment = tf.constant(1, dtype=tf.float32)
        self.dataset = tf.data.Dataset.from_tensor_slices([0, 1, 2])

    @tf.function
    def fn1(self):
        self.var.assign(0)
        for _ in tf.range(3):
            self.var.assign(self.var+self.increment)
            tf.print(self.var)
        tf.print(self.var)
        return self.var

    @tf.function
    def fn2(self):
        self.var.assign(0)
        for _ in self.dataset:
            self.var.assign(self.var+self.increment)
            tf.print(self.var)
        tf.print(self.var)
        return self.var

    @tf.function
    def fn3(self):
        self.var.assign(0)
        y = self.var
        for _ in self.dataset:
            self.var.assign(self.var+self.increment)
            y = self.var
            tf.print(y)
        tf.print(y)
        return y

    @tf.function
    def fn4(self):
        var = 0.0
        for _ in self.dataset:
            var += 1.0
            tf.print(var)
        tf.print(var)
        return var

test.fn1()test.fn3()并且test.fn4()全部返回以下(所需)输出:

1
2
3
3
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>

但是,test.fn2()其行为有所不同:

1
2
3
0
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

有趣的是,执行之后test.fn2test.var似乎确实包含正确的值:

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>

我不确定为什么会test.fn2失败。显然,它在做某些事情是正确的(因为test.var函数执行后包含正确的值),但是它不能提供正确的结果。您能帮助我了解导致此代码失败的原因吗?

当在CentOS 7上将TensorFlow 2.1.0用于Python 3.6时,会发生上述行为。

TF_Support

TensorFlow 2.1.0上运行此脚本可重现您的方案。

它打印1 2 3 0test.fn2(),但你也应该考虑到,当您打印self.vartest.fn3()它还会告诉你self.var = 0在函数调用期间。

修改的 fn3()

    @tf.function
    def fn3(self):
        self.var.assign(0)
        y = self.var
        for _ in self.dataset:
            self.var.assign(self.var+self.increment)
            y = self.var
            tf.print(y)
        tf.print(self.var)  # Inspect self.var value
        tf.print(y)
        return y

输出:

# Executed in Tensorflow 2.1.0
# test.fn3()
1
2
3
0  << self.var
3

如果您在Tensorflow 2.2.0-rc2中执行此操作,此问题已修复
即使在图形执行期间打印输出,也将是您想要的结果。

要快速模拟这个你可以使用谷歌Colab和使用%tensorflow_version 2.x,以获得最新版本Tensorflow

输出:

# Executed in Tensorflow 2.2.0-rc2
Function 1
1
2
3
3
Function 2
1
2
3
3
Function 3
1
2
3
3 << Value of self.var in test.fn3()
3
Function 4
1
2
3
3

您可以在此链接中查看有关最新Tensorflow更新更改的更多信息

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

Related 相关文章

热门标签

归档