TF2-GradientTape与Model.fit()-为什么GradientTape不起作用?

即使你反向传播

晚上好,

我想用tf2和Gradient Tape函数实现一个简单的回归问题的玩具示例。使用Model.fit,它可以正常学习,但是使用GradientTape可以做到,但与model.fit()相比,损失不会增加。这是我的示例代码和结果。我找不到问题。

model_opt = tf.keras.optimizers.Adam() 
loss_fn = tf.keras.losses.MeanSquaredError()
with tf.GradientTape() as tape:
    y = model(X, training=True)
    loss_value = loss_fn(y_true, y)
grads = tape.gradient(loss_value, model.trainable_variables)
model_opt.apply_gradients(zip(grads, model.trainable_variables))

#Results:
42.47433806265809
42.63973672226078
36.687397360178586
38.744844324717526
36.59080452300609
...

这是带有model.fit()的常规情况

model.compile(optimizer=tf.keras.optimizers.Adam(),loss=tf.keras.losses.MSE,metrics="mse")
...
model.fit(X,y_true,verbose=0)
#Results
[40.97759069299212]
[28.04145720307729]
[17.643483147375473]
[7.575242056454791]
[5.83682193867299]

精度应该大致相同,但看起来根本无法学习。输入X是张量,y_true也是。

编辑测试

import pathlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

dataset_path = keras.utils.get_file("auto-mpg.data", "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data")

column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
                'Acceleration', 'Model Year', 'Origin']
dataset = pd.read_csv(dataset_path, names=column_names,
                      na_values = "?", comment='\t',
                      sep=" ", skipinitialspace=True)

dataset = dataset.dropna()
dataset['Origin'] = dataset['Origin'].map({1: 'USA', 2: 'Europe', 3: 'Japan'})
dataset = pd.get_dummies(dataset, prefix='', prefix_sep='')

train_dataset = dataset.sample(frac=0.8,random_state=0)
test_dataset = dataset.drop(train_dataset.index)

train_stats = train_dataset.describe()
train_stats.pop("MPG")
train_stats = train_stats.transpose()

train_labels = train_dataset.pop('MPG')
test_labels = test_dataset.pop('MPG')

def norm(x):
  return (x - train_stats['mean']) / train_stats['std']

normed_train_data = norm(train_dataset)
normed_test_data = norm(test_dataset)

def build_model_fit():
  model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=[len(train_dataset.keys())]),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)])
  optimizer = tf.keras.optimizers.RMSprop(0.001)
  model.compile(loss='mse',optimizer=optimizer)
  return model

def build_model_tape():
  model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=[len(train_dataset.keys())]),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)])
  opt = tf.keras.optimizers.RMSprop(0.001)
  return model, opt

model_f = build_model_fit()
model_g, opt_g = build_model_tape()

EPOCHS = 20

#Model.fit() - Test
history = model_f.fit(normed_train_data, train_labels, epochs=EPOCHS, verbose=2)

X = tf.convert_to_tensor(normed_train_data.to_numpy())
y_true = tf.convert_to_tensor(train_labels.to_numpy())

#GradientTape - Test
loss_fn = tf.keras.losses.MeanSquaredError()
for i in range(0,EPOCHS):
    with tf.GradientTape() as tape:
        y = model_g(X, training=True)
        loss_value = loss_fn(y_true, y)
    grads = tape.gradient(loss_value, model_g.trainable_variables)
    opt_g.apply_gradients(zip(grads, model_g.trainable_variables))
    print(loss_value)
雅各布

OP在损失值中看到的差异是由于在model.fittf.GradientTape训练循环中使用了不同的批次大小如果未指定batch_size关键字参数to model.fit,则将使用32的批处理大小。tf.GradientTape训练循环中,批次大小等于训练集中的样本数量(即314)。

要解决此问题,请在训练循环中实施批处理。一种方法是使用tf.dataAPI,如下所示。

loss_fn = tf.keras.losses.MeanSquaredError()
for i in range(0,EPOCHS):
    epoch_losses = []
    for x_batch, y_batch in tf.data.Dataset.from_tensor_slices((X, y_true)).batch(32):
        with tf.GradientTape() as tape:
            y = model_g(x_batch, training=True)
            loss_value = loss_fn(y_batch, y)
            epoch_losses.append(loss_value.numpy())
        grads = tape.gradient(loss_value, model_g.trainable_variables)
        opt_g.apply_gradients(zip(grads, model_g.trainable_variables))
    print(np.mean(loss_value))

还要注意,model.fit每次迭代都会对数据进行混洗,而自定义训练循环则不会(需要由开发人员实现)。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

为什么在Tensorflow 2中使用tf.GradientTape进行训练与使用fit API进行训练有不同的行为?

来自分类Dev

tf.GradientTape()的__exit__函数的参数是什么?

来自分类Dev

为什么tf2无法将tf_function模型另存为.pb文件?

来自分类Dev

$ {#$ 2}为什么不起作用?

来自分类Dev

tf.GradientTape不返回渐变

来自分类Dev

为什么TF2的Dataset.map不能像正常的for循环遍历一样?

来自分类Dev

model.predict()和model.fit()有什么作用?

来自分类Dev

为什么Https在EC2上不起作用?

来自分类Dev

为什么我的Prolog谓词invert / 2不起作用?

来自分类Dev

为什么FragmentStateAdapter在ViewPager2中不起作用?

来自分类Dev

为什么RAND()对于doctrine2不起作用

来自分类Dev

为什么phpUnit测试不起作用?ZF2

来自分类Dev

为什么> / dev / null 2>&1不起作用?

来自分类Dev

为什么“ 2>&1”在此管道中不起作用?

来自分类Dev

为什么我的 if 条件不起作用(if (n>2))?

来自分类Dev

为什么在Tensorflow中将batch_size乘以GradientTape结果?

来自分类Dev

tf.GradientTape无法在“ with”块之外观看

来自分类Dev

为什么model.fit_generator()出现归因错误?

来自分类Dev

tf.keras.Model.fit 不训练模型

来自分类Dev

我用 'tf.keras.Sequential()' 构建的模型不起作用,为什么?

来自分类Dev

Active Model Serializer 10缓存。它似乎不起作用。为什么?

来自分类Dev

如果(model.Building.Address == null)不起作用,则无法编译-为什么?

来自分类Dev

为什么没有ng-model的Angular 1表单验证不起作用

来自分类Dev

MySQL为什么cursor.execute(sql,multi = True)不起作用,但是2 cursor.execute(sql)起作用?

来自分类Dev

MySQL为什么cursor.execute(sql,multi = True)不起作用,但是2 cursor.execute(sql)起作用?

来自分类Dev

从.tfrecord到tf.data.Dataset到tf.keras.model.fit

来自分类Dev

tf model.fit()中的batch_size与tf.data.Dataset中的batch_size

来自分类Dev

使用 model.fit() InvalidArgumentError 训练自定义 tf.keras.model

来自分类Dev

为什么Graphics2D.setStoke()对Graphics2D.drawString不起作用?

Related 相关文章

  1. 1

    为什么在Tensorflow 2中使用tf.GradientTape进行训练与使用fit API进行训练有不同的行为?

  2. 2

    tf.GradientTape()的__exit__函数的参数是什么?

  3. 3

    为什么tf2无法将tf_function模型另存为.pb文件?

  4. 4

    $ {#$ 2}为什么不起作用?

  5. 5

    tf.GradientTape不返回渐变

  6. 6

    为什么TF2的Dataset.map不能像正常的for循环遍历一样?

  7. 7

    model.predict()和model.fit()有什么作用?

  8. 8

    为什么Https在EC2上不起作用?

  9. 9

    为什么我的Prolog谓词invert / 2不起作用?

  10. 10

    为什么FragmentStateAdapter在ViewPager2中不起作用?

  11. 11

    为什么RAND()对于doctrine2不起作用

  12. 12

    为什么phpUnit测试不起作用?ZF2

  13. 13

    为什么> / dev / null 2>&1不起作用?

  14. 14

    为什么“ 2>&1”在此管道中不起作用?

  15. 15

    为什么我的 if 条件不起作用(if (n>2))?

  16. 16

    为什么在Tensorflow中将batch_size乘以GradientTape结果?

  17. 17

    tf.GradientTape无法在“ with”块之外观看

  18. 18

    为什么model.fit_generator()出现归因错误?

  19. 19

    tf.keras.Model.fit 不训练模型

  20. 20

    我用 'tf.keras.Sequential()' 构建的模型不起作用,为什么?

  21. 21

    Active Model Serializer 10缓存。它似乎不起作用。为什么?

  22. 22

    如果(model.Building.Address == null)不起作用,则无法编译-为什么?

  23. 23

    为什么没有ng-model的Angular 1表单验证不起作用

  24. 24

    MySQL为什么cursor.execute(sql,multi = True)不起作用,但是2 cursor.execute(sql)起作用?

  25. 25

    MySQL为什么cursor.execute(sql,multi = True)不起作用,但是2 cursor.execute(sql)起作用?

  26. 26

    从.tfrecord到tf.data.Dataset到tf.keras.model.fit

  27. 27

    tf model.fit()中的batch_size与tf.data.Dataset中的batch_size

  28. 28

    使用 model.fit() InvalidArgumentError 训练自定义 tf.keras.model

  29. 29

    为什么Graphics2D.setStoke()对Graphics2D.drawString不起作用?

热门标签

归档