Bert(huggingface)分類子を使用したtf.kerasモデルの保存の問題

ボルツマン

Bert(huggingface)を使用するバイナリ分類器をトレーニングしています。モデルは次のようになります。

def get_model(lr=0.00001):
    inp_bert = Input(shape=(512), dtype="int32")
    bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0]
    doc_encodings = tf.squeeze(bert[:, 0:1, :], axis=1)
    out = Dense(1, activation="sigmoid")(doc_encodings)
    model = Model(inp_bert, out)
    adam = optimizers.Adam(lr=lr)
    model.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])
    return model

分類タスクを微調整した後、モデルを保存したいと思います。

model.save("best_model.h5")

ただし、これによりNotImplementedErrorが発生します。

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-55-8c5545f0cd9b> in <module>()
----> 1 model.save("best_spam.h5")
      2 # import transformers

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    110           'or using `save_weights`.')
    111     hdf5_format.save_model_to_hdf5(
--> 112         model, filepath, overwrite, include_optimizer)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
     97 
     98   try:
---> 99     model_metadata = saving_utils.model_metadata(model, include_optimizer)
    100     for k, v in model_metadata.items():
    101       if isinstance(v, (dict, list, tuple)):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    163   except NotImplementedError as e:
    164     if require_config:
--> 165       raise e
    166 
    167   metadata = dict(

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    160   model_config = {'class_name': model.__class__.__name__}
    161   try:
--> 162     model_config['config'] = model.get_config()
    163   except NotImplementedError as e:
    164     if require_config:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    885     if not self._is_graph_network:
    886       raise NotImplementedError
--> 887     return copy.deepcopy(get_network_config(self))
    888 
    889   @classmethod

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1940           filtered_inbound_nodes.append(node_data)
   1941 
-> 1942     layer_config = serialize_layer_fn(layer)
   1943     layer_config['name'] = layer.name
   1944     layer_config['inbound_nodes'] = filtered_inbound_nodes

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    138   if hasattr(instance, 'get_config'):
    139     return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140                                             instance.get_config())
    141   if hasattr(instance, '__name__'):
    142     return instance.__name__

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    884   def get_config(self):
    885     if not self._is_graph_network:
--> 886       raise NotImplementedError
    887     return copy.deepcopy(get_network_config(self))
    888 

NotImplementedError: 

huggingfaceがTFBertModelのmodel.save_pretrained()メソッドを提供することは知っていますが、このネットワークに他のコンポーネント/機能を追加する予定なので、tf.keras.Modelでラップすることを好みます。誰かが現在のモデルを保存するための解決策を提案できますか?

Ashwin Geet D'Sa

これは確かにtensorflow2.0の問題です。

使ってください: model.save("model_name",save_format='tf')

または、テンソルフローをアップグレードまたはダウングレードしてみることもできます。

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

分類Dev

空白のネストされたモデルの保存の問題

分類Dev

TensorflowでKerasモデルを実装する際の問題

分類Dev

Keras:モデルの保存を早期に停止

分類Dev

Kerasモデルでmodel.predict()の正しい入力がある問題

分類Dev

keras ResNet50モデルを使用したバイナリ分類の出力層

分類Dev

joblib / pickleを使用して保存されたMLモデルの読み込みの問題

分類Dev

私自身の訓練されたKerasモデルでの私の画像の予測における問題

分類Dev

既存のTensorflowモデルを使用した予測の問題

分類Dev

既存のTensorflowモデルを使用した予測の問題

分類Dev

Django RestFrameworkを使用したスルーモデルの問題

分類Dev

Tensorflow2.2.0とKerasがモデルを保存/モデルをロードする問題

分類Dev

マルチラベル問題のケラスモデルを使用したscikitlearnチェーン分類器の適合方法のエラー

分類Dev

jsonschema2pojoを使用したモデルの問題

分類Dev

KerasのEarlyStoppingは最良のモデルを保存しますか?

分類Dev

Androidでの分類にKerasモデルを使用する

分類Dev

再トレーニングされたテンソルフローモデルの保存に関する問題

分類Dev

Firebaseデータベースの子作成の問題(モデルを使用)

分類Dev

TF-IDFスコアを使用したテキスト分類のKNN

分類Dev

Tensorflow 2.0 / Kerasの他のデータ機能を使用したテキスト分類子の作成

分類Dev

カスタムレイヤーを使用したKerasモデルの保存

分類Dev

Kerasで最高のウェイトとモデルを保存する

分類Dev

モデルをkerasに保存する際のNonetypeエラー

分類Dev

特定のエポックでKerasモデルを保存する

分類Dev

メモリの問題:大量のデータをマップに保存する

分類Dev

R-Caretの回帰問題における「分類のための間違ったモデルタイプ」

分類Dev

変換されたGLTF2.0モデルを使用したTHREEJSの問題

分類Dev

分類タスクにhuggingfaceのpytorch-transformersGPT-2を使用する

分類Dev

複数の出力を持つモデルに対してtrain_on_batchを試行するときのKerasのsample_weightの問題

分類Dev

Fluentを使用して2つのモデルを保存し、データベースからレコードを取得する際の問題

Related 関連記事

  1. 1

    空白のネストされたモデルの保存の問題

  2. 2

    TensorflowでKerasモデルを実装する際の問題

  3. 3

    Keras:モデルの保存を早期に停止

  4. 4

    Kerasモデルでmodel.predict()の正しい入力がある問題

  5. 5

    keras ResNet50モデルを使用したバイナリ分類の出力層

  6. 6

    joblib / pickleを使用して保存されたMLモデルの読み込みの問題

  7. 7

    私自身の訓練されたKerasモデルでの私の画像の予測における問題

  8. 8

    既存のTensorflowモデルを使用した予測の問題

  9. 9

    既存のTensorflowモデルを使用した予測の問題

  10. 10

    Django RestFrameworkを使用したスルーモデルの問題

  11. 11

    Tensorflow2.2.0とKerasがモデルを保存/モデルをロードする問題

  12. 12

    マルチラベル問題のケラスモデルを使用したscikitlearnチェーン分類器の適合方法のエラー

  13. 13

    jsonschema2pojoを使用したモデルの問題

  14. 14

    KerasのEarlyStoppingは最良のモデルを保存しますか?

  15. 15

    Androidでの分類にKerasモデルを使用する

  16. 16

    再トレーニングされたテンソルフローモデルの保存に関する問題

  17. 17

    Firebaseデータベースの子作成の問題(モデルを使用)

  18. 18

    TF-IDFスコアを使用したテキスト分類のKNN

  19. 19

    Tensorflow 2.0 / Kerasの他のデータ機能を使用したテキスト分類子の作成

  20. 20

    カスタムレイヤーを使用したKerasモデルの保存

  21. 21

    Kerasで最高のウェイトとモデルを保存する

  22. 22

    モデルをkerasに保存する際のNonetypeエラー

  23. 23

    特定のエポックでKerasモデルを保存する

  24. 24

    メモリの問題:大量のデータをマップに保存する

  25. 25

    R-Caretの回帰問題における「分類のための間違ったモデルタイプ」

  26. 26

    変換されたGLTF2.0モデルを使用したTHREEJSの問題

  27. 27

    分類タスクにhuggingfaceのpytorch-transformersGPT-2を使用する

  28. 28

    複数の出力を持つモデルに対してtrain_on_batchを試行するときのKerasのsample_weightの問題

  29. 29

    Fluentを使用して2つのモデルを保存し、データベースからレコードを取得する際の問題

ホットタグ

アーカイブ