为了账号安全,请及时绑定邮箱和手机立即绑定

在卷积神经网络中设置层的维度

在卷积神经网络中设置层的维度

桃花长相依 2021-09-14 17:44:24
假设我有 4 个批次的 3x100x100 图像作为输入,并且我正在尝试使用 pytorch 制作我的第一个卷积神经网络。我真的不确定我的卷积神经网络是否正确,因为当我通过以下安排训练我的输入时,我遇到了错误:Expected input batch_size (1) to match target batch_size (4).以下是我的转发nnet:那么如果我要通过它:nn.Conv2d(3, 6, 5)我会得到 6 层地图,每层都有尺寸(100-5+1)。那么如果我要通过它:nn.MaxPool2d(2, 2)我会得到 6 层地图,每层都有尺寸 (96/2)然后,如果我要通过它:nn.Conv2d(6, 16, 5)我会得到 16 层地图,每层都有尺寸 (48-5+1)那么如果我要通过它:self.fc1 = nn.Linear(44*44*16, 120)我会得到 120 个神经元那么如果我要通过它:self.fc2 = nn.Linear(120, 84)我会得到 84 个神经元那么如果我要通过它:self.fc3 = nn.Linear(84, 3)我会得到 3 个输出,这将是完美的,因为我有 3 类标签。但正如我之前所说,这会导致一个非常令人惊讶的错误,因为这对我来说很有意义。完整的神经网络代码:import torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(3, 6, 5)        self.pool = nn.MaxPool2d(2, 2)        self.conv2 = nn.Conv2d(6, 16, 5)        self.fc1 = nn.Linear(44*44*16, 120)        self.fc2 = nn.Linear(120, 84)        self.fc3 = nn.Linear(84, 3)    def forward(self, x):        x = self.pool(F.relu(self.conv1(x)))        x = self.pool(F.relu(self.conv2(x)))        x = x.view(-1, 16 *44*44)        x = F.relu(self.fc1(x))        x = F.relu(self.fc2(x))        x = self.fc3(x)        return xnet = Net()net.to(device)
查看完整描述

1 回答

?
慕村225694

TA贡献1880条经验 获得超4个赞

你的理解是正确的,非常详细。


但是,您使用了两个池化层(请参阅下面的相关代码)。所以第二步之后的输出将是16个44/2=22维度的地图。


x = self.pool(F.relu(self.conv1(x)))

x = self.pool(F.relu(self.conv2(x)))

要解决此问题,要么不池化,要么将全连接层的维度更改为22*22*16。


要通过不池化来修复,请修改您的转发功能,如下所示。


def forward(self, x):

    x = self.pool(F.relu(self.conv1(x)))

    x = F.relu(self.conv2(x))

    x = x.view(-1, 16 *44*44)

    x = F.relu(self.fc1(x))

    x = F.relu(self.fc2(x))

    x = self.fc3(x)

    return x

要通过更改全连接层的维度来修复,请更改网络的声明如下。


def __init__(self):

    super(Net, self).__init__()

    self.conv1 = nn.Conv2d(3, 6, 5)

    self.pool = nn.MaxPool2d(2, 2)

    self.conv2 = nn.Conv2d(6, 16, 5)

    self.fc1 = nn.Linear(22*22*16, 120)

    self.fc2 = nn.Linear(120, 84)

    self.fc3 = nn.Linear(84, 10)


查看完整回答
反对 回复 2021-09-14
  • 1 回答
  • 0 关注
  • 268 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信