Tensorflow可以计算出积分逼近的梯度吗?

鹅卵石

我正在尝试使用哈密顿量蒙特卡罗(HMC),来自Tensorflow概率,但我的目标分布包含一个难解的一维积分,我用梯形法则对此进行了近似。我对HMC的理解是,它可以计算目标分布的梯度以构建更有效的过渡内核。我的问题是Tensorflow是否可以根据函数的参数计算出梯度,它们是否有意义?

例如,这是目标分布的对数概率,其中“ A”是模型参数:

# integrate e^At * f[t] with respect to t between 0 and t, for all t

t = tf.linspace(0., 10., 100)
f = tf.ones(100)
delta = t[1]-t[0]
sum_term = tfm.multiply(tfm.exp(A*t), f)
integrals = 0.5*delta*tfm.cumsum(sum_term[:-1] + sum_term[1:], axis=0) 
pred = integrals
sq_diff = tfm.square(observed_data - pred)
sq_diff = tf.reduce_sum(sq_diff, axis=0)
log_lik = -0.5*tfm.log(2*PI*variance) - 0.5*sq_diff/variance
return log_lik

以A为单位的该函数的梯度有意义吗?

宫殿火车

是的,您可以使用tensorflow GradientTape计算出梯度。我假设您有一个log_lik带有许多输入的数学函数输出,其中之一是A

GradientTape获取A的渐变

该得到的梯度log_lik相对于A,您可以使用tf.GradientTape在tensorflow

例如:

with tf.GradientTape(persistent=True) as g:
  g.watch(A)

  t = tf.linspace(0., 10., 100)
  f = tf.ones(100)
  delta = t[1]-t[0]
  sum_term = tfm.multiply(tfm.exp(A*t), f)
  integrals = 0.5*delta*tfm.cumsum(sum_term[:-1] + sum_term[1:], axis=0) 
  pred = integrals
  sq_diff = tfm.square(observed_data - pred)
  sq_diff = tf.reduce_sum(sq_diff, axis=0)
  log_lik = -0.5*tfm.log(2*PI*variance) - 0.5*sq_diff/variance
  z = log_lik

## then, you can get the gradients of log_lik with respect to A like this
dz_dA = g.gradient(z, A)

dz_dA 包含变量中的所有部分导数 A

我只是通过上面的代码向您展示这个想法。为了使它起作用,您需要通过Tensor操作进行计算。因此更改以修改您的函数以使用张量类型进行计算

另一个例子但在张量运算中

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x
  dy_dx = gg.gradient(y, x)     # Will compute to 6.0
d2y_dx2 = g.gradient(dy_dx, x)  # Will compute to 2.0

在这里,您可以看到文档中的更多示例,以了解更多https://www.tensorflow.org/api_docs/python/tf/GradientTape

关于“有意义”的进一步讨论

首先让我将python代码转换为数学(我使用https://www.codecogs.com/latex/eqneditor.php,希望它可以正确显示):

# integrate e^At * f[t] with respect to t between 0 and t, for all t

从上面开始,这意味着您具有功能。我叫它g(t, A)

然后,您正在做一个积分。我叫它G(t,A)

从您的代码开始,t不再是变量,将其设置为10。因此,我们简化为仅具有一个变量的函数h(A)

到目前为止,函数h内部具有确定的积分。但是,由于您是近似值,因此我们不应该将其视为实积分(dt-> 0),这只是简单数学的另一链这里没有神秘。

然后,最后一个输出log_lik(它只是一些带有一个新输入变量的简单数学运算)被observed_data称为y

然后z计算的函数log_lik是:

z与张量流中的其他普通数学运算链没有什么不同。因此,dz_dAzwrtA的梯度为A您提供可以最小化的更新梯度的意义上讲,这是有意义的z

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

有什么方法可以计算出卫星(ISS)的视觉大小?

来自分类Dev

Solr可以从子字段中检索值,并将计算出的值与父结果相关联吗?

来自分类Dev

Golang Fibonacci计算出现

来自分类Dev

计算出现次数

来自分类Dev

证明旅行商(TSP)的2倍最佳逼近算法无法计算出最佳解

来自分类Dev

在SQLAlchemy中打印计算出的距离

来自分类Dev

是否可以计算出卢森堡的“当前半小时”的开始时间?

来自分类Dev

如何使用大熊猫积分计算出高于功率值的能量?

来自分类Dev

是否在层上添加前向钩子以确保使用该层的输出计算出的损耗梯度会自动计算出来?

来自分类Dev

还有其他方法可以将计算出的信息添加到表中吗?

来自分类Dev

JS中的Rem单位是计算出来的吗?

来自分类Dev

有什么方法可以计算出卫星(ISS)的视觉大小?

来自分类Dev

java的,当不是数组时,将计算出的值打印到文件中吗?

来自分类Dev

如何在python 3.3.4中编写可以计算出矩形区域的程序?

来自分类Dev

从数组计算出的KO可更新

来自分类Dev

可以从Android密钥库的指纹中计算出证书吗?

来自分类Dev

Ember的“观察”会引发断言错误,而计算出的属性不会吗?

来自分类Dev

ExtJs4 +:计算出的模型字段可以编辑吗?

来自分类Dev

使用grep计算出现的总数

来自分类Dev

FPGA上的乘法器功能需要多长时间?可以计算出这个时间吗?

来自分类Dev

如何计算出勤率

来自分类Dev

是否可以将通过条件计算出的值插入到mysql中?

来自分类Dev

在行中显示计算出的度量?

来自分类Dev

ALU计算出的Mips架构地址

来自分类Dev

我可以在不调用对象函数的情况下获得(可能是计算出的)属性值吗?

来自分类Dev

XSL-计算出的金额总和

来自分类Dev

Makefile:计算出的变量名

来自分类Dev

是否可以在 django REST 序列化程序中添加计算出的超链接?

来自分类Dev

无法计算出可以得到我想要的结果的 SQL 查询

Related 相关文章

  1. 1

    有什么方法可以计算出卫星(ISS)的视觉大小?

  2. 2

    Solr可以从子字段中检索值,并将计算出的值与父结果相关联吗?

  3. 3

    Golang Fibonacci计算出现

  4. 4

    计算出现次数

  5. 5

    证明旅行商(TSP)的2倍最佳逼近算法无法计算出最佳解

  6. 6

    在SQLAlchemy中打印计算出的距离

  7. 7

    是否可以计算出卢森堡的“当前半小时”的开始时间?

  8. 8

    如何使用大熊猫积分计算出高于功率值的能量?

  9. 9

    是否在层上添加前向钩子以确保使用该层的输出计算出的损耗梯度会自动计算出来?

  10. 10

    还有其他方法可以将计算出的信息添加到表中吗?

  11. 11

    JS中的Rem单位是计算出来的吗?

  12. 12

    有什么方法可以计算出卫星(ISS)的视觉大小?

  13. 13

    java的,当不是数组时,将计算出的值打印到文件中吗?

  14. 14

    如何在python 3.3.4中编写可以计算出矩形区域的程序?

  15. 15

    从数组计算出的KO可更新

  16. 16

    可以从Android密钥库的指纹中计算出证书吗?

  17. 17

    Ember的“观察”会引发断言错误,而计算出的属性不会吗?

  18. 18

    ExtJs4 +:计算出的模型字段可以编辑吗?

  19. 19

    使用grep计算出现的总数

  20. 20

    FPGA上的乘法器功能需要多长时间?可以计算出这个时间吗?

  21. 21

    如何计算出勤率

  22. 22

    是否可以将通过条件计算出的值插入到mysql中?

  23. 23

    在行中显示计算出的度量?

  24. 24

    ALU计算出的Mips架构地址

  25. 25

    我可以在不调用对象函数的情况下获得(可能是计算出的)属性值吗?

  26. 26

    XSL-计算出的金额总和

  27. 27

    Makefile:计算出的变量名

  28. 28

    是否可以在 django REST 序列化程序中添加计算出的超链接?

  29. 29

    无法计算出可以得到我想要的结果的 SQL 查询

热门标签

归档