pixel2style2pixel(pSp)代码实现解读

标签: 代码实现解读 图像转换 发布于:2021-09-10 11:20:08 编辑于:2021-09-13 15:48:52 浏览量:3839

概述

原本的代码实现:https://github.com/eladrich/pixel2style2pixel

我的修改版本:

  1. https://github.com/justsong-lab/pixel2style2pixel
  2. https://github.com/justsong-lab/pixel2style2pixel-v2

pSp 的实现

样式码输出的个数

self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2

这样对于 1024 的输入,样式码个数为 18,对于 256 的输入,样式码个数为 14。

GradualStyleEncoder 的实现

这个就是论文中用到的 encoder。

encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)

论文中说是 ResNet-IR,pretrained on face recognition,但我硬是找不到 ResNet-IR 的官方架构,淦。

代码中又给链接,实现是从这里拿的。

所谓的 se 是指的是在构建网络的基础块的主干分支末尾加一个 SEModule:

class SEModule(Module):
    def __init__(self, channels, reduction):
        super(SEModule, self).__init__()
        self.avg_pool = AdaptiveAvgPool2d(1)
        self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
        self.relu = ReLU(inplace=True)
        self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
        self.sigmoid = Sigmoid()

    def forward(self, x):
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x

这里的关键点在于是怎么从不同 layer 拿出 latent code 的,直接看 forward 函数就好了:

def forward(self, x):
    x = self.input_layer(x)

    latents = []
    modulelist = list(self.body._modules.values())
    for i, l in enumerate(modulelist):
        x = l(x)
        if i == 6:
            c1 = x
        elif i == 20:
            c2 = x
        elif i == 23:
            c3 = x

    for j in range(self.coarse_ind):
        latents.append(self.styles[j](c3))

    p2 = self._upsample_add(c3, self.latlayer1(c2))
    for j in range(self.coarse_ind, self.middle_ind):
        latents.append(self.styles[j](p2))

    p1 = self._upsample_add(p2, self.latlayer2(c1))
    for j in range(self.middle_ind, self.style_count):
        latents.append(self.styles[j](p1))

    out = torch.stack(latents, dim=1)
    return out

大致流程就是,

  1. 首先从指定 layer (第 6,20 和 23 层)拿中间的特征图,
  2. 对于 1024 的输入,有 18 个样式码要输出,这 18 个样式码分别有一个对应的实例化的 GradualStyleBlock, 依据三个层次的 coarse 程度,越往后这个 GradualStyleBlock 中的卷积层个数就越多。 三个特征图,分别送到对应的 GradualStyleBlock 中,每一个都对应多个。之后我们就得到了要的 latent code。

BackboneEncoderUsingLastLayerIntoW & BackboneEncoderUsingLastLayerIntoWPlus 的实现

前者直接把主干网络的输出 resize 成 512 输入到 w 空间,后者加了一层全连接层,之后 resize 成 18 * 512 输入到 W+ 空间。

forward 函数的逻辑

首先把输入图片直接扔到 encoder 里拿到 latent code。

如果选择了 start_from_latent_avg,则 code 加上之前的平均 latent code。

如果要用 style mixing,这里则借助一个 latent_maskinject_latent mix 起来,

if alpha is not None:
    codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
    codes[:, i] = inject_latent[:, i]

损失函数的实现

MSE 重构损失

约束 inversion 的结果要与目标尽量接近。

loss_l2 = F.mse_loss(y_hat, y)

具体使用时可选是否要 corp:

loss_l2_crop = F.mse_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])

LPIPS 损失

类似感知损失,或者说就是感知损失?

self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()

我看了一下应该是从直接从这里把代码拷贝了过来。

具体使用时可选是否要 corp。

loss_lpips_crop = self.lpips_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220])

ID 损失

输入为:

  1. x:要 invert 的图像,使用 source_transform 进行预处理,即网络的输入。
  2. y:要 invert 的图像,使用 target_transform 进行预处理,即网络的目标输出。
  3. y_hat:invert 后的生成图像。

为什么要这样搞呢?实际上对于 GAN inversion 确实是多余的,但是对于超分辨、上色等任务,我们就需要处理前后的图像。

另外对于 SketchToImage 和 SegToImage 任务,input 和 target 的数据集是不同的,其他任务是一样的:

'ffhq_frontalize': {
    'transforms': transforms_config.FrontalizationTransforms,
    'train_source_root': dataset_paths['ffhq'],
    'train_target_root': dataset_paths['ffhq'],
    'test_source_root': dataset_paths['celeba_test'],
    'test_target_root': dataset_paths['celeba_test'],
},
'celebs_sketch_to_face': {
    'transforms': transforms_config.SketchToImageTransforms,
    'train_source_root': dataset_paths['celeba_train_sketch'],
    'train_target_root': dataset_paths['celeba_train'],
    'test_source_root': dataset_paths['celeba_test_sketch'],
    'test_target_root': dataset_paths['celeba_test'],
}

具体的 transform:

transforms_dict = {
    'transform_gt_train': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
    'transform_source': None,
    'transform_test': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
    'transform_inference': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}

train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'],
                              target_root=dataset_args['train_target_root'],
                              source_transform=transforms_dict['transform_source'],
                              target_transform=transforms_dict['transform_gt_train'],
                              opts=self.opts)
test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'],
                             target_root=dataset_args['test_target_root'],
                             source_transform=transforms_dict['transform_source'],
                             target_transform=transforms_dict['transform_test'],
                             opts=self.opts)

注意啊,这里设为 None 之后不是说不转换了,而且直接将 input 图像和 target 图像设置为一样的,而不是将 transforms 设置为一样的。 这点非常重要,因为 transforms 中可能有随机因素。

可以看到无论输入的尺寸如何,编码器的输入一直为 256*256,这样不会丢细节吗?毕竟部分信息在降采样过程中丢失了。

流程:

  1. 首先裁剪我们感兴趣的区域:x = x[:, :, 35:223, 32:220],要注意,这里是硬编码!
  2. 用一个 AdaptiveAvgPool2d 来 resize 输入图像,
  3. 之后输入到预训练好的 ResNet ArcFace 之中,得到一个 512 维的向量。
  4. 之后使用这三个向量的点积构造 loss,具体如下:
    loss = 0
    sim_improvement = 0
    id_logs = []
    count = 0
    for i in range(n_samples):
        diff_target = y_hat_feats[i].dot(y_feats[i])
        diff_input = y_hat_feats[i].dot(x_feats[i])
        diff_views = y_feats[i].dot(x_feats[i])
        id_logs.append({'diff_target': float(diff_target),
                        'diff_input': float(diff_input),
                        'diff_views': float(diff_views)})
        loss += 1 - diff_target
        id_diff = float(diff_target) - float(diff_views)
        sim_improvement += id_diff
        count += 1
    
    return loss / count, sim_improvement / count, id_logs
    

moco_loss 和这个是类似的,主要区别在于换了其他特征提取器,针对物体而非人脸。

其他任务

对于其他任务,只需要更换相应的数据集,对应的 StyleGAN 生成器权重,以及对应的 transforms 即可。

这里重点说一下 transforms,

  1. 图像 encode 任务:input 图像和 target 图像相同,应用了 RandomHorizontalFlip(0.5) 来做数据增强。
  2. 人脸正面化任务:input 图像和 target 各自 RandomHorizontalFlip(0.5)
  3. SketchToImage 任务:input 图像没有做 Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 的处理,而目标图像做了。
  4. SegToImage 任务:这里有个 ToOneHot 的 transform 我没看懂,无关紧要,算了。
  5. 超分辨任务:这里对 input 图像做了随机倍数的下采样,之后 resize 回 256。
  6. 上色任务:这里是我自己加的,只需要对 input 图像做一下 transforms.Grayscale(num_output_channels=3) 就好。

吐槽

  1. 没有存优化器的 state_dict,导致没办法恢复训练。
  2. 数据集路径和权重路径硬编码,除了直接改代码外不能修改,搞笑呢。

芜湖,没了 :D

未经允许,禁止转载,本文源站链接:https://iamazing.cn/