我有两个等效的PyTorch模型(我认为),它们之间的唯一区别是填充:
import torch
import torch.nn as nn
i = torch.arange(9, dtype=torch.float).reshape(1,1,3,3)
# First model:
model1 = nn.Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflection')
# tensor([[[[-0.6095, -0.0321, 2.2022],
# [ 0.1018, 1.7650, 5.5392],
# [ 1.7988, 3.9165, 5.6506]]]], grad_fn=<MkldnnConvolutionBackward>)
# Second model:
model2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(1, 1, kernel_size=3))
# tensor([[[[1.4751, 1.5513, 2.6566],
# [4.0281, 4.1043, 5.2096],
# [2.6149, 2.6911, 3.7964]]]], grad_fn=<MkldnnConvolutionBackward>)
我想知道为什么以及当您使用这两种方法时,两者的输出是不同的,但是正如我所看到的,它们应该是相同的,因为填充是反射类型。
希望能对您有所帮助。
@Ash说完之后,我想检查一下权重是否有影响,所以我将所有权重固定为相同的值,但是这两种方法之间仍然存在差异:
import torch
import torch.nn as nn
i = torch.arange(9, dtype=torch.float).reshape(1,1,3,3)
# First model:
model1 = nn.Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflection')
model1.weight.data = torch.full(model1.weight.data.shape, 0.4)
print(model1(i))
print(model1.weight)
# tensor([[[[ 3.4411, 6.2411, 5.0412],
# [ 8.6411, 14.6411, 11.0412],
# [ 8.2411, 13.4411, 9.8412]]]], grad_fn=<MkldnnConvolutionBackward>)
# Parameter containing:
# tensor([[[[0.4000, 0.4000, 0.4000],
# [0.4000, 0.4000, 0.4000],
# [0.4000, 0.4000, 0.4000]]]], requires_grad=True)
# Second model:
model2 = [nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(1, 1, kernel_size=3)]
model2[1].weight.data = torch.full(model2[1].weight.data.shape, 0.4)
model2 = nn.Sequential(*model2)
print(model2(i))
print(model2[1].weight)
# tensor([[[[ 9.8926, 11.0926, 12.2926],
# [13.4926, 14.6926, 15.8926],
# [17.0926, 18.2926, 19.4926]]]], grad_fn=<MkldnnConvolutionBackward>)
# Parameter containing:
# tensor([[[[0.4000, 0.4000, 0.4000],
# [0.4000, 0.4000, 0.4000],
# [0.4000, 0.4000, 0.4000]]]], requires_grad=True)
两者的输出是不同的,但正如我所见,它们应该是相同的
我认为您获得的不同输出仅与反射性填充的实现方式无关。在您提供的代码片段中,卷积的权重和偏差的值model1
与model2
有所不同,因为它们是随机初始化的,而且您似乎没有在代码中固定它们的值。
编辑:
进行新的编辑后,对于之前的版本,似乎1.5
看一下正向传递的实现<your_torch_install>/nn/modules/conv.py
表明不支持“反射”。它也不会抱怨任意字符串而不是“反射”,但默认为零填充。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句