我想使用 .tfrecord 批量加载 .tfrecord 条目tf.contrib.learn.read_batch_features(...)
。当我read_and_decode()
在训练 Estimator时使用以下代码(类似于 的内容)时,它会起作用。当我load_samples()
在eval(session=sess)
没有 Estimator 的情况下运行脚本时,它会挂起。我想这是管道的问题,但我不知何故无法确定问题。我遵循了tensorflow 网站上的指南,但没有任何运气。
def read_and_decode(sess, cnt):
def get_reader():
return tf.TFRecordReader()
features = tf.contrib.learn.read_batch_features(
file_pattern=os.path.join('.', 'test.tfrecord'),
batch_size=cnt,
reader=get_reader,
features={
'label': tf.FixedLenFeature([], tf.int64),
'data': tf.FixedLenFeature([], tf.string),
})
label = tf.cast(features['label'], tf.int64)
data = tf.decode_raw(features['data'], tf.float32)
patch = tf.reshape(data, tf.stack( [cnt, 6, 20, 20] ))
patch.set_shape( [cnt, 6, 20, 20] )
return label.eval(session=sess), patch.eval(session=sess)
def load_samples():
with tf.Session() as sess:
sess.run([
tf.local_variables_initializer(),
tf.global_variables_initializer()
])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
samples = read_and_decode(sess,100)
except tf.errors.OutOfRangeError as error:
coord.request_stop(error)
finally:
coord.request_stop()
coord.join(threads)
你能解释一下我做错了什么吗?
原因是因为您正在启动队列运行程序,然后您正在定义队列。您需要先定义函数read_and_decode()
,然后启动队列运行器,这将解决问题。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句