我对这一切有点陌生所以你能帮我吗?我试图找到这个问题的答案,但一无所获。
我正在尝试在单独的函数中在 python 中加载 Tensorflow 模型,以便我可以在循环中使用该模型,而不必在 for 循环的每次迭代中加载它。
这是我现在的代码:
def load_network():
prediction = neural_network_model(x)
return (prediction)
def use_neural_network(data, prediction):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(model_name+'.meta')
saver.restore(sess,model_name)
pred = sess.run(prediction, feed_dict={x: data})
pred = np.asarray(pred)
return pred
if __name__ == '__main__':
result=[]
Load= start_network()
for i in data:
result.append(use_neural_network(i,Load))
我想得到这样的东西:
def load_network():
prediction = neural_network_model(x)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(model_name+'.meta')
saver.restore(sess,model_name)
return (prediction)
def use_neural_network(data, prediction):
with tf.Session() as sess:
pred = sess.run(prediction, feed_dict={x: data})
pred = np.asarray(pred)
return pred
if __name__ == '__main__':
result=[]
Load= start_network()
for i in data:
result.append(use_neural_network(i,Load))
通常,您想要实现的目标很容易实现,并且您走在正确的轨道上。在主块中,你有start_network()
而不是load_network()
在你的第一行。我还建议不要将其Load
用作变量名,但这应该不是问题。此外,TensorFlow 会话(sess
在您的代码中)应该是一个全局变量,或者您应该在主块或load_network()
函数中初始化它,然后将其传递给use_neural_network()
函数。当前写入sess
两个函数中的两个变量的方式是本地的,因此引用不同的会话。
如果您想避免使用该neural_network_model( x )
功能,即在开始时构建模型,您可能希望冻结模型并以这种方式加载它,同时嵌入架构。最容易遵循指南,就像这个。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句