使用条件裁剪或阈值张量并在PyTorch中将结果零填充

奥梅尔·萨班

假设我有一个这样的张量

w = [[0.1, 0.7, 0.7, 0.8, 0.3],
    [0.3, 0.2, 0.9, 0.1, 0.5],
    [0.1, 0.4, 0.8, 0.3, 0.4]]

现在我想根据某些条件消除某些值(例如大于或小于0.5)

w = [[0.1, 0.3],
     [0.3, 0.2, 0.1],
     [0.1, 0.4, 0.3, 0.4]]

然后将其填充到相等的长度:

w = [[0.1, 0.3, 0, 0],
     [0.3, 0.2, 0.1, 0],
     [0.1, 0.4, 0.3, 0.4]]

这就是我在pytorch中实现它的方式:

w = torch.rand(3, 5)
condition = w <= 0.5
w = [w[i][condition[i]] for i in range(3)]
w = torch.nn.utils.rnn.pad_sequence(w)

但是显然这将非常慢,主要是由于列表理解。有什么更好的办法吗?

kmario23

这是使用布尔值掩蔽张量分裂,然后最终使用填充分裂的张量的一种直接方法torch.nn.utils.rnn.pad_sequence(...)

# input tensor to work with
In [213]: w 
Out[213]: 
tensor([[0.1000, 0.7000, 0.7000, 0.8000, 0.3000],
        [0.3000, 0.2000, 0.9000, 0.1000, 0.5000],
        [0.1000, 0.4000, 0.8000, 0.3000, 0.4000]])

# values above this should be clipped from the input tensor
In [214]: clip_value = 0.5 

# generate a boolean mask that satisfies the condition
In [215]: boolean_mask = (w <= clip_value) 

# we need to sum the mask along axis 1 (needed for splitting)
In [216]: summed_mask = boolean_mask.sum(dim=1) 

# a sequence of splitted tensors
In [217]: splitted_tensors = torch.split(w[boolean_mask], summed_mask.tolist())  

# finally pad them along dimension 1 (or axis 1)
In [219]: torch.nn.utils.rnn.pad_sequence(splitted_tensors, 1) 
Out[219]: 
tensor([[0.1000, 0.3000, 0.0000, 0.0000],
        [0.3000, 0.2000, 0.1000, 0.5000],
        [0.1000, 0.4000, 0.3000, 0.4000]])

关于效率的简短说明:使用torch.split()是超级高效的,因为它会将分割后的张量作为原始张量的视图返回(即不进行任何复制)。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

翻转张量并在张量流中填充零

来自分类Dev

使用阈值的条件列

来自分类Dev

Tensorflow.js裁剪图像返回零张量

来自分类Dev

如何基于模运算在PyTorch张量的末端添加零?

来自分类Dev

如何在某个轴上用张量零填充(Python)

来自分类Dev

如何裁剪未知值和大小的张量的常值填充(填充高度和宽度相同)?

来自分类Dev

Numpy PIL Python:在空格上裁剪图像或使用直方图阈值裁剪文本

来自分类Dev

如何在Pytorch中将张量转换为复杂类型?

来自分类Dev

在pytorch中将哈希值另存为张量

来自分类Dev

PyTorch Conv2D返回零输入张量的非零输出?

来自分类Dev

Python:使用rsub填充零填充zfill

来自分类Dev

使用FFTW的零填充FFT

来自分类Dev

使用零填充重塑矩阵

来自分类Dev

使用AVFoundation获得裁剪视频的意外结果

来自分类Dev

在PyTorch中将5D张量转换为4D张量

来自分类Dev

如何在pytorch中将3D张量转换为2D张量?

来自分类Dev

在PyTorch中有条件地应用张量操作

来自分类Dev

过零与阈值的区别

来自分类Dev

测试减法并在结果大于零时应用它

来自分类Dev

生成cmd命令以从任务列表中提取PID并在结果上使用条件

来自分类Dev

SQL JOIN 并在第二个表中使用某些条件限制结果

来自分类Dev

在php中将零填充为十六进制数

来自分类Dev

如何在批处理FOR循环中将数字零填充?

来自分类Dev

Scikit Image Otsu阈值产生零阈值

来自分类Dev

从Pytorch的4D张量中查找具有零的索引

来自分类Dev

如何摆脱Pytorch张量中充满零的每一列?

来自分类Dev

如何将零填充到十进制结果

来自分类Dev

具有零填充数字的意外算术结果

来自分类Dev

使用Pytorch如何使用索引和相应值定义张量

Related 相关文章

热门标签

归档