Kerasジェネレーターをマイモデルの__call__メソッドに渡す

サージ・イオネスク

私は独自のkerasモデルを作成し、model.fitへの入力としてKerasジェネレーターを渡そうとしています。問題は、MyModelの呼び出しメソッドを使用しているときに、ジェネレーターを処理する方法がわからないことですエンコーダーとデコーダーネットワークへの入力としてそれらを渡すために、ジェネレーターからxとyにアクセスし、ジェネレーターを魔法のように動作させ続け、各エポックにバッチをロードするにはどうすればよいですか?

さて、tf.keras.Modelを継承するこのMyModelクラス

class MyModel(tf.keras.Model):

def __init__(self):
    super(MyModel, self).__init__()
    self.enc = Encoder()
    self.dec1 = Decoder1()
    self.dec2 = Decoder2()

def __call__(self, data_generator, **kwargs):

    ################################################ 
    ? how do I acces x and y in order to pass them to the encoder and decoder ? 
    and also keep the generator proprieties
    ###############################################
    x_train, y_train = data_generator # ?????????
    #####################################

    dec_inputs = tf.concat((tf.zeros_like(y_train[:, :1, :]), y_train[:, :-1, :]), 1)  
                                                                                      
    dec_inputs = dec_inputs[:, :, -hp.n_mels:] 

    print("########ENC INPUTS #####")
    #print(tf.shape(x_train))
    print("######################")

    print("#########DEC INPUTS #####")
    #print(tf.shape(dec_inputs))
    print("######################")

    memory = self.enc(x_train)
    y_hat = self.dec1(dec_inputs, memory)
    #z_hat = self.dec2(y_hat)
    return y_hat

そしてこれが私のジェネレーター機能です

class DataGenerator(keras.utils.Sequence):

def __init__(self, list_IDs, ID_dictionary, labels, batch_size=8, dim1=(32, 32, 32), dim2=(32, 32, 32),
             n_channels=None, n_classes=None, shuffle=True):
    'Initialization'

    self.dim1 = dim1  # dimensiune X
    self.dim2 = dim2  # dimensiune Y
    self.batch_size = batch_size

    self.ID_dictionary = ID_dictionary
    self.labels = labels
    self.list_IDs = list_IDs

    self.n_channels = n_channels
    self.n_classes = n_classes

    self.shuffle = shuffle
    self.on_epoch_end()

def __len__(self):
    'Denotes the number of batches per epoch'
    return int(np.floor(len(self.list_IDs) / self.batch_size))

# 3
def __getitem__(self, index):
    'Generate one batch of data'
    # Generate indexes of the batch
    indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] 

    # Find list of IDs
    list_IDs_temp = [self.list_IDs[k] for k in indexes]         

    # Generate data
    x, y = self.__data_generation(list_IDs_temp)  

    return x, y

# 1
def on_epoch_end(self):
    'Updates indexes after each epoch'
    self.indexes = np.arange(len(self.list_IDs))  
    if self.shuffle == True:
        np.random.shuffle(self.indexes)

# 2
def __data_generation(self, list_IDs_temp):
    # Initialization            
    x = np.empty((self.batch_size, self.dim1)) 
    y = np.empty((self.batch_size, *self.dim2), dtype=float)

    # Generate data
    for i, ID in enumerate(list_IDs_temp):  
        # Store sample
        x[i, ] = self.ID_dictionary[ID]       
        # Store class
        y[i] = self.labels[ID]                  
    return x, y

そしてこれは私がメインでMyModelを呼び出す方法です

 listID, dict1, dict2, text_shape, mel_shape = get_batch()
 # dict1 has the inputs ( text ) and dict2 has the labels ( the mels )
 
 training_generator = DataGenerator(listID, dict1, dict2, dim1=text_shape, dim2=mel_shape)

 model = MyModel()


 model.compile(
      optimizer=keras.optimizers.Adam(),
      metrics=["accuracy"],

      )

  #model.fit_generator(generator=training_generator, use_multiprocessing=True, workers=6)
  model.fit(training_generator, epochs=2)
Sanchit.Jain

call methodはで呼び出されますmodel.fit。これは、への1つの入力を期待するx_inputため、model.fitメソッドを使用する場合、呼び出しメソッドの入力としてジェネレーターを期待することはできません。物事がどのように機能しているかをより深く理解するには、tensorflow.org / guide / keras / custom_layers_and_modelsとtensorflow.org/tutorials/text/nmt_with_attentionをお読みください。

編集1:呼び出しメソッドで2つの変数を渡す方法

# pass list [x,y] to your call function instead of only x, we will club x and y into one variable
def __call__(self, inputs):
    x = inputs[0]
    y = inputs[1]
    # now you can use x and y coming from your generator without changing much


# update your generator to return [x,y] and y
def generator
    yield [x, y], y

# simply call model.fit like you were doing before
model.fit(generator)

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

分類Dev

ジェネレーター/モック関数をレガシーモジュールのメソッドに渡します

分類Dev

djangoモデルマネージャーのget_querysetメソッドにパラメーターを渡す方法は?

分類Dev

Angular:$ http postメソッドを使用して大規模データを.ashxファイル(ジェネリックハンドラー)に渡す方法

分類Dev

ジェネリックパラメーターをメソッドに渡す

分類Dev

__call__メソッドをモジュール/パッケージに割り当てる方法は?

分類Dev

Keras:model.fit_generatorでマルチ出力モデルにジェネレーターを使用する

分類Dev

kerasモデルに関する混乱:__ call__ vs. call vs.predictメソッド

分類Dev

データ移行におけるモデルマネージャーメソッドへのアクセス

分類Dev

モデルクラスをネストされたメソッドに渡す際のSwiftジェネリックメソッドの問題

分類Dev

インターフェイスのデフォルトメソッドでSpringマネージドBeanを使用していますか?

分類Dev

キャメル-ルートからジェネリックBeanメソッドに特定のパラメーターを渡す

分類Dev

Windowsの仮想メモリマネージャがメモリマップトファイルデータをフェッチする方法の明確化

分類Dev

__call__メソッドのパラメーターとして* argsと** kargsを使用するデコレーター関数。

分類Dev

2(マルチ)レイヤーのジェネリック型の仮想メソッドを呼び出すときに、追加の型パラメーターを削除するにはどうすればよいですか?

分類Dev

ジェネリックメソッドの具象型パラメーターを渡すときにジェネリックIEnumerableを返す

分類Dev

クラス内のジェネレーターメソッドをオーバーライドする

分類Dev

メソッドのローカル呼び出しにJavaジェネリックLambdaパラメーターを渡す

分類Dev

adminでカスタムモデルマネージャーとオーバーライドされた削除メソッドを無視する方法は?

分類Dev

共通のcreateValueメソッドを定義するパラメーターを持つケースクラスのシェイプレスのジェネリックデフォルトインスタンスを使用してビルドします

分類Dev

Djangoモデルマネージャーのメソッドをモックするにはどうすればよいですか?

分類Dev

検索を実行するために、ジェネリックセレクターFunc <T、U>を何らかのメソッドに渡す方法は?

分類Dev

オブジェクトをパラメータとしてハンドルバーテンプレートのonclickメソッドに渡します

分類Dev

ジェネリックメソッドをパラメーターとして別のメソッドに渡す

分類Dev

ハドソンジョブのシェルコマンドにパラメータを渡す方法

分類Dev

ジェネリックパラメーターを受け入れるメソッドにRawタイプのコレクションオブジェクトが渡されると、ジェネリックIterator <E>の動作が異なる

分類Dev

JavaScriptジェネレーターのthrowメソッドを理解する方法は?

分類Dev

ループバック:リモートメソッドで複数のオブジェクトタイプを渡す

分類Dev

Vue.jsのモデルからメソッドにデータを渡すことができません

分類Dev

あるメソッドから別のメソッドレールに値パラメーターを渡す

Related 関連記事

  1. 1

    ジェネレーター/モック関数をレガシーモジュールのメソッドに渡します

  2. 2

    djangoモデルマネージャーのget_querysetメソッドにパラメーターを渡す方法は?

  3. 3

    Angular:$ http postメソッドを使用して大規模データを.ashxファイル(ジェネリックハンドラー)に渡す方法

  4. 4

    ジェネリックパラメーターをメソッドに渡す

  5. 5

    __call__メソッドをモジュール/パッケージに割り当てる方法は?

  6. 6

    Keras:model.fit_generatorでマルチ出力モデルにジェネレーターを使用する

  7. 7

    kerasモデルに関する混乱:__ call__ vs. call vs.predictメソッド

  8. 8

    データ移行におけるモデルマネージャーメソッドへのアクセス

  9. 9

    モデルクラスをネストされたメソッドに渡す際のSwiftジェネリックメソッドの問題

  10. 10

    インターフェイスのデフォルトメソッドでSpringマネージドBeanを使用していますか?

  11. 11

    キャメル-ルートからジェネリックBeanメソッドに特定のパラメーターを渡す

  12. 12

    Windowsの仮想メモリマネージャがメモリマップトファイルデータをフェッチする方法の明確化

  13. 13

    __call__メソッドのパラメーターとして* argsと** kargsを使用するデコレーター関数。

  14. 14

    2(マルチ)レイヤーのジェネリック型の仮想メソッドを呼び出すときに、追加の型パラメーターを削除するにはどうすればよいですか?

  15. 15

    ジェネリックメソッドの具象型パラメーターを渡すときにジェネリックIEnumerableを返す

  16. 16

    クラス内のジェネレーターメソッドをオーバーライドする

  17. 17

    メソッドのローカル呼び出しにJavaジェネリックLambdaパラメーターを渡す

  18. 18

    adminでカスタムモデルマネージャーとオーバーライドされた削除メソッドを無視する方法は?

  19. 19

    共通のcreateValueメソッドを定義するパラメーターを持つケースクラスのシェイプレスのジェネリックデフォルトインスタンスを使用してビルドします

  20. 20

    Djangoモデルマネージャーのメソッドをモックするにはどうすればよいですか?

  21. 21

    検索を実行するために、ジェネリックセレクターFunc <T、U>を何らかのメソッドに渡す方法は?

  22. 22

    オブジェクトをパラメータとしてハンドルバーテンプレートのonclickメソッドに渡します

  23. 23

    ジェネリックメソッドをパラメーターとして別のメソッドに渡す

  24. 24

    ハドソンジョブのシェルコマンドにパラメータを渡す方法

  25. 25

    ジェネリックパラメーターを受け入れるメソッドにRawタイプのコレクションオブジェクトが渡されると、ジェネリックIterator <E>の動作が異なる

  26. 26

    JavaScriptジェネレーターのthrowメソッドを理解する方法は?

  27. 27

    ループバック:リモートメソッドで複数のオブジェクトタイプを渡す

  28. 28

    Vue.jsのモデルからメソッドにデータを渡すことができません

  29. 29

    あるメソッドから別のメソッドレールに値パラメーターを渡す

ホットタグ

アーカイブ