StyleGAN 代码分析
概述
我目前对 TensorFlow 还不是很熟,因此这里看的是非官方实现:https://github.com/rosinality/style-based-gan-pytorch
数据处理
StyleGAN 的训练是多阶段的,不同阶段需要不同尺寸的图片,
因此这里处理时将原始图片的大小调整为:8, 16, 32, 64, 128, 256, 512, 1024
,另外为了加快速度读取,格式存为 lmdb 二进制格式。
transforms 里没啥花里胡哨的东西:
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
]
)
训练选项
我们这里列出一些重要的训练选项:
- phase:不同训练阶段所用到的样本数量,默认 600k;
- lr:学习率,默认 0.001;
- sched:是否使用动态调整学习率,README 里对于 FFHQ 用到了这一点,对于 CelebA,则没有;
- init_size:初始图片大小,默认为 8;
- max_size:最大图片大小,默认为 1024;
- mixing:是否使用 mixing 正则化,默认启用;
- loss:对抗损失类型,这里可选的有 wgan-gp 和 r1,默认为 wgan-gp;
- batch_size:批量大小,默认为 16;
- code_size:样式码大小,默认为 512;
- n_critic:每更新多少次辨别器更新一次生成器,默认为 1;
- n_mlp:用以转换 z 到 w 的多层感知机的网络层数,默认为 8;
训练过程
EMA 版本的生成器
首先,这里构造了 ema 版本的生成器,代码如下:
g_running = StyledGenerator(code_size).cuda()
g_running.train(False)
# 注意,这里初始化时指定 decay 为 0,实际上就相当于把参数复制了过来。
accumulate(g_running, generator.module, 0)
# 后来训练时:
accumulate(g_running, generator.module)
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
优化器设置
优化器初始化:
g_optimizer = optim.Adam(
generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
)
g_optimizer.add_param_group(
{
'params': generator.module.style.parameters(),
'lr': args.lr * 0.01,
'mult': 0.01,
}
)
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))
注意,生成器中 style 部分(就是那个用以转换 z 的 MLP,这名字起的。。。)的学习率是设置选项中学习率的 0.01 倍。
如果启用了动态调整学习率,不同图像尺寸对应不同的学习率和 batch size:
# 首先是将 lr 和 batch 设置为字典:
if args.sched:
args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32}
else:
args.lr = {}
args.batch = {}
# 之后修改优化器学习率,注意要乘上 mult。
adjust_lr(g_optimizer, args.lr.get(resolution, 0.001))
adjust_lr(d_optimizer, args.lr.get(resolution, 0.001))
def adjust_lr(optimizer, lr):
for group in optimizer.param_groups:
mult = group.get('mult', 1)
group['lr'] = lr * mult
alpha 相关
alpha = min(1, 1 / args.phase * (used_sample + 1))
if (resolution == args.init_size and args.ckpt is None) or final_progress:
alpha = 1
条件分支:if used_sample > args.phase * 2
TODO
辨别器的损失函数
if args.loss == 'wgan-gp':
real_predict = discriminator(real_image, step=step, alpha=alpha)
real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
(-real_predict).backward()
elif args.loss == 'r1':
real_image.requires_grad = True
real_scores = discriminator(real_image, step=step, alpha=alpha)
real_predict = F.softplus(-real_scores).mean()
real_predict.backward(retain_graph=True)
grad_real = grad(
outputs=real_scores.sum(), inputs=real_image, create_graph=True
)[0]
grad_penalty = (
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
).mean()
grad_penalty = 10 / 2 * grad_penalty
grad_penalty.backward()
fake_image = generator(gen_in1, step=step, alpha=alpha)
fake_predict = discriminator(fake_image, step=step, alpha=alpha)
if args.loss == 'wgan-gp':
fake_predict = fake_predict.mean()
fake_predict.backward()
eps = torch.rand(b_size, 1, 1, 1).cuda()
x_hat = eps * real_image.data + (1 - eps) * fake_image.data
x_hat.requires_grad = True
hat_predict = discriminator(x_hat, step=step, alpha=alpha)
grad_x_hat = grad(
outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
)[0]
grad_penalty = (
(grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
).mean()
grad_penalty = 10 * grad_penalty
grad_penalty.backward()
elif args.loss == 'r1':
fake_predict = F.softplus(fake_predict).mean()
fake_predict.backward()
d_optimizer.step()
这里注意各个损失不是加起来后 backward,而是各自 backward,为啥要这样写,不是没区别么,这样写可读性还差。
生成器的损失函数
if (i + 1) % n_critic == 0:
generator.zero_grad()
requires_grad(generator, True)
requires_grad(discriminator, False)
fake_image = generator(gen_in2, step=step, alpha=alpha)
predict = discriminator(fake_image, step=step, alpha=alpha)
if args.loss == 'wgan-gp':
loss = -predict.mean()
elif args.loss == 'r1':
loss = F.softplus(-predict).mean()
loss.backward()
g_optimizer.step()
accumulate(g_running, generator.module)
requires_grad(generator, False)
requires_grad(discriminator, True)
也没啥好说的。
模型定义
网络层
PixelNorm
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
在通道维度上取均值,用来归一化。
Equalized Learning Rate
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
linear = nn.Linear(in_dim, out_dim)
linear.weight.data.normal_()
linear.bias.data.zero_()
self.linear = equal_lr(linear)
def forward(self, input):
return self.linear(input)
对网络层应用 equal_lr 后,会注册一个钩子,该钩子对网络层的权重做了一个缩放的操作。
The equalized learning rate involves:
- Initializing all weights (linear and conv) from regular normal distribution, no fancy init
- Scaling all weights by the per-layer normalization constant from the Kaiming He initialization.
From: https://github.com/lucidrains/stylegan2-pytorch/issues/112#issuecomment-675204245
注意,下面某人评论:On my experience, equalized learning rate is incredibly important for high resolution training (256 and above).
要注意,这里对权重进行了自定义的初始化,因此不要二次使用别的方式进行初始化。
参考:
- https://github.com/lucidrains/stylegan2-pytorch/issues/112
- https://samaelchen.github.io/pytorch-pggan/
FusedUpsample & FusedDownsample
class FusedUpsample(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, padding=0):
super().__init__()
weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)
bias = torch.zeros(out_channel)
fan_in = in_channel * kernel_size * kernel_size
self.multiplier = sqrt(2 / fan_in)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)
self.pad = padding
def forward(self, input):
weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
weight = (
weight[:, :, 1:, 1:]
+ weight[:, :, :-1, 1:]
+ weight[:, :, 1:, :-1]
+ weight[:, :, :-1, :-1]
) / 4
out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)
return out
class FusedDownsample(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, padding=0):
super().__init__()
weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size)
bias = torch.zeros(out_channel)
fan_in = in_channel * kernel_size * kernel_size
self.multiplier = sqrt(2 / fan_in)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)
self.pad = padding
def forward(self, input):
weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
weight = (
weight[:, :, 1:, 1:]
+ weight[:, :, :-1, 1:]
+ weight[:, :, 1:, :-1]
+ weight[:, :, :-1, :-1]
) / 4
out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)
return out
其中上采样是通过反卷积实现的:F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)
,
下采样则是通过普通的卷积实现的:out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)
。
具体这个所谓的 fused 究竟是干嘛的,我目前还不清楚。
AdaptiveInstanceNorm
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel, style_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channel)
self.style = EqualLinear(style_dim, in_channel * 2)
self.style.linear.bias.data[:in_channel] = 1
self.style.linear.bias.data[in_channel:] = 0
def forward(self, input, style):
style = self.style(style).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, 1)
out = self.norm(input)
out = gamma * out + beta
return out
核心部分,处理流程:
- 首先使用一个 EqualLinear 线性层处理样式码,形状为 (B, in_channel * 2);
- unsqueeze 到 (B, in_channel * 2, 1, 1);
- 之后沿维度 1 分割成两块,分别作为 gamma 和 beta;
- 使用一个标准的 nn.InstanceNorm2d 对输入进行处理(默认 affine 为 False,详见此处);
- 之后再使用第三步得到 gamma 和 beta 对其进行缩放和偏置。
这里和 StarGANv2 的 AdaIN 实现是一致的:
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features * 2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
NoiseInjection
class NoiseInjection(nn.Module):
def __init__(self, channel):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
def forward(self, image, noise):
return image + self.weight * noise
拿来注入噪声,注意该层的权重,仅在通道维度上不同,也就说对于任意图片的第 n 个特征图上的所有值,其用到的权重是相同的一个值。
ConvBlock
class ConvBlock(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
padding,
kernel_size2=None,
padding2=None,
downsample=False,
fused=False,
):
super().__init__()
pad1 = padding
pad2 = padding
if padding2 is not None:
pad2 = padding2
kernel1 = kernel_size
kernel2 = kernel_size
if kernel_size2 is not None:
kernel2 = kernel_size2
self.conv1 = nn.Sequential(
EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),
nn.LeakyReLU(0.2),
)
if downsample:
if fused:
self.conv2 = nn.Sequential(
Blur(out_channel),
FusedDownsample(out_channel, out_channel, kernel2, padding=pad2),
nn.LeakyReLU(0.2),
)
else:
self.conv2 = nn.Sequential(
Blur(out_channel),
EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
nn.AvgPool2d(2),
nn.LeakyReLU(0.2),
)
else:
self.conv2 = nn.Sequential(
EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
nn.LeakyReLU(0.2),
)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
return out
用于辨别器的卷积块,主体就是两个卷积层,卷积层应用了上面提到的 Equalized learning rate 技巧。 由于是辨别器,牵涉到降采样,这里提供了两种方式:
- fused 方式:详见上文;
- pooling 方式(默认方式):这种方式就是在卷积层后加一个 kernel size 为 2 的平均池化层:
nn.AvgPool2d(2)
。
如果需要降采样的话,注意是第二个卷积层负责。
StyledConvBlock
class StyledConvBlock(nn.Module):
def __init__(
self,
in_channel,
out_channel,
kernel_size=3,
padding=1,
style_dim=512,
initial=False,
upsample=False,
fused=False,
):
super().__init__()
if initial:
self.conv1 = ConstantInput(in_channel)
else:
if upsample:
if fused:
self.conv1 = nn.Sequential(
FusedUpsample(
in_channel, out_channel, kernel_size, padding=padding
),
Blur(out_channel),
)
else:
self.conv1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
EqualConv2d(
in_channel, out_channel, kernel_size, padding=padding
),
Blur(out_channel),
)
else:
self.conv1 = EqualConv2d(
in_channel, out_channel, kernel_size, padding=padding
)
self.noise1 = equal_lr(NoiseInjection(out_channel))
self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
self.lrelu1 = nn.LeakyReLU(0.2)
self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
self.noise2 = equal_lr(NoiseInjection(out_channel))
self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
self.lrelu2 = nn.LeakyReLU(0.2)
def forward(self, input, style, noise):
out = self.conv1(input)
out = self.noise1(out, noise)
out = self.lrelu1(out)
out = self.adain1(out, style)
out = self.conv2(out)
out = self.noise2(out, noise)
out = self.lrelu2(out)
out = self.adain2(out, style)
return out
主要架构:
- 第一个卷积层,如果需要上采样的话,由该层负责;
- 第一个噪声注入层,拿来注入随机因素;
- LeakyReLU 激活;
- 第一个自适应实例归一化层,拿来注入样式;
- 第二个卷积层;
- 第二个噪声注入层;
- LeakyReLU 激活;
- 第二个自适应实例归一化层;
注意这个顺序,按我的理解是,噪声注入层需要在卷积层生成的特征图上进行噪声注入,之后进行激活, 再之后进行自适应实例归一化。
注意,这里和 StarGANv2 中的 AdainResBlk 的实现有点区别,其实现如下:
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
也就是这里的顺序是先归一化,之后激活,再之后才是卷积层。
生成器
先放实现:
class Generator(nn.Module):
def __init__(self, code_dim, fused=True):
super().__init__()
self.progression = nn.ModuleList(
[
StyledConvBlock(512, 512, 3, 1, initial=True), # 4
StyledConvBlock(512, 512, 3, 1, upsample=True), # 8
StyledConvBlock(512, 512, 3, 1, upsample=True), # 16
StyledConvBlock(512, 512, 3, 1, upsample=True), # 32
StyledConvBlock(512, 256, 3, 1, upsample=True), # 64
StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128
StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256
StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512
StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024
]
)
self.to_rgb = nn.ModuleList(
[
EqualConv2d(512, 3, 1),
EqualConv2d(512, 3, 1),
EqualConv2d(512, 3, 1),
EqualConv2d(512, 3, 1),
EqualConv2d(256, 3, 1),
EqualConv2d(128, 3, 1),
EqualConv2d(64, 3, 1),
EqualConv2d(32, 3, 1),
EqualConv2d(16, 3, 1),
]
)
# self.blur = Blur()
def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
out = noise[0]
if len(style) < 2:
inject_index = [len(self.progression) + 1]
else:
inject_index = sorted(random.sample(list(range(step)), len(style) - 1))
crossover = 0
for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
if mixing_range == (-1, -1):
if crossover < len(inject_index) and i > inject_index[crossover]:
crossover = min(crossover + 1, len(style))
style_step = style[crossover]
else:
if mixing_range[0] <= i <= mixing_range[1]:
style_step = style[1]
else:
style_step = style[0]
if i > 0 and step > 0:
out_prev = out
out = conv(out, style_step, noise[i])
if i == step:
out = to_rgb(out)
if i > 0 and 0 <= alpha < 1:
skip_rgb = self.to_rgb[i - 1](out_prev)
skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
out = (1 - alpha) * skip_rgb + alpha * out
break
return out
网络主体是 9 个样式卷积块(分别对应 2^2 ~ 2^10,一共 9 个),通道数最大 512。 to_rgb 层则是针对每一个分辨率配一个卷积层,注意后面不接激活函数。
与 MLP 合并后的生成器
class StyledGenerator(nn.Module):
def __init__(self, code_dim=512, n_mlp=8):
super().__init__()
self.generator = Generator(code_dim)
layers = [PixelNorm()]
for i in range(n_mlp):
layers.append(EqualLinear(code_dim, code_dim))
layers.append(nn.LeakyReLU(0.2))
self.style = nn.Sequential(*layers)
def forward(
self,
input,
noise=None,
step=0,
alpha=-1,
mean_style=None,
style_weight=0,
mixing_range=(-1, -1),
):
styles = []
if type(input) not in (list, tuple):
input = [input]
for i in input:
styles.append(self.style(i))
batch = input[0].shape[0]
if noise is None:
noise = []
for i in range(step + 1):
size = 4 * 2 ** i
noise.append(torch.randn(batch, 1, size, size, device=input[0].device))
if mean_style is not None:
styles_norm = []
for style in styles:
styles_norm.append(mean_style + style_weight * (style - mean_style))
styles = styles_norm
return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)
def mean_style(self, input):
style = self.style(input).mean(0, keepdim=True)
return style
首先注意到这个 MLP 输入之前做了 PixelNorm,之后 6 层线性层配 LeakyReLU(0.2)。
除此之外就那个 mean_style 还值得说一下,就是对生成的样式码做了一下处理:
# 生成 mean_style:
mean_style = None
for i in range(10):
style = generator.mean_style(torch.randn(1024, 512).to(device))
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
# 其中 generator.mean_style 方法的定义:
def mean_style(self, input):
style = self.style(input).mean(0, keepdim=True)
return style
# 最后使用时:
styles_norm.append(mean_style + style_weight * (style - mean_style))
要注意这个 mean_style 是启用了的。
辨别器
class Discriminator(nn.Module):
def __init__(self, fused=True, from_rgb_activate=False):
super().__init__()
self.progression = nn.ModuleList(
[
ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512
ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256
ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128
ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64
ConvBlock(256, 512, 3, 1, downsample=True), # 32
ConvBlock(512, 512, 3, 1, downsample=True), # 16
ConvBlock(512, 512, 3, 1, downsample=True), # 8
ConvBlock(512, 512, 3, 1, downsample=True), # 4
ConvBlock(513, 512, 3, 1, 4, 0),
]
)
def make_from_rgb(out_channel):
if from_rgb_activate:
return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))
else:
return EqualConv2d(3, out_channel, 1)
self.from_rgb = nn.ModuleList(
[
make_from_rgb(16),
make_from_rgb(32),
make_from_rgb(64),
make_from_rgb(128),
make_from_rgb(256),
make_from_rgb(512),
make_from_rgb(512),
make_from_rgb(512),
make_from_rgb(512),
]
)
# self.blur = Blur()
self.n_layer = len(self.progression)
self.linear = EqualLinear(512, 1)
def forward(self, input, step=0, alpha=-1):
for i in range(step, -1, -1):
index = self.n_layer - i - 1
if i == step:
out = self.from_rgb[index](input)
if i == 0:
out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
mean_std = out_std.mean()
mean_std = mean_std.expand(out.size(0), 1, 4, 4)
out = torch.cat([out, mean_std], 1)
out = self.progression[index](out)
if i > 0:
if i == step and 0 <= alpha < 1:
skip_rgb = F.avg_pool2d(input, 2)
skip_rgb = self.from_rgb[index + 1](skip_rgb)
out = (1 - alpha) * skip_rgb + alpha * out
out = out.squeeze(2).squeeze(2)
# print(input.size(), out.size(), step)
out = self.linear(out)
return out
网络主体也是 9 个卷积块,然后 from_rgb 也是每个分辨率一个卷积层,注意后面要接激活。
要注意,网络主体的最后一个卷积块的输入维度是 513,这是因为最后一层输入时,并上了特征图的平均标准差:
out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
mean_std = out_std.mean()
mean_std = mean_std.expand(out.size(0), 1, 4, 4)
out = torch.cat([out, mean_std], 1)
在最后,用了一个线性层做最后的输出:
self.linear = EqualLinear(512, 1)
out = self.linear(out)
吐槽环节
- 多处硬编码 option;
- if 和 else 之间莫名其妙加空行。
- 保存权重时保存成单个文件,这样不利于权重文件的转移。
- 损失函数 backward 时没有先加到一起再 backward,可读性差。
Links: stylegan-code-analysis