encoder4editing(e4e)代码实现解读
概述
官方代码实现:https://github.com/omertov/encoder4editing
我自己的修改版:
其代码基于 pSp,具体可见之前的博客,这里仅解读不一样的地方。
Encoder4Editing 编码器的实现
我拿他的代码和 pSp 的 diff 了一下,可以看到网络架构基本上是一样的,区别主要在于 forward 函数。
# Infer main W and duplicate it
w0 = self.styles[0](c3)
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
stage = self.progressive_stage.value
features = c3
for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
if i == self.coarse_ind:
p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
features = p2
elif i == self.middle_ind:
p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
features = p1
delta_i = self.styles[i](features)
w[:, i] += delta_i
return w
可以看到,拿到 18 个 w 向量的逻辑和 pSp 的是一模一样的,区别仅在于后面 17 个 w 向量要加上第一个 w 向量。
改动的地方非常之少。
另外注意 for 循环的范围,是随着训练进度改变的。
具体而言,
# Progressive training
self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
help="The training steps of training new deltas. steps[i] starts the delta_i training")
self.parser.add_argument('--progressive_start', type=int, default=20000,
help="The training step to start training the deltas, overrides progressive_steps")
self.parser.add_argument('--progressive_step_every', type=int, default=2000,
help="Amount of training steps for each progressive step")
简单来说就是,从第 20k 次迭代开始,每过 2k 次迭代 stage 加一。
损失函数的实现
隐码的对抗损失
首先看其辨别器的实现:
class LatentCodesDiscriminator(nn.Module):
def __init__(self, style_dim, n_mlp):
super().__init__()
self.style_dim = style_dim
layers = []
for i in range(n_mlp-1):
layers.append(
nn.Linear(style_dim, style_dim)
)
layers.append(nn.LeakyReLU(0.2))
layers.append(nn.Linear(512, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, w):
return self.mlp(w)
就是简单地堆了几层全连接层。
具体的辨别损失:
for i in dims_to_discriminate:
w = latent[:, i, :]
fake_pred = self.discriminator(w)
loss_disc += F.softplus(-fake_pred).mean()
loss_disc /= len(dims_to_discriminate)
Delta Regularization Loss
delta = latent[:, curr_dim, :] - first_w
delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean()
其中 self.opts.delta_norm
默认值为 2,即 L2 范数。
Links: encoder4editing-code-notes