Tensorflow 2中的fit方法中使用Dataset和ndarray有什么区别?

Zhongzheng_11

作为TF的新手,我对在训练模型时使用BatchDataset感到有些困惑。

让我们以MNIST为例。在此分类任务中,我们可以加载数据并将x_trian,y_train的ndarray直接输入模型中。

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train,y_train, epochs=5)

培训结果是:

Epoch 1/5
2021-02-17 15:43:02.621749: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library cublas64_10.dll
   1/1875 [..............................] - ETA: 0s - loss: 2.2977 - accuracy: 0.0938WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0000s vs `on_train_batch_end` time: 0.0010s). Check your callbacks.
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3047 - accuracy: 0.9117
Epoch 2/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.1473 - accuracy: 0.9569
Epoch 3/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.1097 - accuracy: 0.9673
Epoch 4/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0905 - accuracy: 0.9724
Epoch 5/5
1875/1875 [==============================] - 2s 1ms/step - loss: 0.0759 - accuracy: 0.9764

我们还可以使用tf.data.Dataset.from_tensor_slices生成BatchDataset并将其输入以适合函数。

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_ds, epochs=5)

训练过程的结果如下。

Epoch 1/5
2021-02-17 15:30:34.698718: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library cublas64_10.dll
1875/1875 [==============================] - 3s 1ms/step - loss: 0.2969 - accuracy: 0.9140
Epoch 2/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.1462 - accuracy: 0.9566
Epoch 3/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.1087 - accuracy: 0.9669
Epoch 4/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0881 - accuracy: 0.9730
Epoch 5/5
1875/1875 [==============================] - 3s 1ms/step - loss: 0.0765 - accuracy: 0.9759

可以使用两种方法成功训练模型,但是两者之间有什么区别吗?使用数据集进行培训是否还有其他优势?如果在这种情况下这两种方法之间没有区别,那么生成用于训练的数据集的典型用法是什么?何时应使用此方法?

谢谢你。

因纳特

使用时Model.fit(x=None, y=None, ...,我们可以将训练对参数传递为纯numpy数组或keras.utils.Sequencetf.data

当我们按如下方式使用时,我们将每个训练对(xy)作为直接的numpy数组分别传递给该fit函数。

# data 
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

# fit
model.fit(x = x_train, y = y_train, ... 

# check
print(x_train.shape, y_train.shape)
print(type(x_train), type(y_train))

# (60000, 28, 28) (60000,)
# <class 'numpy.ndarray'> <class 'numpy.ndarray'>

在另一方面tf.dataSequence我们通过培训对作为元组的形状,目前仍是数据类型是ndarray根据文档

  • 一个tf.data数据集。应该返回的任何一个元组(inputstargets
  • 生成器或keras.utils.Sequence返回(inputstargets

IE

# data
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(2)

# check
next(iter(train_ds))

(<tf.Tensor: shape=(2, 28, 28), dtype=uint8, numpy= array([[[...], [[...]]], dtype=uint8)>,
 <tf.Tensor: shape=(2,), dtype=uint8, numpy=array([7, 8], dtype=uint8)>)

这就是为什么,如果xtf.datageneratorkeras.utils.Sequence实例,y不应该被指定(因为目标将从中获得x)。

# fit 
model.fit(train_ds, ...

在这三种方法中,tf.data数据管道是紧随其后的最有效方法generator当数据集足够小时,将首先选择第一种方法(xy)。但是,当数据集足够大时,您将考虑tf.datagenerator寻求有效的输入管道。因此,这些的选择完全取决于。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

使用CPU的“ Keras后端+ Tensorflow”和“来自Tensorflow的Keras”之间有什么区别(在Tensorflow 2.x中)

来自分类Dev

在 tensorflow 中使用 session 和只使用 session 有什么区别

来自分类Dev

在Tensorflow中,变量和张量之间有什么区别?

来自分类Dev

tensorflow.nn.np和numpy有什么区别?

来自分类Dev

Pytorch中的dim和Tensorflow中的axis有什么区别?

来自分类Dev

tf.sub和tensorflow中的正负运算之间有什么区别?

来自分类Dev

在Tensorflow中,sampled_softmax_loss和softmax_cross_entropy_with_logits有什么区别

来自分类Dev

PyPi中的tf-nightly和tensorflow有什么区别?

来自分类Dev

Tensorflow 中的 tf.variable_scope 和 variable_scope.variable_scope 有什么区别?

来自分类Dev

scikit-learn和tensorflow有什么区别?可以一起使用吗?

来自分类Dev

Android TensorFlow 支持和 TensorFlow Lite for Android 有什么区别?

来自分类Dev

为什么在Tensorflow 2中使用tf.GradientTape进行训练与使用fit API进行训练有不同的行为?

来自分类Dev

Tensorflow.js图层模型和Graph模型有什么区别?

来自分类Dev

TensorFlow 2.0:sparse_categorical_crossentropy和SparseCategoricalCrossentropy有什么区别?

来自分类Dev

tf-nightly-gpu和tensorflow-gpu有什么区别

来自分类Dev

“导入 keras”和“导入 tensorflow.keras”有什么区别

来自分类Dev

在Tensorflow中,类型以_ref结尾的张量与没有_ref结尾的张量之间有什么区别?

来自分类Dev

TensorFlow中的这两个命令有什么区别

来自分类Dev

在TensorFlow 2.0中使用tf.Dataset进行训练

来自分类Dev

TensorFlow`fit()`方法上的`AssertionError`

来自分类Dev

tensorflow_hub 和 tensorflow 1.10 的问题

来自分类Dev

什么是tensorflow float ref?

来自分类Dev

什么是tensorflow.matmul?

来自分类Dev

Tensorflow 和 Pycharm

来自分类Dev

在Windows的IDE中使用Tensorflow

来自分类Dev

在Tensorflow函数中使用@登录

来自分类Dev

无法在PyCharm中使用tensorflow

来自分类Dev

Tensorflow 2:嵌套TensorArray

来自分类Dev

Tensorflow中的多维聚集

Related 相关文章

  1. 1

    使用CPU的“ Keras后端+ Tensorflow”和“来自Tensorflow的Keras”之间有什么区别(在Tensorflow 2.x中)

  2. 2

    在 tensorflow 中使用 session 和只使用 session 有什么区别

  3. 3

    在Tensorflow中,变量和张量之间有什么区别?

  4. 4

    tensorflow.nn.np和numpy有什么区别?

  5. 5

    Pytorch中的dim和Tensorflow中的axis有什么区别?

  6. 6

    tf.sub和tensorflow中的正负运算之间有什么区别?

  7. 7

    在Tensorflow中,sampled_softmax_loss和softmax_cross_entropy_with_logits有什么区别

  8. 8

    PyPi中的tf-nightly和tensorflow有什么区别?

  9. 9

    Tensorflow 中的 tf.variable_scope 和 variable_scope.variable_scope 有什么区别?

  10. 10

    scikit-learn和tensorflow有什么区别?可以一起使用吗?

  11. 11

    Android TensorFlow 支持和 TensorFlow Lite for Android 有什么区别?

  12. 12

    为什么在Tensorflow 2中使用tf.GradientTape进行训练与使用fit API进行训练有不同的行为?

  13. 13

    Tensorflow.js图层模型和Graph模型有什么区别?

  14. 14

    TensorFlow 2.0:sparse_categorical_crossentropy和SparseCategoricalCrossentropy有什么区别?

  15. 15

    tf-nightly-gpu和tensorflow-gpu有什么区别

  16. 16

    “导入 keras”和“导入 tensorflow.keras”有什么区别

  17. 17

    在Tensorflow中,类型以_ref结尾的张量与没有_ref结尾的张量之间有什么区别?

  18. 18

    TensorFlow中的这两个命令有什么区别

  19. 19

    在TensorFlow 2.0中使用tf.Dataset进行训练

  20. 20

    TensorFlow`fit()`方法上的`AssertionError`

  21. 21

    tensorflow_hub 和 tensorflow 1.10 的问题

  22. 22

    什么是tensorflow float ref?

  23. 23

    什么是tensorflow.matmul?

  24. 24

    Tensorflow 和 Pycharm

  25. 25

    在Windows的IDE中使用Tensorflow

  26. 26

    在Tensorflow函数中使用@登录

  27. 27

    无法在PyCharm中使用tensorflow

  28. 28

    Tensorflow 2:嵌套TensorArray

  29. 29

    Tensorflow中的多维聚集

热门标签

归档