如何使PyTorch热图功能更快,更有效?

ProGamerGov

我有这个函数可以为2d张量创建排序热图,但是当使用更大的张量输入时,它的速度很慢。如何加快速度并提高效率?

import torch
import numpy as np
import matplotlib.pyplot as plt


def heatmap(
    tensor: torch.Tensor,
) -> torch.Tensor:
    assert tensor.dim() == 2

    def color_tensor(x: torch.Tensor) -> torch.Tensor:
        if x < 0:
            x = -x
            if x < 0.5:
                x = x * 2
                return (1 - x) * torch.tensor(
                    [0.9686, 0.9686, 0.9686]
                ) + x * torch.tensor([0.5725, 0.7725, 0.8706])
            else:
                x = (x - 0.5) * 2
                return (1 - x) * torch.tensor(
                    [0.5725, 0.7725, 0.8706]
                ) + x * torch.tensor([0.0196, 0.4431, 0.6902])
        else:
            if x < 0.5:
                x = x * 2
                return (1 - x) * torch.tensor(
                    [0.9686, 0.9686, 0.9686]
                ) + x * torch.tensor([0.9569, 0.6471, 0.5098])
            else:
                x = (x - 0.5) * 2
                return (1 - x) * torch.tensor(
                    [0.9569, 0.6471, 0.5098]
                ) + x * torch.tensor([0.7922, 0.0000, 0.1255])

    return torch.stack(
        [torch.stack([color_tensor(x) for x in t]) for t in tensor]
    ).permute(2, 0, 1)

x = torch.randn(3,3)
x = x / x.max()
x_out = heatmap(x)

x_out = (x_out.permute(1, 2, 0) * 255).numpy()
plt.imshow(x_out.astype(np.uint8))
plt.axis("off")
plt.show()

输出示例:

在此处输入图片说明

阿玛穆特

您需要摆脱ifs和for循环,并制作向量化函数。为此,您可以使用遮罩并一并计算。这里是:


def heatmap(tensor: torch.Tensor) -> torch.Tensor:
    assert tensor.dim() == 2

    # We're expanding to create one more dimension, for mult. to work.
    xt = x.expand((3, x.shape[0], x.shape[1])).permute(1, 2, 0)

    # this part is the mask: (xt >= 0) * (xt < 0.5) ...
    # ... the rest is the original function translated
    color_tensor = (
        (xt >= 0) * (xt < 0.5) * ((1 - xt * 2) * torch.tensor([0.9686, 0.9686, 0.9686]) + xt * 2 * torch.tensor([0.9569, 0.6471, 0.5098]))
        +
        (xt >= 0) * (xt >= 0.5) * ((1 - (xt - 0.5) * 2) * torch.tensor([0.9569, 0.6471, 0.5098]) + (xt - 0.5) * 2 * torch.tensor([0.7922, 0.0000, 0.1255]))
        +
        (xt < 0) * (xt > -0.5) * ((1 - (-xt * 2)) * torch.tensor([0.9686, 0.9686, 0.9686]) + (-xt * 2) * torch.tensor([0.5725, 0.7725, 0.8706]))
        +
        (xt < 0) * (xt <= -0.5) * ((1 - (-xt - 0.5) * 2) * torch.tensor([0.5725, 0.7725, 0.8706]) + (-xt - 0.5) * 2 * torch.tensor([0.0196, 0.4431, 0.6902]))
    ).permute(2, 0, 1)
    
    return color_tensor

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

更快/更有效的绑定方式?

来自分类Dev

如何实现更有效的搜索功能?

来自分类Dev

如何使Rust功能更通用,更有效?

来自分类Dev

选择功能的更有效方法

来自分类Dev

使排列功能更有效

来自分类Dev

php | 更有效的功能

来自分类Dev

什么是更有效/更快的rsync压缩或ssh压缩?

来自分类Dev

使大量的单个行更新更快或更有效

来自分类Dev

更快,更有效的方式更改日期格式

来自分类Dev

流如何更有效?

来自分类Dev

如何使此查询更有效?

来自分类Dev

如何使循环更有效?

来自分类Dev

如何使Listview更有效?

来自分类Dev

如何使GridView更有效?

来自分类Dev

如何使此循环更有效?

来自分类Dev

如何使for内部循环更有效?

来自分类Dev

更有效的Matplotlib堆叠条形图-如何计算底值

来自分类Dev

更有效的Matplotlib堆叠条形图-如何计算底值

来自分类Dev

for循环并更有效地生成图

来自分类Dev

如何使用算法使白名单功能更有效?

来自分类Dev

如何使用功能更有效的样式重写此代码?

来自分类Dev

如何使这种“裁剪”/背景颜色更改功能更有效?

来自分类Dev

有没有更快,更有效的方法来保存python字典?

来自分类Dev

哪个更有效?

来自分类Dev

使循环更有效

来自分类Dev

更有效的循环

来自分类Dev

使循环更有效

来自分类常见问题

如何使PyTorch张量(B,C,H,W)平铺和混合代码更简单,更有效?

来自分类Dev

如何使PyTorch张量(B,C,H,W)平铺和混合代码更简单,更有效?