使用 PyTorch 实现 LeNet

标签: 动手实现经典神经网络 发布于:2021-08-19 09:54:47 编辑于:2021-08-20 20:22:26 浏览量:1379

概述

LeNet 是卷积神经网络开山之作,发表于 1989 年。

这个网络本身非常简单,作为本系列的第一篇,就拿这个入手咯。

对应仓库地址:https://github.com/songquanpeng/pytorch-classifiers

网络架构

lenet

PyTorch 实现

https://github.com/songquanpeng/pytorch-classifiers/blob/master/models/LeNet.py

class LeNet5(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        assert args.img_size == 32
        layers = [  # 1x32x32
            nn.Conv2d(args.img_dim, 6, 5, 1),  # 6x28x28
            nn.Tanh(),
            nn.MaxPool2d(2, 2),  # 6x14x14
            nn.Conv2d(6, 16, 5, 1),  # 16x10x10
            nn.Tanh(),
            nn.MaxPool2d(2, 2),  # 16x5x5
        ]
        self.conv = nn.Sequential(*layers)

        layers = [
            nn.Linear(16 * 5 * 5, 84),
            nn.Tanh(),
            nn.Linear(84, args.num_classes),
            # nn.Softmax()  # output logits (-inf, inf)
        ]
        self.fc = nn.Sequential(*layers)

    def forward(self, x):
        h = self.conv(x)
        y = self.fc(h.view(x.shape[0], -1))
        return y

这里的主要改动在于移除了网络输出层后的 Softmax,进而直接输出 logits。

未经允许,禁止转载,本文源站链接:https://iamazing.cn/