Tensorfow 2.0.0-alpha0でテキスト生成モデル(RNN)を実行しています。モデルをフィッティングすると損失メトリックが得られますが、精度を挿入すると次のエラーが発生します。
InvalidArgumentError:互換性のない形状:[64]と[64,200]
[[{{nodemetrics_4 / accuracy / Equal}}]] [Op:__ inference_keras_scratch_graph_6491]
単一のバッチで精度を手動で定義しようとしました(事前トレーニング):
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def accuracy(labels, logits):
return tf.keras.metrics.sparse_categorical_accuracy(labels,l ogits)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
example_batch_acc = accuracy(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Loss: ", example_batch_loss.numpy().mean())
print("Accuracy: ", example_batch_acc.numpy().mean())
出力は次のとおりです。
予測形状:(64、200、34)#(batch_size、sequence_length、vocab_size)損失:3.5263805精度:0.01265625
それから私は続いた:
optimizer = tf.keras.optimizers.RMSprop(lr=lr)
model.compile(optimizer=optimizer, loss=loss, metrics =['accuracy'])
history = model.fit(dataset, epochs=epochs, callbacks[checkpoint_callback])
上記のエラーが発生しました(損失は正常に機能します)。コンパイル内で「accuracy = accure」を試してみると、次のようになります。
raise ValueError( 'セッションキーワード引数は、熱心な実行中はサポートされていません。渡されました:%s'%(kwargs、))
何か考え/提案はありますか?
accuracy
はの標準的な引数ではありませんModel.fit
-それは受け入れられ、グラフモードで**kwargs
渡されsession.run
ます。試してみてくださいmetrics=[accuracy]
。
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加