RuntimeError:给定组= 1,权重为[64、1、4、4],预期输入[256、3、32、32]具有1个通道,但改为3个通道

克里斯托帕

您能帮我解决以上错误吗?如果要加载mnist数据集,则不会弹出错误。该错误与其他数据集cifar10,fmnist等的维度有关,并且在应用于这些集合时无法运行。任何帮助表示赞赏。


# noinspection PyUnresolvedReferences
import os
# imports
# noinspection PyUnresolvedReferences
import pickle
from time import time

from torchvision import datasets, transforms
from torchvision.utils import save_image

import site
site.addsitedir('/content/gw_gan/model') 
from loss import gwnorm_distance, loss_total_variation, loss_procrustes
from model_cnn import Generator, Adversary
from model_cnn import weights_init_generator, weights_init_adversary

# internal imports
from utils import *

# get arguments
args = get_args()

# system preferences
seed = np.random.randint(100)
torch.set_default_dtype(torch.double)
np.random.seed(seed)
torch.manual_seed(seed)

# settings
batch_size = 256
z_dim = 100
lr = 0.0002
ngen = 3
beta = args.beta
lam = 0.5
niter = 10
epsilon = 0.005
num_epochs = args.num_epochs
cuda = args.cuda
channels = args.n_channels
id1 = args.id

model = 'gwgan_{}_eps_{}_tv_{}_procrustes_{}_ngen_{}_channels_{}_{}' \
        .format(args.data, epsilon, lam, beta, ngen, channels, id1)
save_fig_path = 'out_' + model
if not os.path.exists(save_fig_path):
    os.makedirs(save_fig_path)

# data import
dataloader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data/cifar10', train=True, download=True,
                         transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5))])),
        batch_size=batch_size, drop_last=True, shuffle=True)

# print example images
save_image(next(iter(dataloader))[0][:25],
           os.path.join(save_fig_path, 'real.pdf'), nrow=5, normalize=True)

# define networks and parameters
generator = Generator(output_dim=channels)
adversary = Adversary(input_dim=channels)

# weight initialisation
generator.apply(weights_init_generator)
adversary.apply(weights_init_adversary)

if cuda:
    generator = generator.cuda()
    adversary = adversary.cuda()

# create optimizer
g_optimizer = torch.optim.Adam(generator.parameters(), lr, betas=(0.5, 0.99))
# zero gradients
generator.zero_grad()

c_optimizer = torch.optim.Adam(adversary.parameters(), lr, betas=(0.5, 0.99))
# zero gradients
adversary.zero_grad()

# sample for plotting
num_test_samples = batch_size
z_ex = torch.randn(num_test_samples, z_dim)
if cuda:
    z_ex = z_ex.cuda()

loss_history = list()
loss_tv = list()
loss_orth = list()
loss_og = 0
is_hist = list()

for epoch in range(num_epochs):
    t0 = time()

    for it, (image, _) in enumerate(dataloader):
        train_c = ((it + 1) % (ngen + 1) == 0)

        x = image.double()
        if cuda:
            x = x.cuda()

        # sample random number z from Z
        z = torch.randn(image.shape[0], z_dim)

        if cuda:
            z = z.cuda()

        if train_c:
            for q in generator.parameters():
                q.requires_grad = False
            for p in adversary.parameters():
                p.requires_grad = True
        else:
            for q in generator.parameters():
                q.requires_grad = True
            for p in adversary.parameters():
                p.requires_grad = False

        # result generator
        g = generator.forward(z)

        # result adversary
        f_x = adversary.forward(x)
        f_g = adversary.forward(g)

        # compute inner distances
        D_g = get_inner_distances(f_g, metric='euclidean', concat=False)
        D_x = get_inner_distances(f_x, metric='euclidean', concat=False)

        # distance matrix normalisation
        D_x_norm = normalise_matrices(D_x)
        D_g_norm = normalise_matrices(D_g)

        # compute normalized gromov-wasserstein distance
        loss, T = gwnorm_distance((D_x, D_x_norm), (D_g, D_g_norm),
                                  epsilon, niter, loss_fun='square_loss',
                                  coupling=True, cuda=cuda)

        if train_c:
            # train adversary
            loss_og = loss_procrustes(f_x, x.view(x.shape[0], -1), cuda)
            loss_to = -loss + beta * loss_og
            loss_to.backward()

            # parameter updates
            c_optimizer.step()
            # zero gradients
            reset_grad(generator, adversary)

        else:
            # train generator
            loss_t = loss_total_variation(g)
            loss_to = loss + lam * loss_t
            loss_to.backward()

            # parameter updates
            g_optimizer.step()
            # zero gradients
            reset_grad(generator, adversary)

    # plotting
    # get generator example
    g_ex = generator.forward(z_ex)
    g_plot = g_ex.cpu().detach()

    # plot result
    save_image(g_plot.data[:25],
               os.path.join(save_fig_path, 'g_%d.pdf' % epoch),
               nrow=5, normalize=True)

    fig1, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax0 = ax[0].imshow(T.cpu().detach().numpy(), cmap='RdBu_r')
    colorbar(ax0)
    ax1 = ax[1].imshow(D_x.cpu().detach().numpy(), cmap='Blues')
    colorbar(ax1)
    ax2 = ax[2].imshow(D_g.cpu().detach().numpy(), cmap='Blues')
    colorbar(ax2)
    ax[0].set_title(r'$T$')
    ax[1].set_title(r'inner distances of $D$')
    ax[2].set_title(r'inner distances of $G$')
    plt.tight_layout(h_pad=1)
    fig1.savefig(os.path.join(save_fig_path, '{}_ccc.pdf'.format(
            str(epoch).zfill(3))), bbox_inches='tight')

    loss_history.append(loss)
    loss_tv.append(loss_t)
    loss_orth.append(loss_og)
    plt.close('all')

# plot loss history
fig2 = plt.figure(figsize=(2.4, 2))
ax2 = fig2.add_subplot(111)
ax2.plot(loss_history, 'k.')
ax2.set_xlabel('Iterations')
ax2.set_ylabel(r'$\overline{GW}_\epsilon$ Loss')
plt.tight_layout()
plt.grid()
fig2.savefig(save_fig_path + '/loss_history.pdf')

fig3 = plt.figure(figsize=(2.4, 2))
ax3 = fig3.add_subplot(111)
ax3.plot(loss_tv, 'k.')
ax3.set_xlabel('Iterations')
ax3.set_ylabel(r'Total Variation Loss')
plt.tight_layout()
plt.grid()
fig3.savefig(save_fig_path + '/loss_tv.pdf')

fig4 = plt.figure(figsize=(2.4, 2))
ax4 = fig4.add_subplot(111)
ax4.plot(loss_orth, 'k.')
ax4.set_xlabel('Iterations')
ax4.set_ylabel(r'$R_\beta(f_\omega(X), X)$ Loss')
plt.tight_layout()
plt.grid()
fig4.savefig(save_fig_path + '/loss_orth.pdf')

错误显示:

Traceback (most recent call last):
  File "/content/gw_gan/main_gwgan_cnn.py", line 160, in <module>
    f_x = adversary.forward(x)
  File "/content/gw_gan/model/model_cnn.py", line 62, in forward
    x = self.conv(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 420, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [64, 1, 4, 4], expected input[256, 3, 32, 32] to have 1 channels, but got 3 channels instead

这是用于生成模型的应用,其中这是CNN。引用来自于https://github.com/bunnech/gw_gan的main_gwgan_cnn 提出了GAN,以从无与伦比的空间中学习并产生结果。

伊万

您必须设置--n_channels否则args.n_chanels将默认设置1 为此处此处给出的示例是针对具有单个通道的FMNIST。您正在CIFAR上运行,3由于有三个通道,因此应将其设置为

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

Related 相关文章

热门标签

归档