我想将现有的张量流图粘贴到新图中。
假设我创建了一个图计算 y = tanh(x @ w)
import tensorflow as tf
import numpy as np
def some_function(x):
w = tf.Variable(initial_value=np.random.randn(4, 5), dtype=tf.float32)
return tf.tanh(x @ w)
x = tf.placeholder(shape=(None, 4), dtype = tf.float32)
y = some_function(x)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
val_x = np.random.randn(3, 4)
val_y, = sess.run([y], feed_dict={x: val_x})
伟大的。现在假设我丢失了生成该图的代码,但我仍然可以访问变量 ( x
, y
)。现在我想获取这个图(使用 w 的当前值),并将它复制两次到一个新图中(两条路径应该共享相同的w
),以便我现在d = tf.reduce_sum((tanh(x1 @ w)-tanh(x2 @ w))**2)
通过添加以下行进行计算:
# Starting with access to tensors: x, y
<SOMETHING HERE>
d = tf.reduce_sum((y1-y2)**2)
val_x1 = np.random.randn(3, 4)
val_x2 = np.random.randn(3, 4)
val_d = sess.run([d], feed_dict = {x1: val_x1, x2: val_x2})
我要填写什么才能完成<SOMETHING HERE>
这项工作?(显然,没有重新创建第一张图)
有Graph Editor模块可以帮助进行此类操作。它的主要缺点是在修改图形时不能有正在运行的会话。但是,您可以检查会话、修改图形并在需要时将其恢复。
你想要的问题是你基本上需要复制一个子图,除非你不想复制变量。所以你可以简单地排除变量类型(主要是Variable
,VariableV2
也许是VarHandleOp
,尽管我在TensorFlow 代码中发现了更多)。你可以用这样的函数来做到这一点:
import tensorflow as tf
# Receives the outputs to recalculate and the input replacements
def replicate_subgraph(outputs, mappings):
# Types of operation that should not be replicated
# Taken from tensorflow/python/training/device_setter.py
NON_REPLICABLE = {'Variable', 'VariableV2', 'AutoReloadVariable',
'MutableHashTable', 'MutableHashTableV2',
'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2',
'MutableDenseHashTable', 'MutableDenseHashTableV2',
'VarHandleOp', 'BoostedTreesEnsembleResourceHandleOp'}
# Find subgraph ops
ops = tf.contrib.graph_editor.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys())
# Exclude non-replicable operations
ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE]
# Make subgraph viewitems
sgv = tf.contrib.graph_editor.make_view(*ops_replicate)
# Make the copy
_, info = tf.contrib.graph_editor.copy_with_input_replacements(sgv, mappings)
# Return new outputs
return info.transformed(outputs)
对于类似于您的示例(我对其进行了一些编辑,因此很容易看出输出是正确的,因为第二个值是第一个值的十倍)。
import tensorflow as tf
def some_function(x):
w = tf.Variable(initial_value=tf.random_normal((5,)), dtype=tf.float32)
return 2 * (x * w)
x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
y1 = some_function(x1)
y2, = replicate_subgraph([y1], {x1: x2})
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')
输出:
[ 2.3356955 2.277849 0.58513653 2.0919807 -0.15102367]
[23.356955 22.77849 5.851365 20.919807 -1.5102367]
编辑:
这是使用tf.make_template
. 这要求您实际拥有该函数的代码,但它是一种支持子图重用的更清晰、“更正式”的方式。
import tensorflow as tf
def some_function(x):
w = tf.get_variable('W', (5,), initializer=tf.random_normal_initializer())
# Or if the variable is only local and not trainable
# w = tf.Variable(initial_value=tf.random_normal(5,), dtype=tf.float32, trainable=False)
return 2 * (x * w)
x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
some_function_tpl = tf.make_template('some_function', some_function)
y1 = some_function_tpl(x1)
y2 = some_function_tpl(x2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句