我正在PyTorch上工作,目前遇到了一个我不知道如何以割炬/ numpy样式解决它的问题。例如,假设我有三个PyTorch张量
import torch
import numpy as np
indices = torch.from_numpy(np.array([[2, 1, 3, 0], [1, 0, 3, 2]]))
flags = torch.from_numpy(np.array([[False, False, False, True], [False, False, True, True]]))
tensor = torch.from_numpy(np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]]))
这flags
是一个布尔标志张量,用于显示indices
应提取其中的元素。给定提取的索引,我想将相应的元素设置为tensor
指示的const(例如1e-30)。根据上面显示的示例,我想要
>>> sub_indices = indices.op1(flags)
>>> sub_indices
tensor([[0], [3, 2]])
>>> tensor.op2(sub_indices, 1e-30)
>>> tensor
tensor([[1e-30, 0.5, 1.2, 0.9], [3.1, 2.8, 1e-30, 1e-30]])
谁能提供解决方案?我正在使用列表理解,但是我认为这种方式有点难看。我试过了,indices[flags]
但它只返回一个1d数组,[0, 3, 2]
因此应用它会更改同一列0、2、3上的所有行
一些补充说明:
flags
无法确定每行中“ True”值的数量indices
确保每一行都是序列的置换0 ... N - 1
以下是示例代码的Numpy版本,以方便复制粘贴。我怀疑这是否可以以纯粹的麻木的方式完成
import numpy as np
indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
您可以flags
根据进行排序indices
以创建mask
,然后将mask
用作多路复用器。这是示例代码:
indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
indices_sorted = indices.argsort(axis=1)
mask = np.take_along_axis(flags, indices_sorted, axis=1)
result = tensor * (1 - mask) + 1e-30 * mask
我对pytorch不太熟悉,但是我想收集一个参差不齐的张量不是一个好主意。虽然,即使在最坏的情况下,您也可以在numpy数组之间进行转换。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句