BigGAN Code Analysis I

标签: 代码实现解读 GAN 发布于:2020-10-31 11:49:06 编辑于:2021-09-13 15:49:22 浏览量:1943

utils.py:各种工具函数

  1. 解析命令行参数。
  2. 准备 dataloader 以及配套的图片处理函数。
  3. EMA 的实现。
  4. ortho 的实现。
  5. 用于记录和输出训练信息的 logger。
  6. Adam 的实现。
  7. 其他工具函数:
    1. 切换 grad 状态。
    2. 保存权重。
    3. 加载权重。
    4. 自定义的进度条。
    5. 根据配置信息来给实验命名的函数。
    6. 设置所有的随机数种子。

train.py:训练网络的启动脚本

  1. 补充 config 的内容,例如根据数据库来设置类别个数。
  2. 设置随机数种子。
  3. 准备目录。
  4. 导入模型,并进行构建:G,D 以及 GD。
  5. 检查是否是从之前的存档恢复,如果是,加载相应的模型参数(通过传入的实验名称来找到相应的目录)。
  6. 如果是并行训练,则进行准备。
  7. 准备 logger。
  8. 写入实验的 meta 数据。
  9. 准备 data loader。
  10. 准备用于测量 IS 和 FID 的工具函数。
  11. 准备用于取样 z 和 y 的取样函数。
  12. 准备一组固定的 z 和 y 用于观察在其在训练过程中的变化。
  13. 准备训练函数:train = train_fns.GAN_training_function()
  14. 准备用于给测量 IS 和 FID 的函数提供 z 和 y 的取样函数。
  15. 对于每一个 epoch:
    1. 准备用于实时显示训练进度的 progress bar。
    2. 对于其中的每一个 iteration:
      1. 从 data loader 中取出一个 batch 的训练数据(x,y),以及 batch 序号(i)。
      2. 更新训练状态字典中的 itr 字段(记录迭代次数)。
      3. 训练 G(G.train())。
      4. 训练 D(D.train())。
      5. 如果启用了 EMA,训练 G_ema(G_ema.train())。
      6. 将训练数据加载到目标设备中。
      7. 调用之前准备的训练函数,拿到 metrics。
      8. 调用之前准备的 train logger,输出奇异值(调用 utils.get_SVs 拿到奇异值)。
      9. 检查本次迭代是否应该保存模型权重,如果是,调用 train_fns.save_and_sample()
      10. 检查本次迭代是否应该测试模型性能,如果是,将 G 切换到 eval 模式,调用 train_fns.test()
    3. 更新训练状态字典中的 epoch 字段(记录完成的 epoch 数目)。

BigGAN.py:模型定义

函数 G_arch()

用于返回 arch 字典,key 是 resolution,value 则是对应的具体的模型超参数,有:

  1. 生成器各个 block 的输入通道数:in_channels。
  2. 生成器各个 block 的输出通道数:out_channels。
  3. 每一层是否是上采样:upsample。
  4. 每一层的分辨率:resolution。
  5. 每一层是否使用注意力机制:attention。

Generator(nn.Module)

函数__init__()

  1. 设置通道宽度倍增算子:ch。

  2. 设置隐空间维度个数:dim_z。

  3. 设置最底层的通道尺寸:bottom_width。

  4. 设置生成器输出图片的分辨率:resolution(这里生成的图片只能是方形图片)。

  5. 设置卷积层的卷积核尺寸:kernel_size。

  6. 设置在哪几层应用注意力机制:attention。

  7. 基于之前设置的通道宽度倍增算子 ch 和决定在哪几层使用注意力机制的 attention 以及函数 G_arch()来取得一个存有模型架构的字典:arch。

  8. 设置在基于类别的图片生成中所需要的类别个数:n_classes。

  9. 设置是否启用分层的 latent space:hier

    1. 启用:设置 num_slots 为 arch 中的 in_channels 数目加一,再基于此计算每个 slot 分到的 z_chunk_size,注意这里要保证可以整除。
    2. 禁用:设置 num_slots 为 1,z_chunk_size 为 0。
  10. TODO:设置是否启用跨样本的 Batch Normalization:cross_replica。

  11. 设置是否使用作者提供的 Batch Normalization:mybn。

  12. 设置残差块的激活函数:activation。

  13. 设置参数初始化方式:init。

  14. 设置 parameterization 方式:G_param

    1. 如果是 Spectral Normalization(SN):分别设置 which_conv 和 which_linear 为 layers.SNConv2d 和 layers.SNLinear。
    2. 否则:分别设置 which_conv 和 which_linear 为 nn.Conv2d (用上之前的 kernel_size,并设置 padding 为 1)和 nn.Linear。
  15. 设置卷积层的归一化方式:norm_style。

  16. 设置 Batch Normalization 的参数 ε:BN_eps。

  17. 设置 Spectral Normalization 的参数 ε:SN_eps。

  18. 设置是否使用 16 位的浮点数进行运算(速度更快但需要指定硬件):fp16。

  19. 设置共享 embedding 的维度:shared_dim。

  20. 设置是否启用共享 embedding:G_shared

    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding)
    self.which_bn = functools.partial(layers.ccbn,
                                      which_linear=bn_linear,
                                      cross_replica=self.cross_replica,
                                      mybn=self.mybn,
                                      input_size=(self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes),
                                      norm_style=self.norm_style,
                                      eps=self.BN_eps)
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity())	
    
  21. 设置模型开始处的 linear layer:

    self.linear = self.which_linear(self.dim_z // self.num_slots,                 				self.arch['in_channels'][0] * (self.bottom_width**2))
    
  22. 构造 blocks

    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      # 为什么这里要加上一个只含一个元素的列表?因为考虑到可能后面还要加 attention layer。
      self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] else None))]]
    
      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
    
    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
    
  23. 构造输出层:

    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))
    
  24. 是否需要初始化参数(skip_init):

    1. 需要,调用 self.init_weights()
    2. 不需要,不调用上述函数,注意 G_ema 不需要初始化参数。
  25. 是否需要需要设置优化器(no_optime):

    1. 需要,设置 Adam 需要的四个超参数:

      1. 学习率:lr。
      2. β1:B1。
      3. β2:B2。
      4. ε:adam_eps。

      并检查是否 train with half-precision activations but fp32 params in G:

      1. 是,使用 utils.Adam16。
      2. 否,使用 torch.optim.Adam。
    2. 不需要,结束,注意 G_ema 不需要设置优化器。

函数init_weights()

def init_weights(self):
  self.param_count = 0
  for module in self.modules():
    if (isinstance(module, nn.Conv2d) 
        or isinstance(module, nn.Linear) 
        or isinstance(module, nn.Embedding)):
      if self.init == 'ortho':
        init.orthogonal_(module.weight)
      elif self.init == 'N02':
        init.normal_(module.weight, 0, 0.02)
      elif self.init in ['glorot', 'xavier']:
        init.xavier_uniform_(module.weight)
      else:
        print('Init style not recognized...')
      self.param_count += sum([p.data.nelement() for p in module.parameters()])

函数 forward()

  1. 判断是否使用了分层的 latent space(hier):

    1. 是:

      zs = torch.split(z, self.z_chunk_size, 1)
      z = zs[0]
      ys = [torch.cat([y, item], 1) for item in zs[1:]]
      
    2. 否:ys = [y] * len(self.blocks)

  2. 将 z 传入 the first linear layer:h = self.linear(z),并 reshape:h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)

  3. 对 blocks 中每一个 block 进行遍历:

    for index, blocklist in enumerate(self.blocks):
      # Second inner loop in case block has multiple layers
      for block in blocklist:
        h = block(h, ys[index])
    
  4. 将 h 传入到输出层,并对输出应用 torch.tanh,之后输出:torch.tanh(self.output_layer(h))

函数 D_arch()

G_arch() 的区别仅在于 upsample 换成了 downsample。

Discriminator(nn.Module)

函数__init__()

  1. 设置通道宽度倍增算子:ch。

  2. 设置是否使用 wide D:D_wide。

  3. 设置生成器输出图片的分辨率:resolution。

  4. 设置卷积层的卷积核尺寸:kernel_size。

  5. 设置在哪几层应用注意力机制:attention。

  6. 设置类别个数:n_classes。

  7. 设置残差块的激活函数:activation。

  8. 取得存有模型架构的字典:arch。

  9. 设置参数初始化方式:init。

  10. 设置 Spectral Normalization 的参数 ε:SN_eps。

  11. 设置是否使用 16 位的浮点数进行运算(速度更快但需要指定硬件):fp16。

  12. 设置 parameterization 方式:D_param

    if self.D_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                          eps=self.SN_eps)
      self.which_embedding = functools.partial(layers.SNEmbedding,
                              num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
                              eps=self.SN_eps)
    # 注意没有 else 哦,但是后面的代码又用到了 which_conv,因此可以推测作者代码没写完,我们这里目前必须使用 SN。
    
  13. 构造 blocks:

    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
                       out_channels=self.arch['out_channels'][index],
                       which_conv=self.which_conv,
                       wide=self.D_wide,
                       activation=self.activation,
                       preactivation=(index > 0),
                       downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
                                             self.which_conv)]
    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
    
  14. 设置模型结束处的 linear layer:self.which_linear(self.arch['out_channels'][-1], output_dim)

  15. 设置用于 projection discrimination 的 embedding(embed):self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])

  16. 是否需要初始化参数(skip_init):

    1. 需要,调用 self.init_weights()
    2. 不需要,不调用上述函数。
  17. 设置优化器:与 Generator 中的一致。

函数init_weights()

与 Generator 中的完全一致。

函数 forward()

  1. 对 blocks 中的每一个 block 进行遍历:

    for index, blocklist in enumerate(self.blocks):
      for block in blocklist:
        h = block(h)
    
  2. 应用 global sum pooling:h = torch.sum(self.activation(h), [2, 3])

  3. 传入到线性层:out = self.linear(h)

  4. Get projection of final featureset onto class vectors and add to evidence:out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)

G_D(nn.Module)

Parallelized G_D to minimize cross-gpu communication

Without this, Generator outputs would get all-gathered and then rebroadcast.

函数 __init__()

  1. 设置 self.G = G。
  2. 设置 self.D = D。

函数 forward()

参数说明:

  1. z:生成器需要的噪声向量。
  2. gy:生成器需要的类别。
  3. x=None:真实样本。
  4. dy=None:真实样本的类别。
  5. train_G=False:是否对 G 进行训练。
  6. return_G_z=False:是否返回 G 生成的样本。
  7. split_D=False:是否是将真实数据和生成数据分开输入到辨别器中。

函数执行流程:

  1. 生成假样本:

    with torch.set_grad_enabled(train_G):
      # Get Generator output given noise
      G_z = self.G(z, self.G.shared(gy))
      # Cast as necessary
      if self.G.fp16 and not self.D.fp16:
        G_z = G_z.float()
      if self.D.fp16 and not self.G.fp16:
        G_z = G_z.half()
    
  2. 是否是将真实数据和生成数据分开输入到辨别器中(split_D):

    1. 是,首先传入 G 生成的数据:D_fake = self.D(G_z, gy),之后检查是否提供了真实样本:

      1. 是,D_real = self.D(x, dy),之后 return D_fake, D_real
      2. 否,检查是否需要将 G 生成的数据也返回:
        1. 是,return D_fake, G_z
        2. 否,return D_fake
    2. 否,将真实样本和生成样本合并之后传入 D 让其辨别:

      D_input = torch.cat([G_z, x], 0) if x is not None else G_z
      D_class = torch.cat([gy, dy], 0) if dy is not None else gy
      # Get Discriminator output
      D_out = self.D(D_input, D_class)
      
    3. 之后检查是否提供了真实样本:

      1. 是,D_real = self.D(x, dy),之后 return torch.split(D_out, [G_z.shape[0], x.shape[0]])(就是 D_fake, D_real)。
      2. 否,检查是否需要将 G 生成的数据也返回:
        1. 是,return D_out, G_z
        2. 否,return D_out

未经允许,禁止转载,本文源站链接:https://iamazing.cn/