试图了解可计算Torch中LogSoftMax输入的梯度wrt的代码

拉尔斯

代码来自:https : //github.com/torch/nn/blob/master/lib/THNN/generic/LogSoftMax.c

我看不到这段代码如何计算到LogSoftMax模块输入的梯度wrt。我很困惑的是两个for循环在做什么。

for (t = 0; t < nframe; t++)
{
sum = 0;
gradInput_data = gradInput_data0 + dim*t;
output_data = output_data0 + dim*t;
gradOutput_data = gradOutput_data0 + dim*t;

for (d = 0; d < dim; d++)
  sum += gradOutput_data[d];

for (d = 0; d < dim; d++)
  gradInput_data[d] = gradOutput_data[d] - exp(output_data[d])*sum;
 }
}
地尔

在前进时间,我们有(其中x =输入向量,y =输出向量,f = logsoftmax,i =第i个分量):

yi = f(xi)
   = log( exp(xi) / sum_j(exp(xj)) )
   = xi - log( sum_j(exp(xj)) )

计算f的雅可比式Jf时(第i行):

dyi/dxi = 1 - exp(xi) / sum_j(exp(xj))

对于与我不同的k:

dyi/dxk = - exp(xk) / sum_j(exp(xj))

这给Jf:

1-E(x1)     -E(x2)     -E(x3)    ...
 -E(x1)    1-E(x2)     -E(x3)    ...
 -E(x1)     -E(x2)    1-E(x3)    ...
...

E(xi) = exp(xi) / sum_j(exp(xj))

如果我们将gradInput命名为梯度输入,而gradOutput命名为梯度输出,则反向传播给出(链式规则):

gradInputi = sum_j( gradOutputj . dyj/dxi )

这等效于:

gradInput = transpose(Jf) . gradOutput

最后给出第i个组件:

gradInputi = gradOutputi - E(xi) . sum_j( gradOutputj )

因此,第一个循环进行预计算sum_j( gradOutputj ),最后一个循环进行上述计算,即grad的第i个分量。输入-除了1 / sum_j(exp(xj))在Torch实现中缺少指数项之外(上面的演算可能听起来应该正确并解释了当前的实现),也应该仔细检查

更新缺少 1 / sum_j(exp(xj))术语没有问题由于jacobian是根据输出计算的,并且由于此先前计算的输出正好是log-softmax分布,因此该分布的sum-exp为1:

sum_j(exp(outputj)) = sum_j(exp( log(exp(inputj) / sum_k(exp(inputk) ))
                    = sum_j(         exp(inputj) / sum_k(exp(inputk)  )
                    = 1

因此,无需在实现中明确显示该术语,它给出了(对于x =输出):

gradInputi = gradOutputi - exp(outputi) . sum_j( gradOutputj )

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

如何在Tensorflow 2.0中计算输出WRT输入的梯度

来自分类Dev

使用Keras / Tensorflow或autograd计算验证误差wrt输入的梯度

来自分类Dev

试图了解我的代码中的问题

来自分类Dev

试图了解示例代码中的优先顺序

来自分类Dev

试图了解示例代码中的优先顺序

来自分类Dev

铰链损失函数梯度 wrt 输入预测

来自分类Dev

多个输出维度的 Keras 梯度 wrt 输入

来自分类Dev

了解Numpy中的梯度下降算法的梯度

来自分类Dev

在tensorflow.js中,如何计算模型输入的梯度?

来自分类Dev

试图了解此AffineTransform代码以在y中翻转图像

来自分类Dev

试图了解简单的大数计算

来自分类Dev

试图了解此代码块

来自分类Dev

试图了解此Python代码

来自分类Dev

试图了解此代码行为

来自分类Dev

梯度下降计算中的误差

来自分类Dev

Tensorflow 2.0中的梯度计算

来自分类Dev

梯度下降计算中的误差

来自分类Dev

Julia中的并行梯度计算

来自分类Dev

Keras,计算 LSTM 上输入的损失梯度

来自分类Dev

试图从用户C ++输入的代码中删除注释

来自分类Dev

试图了解计算机总线的图片

来自分类Dev

试图了解.net中的任务

来自分类Dev

试图了解R中的cdplot

来自分类Dev

试图了解javascript中的for循环

来自分类Dev

试图了解此代码对$ PSBoundParameters对象的作用

来自分类Dev

试图了解RSA加密代码示例

来自分类Dev

试图了解链表插入功能的代码

来自分类Dev

试图了解代码的最后一行

来自分类Dev

试图了解RSA加密代码示例