How to use Keras LSTM batch_input_size properly

peter bence

I'm using Keras framework to build a stacked LSTM model as follows:

model.add(layers.LSTM(units=32,
                      batch_input_shape=(1, 100, 64),
                      stateful=True,
                      return_sequences=True))
model.add(layers.LSTM(units=32, stateful=True, return_sequences=True))
model.add(layers.LSTM(units=32, stateful=True, return_sequences=False))
model.add(layers.Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(train_dataset,
          train_labels,
          epochs=1,
          validation_split = 0.2,
          verbose=1,
          batch_size=1,
          shuffle=False)

Knowing that the default batch_size for mode.fit, model.predict and model.evaluate is 32, the model forces me to change this default batch_size to the samebatch_size value used in batch_input_shape (batch_size, time_steps, input_dims).

My questions are:

  1. What is the difference between passing the batch_size into batch_input_shape or into the model.fit?
  2. Could I train with batch_size, lets say 10, and evaluate on a single batch (rather than 10 batches) if I passes the batch_size into the structure of the LSTM layer through batch_input_shape?
Vlad

Sequential()モデルを作成すると、任意のバッチサイズをサポートするように定義されます。特に、TensorFlow 1.*入力にはNone、最初の次元として次のようなプレースホルダーがあります。

import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=2, input_shape=(2, )))
print(model.inputs[0].get_shape().as_list()) # [None, 2] <-- supports any batch size
print(model.inputs[0].op.type == 'Placeholder') # True

を使用するtf.keras.InputLayer()場合は、次のように固定バッチサイズを定義できます。

import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer((2,), batch_size=50)) # <-- same as using batch_input_shape
model.add(tf.keras.layers.Dense(units=2, input_shape=(2, )))
print(model.inputs[0].get_shape().as_list()) # [50, 2] <-- supports only batch_size==50
print(model.inputs[0].op.type == 'Placeholder') # True

model.fit()メソッドのバッチサイズは、データをバッチに分割するために使用されます。たとえばInputLayer()、固定バッチサイズを使用および定義し、バッチサイズの異なる値をmodel.fit()メソッドに提供すると、次のようになりますValueError

import tensorflow as tf
import numpy as np

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer((2,), batch_size=2)) # <--batch_size==2
model.add(tf.keras.layers.Dense(units=2, input_shape=(2, )))
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='categorical_crossentropy')
x_train = np.random.normal(size=(10, 2))
y_train = np.array([[0, 1] for _ in range(10)])

model.fit(x_train, y_train, batch_size=3) # <--batch_size==3 

これにより、 ValueError: Thebatch_sizeが発生します。argument value 3 is incompatible with the specified batch size of your Input Layer: 2

要約すると、バッチサイズを定義すると、Noneトレーニングまたは評価のために任意の数のサンプルを渡すことができます。バッチに分割せずに一度にすべてのサンプルを渡すこともできます(データが大きすぎる場合は取得しますOutOfMemoryError)。固定バッチサイズを定義する場合は、トレーニングと評価に同じ固定バッチサイズを使用する必要があります。

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

分類Dev

How to use TF IDF vectorizer with LSTM in Keras Python

分類Dev

Keras LSTM training. How to shape my input data?

分類Dev

How to properly use IReadOnlyDictionary?

分類Dev

How to use Proxy properly?

分類Dev

How to use Keras TimeseriesGenerator

分類Dev

How do I fit the model of two concatenate LSTM in keras?

分類Dev

LSTM, Keras : How many layers should the inference model have?

分類Dev

The mathematical formulation of LSTM in Keras?

分類Dev

How to use keras embedding layer with 3D tensor input?

分類Dev

How to use fgets properly in a structure?

分類Dev

How to use the 'using' command properly?

分類Dev

How to use future functions properly

分類Dev

How to properly use async await

分類Dev

Train a model using lstm and keras

分類Dev

How to properly use Code Contracts in .NET Core

分類Dev

How to properly use Code Contracts in .NET Core

分類Dev

How to properly use Code Contracts in .NET Core

分類Dev

How to properly use pattern matching in java

分類Dev

How to properly use formGroupName in Angular forms

分類Dev

how to properly use jq to sort json output

分類Dev

How use C Dll in C# properly

分類Dev

How to properly use postUpdate, postRemove, postPersist in Doctrine?

分類Dev

How to use multiple (re-)inheritance properly

分類Dev

How use orWhere in laravel elequent properly?

分類Dev

Keras LSTMについて

分類Dev

Kerasの1対多のLSTM

分類Dev

Train Keras LSTM model with a variable number of features

分類Dev

Keras LSTM Larger Feature Overwhelm Smaller Ones?

分類Dev

python - Implementing an LSTM network with Keras and TensorFlow