BigGAN Code Analysis II

标签: 代码实现解读 GAN 发布于:2020-11-06 20:42:31 编辑于:2021-09-13 15:49:33 浏览量:2115

train_fns.py:模型训练函数,测试函数以及保存取样函数

函数 dummy_training_function()

用于 debug。

def dummy_training_function():
  def train(x, y):
    return {}
  return train

函数GAN_training_function()

参数解释:

  1. G:生成器模型。
  2. D:辨别器模型。
  3. GD:G 和 D 的联合模型。
  4. z_:取样器。
  5. y_:取样器。
  6. ema:一个 utils.ema 对象。
  7. state_dict:存储本次训练的元信息的字典。
  8. config:从命令行参数解析来的字典对象。

内部定义并返回了一个 train 函数:train(x, y):

  1. 首先把 G 和 D 的参数的梯度清零:

    G.optim.zero_grad()
    D.optim.zero_grad()
    
  2. 把输入样本及其标签分割成一些列 batch:

    x = torch.split(x, config['batch_size'])
    y = torch.split(y, config['batch_size'])
    
  3. 设置 counter=0,其记录了当前用到的 batch 序号。

  4. 如果设置了 toggle_grads 为 True(默认启用),则:

    utils.toggle_grad(D, True)
    utils.toggle_grad(G, False)
    
  5. 训练辨别器 num_D_steps 次:

    1. 清零 D 的参数梯度:D.optim.zero_grad()

    2. 累加 num_D_accumulations 次梯度:

      1. (我现在也不清楚为啥这个能 work)

        z_.sample_()
        y_.sample_()
        
      2. 传入 GD,拿到 D_fake (D 对生成样本的预测)以及 D_real(D 对真实样本的预测):

        D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], 
                            x[counter], y[counter], train_G=False, 
                            split_D=config['split_D'])
        
      3. 传入辨别器的损失函数拿到 loss:

        D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
        
      4. 之后合并两个损失,并除以 num_D_accumulations:

        D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])
        
      5. 调用 backward 计算梯度:D_loss.backward()

      6. 更新 batch 计数器:counter += 1

    3. 如果 D_ortho > 0,则应用 Orthogonal Regularization:utils.ortho(D, config['D_ortho'])

    4. 调用优化器的 step 更新辨别器模型参数:D.optim.step()

  6. 如果设置了 toggle_grads 为 True(默认启用),则:

    utils.toggle_grad(D, False)
    utils.toggle_grad(G, True)
    
  7. 清零 G 的梯度(保险起见):G.optim.zero_grad()

  8. 累加 num_G_accumulations 次梯度:

    1. 取样:

      z_.sample_()
      y_.sample_()
      
    2. 传入 GD,拿到 D_fake:GD(z_, y_, train_G=True, split_D=config['split_D'])

    3. 传入生成器的损失函数拿到 loss:G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])

    4. 调用 backward 计算梯度:G_loss.backward()

  9. 如果 D_ortho > 0,则应用 Orthogonal Regularization:utils.ortho(D, config['D_ortho'])

  10. 调用优化器的 step 更新生成器模型参数:G.optim.step()

  11. 如果我们启用了 ema,我们需要更新:ema.update(state_dict['itr'])

  12. 输出是一个字典,其包含了辨别器分别在真实样本 & 生成样本上的损失以及生成器的损失:

    out = {'G_loss': float(G_loss.item()), 
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item())}
    

函数 save_and_sample()

  1. 调用函数 utils.save_weights 保存以下内容:
    1. 生成器参数。
    2. 生成器的优化器的参数。
    3. 辨别器的参数。
    4. 辨别器的优化器的参数。
    5. EMA 版本的生成器的参数。
    6. 训练的元信息。
  2. 生成固定 z 和 y 的图片供对比。
  3. 生成随机取样的图片。
  4. 生成三种设置的插值图片。

函数 test()

  1. 调用传入的 get_inception_metrics 函数计算 IS 和 FID。
  2. 如果测试时发现模型表现比之前最好的还要好,就保存此时的权重,注意保存多个历史版本,并更新当前的最佳值。
  3. 并输出相应的 log。

layers.py:定义了 BigGAN 中所需要的众多网络层

class SN(object)

  1. 使用 power_iteration 方法估算 w 的最大的奇异值。
  2. self.weight / svs[0]

class SNConv2d(nn.Conv2d, SN)

  1. 通过多继承的方式来构造我们的启用了谱归一化的二维卷积层。

  2. 其 forward 函数封装了 F.conv2d:

    def forward(self, x):
      return F.conv2d(x, self.W_(), self.bias, self.stride, 
                      self.padding, self.dilation, self.groups)
    

class SNLinear(nn.Linear, SN),class SNEmbedding(nn.Embedding, SN)

  1. 同上。
  2. 注意 nn.Embedding 就是一个简单的查找表。

class Attention(nn.Module)

  1. 初始化四个卷积层:

    self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
    self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
    self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
    self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
    
  2. 初始化一个可训练的超参数:self.gamma = P(torch.tensor(0.), requires_grad=True)

  3. 使用卷积层处理 feature map:

    theta = self.theta(x)
    phi = F.max_pool2d(self.phi(x), [2,2])
    g = F.max_pool2d(self.g(x), [2,2])
    
  4. 计算 attention map:beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)

  5. 计算 self-attention feature map:o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))

  6. 最后乘与 gamma 并加上原始的 feature map:return self.gamma * o + x

def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5)

Fused Batch Normalization

def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5)

TODO

class myBN(nn.Module)

作者自己实现的 Batch Normalization,支持统计信息。

class ccbn(nn.Module)

Class-conditional Batch Normalization。

class bn(nn.Module)

Normal, non-class-conditional Batch Normalization。

class GBlock(nn.Module)

生成器的 Block。

def forward(self, x, y):
  h = self.activation(self.bn1(x, y))
  if self.upsample:
    h = self.upsample(h)
    x = self.upsample(x)
  h = self.conv1(h)
  h = self.activation(self.bn2(h, y))
  h = self.conv2(h)
  if self.learnable_sc:       
    x = self.conv_sc(x)
  return h + x

class DBlock(nn.Module)

辨别器的残差块。

def shortcut(self, x):
  if self.preactivation:
    if self.learnable_sc:
      x = self.conv_sc(x)
    if self.downsample:
      x = self.downsample(x)
  else:
    if self.downsample:
      x = self.downsample(x)
    if self.learnable_sc:
      x = self.conv_sc(x)
  return x
def forward(self, x):
  if self.preactivation:
    # h = self.activation(x) # NOT TODAY SATAN
    # Andy's note: This line *must* be an out-of-place ReLU or it 
    #              will negatively affect the shortcut connection.
    h = F.relu(x)
  else:
    h = x    
  h = self.conv1(h)
  h = self.conv2(self.activation(h))
  if self.downsample:
    h = self.downsample(h)     
      
  return h + self.shortcut(x)

dataset.py:对数据集进行封装

def is_image_file(filename)

通过判断文件的后缀来判断其是否是图片(首先要转换为全小写字母)。

def find_classes(dir)

遍历目标文件夹中的所有子文件夹,并为其分配一个序号,最后构造一个 key 为子文件夹名字,value 为对应的序号的字典输出。

def make_dataset(dir, class_to_idx)

遍历文件夹中的子文件夹中的所有图片文件,拿到其路径,外加相应的类别标签,作为一个元组输出。

class ImageFolder(data.Dataset)

def __init__()

  1. 首先检查之前保存的 npz 文件是否存在:
    1. 存在,加载该文件中的 numpy 数组。
    2. 不存在,调用 make_dataset 函数,加载好图片数组后将其保存为 npz 文件。
  2. 如果启用了 load_in_mem,就把所有的图片加载到内存中(之前的数据里面只有 path),注意要使用 self.transform 函数进行转换。

def __getitem__()

  1. 检查是否启用了 load_in_mem:
    1. 是,直接获取对应的图片及其标签。
    2. 否,加载图片并应用 self.transform。
  2. 如果 target_transform 函数不是 None,则用该函数对标签进行处理。
  3. 最后返回一个元组:(图片,标签)。

def __len__()

直接返回 imgs 的长度。

class ILSVRC_HDF5(data.Dataset)

与前者类似,主要区别在于这里使用了 h5。

class CIFAR10(torchvision.datasets.CIFAR10)

较为类似,这里仅列出区别点:

  1. 如果 download 为 True,则调用基类的 download 函数。

  2. 调用基类的 _check_integrity() 函数检查完整性。

  3. 使用 pickle 从文件中加载 numpy 数组。

  4. 之后进行以下处理:

    self.data = self.data.reshape((10000, 3, 32, 32))
    self.data = self.data.transpose((0, 2, 3, 1)) 
    

class CIFAR100(CIFAR10)

继承了 CIFAR10,改了一些元信息,有:

  1. 文件夹名称。
  2. 下载路径。
  3. 下载文件名。
  4. 下载文件的 md5.
  5. 还有一个 train_list 以及 test_list,里面存有子目录的名称。

未经允许,禁止转载,本文源站链接:https://iamazing.cn/