使用 PyTorch 实现 ResNet

Tag: 动手实现经典神经网络 Posted on 2021-08-25 10:00:52 Edited on 2021-08-25 17:47:43 Views: 709

概述

ILSVR 2015 冠军。

使用残差连接解决了深层网络性能反而不如浅层网络的问题。

该网络分为多个版本,后面接的数字代表网络中可学习参数的网络层的个数。

不同版本的网络架构非常相似,可以分为以下 6 个部分:

  1. 起始的卷积层,一层卷积(kernel size 取 7,stride 取 2,特征图降采样一倍),Batch Normalization,激活。
  2. 4 个由残差块组成的网络块,个数不一(具体见下面的表),其中网络块中的第一个残差块负责降采样特征图,例外是第一个网络块是通过一个最大池化层来进行降采样。
  3. 输出部分,自适应平均池化层将特征图大小降至 1x1,之后接一个全连接层,最后 Softmax 激活。

对应仓库地址:https://github.com/songquanpeng/pytorch-classifiers

PyTorch 官方实现:https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

实现细节

BasicResBlock

用于 ResNet-19,ResNet-34 的残差块。

具体实现如下:

class BasicResBlock(nn.Module):
    """ A building block for ResNet34 """
    expansion = 1

    def __init__(self, in_channels, planes, stride=1):
        """
        when stride=2, this block is used for downsampling
        """
        super().__init__()
        out_channels = planes * BasicResBlock.expansion
        layers = [
            nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ]
        self.main = nn.Sequential(*layers)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        shortcut = x
        out = self.main(x)
        if self.downsample:
            shortcut = self.downsample(x)
        out += shortcut
        out = F.relu(out)
        return out

可以看到里面网络分为两路:

  1. 主体是两个 kernel size 为 3 的卷积层,注意后面一个卷积层后面没有激活。
  2. 残差支路如果需要的话,则使用一个 1x1 卷积来降采样和调整通道数目。 需要降采样的情况:后三个网络块中第一个残差块要降采样特征图时,此时需要同 stride(2)的 1x1 卷积来降采样。

之后两路通道数相同,特征图大小相同,直接对应位置相加,之后再接激活函数。

BottleneckResBlock

用于 ResNet-50,ResNet-101 以及 ResNet-152 的残差块

class BottleneckResBlock(nn.Module):
    """ A bottleneck building block for ResNet-50/101/152. """
    expansion = 4

    def __init__(self, in_channels, planes, stride=1):
        """
        when stride=2, this block is used for downsampling
        """
        super().__init__()
        out_channels = planes * BottleneckResBlock.expansion
        layers = [
            nn.Conv2d(in_channels, planes, 1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, planes, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ]
        self.main = nn.Sequential(*layers)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        shortcut = x
        out = self.main(x)
        if self.downsample:
            shortcut = self.downsample(x)
        out += shortcut
        out = F.relu(out)
        return out

与基础版本的区别主要在于其干路换成了三个卷积层:1x1 卷积,3x3 卷积和 1x1 卷积。

第一个 1x1 卷积降低通道数量,最后一个 1x1 卷积再将通道数量恢复。 这样就使得 3x3 卷积成为了一个瓶颈层。

注意,如果单看一个 BottleneckResBlock,会发现其将通道数扩大了四倍,并没有降低通道数量。 其实不然,要注意我们是堆叠了多个 BottleneckResBlock,这些 BottleneckResBlock 的架构一样, 也就是说实际上第一个 1x1 卷积的输入通道数量是输出通道数量的四倍。

该 ResBlock 用于较深的网络,原论文中的解释如下:

Deeper non-bottleneck ResNets (e.g., Fig. 5 left) also gain accuracy from increased depth (as shown on CIFAR-10), but are not as economical as the bottleneck ResNets. So the usage of bottleneck designs is mainly due to practical considerations. We further note that the degradation problem of plain nets is also witnessed for the bottleneck designs.

所以理由就是节省了参数量和计算量,同时性能也好,没什么理论上的优势。

架构

arch table

ResNet-34 arch

PyTorch 实现

class ResNet(nn.Module):
    def __init__(self, block: Type[Union[BasicResBlock, BottleneckResBlock]], num_layers: List[int], num_classes: int,
                 in_channels: int = 3):
        super().__init__()
        # input shape: 3x224x224
        self.in_planes = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, self.in_planes, 7, 2, 3, bias=False),  # 64x112x112
            nn.BatchNorm2d(self.in_planes),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)  # 64x56x56
        )
        self.conv2 = self.make_layer(block, num_layers[0], 64, downsample=False)  # 64x56x56
        self.conv3 = self.make_layer(block, num_layers[1], 128, downsample=True)  # 128x28x28
        self.conv4 = self.make_layer(block, num_layers[2], 256, downsample=True)  # 256x14x14
        self.conv5 = self.make_layer(block, num_layers[3], 512, downsample=True)  # 512x7x7
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # 512x1x1
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def make_layer(self, block, num, planes, downsample=True):
        layers = [block(self.in_planes, planes, stride=2 if downsample else 1)]
        self.in_planes = planes * block.expansion
        for _ in range(1, num):
            layers.append(block(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avg_pool(x)
        out = self.fc(x.view(x.shape[0], -1))
        return out

未经允许,禁止转载,本文源站链接:https://iamazing.cn/