BigGAN Code Analysis I
utils.py:各种工具函数
- 解析命令行参数。
- 准备 dataloader 以及配套的图片处理函数。
- EMA 的实现。
- ortho 的实现。
- 用于记录和输出训练信息的 logger。
- Adam 的实现。
- 其他工具函数:
- 切换 grad 状态。
- 保存权重。
- 加载权重。
- 自定义的进度条。
- 根据配置信息来给实验命名的函数。
- 设置所有的随机数种子。
train.py:训练网络的启动脚本
- 补充 config 的内容,例如根据数据库来设置类别个数。
- 设置随机数种子。
- 准备目录。
- 导入模型,并进行构建:G,D 以及 GD。
- 检查是否是从之前的存档恢复,如果是,加载相应的模型参数(通过传入的实验名称来找到相应的目录)。
- 如果是并行训练,则进行准备。
- 准备 logger。
- 写入实验的 meta 数据。
- 准备 data loader。
- 准备用于测量 IS 和 FID 的工具函数。
- 准备用于取样 z 和 y 的取样函数。
- 准备一组固定的 z 和 y 用于观察在其在训练过程中的变化。
- 准备训练函数:
train = train_fns.GAN_training_function()
。 - 准备用于给测量 IS 和 FID 的函数提供 z 和 y 的取样函数。
- 对于每一个 epoch:
- 准备用于实时显示训练进度的 progress bar。
- 对于其中的每一个 iteration:
- 从 data loader 中取出一个 batch 的训练数据(x,y),以及 batch 序号(i)。
- 更新训练状态字典中的 itr 字段(记录迭代次数)。
- 训练 G(
G.train()
)。 - 训练 D(
D.train()
)。 - 如果启用了 EMA,训练 G_ema(
G_ema.train()
)。 - 将训练数据加载到目标设备中。
- 调用之前准备的训练函数,拿到 metrics。
- 调用之前准备的 train logger,输出奇异值(调用
utils.get_SVs
拿到奇异值)。 - 检查本次迭代是否应该保存模型权重,如果是,调用
train_fns.save_and_sample()
。 - 检查本次迭代是否应该测试模型性能,如果是,将 G 切换到 eval 模式,调用
train_fns.test()
。
- 更新训练状态字典中的 epoch 字段(记录完成的 epoch 数目)。
BigGAN.py:模型定义
函数 G_arch()
用于返回 arch 字典,key 是 resolution,value 则是对应的具体的模型超参数,有:
- 生成器各个 block 的输入通道数:in_channels。
- 生成器各个 block 的输出通道数:out_channels。
- 每一层是否是上采样:upsample。
- 每一层的分辨率:resolution。
- 每一层是否使用注意力机制:attention。
类 Generator(nn.Module)
函数__init__()
设置通道宽度倍增算子:ch。
设置隐空间维度个数:dim_z。
设置最底层的通道尺寸:bottom_width。
设置生成器输出图片的分辨率:resolution(这里生成的图片只能是方形图片)。
设置卷积层的卷积核尺寸:kernel_size。
设置在哪几层应用注意力机制:attention。
基于之前设置的通道宽度倍增算子 ch 和决定在哪几层使用注意力机制的 attention 以及函数
G_arch()
来取得一个存有模型架构的字典:arch。设置在基于类别的图片生成中所需要的类别个数:n_classes。
设置是否启用分层的 latent space:hier
- 启用:设置 num_slots 为 arch 中的 in_channels 数目加一,再基于此计算每个 slot 分到的 z_chunk_size,注意这里要保证可以整除。
- 禁用:设置 num_slots 为 1,z_chunk_size 为 0。
TODO:设置是否启用跨样本的 Batch Normalization:cross_replica。
设置是否使用作者提供的 Batch Normalization:mybn。
设置残差块的激活函数:activation。
设置参数初始化方式:init。
设置 parameterization 方式:G_param
- 如果是 Spectral Normalization(SN):分别设置 which_conv 和 which_linear 为 layers.SNConv2d 和 layers.SNLinear。
- 否则:分别设置 which_conv 和 which_linear 为 nn.Conv2d (用上之前的 kernel_size,并设置 padding 为 1)和 nn.Linear。
设置卷积层的归一化方式:norm_style。
设置 Batch Normalization 的参数 ε:BN_eps。
设置 Spectral Normalization 的参数 ε:SN_eps。
设置是否使用 16 位的浮点数进行运算(速度更快但需要指定硬件):fp16。
设置共享 embedding 的维度:shared_dim。
设置是否启用共享 embedding:G_shared
self.which_embedding = nn.Embedding bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding) self.which_bn = functools.partial(layers.ccbn, which_linear=bn_linear, cross_replica=self.cross_replica, mybn=self.mybn, input_size=(self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes), norm_style=self.norm_style, eps=self.BN_eps) self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity())
设置模型开始处的 linear layer:
self.linear = self.which_linear(self.dim_z // self.num_slots, self.arch['in_channels'][0] * (self.bottom_width**2))
构造 blocks
self.blocks = [] for index in range(len(self.arch['out_channels'])): # 为什么这里要加上一个只含一个元素的列表?因为考虑到可能后面还要加 attention layer。 self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=2) if self.arch['upsample'][index] else None))]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
构造输出层:
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), self.activation, self.which_conv(self.arch['out_channels'][-1], 3))
是否需要初始化参数(skip_init):
- 需要,调用
self.init_weights()
; - 不需要,不调用上述函数,注意 G_ema 不需要初始化参数。
- 需要,调用
是否需要需要设置优化器(no_optime):
需要,设置 Adam 需要的四个超参数:
- 学习率:lr。
- β1:B1。
- β2:B2。
- ε:adam_eps。
并检查是否 train with half-precision activations but fp32 params in G:
- 是,使用 utils.Adam16。
- 否,使用 torch.optim.Adam。
不需要,结束,注意 G_ema 不需要设置优化器。
函数init_weights()
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
函数 forward()
判断是否使用了分层的 latent space(hier):
是:
zs = torch.split(z, self.z_chunk_size, 1) z = zs[0] ys = [torch.cat([y, item], 1) for item in zs[1:]]
否:
ys = [y] * len(self.blocks)
将 z 传入 the first linear layer:
h = self.linear(z)
,并 reshape:h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
。对 blocks 中每一个 block 进行遍历:
for index, blocklist in enumerate(self.blocks): # Second inner loop in case block has multiple layers for block in blocklist: h = block(h, ys[index])
将 h 传入到输出层,并对输出应用 torch.tanh,之后输出:
torch.tanh(self.output_layer(h))
。
函数 D_arch()
与 G_arch()
的区别仅在于 upsample 换成了 downsample。
类 Discriminator(nn.Module)
函数__init__()
设置通道宽度倍增算子:ch。
设置是否使用 wide D:D_wide。
设置生成器输出图片的分辨率:resolution。
设置卷积层的卷积核尺寸:kernel_size。
设置在哪几层应用注意力机制:attention。
设置类别个数:n_classes。
设置残差块的激活函数:activation。
取得存有模型架构的字典:arch。
设置参数初始化方式:init。
设置 Spectral Normalization 的参数 ε:SN_eps。
设置是否使用 16 位的浮点数进行运算(速度更快但需要指定硬件):fp16。
设置 parameterization 方式:D_param
if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # 注意没有 else 哦,但是后面的代码又用到了 which_conv,因此可以推测作者代码没写完,我们这里目前必须使用 SN。
构造 blocks:
self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
设置模型结束处的 linear layer:
self.which_linear(self.arch['out_channels'][-1], output_dim)
。设置用于 projection discrimination 的 embedding(embed):
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
。是否需要初始化参数(skip_init):
- 需要,调用
self.init_weights()
; - 不需要,不调用上述函数。
- 需要,调用
设置优化器:与 Generator 中的一致。
函数init_weights()
与 Generator 中的完全一致。
函数 forward()
对 blocks 中的每一个 block 进行遍历:
for index, blocklist in enumerate(self.blocks): for block in blocklist: h = block(h)
应用 global sum pooling:
h = torch.sum(self.activation(h), [2, 3])
。传入到线性层:
out = self.linear(h)
。Get projection of final featureset onto class vectors and add to evidence:
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
。
类 G_D(nn.Module)
Parallelized G_D to minimize cross-gpu communication
Without this, Generator outputs would get all-gathered and then rebroadcast.
函数 __init__()
- 设置 self.G = G。
- 设置 self.D = D。
函数 forward()
参数说明:
- z:生成器需要的噪声向量。
- gy:生成器需要的类别。
- x=None:真实样本。
- dy=None:真实样本的类别。
- train_G=False:是否对 G 进行训练。
- return_G_z=False:是否返回 G 生成的样本。
- split_D=False:是否是将真实数据和生成数据分开输入到辨别器中。
函数执行流程:
生成假样本:
with torch.set_grad_enabled(train_G): # Get Generator output given noise G_z = self.G(z, self.G.shared(gy)) # Cast as necessary if self.G.fp16 and not self.D.fp16: G_z = G_z.float() if self.D.fp16 and not self.G.fp16: G_z = G_z.half()
是否是将真实数据和生成数据分开输入到辨别器中(split_D):
是,首先传入 G 生成的数据:
D_fake = self.D(G_z, gy)
,之后检查是否提供了真实样本:- 是,
D_real = self.D(x, dy)
,之后return D_fake, D_real
。 - 否,检查是否需要将 G 生成的数据也返回:
- 是,
return D_fake, G_z
。 - 否,
return D_fake
。
- 是,
- 是,
否,将真实样本和生成样本合并之后传入 D 让其辨别:
D_input = torch.cat([G_z, x], 0) if x is not None else G_z D_class = torch.cat([gy, dy], 0) if dy is not None else gy # Get Discriminator output D_out = self.D(D_input, D_class)
之后检查是否提供了真实样本:
- 是,
D_real = self.D(x, dy)
,之后return torch.split(D_out, [G_z.shape[0], x.shape[0]])
(就是 D_fake, D_real)。 - 否,检查是否需要将 G 生成的数据也返回:
- 是,
return D_out, G_z
。 - 否,
return D_out
。
- 是,
- 是,
Links: BigGAN-Code-Analysis-I