StyleGAN 代码分析

标签: 代码实现解读 GAN 发布于:2021-09-23 15:10:01 编辑于:2021-09-23 23:54:02 浏览量:4932

概述

我目前对 TensorFlow 还不是很熟,因此这里看的是非官方实现:https://github.com/rosinality/style-based-gan-pytorch

数据处理

StyleGAN 的训练是多阶段的,不同阶段需要不同尺寸的图片, 因此这里处理时将原始图片的大小调整为:8, 16, 32, 64, 128, 256, 512, 1024,另外为了加快速度读取,格式存为 lmdb 二进制格式。

transforms 里没啥花里胡哨的东西:

transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

训练选项

我们这里列出一些重要的训练选项:

  1. phase:不同训练阶段所用到的样本数量,默认 600k;
  2. lr:学习率,默认 0.001;
  3. sched:是否使用动态调整学习率,README 里对于 FFHQ 用到了这一点,对于 CelebA,则没有;
  4. init_size:初始图片大小,默认为 8;
  5. max_size:最大图片大小,默认为 1024;
  6. mixing:是否使用 mixing 正则化,默认启用;
  7. loss:对抗损失类型,这里可选的有 wgan-gp 和 r1,默认为 wgan-gp;
  8. batch_size:批量大小,默认为 16;
  9. code_size:样式码大小,默认为 512;
  10. n_critic:每更新多少次辨别器更新一次生成器,默认为 1;
  11. n_mlp:用以转换 z 到 w 的多层感知机的网络层数,默认为 8;

训练过程

EMA 版本的生成器

首先,这里构造了 ema 版本的生成器,代码如下:

g_running = StyledGenerator(code_size).cuda()
g_running.train(False)

# 注意,这里初始化时指定 decay 为 0,实际上就相当于把参数复制了过来。
accumulate(g_running, generator.module, 0)

# 后来训练时:
accumulate(g_running, generator.module)


def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)

优化器设置

优化器初始化:

g_optimizer = optim.Adam(
        generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
    )
g_optimizer.add_param_group(
    {
        'params': generator.module.style.parameters(),
        'lr': args.lr * 0.01,
        'mult': 0.01,
    }
)
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))

注意,生成器中 style 部分(就是那个用以转换 z 的 MLP,这名字起的。。。)的学习率是设置选项中学习率的 0.01 倍。

如果启用了动态调整学习率,不同图像尺寸对应不同的学习率和 batch size:

# 首先是将 lr 和 batch 设置为字典:
if args.sched:
    args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}
else:
    args.lr = {}
    args.batch = {}

# 之后修改优化器学习率,注意要乘上 mult。
adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))

def adjust_lr(optimizer, lr):
    for group in optimizer.param_groups:
        mult = group.get('mult', 1)
        group['lr'] = lr * mult

alpha 相关

alpha = min(1, 1 / args.phase * (used_sample + 1))

if (resolution == args.init_size and args.ckpt is None) or final_progress:
    alpha = 1

条件分支:if used_sample > args.phase * 2

TODO

辨别器的损失函数

if args.loss == 'wgan-gp':
    real_predict = discriminator(real_image, step=step, alpha=alpha)
    real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
    (-real_predict).backward()
elif args.loss == 'r1':
    real_image.requires_grad = True
    real_scores = discriminator(real_image, step=step, alpha=alpha)
    real_predict = F.softplus(-real_scores).mean()
    real_predict.backward(retain_graph=True)
    
    grad_real = grad(
        outputs=real_scores.sum(), inputs=real_image, create_graph=True
    )[0]
    grad_penalty = (
        grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
    ).mean()
    grad_penalty = 10 / 2 * grad_penalty
    grad_penalty.backward()
fake_image = generator(gen_in1, step=step, alpha=alpha)
fake_predict = discriminator(fake_image, step=step, alpha=alpha)

if args.loss == 'wgan-gp':
    fake_predict = fake_predict.mean()
    fake_predict.backward()

    eps = torch.rand(b_size, 1, 1, 1).cuda()
    x_hat = eps * real_image.data + (1 - eps) * fake_image.data
    x_hat.requires_grad = True
    hat_predict = discriminator(x_hat, step=step, alpha=alpha)
    grad_x_hat = grad(
        outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
    )[0]
    grad_penalty = (
        (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
    ).mean()
    grad_penalty = 10 * grad_penalty
    grad_penalty.backward()

elif args.loss == 'r1':
    fake_predict = F.softplus(fake_predict).mean()
    fake_predict.backward()

d_optimizer.step()

这里注意各个损失不是加起来后 backward,而是各自 backward,为啥要这样写,不是没区别么,这样写可读性还差。

生成器的损失函数

if (i + 1) % n_critic == 0:
    generator.zero_grad()

    requires_grad(generator, True)
    requires_grad(discriminator, False)

    fake_image = generator(gen_in2, step=step, alpha=alpha)

    predict = discriminator(fake_image, step=step, alpha=alpha)

    if args.loss == 'wgan-gp':
        loss = -predict.mean()

    elif args.loss == 'r1':
        loss = F.softplus(-predict).mean()

    loss.backward()
    g_optimizer.step()
    accumulate(g_running, generator.module)

    requires_grad(generator, False)
    requires_grad(discriminator, True)

也没啥好说的。

模型定义

网络层

PixelNorm

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)

在通道维度上取均值,用来归一化。

Equalized Learning Rate

class EqualLR:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1) * weight.data[0][0].numel()

        return weight * sqrt(2 / fan_in)

    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)

def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)

    return module
    

class EqualConv2d(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        conv = nn.Conv2d(*args, **kwargs)
        conv.weight.data.normal_()
        conv.bias.data.zero_()
        self.conv = equal_lr(conv)

    def forward(self, input):
        return self.conv(input)


class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        linear = nn.Linear(in_dim, out_dim)
        linear.weight.data.normal_()
        linear.bias.data.zero_()

        self.linear = equal_lr(linear)

    def forward(self, input):
        return self.linear(input)

对网络层应用 equal_lr 后,会注册一个钩子,该钩子对网络层的权重做了一个缩放的操作。

The equalized learning rate involves:

  1. Initializing all weights (linear and conv) from regular normal distribution, no fancy init
  2. Scaling all weights by the per-layer normalization constant from the Kaiming He initialization.

From: https://github.com/lucidrains/stylegan2-pytorch/issues/112#issuecomment-675204245

注意,下面某人评论:On my experience, equalized learning rate is incredibly important for high resolution training (256 and above).

要注意,这里对权重进行了自定义的初始化,因此不要二次使用别的方式进行初始化。

参考:

  1. https://github.com/lucidrains/stylegan2-pytorch/issues/112
  2. https://samaelchen.github.io/pytorch-pggan/

FusedUpsample & FusedDownsample

class FusedUpsample(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding=0):
        super().__init__()

        weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)
        bias = torch.zeros(out_channel)

        fan_in = in_channel * kernel_size * kernel_size
        self.multiplier = sqrt(2 / fan_in)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias)

        self.pad = padding

    def forward(self, input):
        weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
        weight = (
            weight[:, :, 1:, 1:]
            + weight[:, :, :-1, 1:]
            + weight[:, :, 1:, :-1]
            + weight[:, :, :-1, :-1]
        ) / 4

        out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)

        return out


class FusedDownsample(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding=0):
        super().__init__()

        weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        bias = torch.zeros(out_channel)

        fan_in = in_channel * kernel_size * kernel_size
        self.multiplier = sqrt(2 / fan_in)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias)

        self.pad = padding

    def forward(self, input):
        weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
        weight = (
            weight[:, :, 1:, 1:]
            + weight[:, :, :-1, 1:]
            + weight[:, :, 1:, :-1]
            + weight[:, :, :-1, :-1]
        ) / 4

        out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)

        return out

其中上采样是通过反卷积实现的:F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad), 下采样则是通过普通的卷积实现的:out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)

具体这个所谓的 fused 究竟是干嘛的,我目前还不清楚。

AdaptiveInstanceNorm

class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel, style_dim):
        super().__init__()

        self.norm = nn.InstanceNorm2d(in_channel)
        self.style = EqualLinear(style_dim, in_channel * 2)

        self.style.linear.bias.data[:in_channel] = 1
        self.style.linear.bias.data[in_channel:] = 0

    def forward(self, input, style):
        style = self.style(style).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1)

        out = self.norm(input)
        out = gamma * out + beta

        return out

核心部分,处理流程:

  1. 首先使用一个 EqualLinear 线性层处理样式码,形状为 (B, in_channel * 2);
  2. unsqueeze 到 (B, in_channel * 2, 1, 1);
  3. 之后沿维度 1 分割成两块,分别作为 gamma 和 beta;
  4. 使用一个标准的 nn.InstanceNorm2d 对输入进行处理(默认 affine 为 False,详见此处);
  5. 之后再使用第三步得到 gamma 和 beta 对其进行缩放和偏置。

这里和 StarGANv2 的 AdaIN 实现是一致的:

class AdaIN(nn.Module):
    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.fc = nn.Linear(style_dim, num_features * 2)

    def forward(self, x, s):
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1, 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        return (1 + gamma) * self.norm(x) + beta

NoiseInjection

class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))

    def forward(self, image, noise):
        return image + self.weight * noise

拿来注入噪声,注意该层的权重,仅在通道维度上不同,也就说对于任意图片的第 n 个特征图上的所有值,其用到的权重是相同的一个值。

ConvBlock

class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        padding,
        kernel_size2=None,
        padding2=None,
        downsample=False,
        fused=False,
    ):
        super().__init__()

        pad1 = padding
        pad2 = padding
        if padding2 is not None:
            pad2 = padding2

        kernel1 = kernel_size
        kernel2 = kernel_size
        if kernel_size2 is not None:
            kernel2 = kernel_size2

        self.conv1 = nn.Sequential(
            EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),
            nn.LeakyReLU(0.2),
        )

        if downsample:
            if fused:
                self.conv2 = nn.Sequential(
                    Blur(out_channel),
                    FusedDownsample(out_channel, out_channel, kernel2, padding=pad2),
                    nn.LeakyReLU(0.2),
                )

            else:
                self.conv2 = nn.Sequential(
                    Blur(out_channel),
                    EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
                    nn.AvgPool2d(2),
                    nn.LeakyReLU(0.2),
                )

        else:
            self.conv2 = nn.Sequential(
                EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
                nn.LeakyReLU(0.2),
            )

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        return out

用于辨别器的卷积块,主体就是两个卷积层,卷积层应用了上面提到的 Equalized learning rate 技巧。 由于是辨别器,牵涉到降采样,这里提供了两种方式:

  1. fused 方式:详见上文;
  2. pooling 方式(默认方式):这种方式就是在卷积层后加一个 kernel size 为 2 的平均池化层:nn.AvgPool2d(2)

如果需要降采样的话,注意是第二个卷积层负责。

StyledConvBlock

class StyledConvBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size=3,
        padding=1,
        style_dim=512,
        initial=False,
        upsample=False,
        fused=False,
    ):
        super().__init__()

        if initial:
            self.conv1 = ConstantInput(in_channel)

        else:
            if upsample:
                if fused:
                    self.conv1 = nn.Sequential(
                        FusedUpsample(
                            in_channel, out_channel, kernel_size, padding=padding
                        ),
                        Blur(out_channel),
                    )

                else:
                    self.conv1 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        EqualConv2d(
                            in_channel, out_channel, kernel_size, padding=padding
                        ),
                        Blur(out_channel),
                    )

            else:
                self.conv1 = EqualConv2d(
                    in_channel, out_channel, kernel_size, padding=padding
                )

        self.noise1 = equal_lr(NoiseInjection(out_channel))
        self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu1 = nn.LeakyReLU(0.2)

        self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        self.noise2 = equal_lr(NoiseInjection(out_channel))
        self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu2 = nn.LeakyReLU(0.2)

    def forward(self, input, style, noise):
        out = self.conv1(input)
        out = self.noise1(out, noise)
        out = self.lrelu1(out)
        out = self.adain1(out, style)

        out = self.conv2(out)
        out = self.noise2(out, noise)
        out = self.lrelu2(out)
        out = self.adain2(out, style)

        return out

主要架构:

  1. 第一个卷积层,如果需要上采样的话,由该层负责;
  2. 第一个噪声注入层,拿来注入随机因素;
  3. LeakyReLU 激活;
  4. 第一个自适应实例归一化层,拿来注入样式;
  5. 第二个卷积层;
  6. 第二个噪声注入层;
  7. LeakyReLU 激活;
  8. 第二个自适应实例归一化层;

注意这个顺序,按我的理解是,噪声注入层需要在卷积层生成的特征图上进行噪声注入,之后进行激活, 再之后进行自适应实例归一化。

注意,这里和 StarGANv2 中的 AdainResBlk 的实现有点区别,其实现如下:

x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
    x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x

也就是这里的顺序是先归一化,之后激活,再之后才是卷积层。

生成器

先放实现:

class Generator(nn.Module):
    def __init__(self, code_dim, fused=True):
        super().__init__()

        self.progression = nn.ModuleList(
            [
                StyledConvBlock(512, 512, 3, 1, initial=True),  # 4
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 8
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 16
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 32
                StyledConvBlock(512, 256, 3, 1, upsample=True),  # 64
                StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused),  # 128
                StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused),  # 256
                StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused),  # 512
                StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused),  # 1024
            ]
        )

        self.to_rgb = nn.ModuleList(
            [
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(256, 3, 1),
                EqualConv2d(128, 3, 1),
                EqualConv2d(64, 3, 1),
                EqualConv2d(32, 3, 1),
                EqualConv2d(16, 3, 1),
            ]
        )

        # self.blur = Blur()

    def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
        out = noise[0]

        if len(style) < 2:
            inject_index = [len(self.progression) + 1]

        else:
            inject_index = sorted(random.sample(list(range(step)), len(style) - 1))

        crossover = 0

        for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
            if mixing_range == (-1, -1):
                if crossover < len(inject_index) and i > inject_index[crossover]:
                    crossover = min(crossover + 1, len(style))

                style_step = style[crossover]

            else:
                if mixing_range[0] <= i <= mixing_range[1]:
                    style_step = style[1]

                else:
                    style_step = style[0]

            if i > 0 and step > 0:
                out_prev = out
                
            out = conv(out, style_step, noise[i])

            if i == step:
                out = to_rgb(out)

                if i > 0 and 0 <= alpha < 1:
                    skip_rgb = self.to_rgb[i - 1](out_prev)
                    skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
                    out = (1 - alpha) * skip_rgb + alpha * out

                break

        return out

网络主体是 9 个样式卷积块(分别对应 2^2 ~ 2^10,一共 9 个),通道数最大 512。 to_rgb 层则是针对每一个分辨率配一个卷积层,注意后面不接激活函数。

与 MLP 合并后的生成器

class StyledGenerator(nn.Module):
    def __init__(self, code_dim=512, n_mlp=8):
        super().__init__()

        self.generator = Generator(code_dim)

        layers = [PixelNorm()]
        for i in range(n_mlp):
            layers.append(EqualLinear(code_dim, code_dim))
            layers.append(nn.LeakyReLU(0.2))

        self.style = nn.Sequential(*layers)

    def forward(
        self,
        input,
        noise=None,
        step=0,
        alpha=-1,
        mean_style=None,
        style_weight=0,
        mixing_range=(-1, -1),
    ):
        styles = []
        if type(input) not in (list, tuple):
            input = [input]

        for i in input:
            styles.append(self.style(i))

        batch = input[0].shape[0]

        if noise is None:
            noise = []

            for i in range(step + 1):
                size = 4 * 2 ** i
                noise.append(torch.randn(batch, 1, size, size, device=input[0].device))

        if mean_style is not None:
            styles_norm = []

            for style in styles:
                styles_norm.append(mean_style + style_weight * (style - mean_style))

            styles = styles_norm

        return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)

    def mean_style(self, input):
        style = self.style(input).mean(0, keepdim=True)

        return style

首先注意到这个 MLP 输入之前做了 PixelNorm,之后 6 层线性层配 LeakyReLU(0.2)。

除此之外就那个 mean_style 还值得说一下,就是对生成的样式码做了一下处理:

# 生成 mean_style:
mean_style = None
for i in range(10):
    style = generator.mean_style(torch.randn(1024, 512).to(device))

    if mean_style is None:
        mean_style = style

    else:
        mean_style += style
mean_style /= 10

# 其中 generator.mean_style 方法的定义:
def mean_style(self, input):
    style = self.style(input).mean(0, keepdim=True)
    return style

# 最后使用时:
styles_norm.append(mean_style + style_weight * (style - mean_style))

要注意这个 mean_style 是启用了的。

辨别器

class Discriminator(nn.Module):
    def __init__(self, fused=True, from_rgb_activate=False):
        super().__init__()

        self.progression = nn.ModuleList(
            [
                ConvBlock(16, 32, 3, 1, downsample=True, fused=fused),  # 512
                ConvBlock(32, 64, 3, 1, downsample=True, fused=fused),  # 256
                ConvBlock(64, 128, 3, 1, downsample=True, fused=fused),  # 128
                ConvBlock(128, 256, 3, 1, downsample=True, fused=fused),  # 64
                ConvBlock(256, 512, 3, 1, downsample=True),  # 32
                ConvBlock(512, 512, 3, 1, downsample=True),  # 16
                ConvBlock(512, 512, 3, 1, downsample=True),  # 8
                ConvBlock(512, 512, 3, 1, downsample=True),  # 4
                ConvBlock(513, 512, 3, 1, 4, 0),
            ]
        )

        def make_from_rgb(out_channel):
            if from_rgb_activate:
                return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))

            else:
                return EqualConv2d(3, out_channel, 1)

        self.from_rgb = nn.ModuleList(
            [
                make_from_rgb(16),
                make_from_rgb(32),
                make_from_rgb(64),
                make_from_rgb(128),
                make_from_rgb(256),
                make_from_rgb(512),
                make_from_rgb(512),
                make_from_rgb(512),
                make_from_rgb(512),
            ]
        )

        # self.blur = Blur()

        self.n_layer = len(self.progression)

        self.linear = EqualLinear(512, 1)

    def forward(self, input, step=0, alpha=-1):
        for i in range(step, -1, -1):
            index = self.n_layer - i - 1

            if i == step:
                out = self.from_rgb[index](input)

            if i == 0:
                out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
                mean_std = out_std.mean()
                mean_std = mean_std.expand(out.size(0), 1, 4, 4)
                out = torch.cat([out, mean_std], 1)

            out = self.progression[index](out)

            if i > 0:
                if i == step and 0 <= alpha < 1:
                    skip_rgb = F.avg_pool2d(input, 2)
                    skip_rgb = self.from_rgb[index + 1](skip_rgb)

                    out = (1 - alpha) * skip_rgb + alpha * out

        out = out.squeeze(2).squeeze(2)
        # print(input.size(), out.size(), step)
        out = self.linear(out)

        return out

网络主体也是 9 个卷积块,然后 from_rgb 也是每个分辨率一个卷积层,注意后面要接激活。

要注意,网络主体的最后一个卷积块的输入维度是 513,这是因为最后一层输入时,并上了特征图的平均标准差:

out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
mean_std = out_std.mean()
mean_std = mean_std.expand(out.size(0), 1, 4, 4)
out = torch.cat([out, mean_std], 1)

在最后,用了一个线性层做最后的输出:

self.linear = EqualLinear(512, 1)

out = self.linear(out)

吐槽环节

  1. 多处硬编码 option;
  2. if 和 else 之间莫名其妙加空行。
  3. 保存权重时保存成单个文件,这样不利于权重文件的转移。
  4. 损失函数 backward 时没有先加到一起再 backward,可读性差。

未经允许,禁止转载,本文源站链接:https://iamazing.cn/