在TensorFlow中批量访问单个梯度的最佳方法是什么?

它的马沃洛

我目前正在分析在使用Tensorflow 2.x训练CNN的过程中渐变如何发展。我想要做的是将批次中的每个渐变与整个批次中的渐变进行比较。目前,我在每个训练步骤中都使用了以下简单代码段:

[...]
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
[...]

# One training step
# x_train is a batch of input data, y_train the corresponding labels
def train_step(model, optimizer, x_train, y_train):

    # Process batch
    with tf.GradientTape() as tape:
        batch_predictions = model(x_train, training=True)
        batch_loss = loss_object(y_train, batch_predictions)
    batch_grads = tape.gradient(batch_loss, model.trainable_variables)
    # Do something with gradient of whole batch
    # ...

    # Process each data point in the current batch
    for index in range(len(x_train)):
        with tf.GradientTape() as single_tape:
            single_prediction = model(x_train[index:index+1], training=True)
            single_loss = loss_object(y_train[index:index+1], single_prediction)
        single_grad = single_tape.gradient(single_loss, model.trainable_variables)
        # Do something with gradient of single data input
        # ...

    # Use batch gradient to update network weights
    optimizer.apply_gradients(zip(batch_grads, model.trainable_variables))

    train_loss(batch_loss)
    train_accuracy(y_train, batch_predictions)

我的主要问题是,单手计算每个梯度时,计算时间会激增,尽管在计算批次的梯度时,Tensorflow应该已经进行了这些计算。原因是无论是否给出单个或多个数据点GradientTapecompute_gradients总是返回单个梯度。因此,必须对每个数据点进行此计算。

我知道我可以通过使用为每个数据点计算的所有单个梯度来计算批次的梯度以更新网络,但这在节省计算时间方面仅起很小的作用。

有没有更有效的方法来计算单个梯度?

Jdehesa

您可以使用jacobian梯度带方法来获取雅可比矩阵,该矩阵将为您提供每个单个损耗值的梯度:

import tensorflow as tf

# Make a random linear problem
tf.random.set_seed(0)
# Random input batch of ten four-vector examples
x = tf.random.uniform((10, 4))
# Random weights
w = tf.random.uniform((4, 2))
# Random batch label
y = tf.random.uniform((10, 2))
with tf.GradientTape() as tape:
    tape.watch(w)
    # Prediction
    p = x @ w
    # Loss
    loss = tf.losses.mean_squared_error(y, p)
# Compute Jacobian
j = tape.jacobian(loss, w)
# The Jacobian gives you the gradient for each loss value
print(j.shape)
# (10, 4, 2)
# Gradient of the loss wrt the weights for the first example
tf.print(j[0])
# [[0.145728424 0.0756840706]
#  [0.103099883 0.0535449386]
#  [0.267220169 0.138780832]
#  [0.280130595 0.145485848]]

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

在 SQL Server 中批量插入数据的最佳方法是什么?

来自分类Dev

在MySQL DB中存储单个值的最佳方法是什么?

来自分类Dev

在 Sass 中重复单个声明的最佳方法是什么?

来自分类Dev

在Java类中访问实例变量的最佳方法是什么?

来自分类Dev

在JPA中批量删除行同时将删除级联到子记录的最佳方法是什么

来自分类Dev

在angularjs中实现单个三态复选框的最佳方法是什么?

来自分类Dev

在bash中评估代表单个管道命令的两个变量的最佳方法是什么?

来自分类Dev

根据Python中的模式将单个列表分成多个列表的最佳方法是什么?

来自分类Dev

在C ++中,动态分配单个类的最佳方法是什么?

来自分类Dev

使用单个 http 请求在 varnish 4 中填充/刷新 url 的最佳方法是什么?

来自分类Dev

从List <>获取单个随机元素的最佳方法是什么?

来自分类Dev

在ASP MVC身份中,基于条件限制访问的最佳方法是什么?

来自分类Dev

在django中,以OnetoMany关系访问数据以输出到模板的最佳方法是什么?

来自分类Dev

访问python数据框中的前几行的最佳方法是什么?

来自分类Dev

在Yii框架中,登录用户访问的数据库URL的最佳方法是什么?

来自分类Dev

使控制器中的DbContext类无法访问的最佳方法是什么?

来自分类Dev

在ASP MVC身份中,基于条件限制访问的最佳方法是什么?

来自分类Dev

从保存为哈希值的二维数组中访问元素的最佳方法是什么?

来自分类Dev

在Java中IPC的最佳方法是什么?

来自分类Dev

在日志中搜索的最佳方法是什么?

来自分类Dev

将单个浮点值更新到GPU以在CUDA内核中访问它的最快方法是什么?

来自分类Dev

梯度下降的替代方法是什么?

来自分类Dev

在春季访问AOP代理的最佳方法是什么?

来自分类Dev

访问Windows安装过程的最佳方法是什么?

来自分类Dev

限制访问成人内容的最佳方法是什么?

来自分类Dev

限制对继承属性的访问的最佳方法是什么?

来自分类Dev

管理每个用户的访问日志的最佳方法是什么?

来自分类Dev

在Flutter中运行tensorflow-lite的最佳选择是什么?

来自分类Dev

将超长html代码段注入到单个页面应用程序中的最佳方法是什么?

Related 相关文章

  1. 1

    在 SQL Server 中批量插入数据的最佳方法是什么?

  2. 2

    在MySQL DB中存储单个值的最佳方法是什么?

  3. 3

    在 Sass 中重复单个声明的最佳方法是什么?

  4. 4

    在Java类中访问实例变量的最佳方法是什么?

  5. 5

    在JPA中批量删除行同时将删除级联到子记录的最佳方法是什么

  6. 6

    在angularjs中实现单个三态复选框的最佳方法是什么?

  7. 7

    在bash中评估代表单个管道命令的两个变量的最佳方法是什么?

  8. 8

    根据Python中的模式将单个列表分成多个列表的最佳方法是什么?

  9. 9

    在C ++中,动态分配单个类的最佳方法是什么?

  10. 10

    使用单个 http 请求在 varnish 4 中填充/刷新 url 的最佳方法是什么?

  11. 11

    从List <>获取单个随机元素的最佳方法是什么?

  12. 12

    在ASP MVC身份中,基于条件限制访问的最佳方法是什么?

  13. 13

    在django中,以OnetoMany关系访问数据以输出到模板的最佳方法是什么?

  14. 14

    访问python数据框中的前几行的最佳方法是什么?

  15. 15

    在Yii框架中,登录用户访问的数据库URL的最佳方法是什么?

  16. 16

    使控制器中的DbContext类无法访问的最佳方法是什么?

  17. 17

    在ASP MVC身份中,基于条件限制访问的最佳方法是什么?

  18. 18

    从保存为哈希值的二维数组中访问元素的最佳方法是什么?

  19. 19

    在Java中IPC的最佳方法是什么?

  20. 20

    在日志中搜索的最佳方法是什么?

  21. 21

    将单个浮点值更新到GPU以在CUDA内核中访问它的最快方法是什么?

  22. 22

    梯度下降的替代方法是什么?

  23. 23

    在春季访问AOP代理的最佳方法是什么?

  24. 24

    访问Windows安装过程的最佳方法是什么?

  25. 25

    限制访问成人内容的最佳方法是什么?

  26. 26

    限制对继承属性的访问的最佳方法是什么?

  27. 27

    管理每个用户的访问日志的最佳方法是什么?

  28. 28

    在Flutter中运行tensorflow-lite的最佳选择是什么?

  29. 29

    将超长html代码段注入到单个页面应用程序中的最佳方法是什么?

热门标签

归档