Pytorch CNN培训中的“ RuntimeError:预期标量类型为Double,但找到了Float”

饼干大仙

我刚刚开始学习Pytorch并创建了我的第一个CNN。数据集包含3360张RGB图像,我将它们转换为[3360, 3, 224, 224]张量。数据和标签在中dataset(torch.utils.data.TensorDataset)以下是培训代码。

    def train_net():
        dataset = ld.load()
        data_iter = Data.DataLoader(dataset, batch_size=168, shuffle=True)
        net = model.VGG_19()
        summary(net, (3, 224, 224), device="cpu")
        loss_func = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, dampening=0.1)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
        for epoch in range(5):
            print("epoch:", epoch + 1)
            train_loss = 0
            for i, data in enumerate(data_iter, 0):
                x, y = data
                print(x.dtype)
                optimizer.zero_grad()
                out = net(x)
                loss = loss_func(out, y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                if i % 100 == 99:
                    print("loss:", train_loss / 100)
                    train_loss = 0.0
        print("finish train")

然后我有这个错误:

    Traceback (most recent call last):
              File "D:/python/DeepLearning/VGG/train.py", line 52, in <module>
                train_net()
              File "D:/python/DeepLearning/VGG/train.py", line 29, in train_net
                out = net(x)
              File "D:\python\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
                result = self.forward(*input, **kwargs)
              File "D:\python\DeepLearning\VGG\model.py", line 37, in forward
                out = self.conv3_64(x)
              File "D:\python\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
                result = self.forward(*input, **kwargs)
              File "D:\python\lib\site-packages\torch\nn\modules\container.py", line 117, in forward
                input = module(input)
              File "D:\python\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
                result = self.forward(*input, **kwargs)
              File "D:\python\lib\site-packages\torch\nn\modules\conv.py", line 423, in forward
                return self._conv_forward(input, self.weight)
              File "D:\python\lib\site-packages\torch\nn\modules\conv.py", line 419, in _conv_forward
                return F.conv2d(input, weight, self.bias, self.stride,
            RuntimeError: expected scalar type Double but found Float

我认为x出问题了,我通过print(x.dtype)以下方式打印其类型

torch.float64

它是double而不是float。你知道怎么了吗?谢谢你的帮助!

Prajot Kuvalekar

该错误实际上是指在float32调用矩阵乘法时默认情况下处于conv层的权重由于您的输入是doublefloat64在pytorch中),而在conv中的权重是,float
因此您的情况的解决方案是:

def train_net():
    dataset = ld.load()
    data_iter = Data.DataLoader(dataset, batch_size=168, shuffle=True)
    net = model.VGG_19()
    summary(net, (3, 224, 224), device="cpu")
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, dampening=0.1)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    for epoch in range(5):
        print("epoch:", epoch + 1)
        train_loss = 0
        for i, data in enumerate(data_iter, 0):
            x, y = data
            x = x.float()
            print(x.dtype)
            optimizer.zero_grad()
            out = net(x)
            loss = loss_func(out, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if i % 100 == 99:
                print("loss:", train_loss / 100)
                train_loss = 0.0
    print("finish train")

这肯定可以工作

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

Pytorch RuntimeError:预期标量类型为Float,但找到字节

来自分类Dev

pytorch RuntimeError:标量类型为Double的预期对象,但标量类型为Float

来自分类Dev

RuntimeError:预期的标量类型为Long,但找到了Float

来自分类Dev

PyTorch 中的训练错误 - RuntimeError:FloatTensor 与 ByteTensor 类型的预期对象

来自分类Dev

师资培训PyTorch

来自分类Dev

RuntimeError:标量类型为Double的预期对象,但参数#2的标量类型为Float

来自分类Dev

如何在PyTorch中为数字数据正确实现一维CNN?

来自分类Dev

如何通过PyTorch中的Mask R-CNN预测为图像生成准确的Mask?

来自分类Dev

pytorch中的类型不匹配

来自分类Dev

RuntimeError:标量类型为Long的预期对象,但参数#2'target'的标量类型为Float

来自分类Dev

PyTorch中的生成对抗网络(GAN)的培训生成器

来自分类Dev

有没有办法查看 Pytorch 中的培训课程出了什么问题?

来自分类Dev

LSTM培训期间不断损失-PyTorch

来自分类Dev

培训损失根本没有改变(PyTorch)

来自分类Dev

如何在pytorch中向CNN输入矩阵

来自分类Dev

Pytorch CNN无法学习

来自分类Dev

Pytorch CNN 错误:预期输入batch_size (4) 匹配目标batch_size (64)

来自分类Dev

PyTorch为标量提供了不同的价值

来自分类Dev

pytorch中.to(device)的效率

来自分类Dev

在Pytorch中实施LeNet

来自分类Dev

在conda YAML文件中为pytorch指定仅cpu

来自分类Dev

重新使用分类CNN模型autoencoding - pytorch

来自分类Dev

Pytorch CNN损失没有改变,

来自分类Dev

转换后的PyTorch CNN线性层形状

来自分类Dev

PyTorch:用于培训和测试/验证的不同转发方法

来自分类Dev

加快培训-在PyTorch中使用LSTM进行RNN

来自分类Dev

在PyTorch中进行部分培训后添加样本

来自分类Dev

PyTorch C ++ API中的`randperm`不应该返回默认类型为int的张量吗?

来自分类Dev

PyTorch LSTM:RuntimeError:无效参数0:张量的大小必须匹配,但维度0除外。在维度1中为1219和440

Related 相关文章

  1. 1

    Pytorch RuntimeError:预期标量类型为Float,但找到字节

  2. 2

    pytorch RuntimeError:标量类型为Double的预期对象,但标量类型为Float

  3. 3

    RuntimeError:预期的标量类型为Long,但找到了Float

  4. 4

    PyTorch 中的训练错误 - RuntimeError:FloatTensor 与 ByteTensor 类型的预期对象

  5. 5

    师资培训PyTorch

  6. 6

    RuntimeError:标量类型为Double的预期对象,但参数#2的标量类型为Float

  7. 7

    如何在PyTorch中为数字数据正确实现一维CNN?

  8. 8

    如何通过PyTorch中的Mask R-CNN预测为图像生成准确的Mask?

  9. 9

    pytorch中的类型不匹配

  10. 10

    RuntimeError:标量类型为Long的预期对象,但参数#2'target'的标量类型为Float

  11. 11

    PyTorch中的生成对抗网络(GAN)的培训生成器

  12. 12

    有没有办法查看 Pytorch 中的培训课程出了什么问题?

  13. 13

    LSTM培训期间不断损失-PyTorch

  14. 14

    培训损失根本没有改变(PyTorch)

  15. 15

    如何在pytorch中向CNN输入矩阵

  16. 16

    Pytorch CNN无法学习

  17. 17

    Pytorch CNN 错误:预期输入batch_size (4) 匹配目标batch_size (64)

  18. 18

    PyTorch为标量提供了不同的价值

  19. 19

    pytorch中.to(device)的效率

  20. 20

    在Pytorch中实施LeNet

  21. 21

    在conda YAML文件中为pytorch指定仅cpu

  22. 22

    重新使用分类CNN模型autoencoding - pytorch

  23. 23

    Pytorch CNN损失没有改变,

  24. 24

    转换后的PyTorch CNN线性层形状

  25. 25

    PyTorch:用于培训和测试/验证的不同转发方法

  26. 26

    加快培训-在PyTorch中使用LSTM进行RNN

  27. 27

    在PyTorch中进行部分培训后添加样本

  28. 28

    PyTorch C ++ API中的`randperm`不应该返回默认类型为int的张量吗?

  29. 29

    PyTorch LSTM:RuntimeError:无效参数0:张量的大小必须匹配,但维度0除外。在维度1中为1219和440

热门标签

归档