作为TF的新手,我对在训练模型时使用BatchDataset感到有些困惑。
让我们以MNIST为例。在此分类任务中,我们可以加载数据并将x_trian,y_train的ndarray直接输入模型中。
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train,y_train, epochs=5)
培训结果是:
Epoch 1/5
2021-02-17 15:43:02.621749: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library cublas64_10.dll
1/1875 [..............................] - ETA: 0s - loss: 2.2977 - accuracy: 0.0938WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0000s vs `on_train_batch_end` time: 0.0010s). Check your callbacks.
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3047 - accuracy: 0.9117
Epoch 2/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.1473 - accuracy: 0.9569
Epoch 3/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.1097 - accuracy: 0.9673
Epoch 4/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0905 - accuracy: 0.9724
Epoch 5/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0759 - accuracy: 0.9764
我们还可以使用tf.data.Dataset.from_tensor_slices生成BatchDataset并将其输入以适合函数。
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
train_ds = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_ds, epochs=5)
训练过程的结果如下。
Epoch 1/5
2021-02-17 15:30:34.698718: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library cublas64_10.dll
1875/1875 [==============================] - 3s 1ms/step - loss: 0.2969 - accuracy: 0.9140
Epoch 2/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.1462 - accuracy: 0.9566
Epoch 3/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.1087 - accuracy: 0.9669
Epoch 4/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0881 - accuracy: 0.9730
Epoch 5/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0765 - accuracy: 0.9759
可以使用两种方法成功训练模型,但是两者之间有什么区别吗?使用数据集进行培训是否还有其他优势?如果在这种情况下这两种方法之间没有区别,那么生成用于训练的数据集的典型用法是什么?何时应使用此方法?
谢谢你。
使用时Model.fit(x=None, y=None, ...
,我们可以将训练对参数传递为纯numpy
数组或keras.utils.Sequence
或tf.data
。
当我们按如下方式使用时,我们将每个训练对(x
和y
)作为直接的numpy数组分别传递给该fit
函数。
# data
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
# fit
model.fit(x = x_train, y = y_train, ...
# check
print(x_train.shape, y_train.shape)
print(type(x_train), type(y_train))
# (60000, 28, 28) (60000,)
# <class 'numpy.ndarray'> <class 'numpy.ndarray'>
在另一方面tf.data
和Sequence
我们通过培训对作为元组的形状,目前仍是数据类型是ndarray
。根据文档,
tf.data
数据集。应该返回的任何一个元组(inputs
,targets
)keras.utils.Sequence
返回(inputs
,targets
)IE
# data
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(2)
# check
next(iter(train_ds))
(<tf.Tensor: shape=(2, 28, 28), dtype=uint8, numpy= array([[[...], [[...]]], dtype=uint8)>,
<tf.Tensor: shape=(2,), dtype=uint8, numpy=array([7, 8], dtype=uint8)>)
这就是为什么,如果x
是tf.data
,generator
或keras.utils.Sequence
实例,y
不应该被指定(因为目标将从中获得x
)。
# fit
model.fit(train_ds, ...
在这三种方法中,tf.data
数据管道是紧随其后的最有效方法generator
。当数据集足够小时,将首先选择第一种方法(x
和y
)。但是,当数据集足够大时,您将考虑tf.data
或generator
寻求有效的输入管道。因此,这些的选择完全取决于。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句