关于 Tensorflow 和 PyTorch 中的自定义操作

sanjeev mk

我要实现的能量函数,称为刚性能源,如本文的公式7这里
能量函数将两个 3D 对象网格作为输入,并返回它们之间的能量。第一个网格是源网格,第二个网格是源网格的变形版本。在粗略的伪代码中,计算过程如下:

迭代源网格中的所有顶点。

  1. 对于每个顶点,计算其与其相邻顶点的协方差矩阵。
  2. 对计算出的协方差矩阵执行 SVD 并找到顶点的旋转矩阵。
  3. 使用计算出的旋转矩阵、原始网格中的点坐标和变形网格中的相应坐标来计算顶点的能量偏差。

因此,这个能量函数要求我迭代网格中的每个点,并且网格可能有超过 2k 个这样的点。在 Tensorflow 中,有两种方法可以做到这一点。我可以有 2 个形状 (N,3) 的张量,一个代表源点,另一个代表变形网格。

  1. 纯粹使用 Tensorflow 张量来做。也就是说,tf.gather使用现有的 TF 操作迭代上述张量的元素并在每个点上执行计算。这种方法,会非常慢。我之前曾尝试定义迭代超过 1000 个点的损失函数,而图构建本身需要太多时间而无法实用。
  2. 按照此处的 TF 文档中的说明添加新的 TF OP 这涉及在 CPP(和 Cuda,用于 GPU 支持)中编写函数,并使用 TF 注册新 OP。

第一种方法很容易编写,但速度很慢。第二种方法写起来很痛苦。

我已经使用 TF 3 年了,之前从未使用过 PyTorch,但此时我正在考虑切换到它,如果它为这种情况提供了更好的替代方案。

PyTorch 是否有一种方法可以轻松实现此类损失函数,并且执行速度与在 GPU 上一样快。即,编写我自己的在 GPU 上运行的损失函数的 pythonic 方式,我没有任何 C 或 Cuda 代码?

生天烧

据我了解,您本质上是在问这个操作是否可以矢量化。答案是否定的,至少不完全是,因为PyTorch 中的svd实现不是矢量化的。

如果您展示了 tensorflow 实现,它将有助于理解您的起点。我不知道你通过找到顶点的旋转矩阵是什么意思,但我猜这可以被矢量化。这意味着 svd 是唯一的非矢量化操作,您也许可以只编写一个自定义 OP,即矢量化 svd - 这可能很容易,因为它相当于在循环中调用一些库例程在 C++ 中。

我看到的两个可能的问题来源是

  1. 如果N(i)等式 7 中的邻域的大小可能有显着差异(这意味着协方差矩阵的大小不同,并且矢量化需要一些肮脏的技巧)
  2. 处理网格和邻域的一般问题可能很困难。这是不规则网格的固有属性,但 PyTorch 支持稀疏矩阵和专用包torch_geometry,至少有帮助。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

我可以使用现有的操作(例如conv2d和张量操作)在python中的tensorflow中编写自定义层吗?

来自分类Dev

如何在 tensorflow 或 pytorch 中使用自定义权重初始化创建自定义神经网络

来自分类Dev

tensorflow lite 添加自定义操作

来自分类Dev

Pytorch中的dim和Tensorflow中的axis有什么区别?

来自分类Dev

如何保存和使用在 PyTorch/TensorFlow/Keras 中开发的训练好的神经网络?

来自分类Dev

Woocommerce 中的自定义产品模板和操作挂钩

来自分类Dev

Boost.Log:关于文件轮换的自定义操作

来自分类Dev

关于OAUTH2和自定义身份提供者的困惑

来自分类Dev

动态链接和托管关于自定义域是否正确连接存在分歧

来自分类Dev

在Pytorch内置的自定义Batchnorm中更新running_mean和running_var问题吗?

来自分类Dev

在 Pytorch 中优化自定义参数

来自分类Dev

自定义conv2d操作Pytorch

来自分类Dev

使用PyTorch和TorchVision对自定义数据集进行训练有效测试拆分

来自分类Dev

tensorflow在pyTorch中的时间分布等效项

来自分类Dev

关于android定义和概念的帮助

来自分类Dev

关于信息和熵定义的性质

来自分类Dev

TensorFlow中卷积的自定义填充

来自分类Dev

自定义圈关于反应原生中的内容长度

来自分类Dev

关于 Swift 中自定义开关按钮的问题

来自分类Dev

关于JavaScript中的绑定和调用的困惑

来自分类Dev

关于PHP中的内存和变量分配

来自分类Dev

关于Java中的泛型和列表

来自分类Dev

Haskell中关于<$>和<*>的优先混淆

来自分类Dev

关于c中的*(星号)和++的混淆

来自分类Dev

关于计算r中的行和列

来自分类Dev

关于MySQL和PostgreSQL中的子查询

来自分类Dev

关于C#中的事件和代表

来自分类Dev

关于C ++中的i ++和++ i

来自分类Dev

关于UNION和MySQL中的JOIN

Related 相关文章

  1. 1

    我可以使用现有的操作(例如conv2d和张量操作)在python中的tensorflow中编写自定义层吗?

  2. 2

    如何在 tensorflow 或 pytorch 中使用自定义权重初始化创建自定义神经网络

  3. 3

    tensorflow lite 添加自定义操作

  4. 4

    Pytorch中的dim和Tensorflow中的axis有什么区别?

  5. 5

    如何保存和使用在 PyTorch/TensorFlow/Keras 中开发的训练好的神经网络?

  6. 6

    Woocommerce 中的自定义产品模板和操作挂钩

  7. 7

    Boost.Log:关于文件轮换的自定义操作

  8. 8

    关于OAUTH2和自定义身份提供者的困惑

  9. 9

    动态链接和托管关于自定义域是否正确连接存在分歧

  10. 10

    在Pytorch内置的自定义Batchnorm中更新running_mean和running_var问题吗?

  11. 11

    在 Pytorch 中优化自定义参数

  12. 12

    自定义conv2d操作Pytorch

  13. 13

    使用PyTorch和TorchVision对自定义数据集进行训练有效测试拆分

  14. 14

    tensorflow在pyTorch中的时间分布等效项

  15. 15

    关于android定义和概念的帮助

  16. 16

    关于信息和熵定义的性质

  17. 17

    TensorFlow中卷积的自定义填充

  18. 18

    自定义圈关于反应原生中的内容长度

  19. 19

    关于 Swift 中自定义开关按钮的问题

  20. 20

    关于JavaScript中的绑定和调用的困惑

  21. 21

    关于PHP中的内存和变量分配

  22. 22

    关于Java中的泛型和列表

  23. 23

    Haskell中关于<$>和<*>的优先混淆

  24. 24

    关于c中的*(星号)和++的混淆

  25. 25

    关于计算r中的行和列

  26. 26

    关于MySQL和PostgreSQL中的子查询

  27. 27

    关于C#中的事件和代表

  28. 28

    关于C ++中的i ++和++ i

  29. 29

    关于UNION和MySQL中的JOIN

热门标签

归档