pytorch 如何通过 argmax 反向传播?

果酱图

我正在使用质心位置的梯度下降而不是期望最大化在 pytorch 中构建 Kmeans。损失是每个点到其最近质心的平方距离之和。为了确定哪个质心离每个点最近,我使用 argmin,它在任何地方都不可微。然而,pytorch 仍然能够反向传播和更新权重(质心位置),在数据上提供与 sklearn kmeans 相似的性能。

任何想法这是如何工作的,或者我如何在pytorch中解决这个问题?关于 pytorch github 的讨论表明 argmax 是不可微的:https : //github.com/pytorch/pytorch/issues/1339

下面的示例代码(随机点):

import numpy as np
import torch

num_pts, batch_size, n_dims, num_clusters, lr = 1000, 100, 200, 20, 1e-5

# generate random points
vector = torch.from_numpy(np.random.rand(num_pts, n_dims)).float()

# randomly pick starting centroids
idx = np.random.choice(num_pts, size=num_clusters)
kmean_centroids = vector[idx][:,None,:] # [num_clusters,1,n_dims]
kmean_centroids = torch.tensor(kmean_centroids, requires_grad=True)

for t in range(4001):
    # get batch
    idx = np.random.choice(num_pts, size=batch_size)
    vector_batch = vector[idx]

    distances = vector_batch - kmean_centroids # [num_clusters, #pts, #dims]
    distances = torch.sum(distances**2, dim=2) # [num_clusters, #pts]

    # argmin
    membership = torch.min(distances, 0)[1] # [#pts]

    # cluster distances
    cluster_loss = 0
    for i in range(num_clusters):
        subset = torch.transpose(distances,0,1)[membership==i]
        if len(subset)!=0: # to prevent NaN
            cluster_loss += torch.sum(subset[:,i])

    cluster_loss.backward()
    print(cluster_loss.item())

    with torch.no_grad():
        kmean_centroids -= lr * kmean_centroids.grad
        kmean_centroids.grad.zero_()
生天烧

正如 alvas 在评论中指出的那样,argmax是不可区分的。然而,一旦你计算它并将每个数据点分配给一个集群,损失相对于这些集群位置的导数是明确定义的。这就是你的算法所做的。

为什么有效?如果您只有一个集群(因此argmax操作无关紧要),您的损失函数将是二次的,数据点的平均值为最小值。现在有了多个集群,您可以看到您的损失函数是分段的(在更高维度上认为是体积)二次 - 对于任何质心集,[C1, C2, C3, ...]每个数据点都分配给某个质心CN,并且损失是局部二次的。该局部性的范围由所有替代质心[C1', C2', C3', ...]给出,其分配来自argmax保持不变;在这个区域内,argmax可以被视为一个常数,而不是一个函数,因此 的导数loss是明确定义的。

现在,实际上,您不太可能将其argmax视为常数,但您仍然可以将朴素的“argmax-is-a-constant”导数视为近似指向最小值,因为大多数数据点可能确实属于迭代之间的相同集群。一旦接近局部最小值,点不再改变它们的分配,过程就可以收敛到最小值。

另一种更理论的看待它的方法是你正在做一个期望最大化的近似。通常,您会有“计算分配”步骤,它由 反映argmax,而“最小化”步骤归结为在给定当前分配的情况下找到最小化的聚类中心。最小值由 给出d(loss)/d([C1, C2, ...]) == 0,对于二次损失,通过每个集群内的数据点分析给出。在您的实现中,您正在求解相同的方程,但使用梯度下降步骤。事实上,如果您使用二阶 (Newton) 更新方案而不是一阶梯度下降,您将隐式地精确复制基线 EM 方案。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

PyTorch中的截断的反向传播(代码检查)

来自分类Dev

通过nginx反向传播gitbucket

来自分类Dev

通过多次向前传播进行反向传播

来自分类Dev

通过rnn ValueError进行Tensorflow反向传播

来自分类Dev

反向传播中的Gradient函数如何工作?

来自分类Dev

如何在 Pytorch 中进行反向传播(autograd.backward(loss) vs loss.backward())以及在哪里设置 requires_grad=True?

来自分类Dev

通过梯度下降的反向传播如何表示每次前向传播后的误差

来自分类Dev

通过嵌套tf.map_fn反向传播渐变

来自分类Dev

反向传播如何在割炬7中工作?

来自分类Dev

有效地找到存储为PyTorch张量的两个向量列表的点积并保留反向传播

来自分类Dev

如何用sympy计算argmax?

来自分类Dev

如何用sympy计算argmax?

来自分类Dev

使用SGD的神经网络能否通过反向传播更改多个输出中的一个?

来自分类Dev

通过跨通道本地响应归一化(LRN)层的反向传播算法

来自分类Dev

如何通过链表反向

来自分类Dev

神经网络如何使用遗传算法和反向传播玩游戏?

来自分类Dev

当需要两次反向传播时,如何避免重新计算一个函数?

来自分类Dev

给定输出节点和权重上的错误,如何使用反向传播更新隐藏节点上的错误

来自分类Dev

反向传播算法的实现

来自分类Dev

调试反向传播算法

来自分类Dev

澄清反向传播

来自分类Dev

反向传播错误

来自分类Dev

伯特的反向传播

来自分类Dev

Matlab GPU反向传播

来自分类Dev

正向与反向模式差异-Pytorch

来自分类Dev

如何找到最近2个轴的argmax

来自分类Dev

Python:如何循环 10 个,然后反向传播,然后循环下一个 10 个

来自分类Dev

如何从多维张量中提取值而不丢失反向信息-PyTorch

来自分类Dev

脾气暴躁的argmax。如何同时计算max和argmax?

Related 相关文章

  1. 1

    PyTorch中的截断的反向传播(代码检查)

  2. 2

    通过nginx反向传播gitbucket

  3. 3

    通过多次向前传播进行反向传播

  4. 4

    通过rnn ValueError进行Tensorflow反向传播

  5. 5

    反向传播中的Gradient函数如何工作?

  6. 6

    如何在 Pytorch 中进行反向传播(autograd.backward(loss) vs loss.backward())以及在哪里设置 requires_grad=True?

  7. 7

    通过梯度下降的反向传播如何表示每次前向传播后的误差

  8. 8

    通过嵌套tf.map_fn反向传播渐变

  9. 9

    反向传播如何在割炬7中工作?

  10. 10

    有效地找到存储为PyTorch张量的两个向量列表的点积并保留反向传播

  11. 11

    如何用sympy计算argmax?

  12. 12

    如何用sympy计算argmax?

  13. 13

    使用SGD的神经网络能否通过反向传播更改多个输出中的一个?

  14. 14

    通过跨通道本地响应归一化(LRN)层的反向传播算法

  15. 15

    如何通过链表反向

  16. 16

    神经网络如何使用遗传算法和反向传播玩游戏?

  17. 17

    当需要两次反向传播时,如何避免重新计算一个函数?

  18. 18

    给定输出节点和权重上的错误,如何使用反向传播更新隐藏节点上的错误

  19. 19

    反向传播算法的实现

  20. 20

    调试反向传播算法

  21. 21

    澄清反向传播

  22. 22

    反向传播错误

  23. 23

    伯特的反向传播

  24. 24

    Matlab GPU反向传播

  25. 25

    正向与反向模式差异-Pytorch

  26. 26

    如何找到最近2个轴的argmax

  27. 27

    Python:如何循环 10 个,然后反向传播,然后循环下一个 10 个

  28. 28

    如何从多维张量中提取值而不丢失反向信息-PyTorch

  29. 29

    脾气暴躁的argmax。如何同时计算max和argmax?

热门标签

归档