tf.tensorから1つの操作で複数の列を抽出します

レゼブロン

最近のTensorFlow(1.13または2.0)では、1回のパスでテンソルから非連続スライスを抽出する方法はありますか?どうやるか?たとえば、次のテンソルを使用します。

1 2 3 4
5 6 7 8 

1つの操作で列1と3を抽出して、次のようにします。

2 4
6 8

しかし、スライスを使った1回の操作ではできないようです。これを行うための正しい/最速/最もエレガントな方法は何ですか?

ウラド

1.tf.gather(tensor, columns, axis=1)TF1.xTF2)の使用

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3]

print(tf.gather(tensor, columns, axis=1).numpy())
%timeit -n 10000 tf.gather(tensor, columns, axis=1)
# [[2. 4.]
#  [6. 8.]]
82.6 µs ± 5.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

2.インデックス付き(TF1.xTF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
# print(stacked.numpy()) # <-- TF2, TF1.x-eager

with tf.Session() as sess:  # <-- TF1.x
    print(sess.run(stacked))
# [[2. 4.]
#  [6. 8.]]

それを関数にラップして実行%timeittf.__version__=='2.0.0-alpha0'ます:

154 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

それを飾ること@tf.functionは2倍以上速いです:

import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
@tf.function
def extract_columns(tensor=tensor, columns=columns):
    transposed = tf.transpose(tensor)
    sliced = [transposed[c] for c in columns]
    stacked = tf.transpose(tf.stack(sliced, axis=0))
    return stacked

%timeit -n 10000 extract_columns()
66.8 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

3.熱心な実行のためのワンライナーTF2TF1.x-eager):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

res = tf.transpose(tf.stack([t for i, t in enumerate(tf.transpose(tensor))
                             if i in columns], 0))
print(res.numpy())
# [[2. 4.]
#  [6. 8.]]

%timeittf.__version__=='2.0.0-alpha0'

242 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

4.を使用tf.one_hot()して行/列を指定し、次にtf.boolean_mask()これらの行/列を抽出します(TF1.xTF2):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

mask = tf.one_hot(columns, tensor.get_shape().as_list()[-1])
mask = tf.reduce_sum(mask, axis=0)
res = tf.transpose(tf.boolean_mask(tf.transpose(tensor), mask))
# print(res.numpy()) # <-- TF2, TF1.x-eager

with tf.Session() as sess: # TF1.x
    print(sess.run(res))
# [[2. 4.]
#  [6. 8.]]

%timeittf.__version__=='2.0.0-alpha0'

494 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

分類Dev

TF 2.0 @ tf.functionの例

分類Dev

tf.constantとtf.convert_to_tensorの違いは何ですか

分類Dev

tf__norm()は1つの位置引数を取りますが、2つが与えられました

分類Dev

TF 2.0のtf.GradientTapeはtf.gradientsと同等ですか?

分類Dev

TensorFlow:tf.Datasetをtf.Tensorに変換します

分類Dev

PyTorchでのtf.concat操作

分類Dev

TFを使用してKerasで複数のTPUを使用する

分類Dev

複数のtf.data.Datasetを混在させていますか?

分類Dev

tf.scanはTensorの形状を覆い隠します

分類Dev

1つの大きなXMLを複数のtfリソースに分割する

分類Dev

tf.map_fn(...)を複数の入力/出力に適用できますか?

分類Dev

dtypeが文字列であるtf.tensorから文字列値を取得する方法

分類Dev

'Tensor' object has no attribute 'numpy' in tf.function in TF 2.0

分類Dev

tf TensorからNumpy配列を取得する方法は?

分類Dev

TF2.0の `tf.Module`で変数を初期化する方法

分類Dev

tf.data.Dataset.zip((images、labels))から2つのtf.datasetを取得する方法

分類Dev

`tf.keras.Model.compile`はTF2.0で何をしますか?

分類Dev

`tf.keras.losses`と` tf.losses`、または `tf.keras.optimizers`と` tf.optimizers`の違いは何ですか?

分類Dev

@ tf.functionで装飾された関数内のtf.data.Dataset上のforループを使用して、tf.Variableを操作して返す方法は?

分類Dev

tf.Tensorが可変かどうかを確認します

分類Dev

tensorflowのtf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

分類Dev

tf.Tensorのゼロ以外の値のみを評価します

分類Dev

なぜtf.convert_to_tensorが必要なのですか?

分類Dev

tf.train.Serverを使用して複数のtf.Session()を並行して実行する必要があるのはなぜですか?

分類Dev

tf.image_summaryで複数の画像を表示する方法

分類Dev

tf.zeros()はtf.get_variable()を返しますか?

分類Dev

tf.gradientsはtf.condを通過しますか?

分類Dev

TF1からTF2で記述されたmatmulベースのnnを実装する方法

分類Dev

なぜ `が見つからないのですか。-名前 '* .jar' | xargs jar tf`は機能しますか?

Related 関連記事

  1. 1

    TF 2.0 @ tf.functionの例

  2. 2

    tf.constantとtf.convert_to_tensorの違いは何ですか

  3. 3

    tf__norm()は1つの位置引数を取りますが、2つが与えられました

  4. 4

    TF 2.0のtf.GradientTapeはtf.gradientsと同等ですか?

  5. 5

    TensorFlow:tf.Datasetをtf.Tensorに変換します

  6. 6

    PyTorchでのtf.concat操作

  7. 7

    TFを使用してKerasで複数のTPUを使用する

  8. 8

    複数のtf.data.Datasetを混在させていますか?

  9. 9

    tf.scanはTensorの形状を覆い隠します

  10. 10

    1つの大きなXMLを複数のtfリソースに分割する

  11. 11

    tf.map_fn(...)を複数の入力/出力に適用できますか?

  12. 12

    dtypeが文字列であるtf.tensorから文字列値を取得する方法

  13. 13

    'Tensor' object has no attribute 'numpy' in tf.function in TF 2.0

  14. 14

    tf TensorからNumpy配列を取得する方法は?

  15. 15

    TF2.0の `tf.Module`で変数を初期化する方法

  16. 16

    tf.data.Dataset.zip((images、labels))から2つのtf.datasetを取得する方法

  17. 17

    `tf.keras.Model.compile`はTF2.0で何をしますか?

  18. 18

    `tf.keras.losses`と` tf.losses`、または `tf.keras.optimizers`と` tf.optimizers`の違いは何ですか?

  19. 19

    @ tf.functionで装飾された関数内のtf.data.Dataset上のforループを使用して、tf.Variableを操作して返す方法は?

  20. 20

    tf.Tensorが可変かどうかを確認します

  21. 21

    tensorflowのtf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

  22. 22

    tf.Tensorのゼロ以外の値のみを評価します

  23. 23

    なぜtf.convert_to_tensorが必要なのですか?

  24. 24

    tf.train.Serverを使用して複数のtf.Session()を並行して実行する必要があるのはなぜですか?

  25. 25

    tf.image_summaryで複数の画像を表示する方法

  26. 26

    tf.zeros()はtf.get_variable()を返しますか?

  27. 27

    tf.gradientsはtf.condを通過しますか?

  28. 28

    TF1からTF2で記述されたmatmulベースのnnを実装する方法

  29. 29

    なぜ `が見つからないのですか。-名前 '* .jar' | xargs jar tf`は機能しますか?

ホットタグ

アーカイブ