pixel2style2pixel(pSp)代码实现解读
概述
原本的代码实现:https://github.com/eladrich/pixel2style2pixel
我的修改版本:
- https://github.com/justsong-lab/pixel2style2pixel
- https://github.com/justsong-lab/pixel2style2pixel-v2
pSp 的实现
样式码输出的个数
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
这样对于 1024 的输入,样式码个数为 18,对于 256 的输入,样式码个数为 14。
GradualStyleEncoder 的实现
这个就是论文中用到的 encoder。
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
论文中说是 ResNet-IR,pretrained on face recognition,但我硬是找不到 ResNet-IR 的官方架构,淦。
代码中又给链接,实现是从这里拿的。
所谓的 se 是指的是在构建网络的基础块的主干分支末尾加一个 SEModule:
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
这里的关键点在于是怎么从不同 layer 拿出 latent code 的,直接看 forward 函数就好了:
def forward(self, x):
x = self.input_layer(x)
latents = []
modulelist = list(self.body._modules.values())
for i, l in enumerate(modulelist):
x = l(x)
if i == 6:
c1 = x
elif i == 20:
c2 = x
elif i == 23:
c3 = x
for j in range(self.coarse_ind):
latents.append(self.styles[j](c3))
p2 = self._upsample_add(c3, self.latlayer1(c2))
for j in range(self.coarse_ind, self.middle_ind):
latents.append(self.styles[j](p2))
p1 = self._upsample_add(p2, self.latlayer2(c1))
for j in range(self.middle_ind, self.style_count):
latents.append(self.styles[j](p1))
out = torch.stack(latents, dim=1)
return out
大致流程就是,
- 首先从指定 layer (第 6,20 和 23 层)拿中间的特征图,
- 对于 1024 的输入,有 18 个样式码要输出,这 18 个样式码分别有一个对应的实例化的 GradualStyleBlock, 依据三个层次的 coarse 程度,越往后这个 GradualStyleBlock 中的卷积层个数就越多。 三个特征图,分别送到对应的 GradualStyleBlock 中,每一个都对应多个。之后我们就得到了要的 latent code。
BackboneEncoderUsingLastLayerIntoW & BackboneEncoderUsingLastLayerIntoWPlus 的实现
前者直接把主干网络的输出 resize 成 512 输入到 w 空间,后者加了一层全连接层,之后 resize 成 18 * 512 输入到 W+ 空间。
forward 函数的逻辑
首先把输入图片直接扔到 encoder 里拿到 latent code。
如果选择了 start_from_latent_avg
,则 code 加上之前的平均 latent code。
如果要用 style mixing,这里则借助一个 latent_mask
和 inject_latent
mix 起来,
if alpha is not None:
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
codes[:, i] = inject_latent[:, i]
损失函数的实现
MSE 重构损失
约束 inversion 的结果要与目标尽量接近。
loss_l2 = F.mse_loss(y_hat, y)
具体使用时可选是否要 corp:
loss_l2_crop = F.mse_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])
LPIPS 损失
类似感知损失,或者说就是感知损失?
self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
我看了一下应该是从直接从这里把代码拷贝了过来。
具体使用时可选是否要 corp。
loss_lpips_crop = self.lpips_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])
ID 损失
输入为:
- x:要 invert 的图像,使用 source_transform 进行预处理,即网络的输入。
- y:要 invert 的图像,使用 target_transform 进行预处理,即网络的目标输出。
- y_hat:invert 后的生成图像。
为什么要这样搞呢?实际上对于 GAN inversion 确实是多余的,但是对于超分辨、上色等任务,我们就需要处理前后的图像。
另外对于 SketchToImage 和 SegToImage 任务,input 和 target 的数据集是不同的,其他任务是一样的:
'ffhq_frontalize': {
'transforms': transforms_config.FrontalizationTransforms,
'train_source_root': dataset_paths['ffhq'],
'train_target_root': dataset_paths['ffhq'],
'test_source_root': dataset_paths['celeba_test'],
'test_target_root': dataset_paths['celeba_test'],
},
'celebs_sketch_to_face': {
'transforms': transforms_config.SketchToImageTransforms,
'train_source_root': dataset_paths['celeba_train_sketch'],
'train_target_root': dataset_paths['celeba_train'],
'test_source_root': dataset_paths['celeba_test_sketch'],
'test_target_root': dataset_paths['celeba_test'],
}
具体的 transform:
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'],
target_root=dataset_args['train_target_root'],
source_transform=transforms_dict['transform_source'],
target_transform=transforms_dict['transform_gt_train'],
opts=self.opts)
test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
target_root=dataset_args['test_target_root'],
source_transform=transforms_dict['transform_source'],
target_transform=transforms_dict['transform_test'],
opts=self.opts)
注意啊,这里设为 None 之后不是说不转换了,而且直接将 input 图像和 target 图像设置为一样的,而不是将 transforms 设置为一样的。 这点非常重要,因为 transforms 中可能有随机因素。
可以看到无论输入的尺寸如何,编码器的输入一直为 256*256,这样不会丢细节吗?毕竟部分信息在降采样过程中丢失了。
流程:
- 首先裁剪我们感兴趣的区域:
x = x[:, :, 35:223, 32:220]
,要注意,这里是硬编码! - 用一个 AdaptiveAvgPool2d 来 resize 输入图像,
- 之后输入到预训练好的 ResNet ArcFace 之中,得到一个 512 维的向量。
- 之后使用这三个向量的点积构造 loss,具体如下:
loss = 0 sim_improvement = 0 id_logs = [] count = 0 for i in range(n_samples): diff_target = y_hat_feats[i].dot(y_feats[i]) diff_input = y_hat_feats[i].dot(x_feats[i]) diff_views = y_feats[i].dot(x_feats[i]) id_logs.append({'diff_target': float(diff_target), 'diff_input': float(diff_input), 'diff_views': float(diff_views)}) loss += 1 - diff_target id_diff = float(diff_target) - float(diff_views) sim_improvement += id_diff count += 1 return loss / count, sim_improvement / count, id_logs
moco_loss 和这个是类似的,主要区别在于换了其他特征提取器,针对物体而非人脸。
其他任务
对于其他任务,只需要更换相应的数据集,对应的 StyleGAN 生成器权重,以及对应的 transforms 即可。
这里重点说一下 transforms,
- 图像 encode 任务:input 图像和 target 图像相同,应用了
RandomHorizontalFlip(0.5)
来做数据增强。 - 人脸正面化任务:input 图像和 target 各自
RandomHorizontalFlip(0.5)
。 - SketchToImage 任务:input 图像没有做
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
的处理,而目标图像做了。 - SegToImage 任务:这里有个 ToOneHot 的 transform 我没看懂,无关紧要,算了。
- 超分辨任务:这里对 input 图像做了随机倍数的下采样,之后 resize 回 256。
- 上色任务:这里是我自己加的,只需要对 input 图像做一下
transforms.Grayscale(num_output_channels=3)
就好。
吐槽
- 没有存优化器的 state_dict,导致没办法恢复训练。
- 数据集路径和权重路径硬编码,除了直接改代码外不能修改,搞笑呢。
芜湖,没了 :D
Links: pixel2style2pixel-code-notes