我正在尝试tf.Graph
根据某些条件逐步构建一个并在我完成附加后运行一次。
代码如下所示:
class Model:
def __init__(self):
self.graph = tf.Graph()
... some code ...
def build_initial_graph(self):
with self.graph.as_default():
X = tf.placeholder(tf.float32, shape=some_shape)
... some code ...
def add_to_existing_graph(self):
with self.graph.as_default():
... some code adding more ops to the graph ...
def transform(self, data):
with tf.Session(graph=self.graph) as session:
y = session.run(Y, feed_dict={X: data})
return y
调用方法看起来像这样
model = Model()
model.build_initial_graph()
model.add_to_existing_graph()
model.add_to_existing_graph()
result = model.transform(data)
所以,两个问题
X
在feed_dict
运行代码时无法识别,实现该目标的正确方法是什么?Q1:这当然是构建模型的合法方式,但更多的是意见问题。我只会建议将您的张量存储为属性(请参阅 Q2 的答案。)self.X=...
。
您可以查看这篇关于如何以面向对象的方式构建 TensorFlow 模型的非常好的文章。
Q2:原因很简单,因为变量X
不在你的transform
方法范围内。
如果您执行以下操作,一切都会正常进行:
def build_initial_graph(self):
with self.graph.as_default():
self.X = tf.placeholder(tf.float32, shape=some_shape)
... some code ...
def transform(self, data):
with tf.Session(graph=self.graph) as session:
return session.run(self.Y, feed_dict={self.X: data})
更详细地说,在 TensorFlow 中,您定义的所有张量或操作(例如tf.placeholder
或tf.matmul
)都在tf.Graph()
you re working on. You might want to store them in Python variable, as you did by doing
X = tf.placeholder`中定义,但这不是强制性的。
如果您想访问您定义的张量之一,您可以
X
不在方法的范围内)或者,tf.get_variable
方法)。本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句