我有这个函数可以为2d张量创建排序热图,但是当使用更大的张量输入时,它的速度很慢。如何加快速度并提高效率?
import torch
import numpy as np
import matplotlib.pyplot as plt
def heatmap(
tensor: torch.Tensor,
) -> torch.Tensor:
assert tensor.dim() == 2
def color_tensor(x: torch.Tensor) -> torch.Tensor:
if x < 0:
x = -x
if x < 0.5:
x = x * 2
return (1 - x) * torch.tensor(
[0.9686, 0.9686, 0.9686]
) + x * torch.tensor([0.5725, 0.7725, 0.8706])
else:
x = (x - 0.5) * 2
return (1 - x) * torch.tensor(
[0.5725, 0.7725, 0.8706]
) + x * torch.tensor([0.0196, 0.4431, 0.6902])
else:
if x < 0.5:
x = x * 2
return (1 - x) * torch.tensor(
[0.9686, 0.9686, 0.9686]
) + x * torch.tensor([0.9569, 0.6471, 0.5098])
else:
x = (x - 0.5) * 2
return (1 - x) * torch.tensor(
[0.9569, 0.6471, 0.5098]
) + x * torch.tensor([0.7922, 0.0000, 0.1255])
return torch.stack(
[torch.stack([color_tensor(x) for x in t]) for t in tensor]
).permute(2, 0, 1)
x = torch.randn(3,3)
x = x / x.max()
x_out = heatmap(x)
x_out = (x_out.permute(1, 2, 0) * 255).numpy()
plt.imshow(x_out.astype(np.uint8))
plt.axis("off")
plt.show()
输出示例:
您需要摆脱if
s和for循环,并制作向量化函数。为此,您可以使用遮罩并一并计算。这里是:
def heatmap(tensor: torch.Tensor) -> torch.Tensor:
assert tensor.dim() == 2
# We're expanding to create one more dimension, for mult. to work.
xt = x.expand((3, x.shape[0], x.shape[1])).permute(1, 2, 0)
# this part is the mask: (xt >= 0) * (xt < 0.5) ...
# ... the rest is the original function translated
color_tensor = (
(xt >= 0) * (xt < 0.5) * ((1 - xt * 2) * torch.tensor([0.9686, 0.9686, 0.9686]) + xt * 2 * torch.tensor([0.9569, 0.6471, 0.5098]))
+
(xt >= 0) * (xt >= 0.5) * ((1 - (xt - 0.5) * 2) * torch.tensor([0.9569, 0.6471, 0.5098]) + (xt - 0.5) * 2 * torch.tensor([0.7922, 0.0000, 0.1255]))
+
(xt < 0) * (xt > -0.5) * ((1 - (-xt * 2)) * torch.tensor([0.9686, 0.9686, 0.9686]) + (-xt * 2) * torch.tensor([0.5725, 0.7725, 0.8706]))
+
(xt < 0) * (xt <= -0.5) * ((1 - (-xt - 0.5) * 2) * torch.tensor([0.5725, 0.7725, 0.8706]) + (-xt - 0.5) * 2 * torch.tensor([0.0196, 0.4431, 0.6902]))
).permute(2, 0, 1)
return color_tensor
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句