Tensorflow2のこのサンプルコード
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in range(100):
my_func(step)
writer.flush()
しかし、それは警告を投げています。
警告:tensorflow:トリガーされたtf.functionリトレースへの最後の5回の呼び出しのうち5回。トレースはコストがかかり、トレースの数が多すぎるのは、テンソルの代わりにpythonオブジェクトを渡すことが原因である可能性があります。また、tf.functionにはexperimental_relax_shapes = Trueオプションがあり、引数の形状を緩和して、不要な再トレースを回避できます。詳細については、https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_argsおよびhttps://www.tensorflow.org/api_docs/python/tf/functionを参照してください。
これを行うためのより良い方法はありますか?
tf.function
いくつかの「特異性」があります。この記事を読むことを強くお勧めします:https://www.tensorflow.org/tutorials/customization/performance
この場合、問題は、異なる入力シグネチャで呼び出すたびに、関数が「リトレース」される(つまり、新しいグラフが作成される)ことです。テンソルの場合、入力シグネチャは形状とdtypeを参照しますが、Python番号の場合、すべての新しい値は「異なる」と解釈されます。この場合、step
毎回変化する変数を使用して関数を呼び出すため、関数も毎回リトレースされます。これは、「実際の」コード(関数内でモデルを呼び出すなど)では非常に遅くなります。
step
テンソルに変換するだけで修正できます。その場合、異なる値は新しい入力署名としてカウントされません。
for step in range(100):
step = tf.convert_to_tensor(step, dtype=tf.int64)
my_func(step)
writer.flush()
またはtf.range
、テンソルを直接取得するために使用します。
for step in tf.range(100):
step = tf.cast(step, tf.int64)
my_func(step)
writer.flush()
これは警告を生成するべきではありません(そしてはるかに速くなります)。
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加