ESRGAN-(ECCV2018)基于GAN的增强版超分辨率算法 代码笔记

AI基础  收藏
0 / 476

这是一个应用上最近被使用的很多的算法,大名鼎鼎的老北京复原视频也用到了它。自己动手玩的时候,发现一些有趣的点,譬如两个网络插值得到第三个网络,譬如判断图像相对真实性等等。大部分基于GAN的SR算法 都会存在奇奇怪怪的伪影,在动漫类型的图中就表现特别明显。就连cvpr2020中的一些也不能免俗。但这篇文章 ---它不太一样。

论文原文翻译:

[论文翻译] https://aiqianji.com/blog/article/1

论文思想

基于2017年SRGAN算法改进而来,相比于SRGAN它在三个方面进行了改进: 网络架构,对抗性损失和感知损失
1)在没有使用BN的情况下引入RRDB Residual-in-Residual Dense Block作为基本网络构建单元。 并使用残差缩放和更小的初始化来促进训练非常深的网络。

  1. 借用RaGAN的思想来让判别器预测图像的相对真实性而不是图像的绝对真实性。它学会判断一个图像比另一个图像更真实,而不是“一个图像是真实的还是假的”。
  2. 使用激活前的VGG特征来改善感知损失,而不是像SRGAN中激活后使用VGG特征,这可以提供对亮度一致性和纹理恢复更强的监督力。

关键点详解

网络结果 RRDB,对residual blocks的改进:

具体代码参见:RRDBNet_arch.py 主要为生成器G的结构做了两个改进:
1)去除掉所有的BN层。
2)提出用残差密集块(RRDB)代替原始基础块,其结合了多层残差网络和密集连接,如图4所示。

左:我们去除SRGAN中残余块中的BN层。 右:在我们的深层模型中使用RRDB块,β是残差缩放参数。

当训练和测试数据集的统计数据差异很大时,BN层往往引入不适的伪影,限制了泛化能力。我们以经验观察到,BN层有可能当网络深和在GAN网络下训练时带来伪影。这些伪影偶尔出现在迭代和不同设置之间,违反了稳定性能超过训练的需求。因此,我们为了训练稳定和一致性去除了BN层。此外,去除BN层有助于提高泛化能力,减少计算复杂度和内存使用。

作者受denseNet的启发,设计了RRDB,首先看RRDB中的基本模块,由5个卷积层构建的densenet结构:

class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x

然后三个Denseblock 构建了一个RRDB结构: class RRDB(nn.Module): '''Residual in Residual Dense Block'''

def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x

然后根据RRDB结构构建对应的网络:
默认构建一个23个RRDB块的深层模型(nb=23)
model = arch.RRDBNet(3, 3, 64, 23, gc=32)

class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32): super(RRDBNet, self).__init__() RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out

对损失函数的改进

除了改进生成器的结构,作者还基于RaGAN增强了判别器。代码参见 BasicSR项目 esrgan_model.py srgan_model.py

不同于SRGAN标准的判别器D,估算一个输入图像x真实和自然的可能性,如图所示,相对判别器试图预测真实图像x_{r}比假图像x_{f}更真实的概率。

标准鉴别器和相对判别器之间的区别

具体而言,作者把标准的判别器换成Relativistic average Discriminator(RaD)

标准判别器在SRGAN可以表示为

D\left ( x \right )=\sigma \left ( C\left ( x \right ) \right )

其中\sigma是sigmoid函数,C\left ( x \right )是非变换判别器输出。

RaD可以用公式表示为

D_{Ra}(x_{r},x_{f})=\sigma(C(x_{r})-\mathbb{E}{x{f}}[C(x_{f})]) D_{Ra}(x_{r},x_{f})=\sigma(C(x_{r})-\mathbb{E}{x{f}}[C(x_{f})])

其中 IE_{x_{f}}IE_{x_{f}} $表示在mini批处理中对所有假数据取平均值的操作。

生成器的对抗损失包含了 x_{r}x_{r} 和 x_{f},x_{f}, 所以生成器受益于对抗训练中的生成数据和实际数据的梯度,这种调整会使得网络学习到更尖锐的边缘和更细节的纹理。

我们可以看SRGAN的loss:

l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)

而ESRGAN的loss:

l_d_real = self.cri_gan( real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred), False, is_disc=True)

对感知损失的改进

SRGAN是用一个训练好的VGG16来给出超分辨率复原所需要的特征,与常规相反,我们提出在激活层之前使用特性,这将克服原始设计的两个缺点。

第一,被激活的特征是非常稀疏的,特别是在非常深的网络之后,如图6所示。

图6:激活图像'狒狒'之前和之后的代表性特征图。

随着网络的深入,激活后的大多数功能变为非活动状态,而激活前的功能包含更多信息。例如,图像“baboon”激活神经元的平均百分率在VGG19-54层后仅为11.17%。稀疏激活提供弱的监督,从而导致性能较差。

第二,使用激活后的特征也会造成重建后的图像亮度与真实图像不一致。在激活之前使用特征可以导致重建图像的更准确的亮度。为了消除纹理和颜色的影响,我们使用高斯核过滤图像并绘制其灰度对应的直方图。图9显示了每个亮度值的分布。

图9:激活前和激活后的比较

网络插值

为了去掉在GAN-based方法中令人不愉快的噪声,同时保证好的感知质量,作者提出一个灵活有效的策略,即网络插值。训练一个PSNR-oriented网络 G_{PSNR}G_{PSNR} ,然后通过微调获得GAN-based网络 G_{GAN}G_{GAN} 。对这两个网络的所有相应参数进行插值,得到一个插值模型 G_{INTERP}G_{INTERP} ,其参数可以表示为:

\theta_{G}^{\text{INTERP}}=\alpha\ \theta_{G}^{\text{PSNR}}+(1-\alpha)\ \theta% {G}^{\text{GAN}} \theta_{G}^{\text{INTERP}}=\alpha\ \theta_{G}^{\text{PSNR}}+(1-\alpha)\ \theta% {G}^{\text{GAN}}

网络插值有两个优点。

第一,首先,插值模型能够在不引入伪影的情况下对任何可行的\alpha产生有意义的结果。

第二,我们可以在不重新训练模型的情况下,持续地平衡感知质量和感觉。

它的具体实现就是:参见net_interp.py

for k, v_PSNR in net_PSNR.items(): v_ESRGAN = net_ESRGAN[k] net_interp[k] = (1 - alpha) * v_PSNR + alpha * v_ESRGAN

实验效果

为了直观比较效果,选择了一张算法生成的动漫图像(算法生成的图片,具有一定瑕疵,对结果展示更加直观),而不是训练集中的照片。

原图是一张200*200的图,直接显示如图:

通过GAN-based算法输出的图:

通过网络插值模型输出的图:

可以看到 GAN-based算法 女孩眼睛部分效果很不好,而插值模型输出的图,眼睛部分 分外完美。