给定索引ndarray和标志ndarray的任何numpy / torch样式来设置值吗?

TX Shi

我正在PyTorch上工作,目前遇到了一个我不知道如何以割炬/ numpy样式解决它的问题。例如,假设我有三个PyTorch张量

import torch
import numpy as np

indices = torch.from_numpy(np.array([[2, 1, 3, 0], [1, 0, 3, 2]]))
flags = torch.from_numpy(np.array([[False, False, False, True], [False, False, True, True]]))
tensor = torch.from_numpy(np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]]))

flags是一个布尔标志张量,用于显示indices应提取其中的元素给定提取的索引,我想将相应的元素设置为tensor指示的const(例如1e-30)。根据上面显示的示例,我想要

>>> sub_indices = indices.op1(flags)
>>> sub_indices
tensor([[0], [3, 2]])
>>> tensor.op2(sub_indices, 1e-30)
>>> tensor
tensor([[1e-30, 0.5, 1.2, 0.9], [3.1, 2.8, 1e-30, 1e-30]])

谁能提供解决方案?我正在使用列表理解,但是我认为这种方式有点难看。我试过了,indices[flags]但它只返回一个1d数组,[0, 3, 2]因此应用它会更改同一列0、2、3上的所有行

一些补充说明:

  • flags无法确定每行中“ True”值的数量
  • indices确保每一行都是序列的置换0 ... N - 1

以下是示例代码的Numpy版本,以方便复制粘贴。我怀疑这是否可以以纯粹的麻木的方式完成

import numpy as np

indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

独奏

您可以flags根据进行排序indices以创建mask,然后将mask用作多路复用器。这是示例代码:

indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

indices_sorted = indices.argsort(axis=1)
mask = np.take_along_axis(flags, indices_sorted, axis=1)
result = tensor * (1 - mask) + 1e-30 * mask

我对pytorch不太熟悉,但是我想收集一个参差不齐的张量不是一个好主意。虽然,即使在最坏的情况下,您也可以在numpy数组之间进行转换。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

使用mask为numpy ndarray设置值

来自分类Dev

使用mask为numpy ndarray设置值

来自分类Dev

最大奇异值NumPy`ndarray`

来自分类Dev

找到最大值及其索引以移动numpy ndarray的子数组的Python方法是什么?

来自分类Dev

numpy:在ndarray的每一行中查找第二高的值的索引

来自分类Dev

在np.ndarray中检索最小值和最大值的索引

来自分类Dev

numpy.ndarray 的索引

来自分类Dev

如何为嵌套的numpy ndarray设置dtype?

来自分类Dev

numpy.ndarray中的值计数

来自分类Dev

如何在numpy ndArray中插入值?

来自分类Dev

numpy.ndarray作为字典值

来自分类Dev

将相同的值分配给不同的ndarray索引-Python

来自分类Dev

获取ndarray中N个最高值的索引

来自分类Dev

numpy ndarray索引-从元组检索索引

来自分类Dev

NumPy ndarray与等效列表的绘制方式不同吗?

来自分类Dev

numpy append()函数不会更改我的ndarray吗?

来自分类Dev

沿轴插值numpy.ndarray的最佳方法

来自分类Dev

将numpy.ndarray值从字节转换为浮点

来自分类Dev

Numpy ndarray-如何根据多列的值选择行

来自分类Dev

根据其他数组的值选择 numpy.ndarray 的子集

来自分类Dev

如何有条件地沿numpy ndarray的特定轴在特定位置设置值

来自分类Dev

在numpy中,什么是在两个轴上查找3D ndarray的最大值及其索引的有效方法?

来自分类Dev

Python3 / Numpy:ndarray条件索引

来自分类Dev

给定字节缓冲区,dtype,形状和跨距,如何创建Numpy ndarray

来自分类Dev

Microsoft Word中的直接格式设置和字符样式之间有任何实际区别吗?

来自分类Dev

ndarray值的总和

来自分类Dev

ndarray值的总和

来自分类Dev

如何根据列名和单元格值设置样式?

来自分类Dev

使用对象键和值设置html元素的样式

Related 相关文章

热门标签

归档