BigGAN Code Analysis II
train_fns.py:模型训练函数,测试函数以及保存取样函数
函数 dummy_training_function()
用于 debug。
def dummy_training_function():
def train(x, y):
return {}
return train
函数GAN_training_function()
参数解释:
- G:生成器模型。
- D:辨别器模型。
- GD:G 和 D 的联合模型。
- z_:取样器。
- y_:取样器。
- ema:一个 utils.ema 对象。
- state_dict:存储本次训练的元信息的字典。
- config:从命令行参数解析来的字典对象。
内部定义并返回了一个 train 函数:train(x, y)
:
首先把 G 和 D 的参数的梯度清零:
G.optim.zero_grad() D.optim.zero_grad()
把输入样本及其标签分割成一些列 batch:
x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size'])
设置 counter=0,其记录了当前用到的 batch 序号。
如果设置了 toggle_grads 为 True(默认启用),则:
utils.toggle_grad(D, True) utils.toggle_grad(G, False)
训练辨别器 num_D_steps 次:
清零 D 的参数梯度:
D.optim.zero_grad()
。累加 num_D_accumulations 次梯度:
(我现在也不清楚为啥这个能 work)
z_.sample_() y_.sample_()
传入 GD,拿到 D_fake (D 对生成样本的预测)以及 D_real(D 对真实样本的预测):
D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D'])
传入辨别器的损失函数拿到 loss:
D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
之后合并两个损失,并除以 num_D_accumulations:
D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])
调用 backward 计算梯度:
D_loss.backward()
。更新 batch 计数器:
counter += 1
。
如果 D_ortho > 0,则应用 Orthogonal Regularization:
utils.ortho(D, config['D_ortho'])
:调用优化器的 step 更新辨别器模型参数:
D.optim.step()
。
如果设置了 toggle_grads 为 True(默认启用),则:
utils.toggle_grad(D, False) utils.toggle_grad(G, True)
清零 G 的梯度(保险起见):
G.optim.zero_grad()
。累加 num_G_accumulations 次梯度:
取样:
z_.sample_() y_.sample_()
传入 GD,拿到 D_fake:
GD(z_, y_, train_G=True, split_D=config['split_D'])
。传入生成器的损失函数拿到 loss:
G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
。调用 backward 计算梯度:
G_loss.backward()
。
如果 D_ortho > 0,则应用 Orthogonal Regularization:
utils.ortho(D, config['D_ortho'])
:调用优化器的 step 更新生成器模型参数:
G.optim.step()
。如果我们启用了 ema,我们需要更新:
ema.update(state_dict['itr'])
。输出是一个字典,其包含了辨别器分别在真实样本 & 生成样本上的损失以及生成器的损失:
out = {'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item())}
函数 save_and_sample()
- 调用函数 utils.save_weights 保存以下内容:
- 生成器参数。
- 生成器的优化器的参数。
- 辨别器的参数。
- 辨别器的优化器的参数。
- EMA 版本的生成器的参数。
- 训练的元信息。
- 生成固定 z 和 y 的图片供对比。
- 生成随机取样的图片。
- 生成三种设置的插值图片。
函数 test()
- 调用传入的 get_inception_metrics 函数计算 IS 和 FID。
- 如果测试时发现模型表现比之前最好的还要好,就保存此时的权重,注意保存多个历史版本,并更新当前的最佳值。
- 并输出相应的 log。
layers.py:定义了 BigGAN 中所需要的众多网络层
class SN(object)
- 使用 power_iteration 方法估算 w 的最大的奇异值。
self.weight / svs[0]
class SNConv2d(nn.Conv2d, SN)
通过多继承的方式来构造我们的启用了谱归一化的二维卷积层。
其 forward 函数封装了 F.conv2d:
def forward(self, x): return F.conv2d(x, self.W_(), self.bias, self.stride, self.padding, self.dilation, self.groups)
class SNLinear(nn.Linear, SN),class SNEmbedding(nn.Embedding, SN)
- 同上。
- 注意 nn.Embedding 就是一个简单的查找表。
class Attention(nn.Module)
初始化四个卷积层:
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
初始化一个可训练的超参数:
self.gamma = P(torch.tensor(0.), requires_grad=True)
。使用卷积层处理 feature map:
theta = self.theta(x) phi = F.max_pool2d(self.phi(x), [2,2]) g = F.max_pool2d(self.g(x), [2,2])
计算 attention map:
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
。计算 self-attention feature map:
o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
。最后乘与 gamma 并加上原始的 feature map:
return self.gamma * o + x
。
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5)
Fused Batch Normalization
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5)
TODO
class myBN(nn.Module)
作者自己实现的 Batch Normalization,支持统计信息。
class ccbn(nn.Module)
Class-conditional Batch Normalization。
class bn(nn.Module)
Normal, non-class-conditional Batch Normalization。
class GBlock(nn.Module)
生成器的 Block。
def forward(self, x, y):
h = self.activation(self.bn1(x, y))
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = self.activation(self.bn2(h, y))
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
class DBlock(nn.Module)
辨别器的残差块。
def shortcut(self, x):
if self.preactivation:
if self.learnable_sc:
x = self.conv_sc(x)
if self.downsample:
x = self.downsample(x)
else:
if self.downsample:
x = self.downsample(x)
if self.learnable_sc:
x = self.conv_sc(x)
return x
def forward(self, x):
if self.preactivation:
# h = self.activation(x) # NOT TODAY SATAN
# Andy's note: This line *must* be an out-of-place ReLU or it
# will negatively affect the shortcut connection.
h = F.relu(x)
else:
h = x
h = self.conv1(h)
h = self.conv2(self.activation(h))
if self.downsample:
h = self.downsample(h)
return h + self.shortcut(x)
dataset.py:对数据集进行封装
def is_image_file(filename)
通过判断文件的后缀来判断其是否是图片(首先要转换为全小写字母)。
def find_classes(dir)
遍历目标文件夹中的所有子文件夹,并为其分配一个序号,最后构造一个 key 为子文件夹名字,value 为对应的序号的字典输出。
def make_dataset(dir, class_to_idx)
遍历文件夹中的子文件夹中的所有图片文件,拿到其路径,外加相应的类别标签,作为一个元组输出。
class ImageFolder(data.Dataset)
def __init__()
- 首先检查之前保存的 npz 文件是否存在:
- 存在,加载该文件中的 numpy 数组。
- 不存在,调用 make_dataset 函数,加载好图片数组后将其保存为 npz 文件。
- 如果启用了 load_in_mem,就把所有的图片加载到内存中(之前的数据里面只有 path),注意要使用 self.transform 函数进行转换。
def __getitem__()
- 检查是否启用了 load_in_mem:
- 是,直接获取对应的图片及其标签。
- 否,加载图片并应用 self.transform。
- 如果 target_transform 函数不是 None,则用该函数对标签进行处理。
- 最后返回一个元组:(图片,标签)。
def __len__()
直接返回 imgs 的长度。
class ILSVRC_HDF5(data.Dataset)
与前者类似,主要区别在于这里使用了 h5。
class CIFAR10(torchvision.datasets.CIFAR10)
较为类似,这里仅列出区别点:
如果 download 为 True,则调用基类的 download 函数。
调用基类的 _check_integrity() 函数检查完整性。
使用 pickle 从文件中加载 numpy 数组。
之后进行以下处理:
self.data = self.data.reshape((10000, 3, 32, 32)) self.data = self.data.transpose((0, 2, 3, 1))
class CIFAR100(CIFAR10)
继承了 CIFAR10,改了一些元信息,有:
- 文件夹名称。
- 下载路径。
- 下载文件名。
- 下载文件的 md5.
- 还有一个 train_list 以及 test_list,里面存有子目录的名称。
Links: BigGAN-Code-Analysis-II