训练生成敌对网络的陷阱和技巧
本文针对的是刚开始使用GAN的深度学习爱好者。除非你非常幸运,否则第一次自行训练GAN可能是一个令人沮丧的过程,可能需要数小时才能正确完成。当然,随着时间的推移和经验,你会很好地训练GAN,但对于初学者来说,有几件事情可能会出错,而且你甚至可能无法确定哪里可以开始调试。我希望在第一次从头开始对GAN进行训练时,分享我的观察和经验教训,希望能够帮助开始几个小时的调试时间。
生成敌对网络
除非你在过去的一年中一直生活在一间小屋里,否则深度学习的每个人 - 甚至一些不参与深度学习的人 - 都听过并谈论过GAN。GANs或Generative Adversarial Networks是深度神经网络,它们是数据的生成模型。这意味着,给定一组训练数据,GAN可以学习估计数据的潜在概率分布。这是非常有用的,因为除了其他事情之外,我们现在可以从原始训练集中可能不存在的学习概率分布生成样本。
生成的对抗网络实际上是两个相互竞争的深层网络。给定一个训练集X(比如说几千个猫的图像),生成器网络G(X)作为输入一个随机向量,并试图产生与训练集相似的图像。一个鉴别器网络D(X)是一个二进制分类器,它试图根据训练集X和生成器生成的假猫图像来区分真实的猫图像。因此,生成器网络的工作是学习X中的数据的分布,这样它就能生成真实的猫图像,并确保鉴别器不能区分cat图像和来自生成器的cat图像。鉴别器需要不断学习与发生器持续不断地尝试新技巧,以产生伪造的cat图像并欺骗鉴别器。
最终,如果一切顺利,生成器(或多或少)了解了训练数据的真实分布,并非常擅长生成真实的cat图像。鉴别器不能区分训练集猫图像和生成的猫图像。
从这个意义上说,这两个网络一直在努力确保对方在他们的任务中做得不好。那么,这到底是怎么回事呢?
另一种查看GAN设置的方法是,鉴别器通过告诉生成器真实的cat图像是什么样子来引导生成器。最终,Generator 找到了它并开始生成真实的猫图像。训练GANs的方法类似于博弈论中的Minimax算法,这两个网络试图实现所谓的纳什均衡。
GAN训练面临的挑战
回到实际训练GANs。从一些简单的事情开始,我使用带有Tensorflow后端的Keras在MNIST数据集上训练了一个GAN(准确地说是DC-GAN)。这并不难,经过对Generator和Discriminator网络的一些小调整后,GAN能够生成MNIST数字的清晰图像。
黑白数字只是非常有趣。对象和人的彩色图像是所有酷玩家玩的。这是事情开始变得棘手的地方。MNIST之后,明显的下一步是生成CIFAR-10图像。经过几天和几天的调整超参数,改变网络架构,添加和删除图层,我终于能够生成类似于CIFAR-10的像样的图像。
我从一个相当深的网络开始(但是,大部分是不良的)网络,最终实现了一个更简单的网络。当我开始调整网络和训练过程时,15个迭代之后生成的图像看起来像这样,
对此,
最终这样做:
下面是我发现的一些错误,以及我一路上学到的东西。所以,如果你是GAN的新手,并且在训练中看不到很多成功的话,那么看看以下几个方面可能会有所帮助:
1.大内核和更多的过滤器
较大的内核覆盖上一层图像中的更多像素,因此可以查看更多信息。5x5内核在CIFAR-10中运行良好,在鉴别器中使用3x3内核导致鉴别器损失快速接近0.对于发生器,您需要在顶部卷积层上使用更大的内核来保持某种平滑度。在较低层,我没有看到改变内核大小的主要影响。
过滤器的数量可以大量增加参数的数量,但通常需要更多的过滤器。我几乎在所有的卷积层都使用了128个过滤器。使用较少的滤镜,特别是在发生器中,会使最终生成的图像过于模糊。因此,看起来像更多的过滤器有助于捕获额外的信息,最终可以为生成的图像增加清晰度。
2.Flip labels (Generated = True,Real = False)
虽然起初看起来很愚蠢,但对我而言,一个主要的窍门是改变标签分配。
如果您使用Real Images= 1和Generated Images= 0,则反之亦然。正如我们后面会看到的,这有助于在早期迭代中的梯度流动,并有助于实现事物的移动。
3.Soft and Noisy labels
训练鉴别器时,这非常重要。具有硬标签(1或0)几乎会在早期将所有学习都杀死,导致鉴别器非常迅速地接近0损失。我最终使用0到0.1之间的随机数来表示0个标签(实际图像)和一个介于0.9和1.0之间的随机数来表示1个标签(生成的图像)。训练generator时,这不是必需的。
另外,它有助于为训练标签添加一些噪音。对于供给鉴别器的图像的5%,标签被随机翻转,即真实被标记为生成并生成标记为真实。
4.批量规范有所帮助,但前提是你有其他的东西
批量标准化绝对有助于最终结果。添加批量规范导致明显更清晰的生成图像。但是,如果您错误地设置了内核或过滤器,或者鉴别器丢失迅速达到0,则添加批量标准可能无法真正帮助恢复。
5.One class at a time
为了使GAN更容易训练,确保输入数据具有相似的特征是有用的。例如,与其在所有10个类的cifar 10上训练一个GAN,不如选择一个类(比如汽车或青蛙),并训练一个GAN来生成该类的图像。还有其他DC-GAN变体可以更好地学习从多个类生成图像。例如,条件GAN将类标签作为输入并生成以类标签为条件的图像。但是,如果您刚开始使用简单的DC-GAN,则最好简单一些。
6.看看梯度
如果可能的话,尝试监控梯度以及网络的损失。这些可以帮助我们了解训练的进程如果情况不顺利的话甚至可以帮助进行调试。
理想地,generator应该在训练早期接收大的梯度,因为它需要学习怎样产生真实数据。另一方面,鉴别器在早期并不总是得到很大的梯度,因为它可以很容易地辨别真伪图像。一旦generator得到足够的训练,鉴别器就很难分辨真假图像。它会不断出错,并得到强梯度。
我在CIFAR-10汽车上的GAN的第一个版本,有许多卷积和批量标准层,没有标签翻转。除趋势之外,监测梯度的规模也很重要。如果发生器层的梯度太小,学习可能会缓慢或根本不会发生。这在GAN的这个版本中是可见的。
发生器最底层的梯度范围太小,无法进行任何学习。Discriminator梯度也始终保持一致,暗示Discriminator没有真正学到任何东西。现在,让我们比较一下具有上述所有变化的GAN的梯度,并生成良好的真实图像:
到达发生器底层的梯度比例明显高于以前的版本。此外,梯度随着训练的进行而发生,随着generator早期获得大的梯度,一旦generator已经足够训练,判别器在顶层获得一致的高梯度。
7.没有提前停止
我犯的一个愚蠢的错误 - 可能是因为我的急躁 - 在几百个小批次之后,当我看到损失没有取得任何明显进展,或者如果生成的样本保持嘈杂时,就会停止训练。很容易重新开始工作并节省时间,而不是等待训练结束,并最终意识到网络从未学过任何东西。GAN需要很长时间才能进行训练,并且初期的损失值和产生的样本几乎从不显示任何趋势或进展迹象。在杀死训练过程并在设置中调整某些内容之前,等待一段时间非常重要。
此规则的一个例外是,如果您看到鉴别器损失快速接近0.如果发生这种情况,几乎没有恢复的机会,并且重新启动训练可能更好,可能是在更改网络或训练过程中的某些内容后。
结束工作的最终GAN看起来像这样: