Tensorflow:多次将现有图形复制到新图形中

彼得

我想将现有的张量流图粘贴到新图中。

假设我创建了一个图计算 y = tanh(x @ w)

import tensorflow as tf
import numpy as np

def some_function(x):
    w = tf.Variable(initial_value=np.random.randn(4, 5), dtype=tf.float32)
    return tf.tanh(x @ w)

x = tf.placeholder(shape=(None, 4), dtype = tf.float32)
y = some_function(x)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
val_x = np.random.randn(3, 4)
val_y, = sess.run([y], feed_dict={x: val_x})

伟大的。现在假设我丢失了生成该图的代码,但我仍然可以访问变量 ( x, y)。现在我想获取这个图(使用 w 的当前值),并将它复制两次到一个新图中(两条路径应该共享相同的w),以便我现在d = tf.reduce_sum((tanh(x1 @ w)-tanh(x2 @ w))**2)通过添加以下行进行计算

# Starting with access to tensors: x, y
<SOMETHING HERE>
d = tf.reduce_sum((y1-y2)**2)
val_x1 = np.random.randn(3, 4)
val_x2 = np.random.randn(3, 4)
val_d = sess.run([d], feed_dict = {x1: val_x1, x2: val_x2})

我要填写什么才能完成<SOMETHING HERE>这项工作?(显然,没有重新创建第一张图)

杰德赫萨

Graph Editor模块可以帮助进行此类操作。它的主要缺点是在修改图形时不能有正在运行的会话。但是,您可以检查会话、修改图形并在需要时将其恢复。

你想要的问题是你基本上需要复制一个子图,除非你不想复制变量。所以你可以简单地排除变量类型(主要是VariableVariableV2也许是VarHandleOp,尽管我在TensorFlow 代码中发现了更多)。你可以用这样的函数来做到这一点:

import tensorflow as tf

# Receives the outputs to recalculate and the input replacements
def replicate_subgraph(outputs, mappings):
    # Types of operation that should not be replicated
    # Taken from tensorflow/python/training/device_setter.py
    NON_REPLICABLE = {'Variable', 'VariableV2', 'AutoReloadVariable',
                      'MutableHashTable', 'MutableHashTableV2',
                      'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2',
                      'MutableDenseHashTable', 'MutableDenseHashTableV2',
                      'VarHandleOp', 'BoostedTreesEnsembleResourceHandleOp'}
    # Find subgraph ops
    ops = tf.contrib.graph_editor.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys())
    # Exclude non-replicable operations
    ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE]
    # Make subgraph viewitems
    sgv = tf.contrib.graph_editor.make_view(*ops_replicate)
    # Make the copy
    _, info = tf.contrib.graph_editor.copy_with_input_replacements(sgv, mappings)
    # Return new outputs
    return info.transformed(outputs)

对于类似于您的示例(我对其进行了一些编辑,因此很容易看出输出是正确的,因为第二个值是第一个值的十倍)。

import tensorflow as tf

def some_function(x):
    w = tf.Variable(initial_value=tf.random_normal((5,)), dtype=tf.float32)
    return 2 * (x * w)

x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
y1 = some_function(x1)
y2, = replicate_subgraph([y1], {x1: x2})
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')

输出:

[ 2.3356955   2.277849    0.58513653  2.0919807  -0.15102367]
[23.356955  22.77849    5.851365  20.919807  -1.5102367]

编辑:

这是使用tf.make_template. 这要求您实际拥有该函数的代码,但它是一种支持子图重用的更清晰、“更正式”的方式。

import tensorflow as tf

def some_function(x):
    w = tf.get_variable('W', (5,), initializer=tf.random_normal_initializer())
    # Or if the variable is only local and not trainable
    # w = tf.Variable(initial_value=tf.random_normal(5,), dtype=tf.float32, trainable=False)
    return 2 * (x * w)

x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
some_function_tpl = tf.make_template('some_function', some_function)
y1 = some_function_tpl(x1)
y2 = some_function_tpl(x2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

将R图形复制到具有自定义尺寸的剪贴板

来自分类Dev

如何将现有git存储库的副本复制到新存储库

来自分类Dev

将现有的 jenkins 配置复制到新的 jenkins

来自分类Dev

将节点从现有 XML 文件复制到新 XML 文件的简单方法

来自分类Dev

Matplotlib将子图绘制到现有图形

来自分类Dev

如何在现有的 matplotlib 中绘制新图形

来自分类Dev

使用phpMyAdmin将存储过程的结果复制到现有表中

来自分类Dev

将大型github文件复制到现有存储库中,而无需下载/克隆

来自分类Dev

尝试将工作表复制到 vba 中的现有工作簿

来自分类Dev

VBA:将数据复制到现有列表对象中

来自分类Dev

Powershell复制到-如果存在目标目录,则复制到现有目录中

来自分类Dev

八度5.1.0不再支持将图形复制到剪贴板

来自分类Dev

将现有数组的元素复制到一个新数组而不使用拼接?

来自分类Dev

将新的 json 数据合并到 d3.js 和 cola.js 中的现有图形

来自分类Dev

读取现有的CSV并将0、6和9列复制到新的CSV中

来自分类Dev

SQL(服务器):将查询结果复制到现有表中,并带有查询结果的结构

来自分类Dev

将通用protobuf复制到堆上的新对象中

来自分类Dev

将syslog文件复制到Linux中的新目录

来自分类Dev

将数组复制到Java中的新数组

来自分类Dev

YouTrack(Jetbrains)将卡复制到新板中

来自分类Dev

将行表复制到“新表”中。| jQuery

来自分类Dev

Powershell复制项目:无法将容器复制到现有叶项目上

来自分类Dev

将某些字段(仅结构)从另一个表复制到现有表中

来自分类Dev

如何将实例数据复制到现有的类引用中/之上?C#

来自分类Dev

VBA将每个工作表中的非空白单元格复制到现有工作表

来自分类Dev

SQL将现有的列值复制到select语句中的自定义列中

来自分类Dev

使用PowerShell将文本文件数据复制到现有的Excel工作簿中

来自分类Dev

将数组键复制到另一个现有的数组键中

来自分类Dev

将先前存在的AutoCAD图形插入到当前图形中

Related 相关文章

  1. 1

    将R图形复制到具有自定义尺寸的剪贴板

  2. 2

    如何将现有git存储库的副本复制到新存储库

  3. 3

    将现有的 jenkins 配置复制到新的 jenkins

  4. 4

    将节点从现有 XML 文件复制到新 XML 文件的简单方法

  5. 5

    Matplotlib将子图绘制到现有图形

  6. 6

    如何在现有的 matplotlib 中绘制新图形

  7. 7

    使用phpMyAdmin将存储过程的结果复制到现有表中

  8. 8

    将大型github文件复制到现有存储库中,而无需下载/克隆

  9. 9

    尝试将工作表复制到 vba 中的现有工作簿

  10. 10

    VBA:将数据复制到现有列表对象中

  11. 11

    Powershell复制到-如果存在目标目录,则复制到现有目录中

  12. 12

    八度5.1.0不再支持将图形复制到剪贴板

  13. 13

    将现有数组的元素复制到一个新数组而不使用拼接?

  14. 14

    将新的 json 数据合并到 d3.js 和 cola.js 中的现有图形

  15. 15

    读取现有的CSV并将0、6和9列复制到新的CSV中

  16. 16

    SQL(服务器):将查询结果复制到现有表中,并带有查询结果的结构

  17. 17

    将通用protobuf复制到堆上的新对象中

  18. 18

    将syslog文件复制到Linux中的新目录

  19. 19

    将数组复制到Java中的新数组

  20. 20

    YouTrack(Jetbrains)将卡复制到新板中

  21. 21

    将行表复制到“新表”中。| jQuery

  22. 22

    Powershell复制项目:无法将容器复制到现有叶项目上

  23. 23

    将某些字段(仅结构)从另一个表复制到现有表中

  24. 24

    如何将实例数据复制到现有的类引用中/之上?C#

  25. 25

    VBA将每个工作表中的非空白单元格复制到现有工作表

  26. 26

    SQL将现有的列值复制到select语句中的自定义列中

  27. 27

    使用PowerShell将文本文件数据复制到现有的Excel工作簿中

  28. 28

    将数组键复制到另一个现有的数组键中

  29. 29

    将先前存在的AutoCAD图形插入到当前图形中

热门标签

归档