TAC-GAN Code Analysis

标签: 代码实现解读 GAN 发布于:2020-11-07 22:01:42 编辑于:2021-09-13 15:49:12 浏览量:1822

概述

我们这里分析的代码是基于 BigGAN 的版本,https://github.com/batmanlab/twin-auxiliary-classifiers-gan。

考虑到很多内容与原始的 BigGAN 的 PyTorch 实现一致,因此这里着重讲差别之处,相同的地方将一笔带过。

train.py

完全一致,除了多了几个没有用的 import 以及删掉了一个训练日志的输出。

BigGAN.py

生成器的架构与 BigGAN 中的完全一致。

class Discriminator(nn.Module)

def __init__()

  1. 多了一个 AC=False 的参数。

  2. 为分类器提供线性层:

    if self.AC:
        self.linear_mi = self.which_linear(self.arch['out_channels'][-1], n_classes)
    self.linear_c = self.which_linear(self.arch['out_channels'][-1], n_classes)
    

def choose_prob(self, prob, y)

TODO:暂时不清楚用途,L386

def forward(self, x, y=None)

原本为:

out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)

换为:

if self.AC:
    out_mi = self.linear_mi(h)
    out_c = self.linear_c(h)
    return out, out_mi, out_c
else:
    proj_c = self.linear_c(h)
    if y is not None:
        proj = self.choose_prob(proj_c, y)
        out += proj
        return out,proj_c,proj_c

注意这两个分类器仅仅是加了一层线性层构造而来。

TODO:下面的 proj_c 估计是 Projection cGAN 所用到的分类器。

class G_D(nn.Module)

相比 BigGAN 中的辨别器,现在我们的辨别器模型会同时输出两个分类器的预测值,因此这里做了一些适配的改动:

    if split_D:
      D_fake, mi_f, c_cls_f = self.D(G_z, gy)
      if x is not None:
        D_real,mi_r, c_cls_r = self.D(x, dy)
        return D_fake, D_real, torch.cat([mi_f,mi_r],dim=0), torch.cat([c_cls_f,c_cls_r],dim=0)
      else:
        if return_G_z:
          return D_fake, G_z
        else:
          return D_fake, mi_f, c_cls_f
    # If real data is provided, concatenate it with the Generator's output
    # along the batch dimension for improved efficiency.
    else:
      D_input = torch.cat([G_z, x], 0) if x is not None else G_z
      D_class = torch.cat([gy, dy], 0) if dy is not None else gy
      # Get Discriminator output
      D_out, mi, cls = self.D(D_input, D_class)
      if x is not None:
        return D_out[:G_z.shape[0]], D_out[G_z.shape[0]:],mi, cls # D_fake, D_real
      else:
        if return_G_z:
          return D_out, G_z, mi, cls
        else:
          return D_out, mi, cls

其余与原版一致。

train_fns.py

def hinge_multi(prob,y,hinge=True)

prob 是预测的图片类别,y 是真实的图片类别,该函数即合页损失函数,当 hinge=False 时将不再限制 loss 必须大于 0 。

def hinge_multi(prob,y,hinge=True):

    len = prob.size()[0]

    index_list = [[],[]]

    for i in range(len):
        index_list[0].append(i)
        index_list[1].append(np.asscalar(y[i].cpu().detach().numpy()))

    prob_choose = prob[index_list]
    prob_choose = (prob_choose.squeeze()).unsqueeze(dim=1)

    if hinge == True:
        loss = ((1-prob_choose+prob).clamp(min=0)).mean()
    else:
        loss = (1-prob_choose+prob).mean()

    return loss

def GAN_training_function()

该函数依据不同的参数选择不同的损失函数计算 loss,并返回以下 loss:

  1. G_loss
  2. D_loss_real
  3. D_loss_fake
  4. C_loss
  5. MI_loss

utils.py

  1. 在 prepare_parser 中加入了以下项:
    1. loss_type,默认值为 Twin_AC,可选项有:
      1. Twin_AC:使用交叉熵损失函数(对于 AC 和 TAC)。
      2. Twin_AC_M:使用 hinge 损失函数(对于 AC 和 TAC)。
      3. AC:使用交叉熵损失函数(仅对于 AC)。
    2. AC,默认值为 False:是否启用辅助的分类器。
    3. AC_weight,默认值为 1.0:这个用来调节分类器所带来的损失的比重。
    4. num_G_steps,默认值为 2:每次训练时 G 训练的次数。
    5. num_MI_steps,默认值为 2:并没有被用到。
    6. num_MI_accumulations,默认值为 1:并没有被用到。
  2. 添加了对数据集 VGGFACE 的支持。

未经允许,禁止转载,本文源站链接:https://iamazing.cn/