最近のTensorFlow(1.13
または2.0
)では、1回のパスでテンソルから非連続スライスを抽出する方法はありますか?どうやるか?たとえば、次のテンソルを使用します。
1 2 3 4
5 6 7 8
1つの操作で列1と3を抽出して、次のようにします。
2 4
6 8
しかし、スライスを使った1回の操作ではできないようです。これを行うための正しい/最速/最もエレガントな方法は何ですか?
1.tf.gather(tensor, columns, axis=1)
(TF1.x
、TF2
)の使用:
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3]
print(tf.gather(tensor, columns, axis=1).numpy())
%timeit -n 10000 tf.gather(tensor, columns, axis=1)
# [[2. 4.]
# [6. 8.]]
82.6 µs ± 5.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2.インデックス付き(TF1.x
、TF2
):
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
# print(stacked.numpy()) # <-- TF2, TF1.x-eager
with tf.Session() as sess: # <-- TF1.x
print(sess.run(stacked))
# [[2. 4.]
# [6. 8.]]
それを関数にラップして実行%timeit
しtf.__version__=='2.0.0-alpha0'
ます:
154 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
それを飾ること@tf.function
は2倍以上速いです:
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
@tf.function
def extract_columns(tensor=tensor, columns=columns):
transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
return stacked
%timeit -n 10000 extract_columns()
66.8 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3.熱心な実行のためのワンライナー(TF2
、TF1.x-eager
):
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
res = tf.transpose(tf.stack([t for i, t in enumerate(tf.transpose(tensor))
if i in columns], 0))
print(res.numpy())
# [[2. 4.]
# [6. 8.]]
%timeit
でtf.__version__=='2.0.0-alpha0'
:
242 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
4.を使用tf.one_hot()
して行/列を指定し、次にtf.boolean_mask()
これらの行/列を抽出します(TF1.x
、TF2
):
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
mask = tf.one_hot(columns, tensor.get_shape().as_list()[-1])
mask = tf.reduce_sum(mask, axis=0)
res = tf.transpose(tf.boolean_mask(tf.transpose(tensor), mask))
# print(res.numpy()) # <-- TF2, TF1.x-eager
with tf.Session() as sess: # TF1.x
print(sess.run(res))
# [[2. 4.]
# [6. 8.]]
%timeit
でtf.__version__=='2.0.0-alpha0'
:
494 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加