K.learning_phase()
获取值,而不是张量本身。我需要学习阶段张量以K.function
获取图层梯度,输出等。w /可以正常工作import keras.backend as K
,但对于则失败import tensorflow.keras.backend as K
。相关的Git /部分解决方法
我如何获取张量本身?
可重现的示例:
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
import numpy as np
ipt = Input((16,))
out = Dense(16)(ipt)
model = Model(ipt, out)
model.compile('adam', 'mse')
x = np.random.randn(32, 16)
model.train_on_batch(x, x)
grads = model.optimizer.get_gradients(model.total_loss, model.layers[-1].output)
grads_fn = K.function(inputs=[model.inputs[0], model._feed_targets[0], K.learning_phase()],
outputs=grads)
完整的错误跟踪:
File "<ipython-input-2-7f74922d7492>", line 3, in <module>
outputs=grads)
File "D:\Anaconda\envs\tf2_env\lib\site-packages\tensorflow_core\python\keras\backend.py", line 3773, in function
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
File "D:\Anaconda\envs\tf2_env\lib\site-packages\tensorflow_core\python\keras\backend.py", line 3670, in __init__
base_graph=source_graph)
File "D:\Anaconda\envs\tf2_env\lib\site-packages\tensorflow_core\python\eager\lift_to_graph.py", line 249, in lift_to_graph
visited_ops = set([x.op for x in sources])
File "D:\Anaconda\envs\tf2_env\lib\site-packages\tensorflow_core\python\eager\lift_to_graph.py", line 249, in <listcomp>
visited_ops = set([x.op for x in sources])
AttributeError: 'int' object has no attribute 'op'
作为一种(不太好的)解决方法,您可以使用symbolic_learning_phase()
from tensorflow.python.keras.backend
:
from tensorflow.python.keras import backend
# ...
grads_fn = K.function(inputs=[model.inputs[0],
model._feed_targets[0],
backend.symbolic_learning_phase()],
outputs=grads)
g_learning = grads_fn([x, x, True])
g_not_learning = grads_fn([x, x, False])
我不确定为什么没有将此功能learning_phase()
导出到中tensorflow.keras.backend
。也许有充分的理由不这样做。
此外,请注意,仅当您的模型包含一些在训练和推理模式下表现不同(例如,辍学)的图层/操作时,才在此处使用学习阶段才有意义。否则,函数的输出将相同。
Update:backend.symbolic_learning_phase()
用于tensorflow.keras
代码中(示例),建议不要过多地公开使用它。它用作K.learning_phase()
Eager执行中的直接替代品,将在中使用K.function()
。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句