假设我有一个张量
[[0.3, 0.7],
[0.9, 0.1]]
如何在沿轴的最大位置创建一个 1.0 的张量,所以结果应该是轴 = 1
[[0., 1.],
[1., 0.]]
在我的情况下,第一维是批量大小,所以它是 '?'
提出的两个答案在内存/计算方面都是低效的。
您可以在线性时间内(no-matmul)计算它,而无需在一行中分配不必要的内存:
tf.cast(tf.equal(a, tf.reshape(tf.reduce_max(a, axis=1), (-1, 1))), tf.int16)
完整的例子在这里:
import tensorflow as tf
a = tf.constant([
[1, 9, 1, 6],
[6, 5, 0, 6],
[4, 0, 7, 6],
[1, 5, 9, 1]
])
b = tf.cast(tf.equal(a, tf.reshape(tf.reduce_max(a, axis=1), (-1, 1))), tf.int16)
with tf.Session() as sess:
print sess.run(b)
哪个会给你
[[0 1 0 0]
[1 0 0 1]
[0 0 1 0]
[0 0 1 0]]
正如你看到它使用广播在tf.equal
减少的内存数量。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句