LSTMで表現するprobs
形状のテンソルがあります。の3次元から選択する要素のインデックスを含む形状を持つ別のテンソルがあります。(None, None, 110)
(batch_size, sequence_length, 110)
indices
(None, None)
probs
indices
テンソルのインデックス付けに使用したいprobs
。
Numpy相当:
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
以来shape[0]
とshape[1]
のprobs
知られていないが、tf.meshgrid()
オプションではありません。私が見つかりましたtf.gather
、tf.gather_nd
そしてtf.batch_gather
、彼らはすべて私がやりたいように思えません。
誰かがこれを行う方法を知っていますか?
あなたはtf.gather_nd
このようにそれを行うことができます:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
ちなみに、NumPynp.take_along_axis
では同じことを行うために使用できます:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加