我有一个A
带有形状的2D张量[batch_size, D]
和一个B
带有形状的1D张量[batch_size]
。的每个元素对于的每一行B
都是的列索引,例如。。A
A
B[i] in [0,D)
张量流中获取值的最佳方法是什么 A[B]
例如:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
具有所需的输出:
some_slice_func(A, B) -> [2,4]
还有另一个约束。实际上,batch_size
实际上是None
。
提前致谢!
我能够使用线性索引使其工作:
def vector_slice(A, B):
""" Returns values of rows i of A at column B[i]
where A is a 2D Tensor with shape [None, D]
and B is a 1D Tensor with shape [None]
with type int32 elements in [0,D)
Example:
A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
[3,4]]
"""
linear_index = (tf.shape(A)[1]
* tf.range(0,tf.shape(A)[0]))
linear_A = tf.reshape(A, [-1])
return tf.gather(linear_A, B + linear_index)
不过,这感觉有点。
如果有人知道更好(更清晰或更快速),请留下答案!(我暂时不会接受我自己的)
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句