TAC-GAN Code Analysis
概述
我们这里分析的代码是基于 BigGAN 的版本,https://github.com/batmanlab/twin-auxiliary-classifiers-gan。
考虑到很多内容与原始的 BigGAN 的 PyTorch 实现一致,因此这里着重讲差别之处,相同的地方将一笔带过。
train.py
完全一致,除了多了几个没有用的 import 以及删掉了一个训练日志的输出。
BigGAN.py
生成器的架构与 BigGAN 中的完全一致。
class Discriminator(nn.Module)
def __init__()
多了一个
AC=False
的参数。为分类器提供线性层:
if self.AC: self.linear_mi = self.which_linear(self.arch['out_channels'][-1], n_classes) self.linear_c = self.which_linear(self.arch['out_channels'][-1], n_classes)
def choose_prob(self, prob, y)
TODO:暂时不清楚用途,L386
def forward(self, x, y=None)
原本为:
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
换为:
if self.AC:
out_mi = self.linear_mi(h)
out_c = self.linear_c(h)
return out, out_mi, out_c
else:
proj_c = self.linear_c(h)
if y is not None:
proj = self.choose_prob(proj_c, y)
out += proj
return out,proj_c,proj_c
注意这两个分类器仅仅是加了一层线性层构造而来。
TODO:下面的 proj_c 估计是 Projection cGAN 所用到的分类器。
class G_D(nn.Module)
相比 BigGAN 中的辨别器,现在我们的辨别器模型会同时输出两个分类器的预测值,因此这里做了一些适配的改动:
if split_D:
D_fake, mi_f, c_cls_f = self.D(G_z, gy)
if x is not None:
D_real,mi_r, c_cls_r = self.D(x, dy)
return D_fake, D_real, torch.cat([mi_f,mi_r],dim=0), torch.cat([c_cls_f,c_cls_r],dim=0)
else:
if return_G_z:
return D_fake, G_z
else:
return D_fake, mi_f, c_cls_f
# If real data is provided, concatenate it with the Generator's output
# along the batch dimension for improved efficiency.
else:
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
# Get Discriminator output
D_out, mi, cls = self.D(D_input, D_class)
if x is not None:
return D_out[:G_z.shape[0]], D_out[G_z.shape[0]:],mi, cls # D_fake, D_real
else:
if return_G_z:
return D_out, G_z, mi, cls
else:
return D_out, mi, cls
其余与原版一致。
train_fns.py
def hinge_multi(prob,y,hinge=True)
prob 是预测的图片类别,y 是真实的图片类别,该函数即合页损失函数,当 hinge=False 时将不再限制 loss 必须大于 0 。
def hinge_multi(prob,y,hinge=True):
len = prob.size()[0]
index_list = [[],[]]
for i in range(len):
index_list[0].append(i)
index_list[1].append(np.asscalar(y[i].cpu().detach().numpy()))
prob_choose = prob[index_list]
prob_choose = (prob_choose.squeeze()).unsqueeze(dim=1)
if hinge == True:
loss = ((1-prob_choose+prob).clamp(min=0)).mean()
else:
loss = (1-prob_choose+prob).mean()
return loss
def GAN_training_function()
该函数依据不同的参数选择不同的损失函数计算 loss,并返回以下 loss:
- G_loss
- D_loss_real
- D_loss_fake
- C_loss
- MI_loss
utils.py
- 在 prepare_parser 中加入了以下项:
- loss_type,默认值为 Twin_AC,可选项有:
- Twin_AC:使用交叉熵损失函数(对于 AC 和 TAC)。
- Twin_AC_M:使用 hinge 损失函数(对于 AC 和 TAC)。
- AC:使用交叉熵损失函数(仅对于 AC)。
- AC,默认值为 False:是否启用辅助的分类器。
- AC_weight,默认值为 1.0:这个用来调节分类器所带来的损失的比重。
- num_G_steps,默认值为 2:每次训练时 G 训练的次数。
- num_MI_steps,默认值为 2:并没有被用到。
- num_MI_accumulations,默认值为 1:并没有被用到。
- loss_type,默认值为 Twin_AC,可选项有:
- 添加了对数据集 VGGFACE 的支持。
Links: TAC-GAN-Code-Analysis