如何给定形状为(batch_size,200,256)的张量索引以获得给定的长度= batch_size的索引张量列表(batch_size,1,256)?

吨胡同

我有形状(batch_size,200,256)的LSTM层的输出,其中200是令牌序列的长度,而256是LSTM输出尺寸。我还具有另一个形状为(batch_size)的张量,这是我要从批次中的每个样本序列中切出的标记索引的列表。

如果令牌索引不是-1,我将切出令牌向量表示形式(长度= 256)。如果令牌索引为-1,我将给出零个向量(长度= 256)。

预期的输出结果的形状为(batch_size,1,256)。我该怎么办?

谢谢

到目前为止,这是我尝试过的

bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) 
dropout = Dropout(params['dropout_rate'])(bidir)
def slice_by_tensor(x):
    matrix_to_slice = x[0]
    index_tensor = x[1]


    out_tensor = tf.where(index_tensor == -1, 
                          tf.zeros(tf.shape(tf.gather(matrix_to_slice, 
                                                      index_tensor, axis=1))), 
                          tf.gather(matrix_to_slice, index_tensor, axis=1))



    return out_tensor


representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) 
# stack_idx0 shape is (batch_size) 
# I got output with shape (batch_size, batch_size, 256) with this code
zihaozhihao
a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))
#     [[[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#      [[12, 13, 14, 15],
#      [16, 17, 18, 19],
#       [20, 21, 22, 23]]]

b=tf.constant([-1,2]) 

aa=tf.pad(a,[[0,0],[1,0],[0,0]]) 

bb=b+1 

index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) 
res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)
#[[[ 0,  0,  0,  0]],
#[[20, 21, 22, 23]]]

当index为-1时,我们需要像张量那样的零。因此,我们可以先沿第二个轴填充原始张量。然后将索引增加1。之后,使用tf.gather_nd将返回答案。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

Related 相关文章

热门标签

归档