encoder4editing(e4e)代码实现解读

代码实现解读 图像转换 2021-09-13 14:59:24 2021-09-13 15:48:58 63 次浏览

概述

官方代码实现:https://github.com/omertov/encoder4editing

我自己的修改版:

  1. https://github.com/justsong-lab/encoder4editing
  2. https://github.com/justsong-lab/encoder4editing-v2

其代码基于 pSp,具体可见之前的博客,这里仅解读不一样的地方。

Encoder4Editing 编码器的实现

我拿他的代码和 pSp 的 diff 了一下,可以看到网络架构基本上是一样的,区别主要在于 forward 函数。

# Infer main W and duplicate it
w0 = self.styles[0](c3)
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
stage = self.progressive_stage.value
features = c3
for i in range(1, min(stage + 1, self.style_count)):  # Infer additional deltas
    if i == self.coarse_ind:
        p2 = _upsample_add(c3, self.latlayer1(c2))  # FPN's middle features
        features = p2
    elif i == self.middle_ind:
        p1 = _upsample_add(p2, self.latlayer2(c1))  # FPN's fine features
        features = p1
    delta_i = self.styles[i](features)
    w[:, i] += delta_i
return w

可以看到,拿到 18 个 w 向量的逻辑和 pSp 的是一模一样的,区别仅在于后面 17 个 w 向量要加上第一个 w 向量。

改动的地方非常之少。

另外注意 for 循环的范围,是随着训练进度改变的。

具体而言,

# Progressive training
self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
                         help="The training steps of training new deltas. steps[i] starts the delta_i training")
self.parser.add_argument('--progressive_start', type=int, default=20000,
                         help="The training step to start training the deltas, overrides progressive_steps")
self.parser.add_argument('--progressive_step_every', type=int, default=2000,
                         help="Amount of training steps for each progressive step")

简单来说就是,从第 20k 次迭代开始,每过 2k 次迭代 stage 加一。

损失函数的实现

隐码的对抗损失

首先看其辨别器的实现:

class LatentCodesDiscriminator(nn.Module):
    def __init__(self, style_dim, n_mlp):
        super().__init__()

        self.style_dim = style_dim

        layers = []
        for i in range(n_mlp-1):
            layers.append(
                nn.Linear(style_dim, style_dim)
            )
            layers.append(nn.LeakyReLU(0.2))
        layers.append(nn.Linear(512, 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, w):
        return self.mlp(w)

就是简单地堆了几层全连接层。

具体的辨别损失:

for i in dims_to_discriminate:
    w = latent[:, i, :]
    fake_pred = self.discriminator(w)
    loss_disc += F.softplus(-fake_pred).mean()
loss_disc /= len(dims_to_discriminate)

Delta Regularization Loss

delta = latent[:, curr_dim, :] - first_w
delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean()

其中 self.opts.delta_norm 默认值为 2,即 L2 范数。

未经允许,禁止转载,本文源站链接:https://iamazing.cn/