我在训练集中发现了一些错误的数据(标签错误的示例),并且在固定源代码的同时,我想继续尝试使用相同的数据集,因此我需要跳过这些记录。
我正在使用TFRecordReader并使用parse_single_example&shuffle_batch加载。我可以在某处提供过滤器吗?
在文档中使用tf.train.shuffle_batch()
和做了简短的参考enqueue_many=True
。如果您可以使用图形操作确定示例是否贴错标签,则可以像这样过滤结果(改编自另一个SO答案):
X, y = tf.parse_single_example(...)
is_correctly_labelled = correctly_labelled(X, y)
X = tf.expand_dims(X, 0)
y = tf.expand_dims(y, 0)
empty = tf.constant([], tf.int32)
X, y = tf.cond(is_correctly_labelled,
lambda: [X, y],
lambda: [tf.gather(X, empty), tf.gather(y, empty)])
Xs, ys = tf.train.shuffle_batch(
[X, y], batch_size, capacity, min_after_dequeue,
enqueue_many=True)
这tf.gather
只是获取零尺寸切片的一种方法。在numpy中,它将是X[[], ...]
。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句