将手写的数字对齐到原始位置,自动编码器了解一下

点击上方关注,All in AI中国

作者——Barna Pásztor

自编码器广泛应用于神经网络的无监督学习以来,其最初的目的是寻找潜在的低维状态空间数据集,但它们也能够解决其他问题,如图像去噪、增强或着色(彩色化)。

在这篇文章中,我想通过训练的卷积自动编码器来分享我的一些想法,我训练编码器将MNIST数据集中随机旋转的手写数字对齐到原始位置。请注意,这篇文章的主要目的不是对自动编码器做介绍,而是对一个应用程序做出描述以及对不同解码架构做出比较。在接下来的部分中,我提出了一个案例,它强化了这样的共识:卷积层的上采样比反卷积层表现得更好。并表明卷积层和全连接层组合起来在简单的自动编码器和只有卷积层的卷积自动编码器上都具有优势。

如果您已经熟悉卷积自动编码器和上采样技术,请跳过下一节,如果不熟悉,我建议您阅读它和相关文章。

导言

自动编码器背后的主要思想是将输入减少到具有更小维度的潜在状态空间中,然后尝试从(现有的)表达中重构输入。第一部分是编码,第二部分是解码阶段。为了减少表示数据的变量,我们强迫模型学习如何只保留有意义的信息,这些信息中的输入是可重构的。它也可以被看作是一种压缩技术。

以下是我推荐的一个很好的介绍:

如果您也对实现感兴趣,那么下一篇文章是基于第一个链接编写的,但是它还包括TensorFlow中的详细实现:

这个实现的一个有趣部分是对上采样的使用。由于核的重叠,使用反转卷积层可能相当具有挑战性。大多数深度学习框架都包含反褶积层(有人称为转置卷积层),它只是一个反向卷积层。尽管这些层在重新输入时的确有直观上的意义,但是它们的缺点也不容忽视。为了克服这个问题,上采样和简单的卷积层应运而生。以下两篇文章都对这个问题做了详细的解释。

项目

正如我前面提到的,我使用了MNIST数据集,并以某种随机角度旋转了每幅图像。这些模型的任务是将它们重新调整到原来的状态。在实验中,我比较了以下几种架构;

  • 一个简单的带有三个隐藏层的自动编码器,我用它作为基准
  • 一种卷积式自动编码器,它仅由编码器中的卷积层和解码器中的转置卷积层组成
  • 另一个卷积模型在编码器部分使用卷积块和最大池,在解码器中使用卷积层进行上采样
  • 最后一个模型是卷积层和全连接层的组合

所有模型都是在同一数据集上训练的,具有相同的超参数,瓶颈层有256个变量。

基线模型

首先,让我们看看基线自动编码器,它有以下结构。

将手写的数字对齐到原始位置,自动编码器了解一下

在训练过程中,我使用了均方误差,在训练数据上,最好的模型达到0.0158,在验证数据上,最好的模型达到0.0208,而在测试数据上,就像预期的那样,达到了0.0214。下面的图片是测试数据集中的一个示例,该示例显示了模型如何设法重新对齐数字4。

将手写的数字对齐到原始位置,自动编码器了解一下

该模型实现了数字4在图像上的显示,并将其旋转回原来的位置。(然而下面的)预测仍然是数字4,但是边缘有点模糊,顶部的两条线之间的差距也几乎消失了,这大概是低维bottleneck的结果。总的来说,该模式能够完成任务并产生可接受的结果已经很好了。但是,它(对部分数字)的输入仍然有困难。以下图像对大部分模型而言都是一场挑战。

将手写的数字对齐到原始位置,自动编码器了解一下

中间的图像清楚地显示出数字4,然而,由于较长的水平线,旋转后的图像类似于数字5。这个模型很难识别出它,最终得到的是数字5而不是数字4。

转置卷积自动编码器

第二种是卷积自动编码器,它只包含卷积层和反卷积层。在编码器中,输入数据要通过12个卷积层,3×3的内核和从4增加到16的滤波器。由于没有填充卷积层,而且步长大小为1,因此瓶颈的大小为16x4x4,这意味着bottleneck中的变量与基线模型相匹配。解码器用转置卷积层来反映这个体系结构。只使用卷积层可能看起来不太寻常,但在这种情况下,我们的目标是比较技术,而不是取得出色的结果。

就度量而言,该体系结构不能接近基准模型。训练MSE损失为0.0412,验证损失为0.0409,测试损失为0.0407。我们要知道可训练参数越小,损失越大。而现在卷积层将基准中的参数从大约100万减少到只有23000个。

将手写的数字对齐到原始位置,自动编码器了解一下

在许多情况下,如上面的数字9所示,该模型能够解决这个问题,并能够预测出一个可识别的数字(该数字与期望的输出相匹配)。但总的来说,它的性能仍需改进。因为它生成的线条太过模糊,几乎无法识别。如下所示:

将手写的数字对齐到原始位置,自动编码器了解一下

这种效果可以归因于反卷积层,因为除了边缘之外的每个像素都是作为层叠滤波器的总和产生的。在上面的复杂示例中,这个模型的性能甚至更差,它不仅使输出变得模糊,而且还创建了一个与数字3相似的数字,而不是我们期望的数字4。

将手写的数字对齐到原始位置,自动编码器了解一下

上采样卷积自动编码器

我考虑的下一个架构是一个卷积自动编码器,它有卷积层、最大池和上采样层。

将手写的数字对齐到原始位置,自动编码器了解一下

具有上采样层的CAD体系结构

在训练方面,它的MSE值略大于基准模型;训练结果为0.0293,验证结果为0.0293,测试结果为0.0297。但与上一个模型类似,得分越低的模型反而越有优势。该模型的大小不到基准模型的三分之一。只有2.9万个可训练参数,而且它的性能仍然可以接受。此外,度量的窄扩展表明它在没有额外正则化的情况下推广得很好。

将手写的数字对齐到原始位置,自动编码器了解一下

在上面的图片中,我们可以看到,它仍然能够识别数字并重新对齐它。虽然边缘仍十分模糊,但比起以前的模型已经好很多了。在这种情况下,它甚至对一些不重要的部分进行了"取舍"。

将手写的数字对齐到原始位置,自动编码器了解一下

这个模型也无法识别前面讨论过的那个极具挑战性的示例。但是它没有以前的模型那么混乱。它的输出类似于数字5,这表明它在识别数字方面(而非重构方面)存在问题。

组合模型

在看到上采样提供了更好的结果和更精确的输出之后,我创建了另一个架构,它将上采样模型中的块和全连接层以下面的方式组合在一起。

将手写的数字对齐到原始位置,自动编码器了解一下

卷积层和池化层成功地取代了基准测试的第一个稠密层,并产生了迄今为止最好的模型。只有40万个可训练参数,明显少于基准测试时的100万个参数。训练数据损失为0.0151 MSE,验证数据损失为0.0174 MSE,测试数据损失为0.0173 MSE。生成的图也比之前的好。

将手写的数字对齐到原始位置,自动编码器了解一下

通过混合两种类型的层,模型能够生成更细的线条和较少模糊的图像,但重构仍然不完美,比如说下图。

将手写的数字对齐到原始位置,自动编码器了解一下

模型输出上存在的问题在这个例子中展现的更加明确。

将手写的数字对齐到原始位置,自动编码器了解一下

显然,模型认为它是数字5并试图将数字重新对齐。但正是这样的假设导致数字无法识别,模型难以处理。

评语

尽管我使用MSE作为性能评判指标,但需要强调的是,在某些情况下,高MSE并不意味着错误的输出。例如,下面的预测是由组合模型做出的。

将手写的数字对齐到原始位置,自动编码器了解一下

预测图像是一个很好绘制的0,这很符合输入图像的"期望",但是目标图像上的数字不是传统的书写方式,因此MSE预测的误差很高。自然,我们只能通过神经网络来了解这些信息。另外,我在之前就预计模型会混淆数字6和数字9,但输出的图像显示这种情况只是偶尔发生。例如,下面的图像显示了验证数据集中的上采样模型的一个输出,该模型已经被很好地识别,除了尾部模糊之外,它被重构为数字9。

将手写的数字对齐到原始位置,自动编码器了解一下

另一方面,我用来演示数据集中难以处理的样本问题似乎是一个更大的挑战,因为许多模型在这种情况下都难以将数字4旋转对齐。特别是,带有转置卷积层的模型在处理这个问题时生成了如下图所示的图像。

将手写的数字对齐到原始位置,自动编码器了解一下

但是,其他的模型表现也不算好。例如,基线模型犯了以下错误。

将手写的数字对齐到原始位置,自动编码器了解一下

令人惊讶的是,旋转数字4产生的图像会让很多人认为它是数字5。

结语

正如这个项目所展示的,仔细地选择模型不一定会提高它的性能,但是它有助于我们构建合适的模型,从而减少过度拟合的可能性。正如我们在训练度量中看到的那样。一般来说,使用一个简单的自动编码器似乎是一个恰当的选择,因为它以一种令人满意的方式解决了这个问题,但是在度量中可以观察到过度拟合的痕迹,而且它的尺寸比其他模型大得多。我们可以通过减少中间层的节点数量或者简单地忽略它们来简化这个模型,但是,如果我们仍然保持256个大bottleneck的话,那么可以实现的最小参数数量大约是40万个。另外,额外的正则化技术可能有助于泛化,但这似乎是不必要的,因为带有上采样的卷积自编码器能够在比网络小十倍以上的情况下获得几乎同样好的结果。这两种类型层的组合提供了具有合理规模和最佳性能的体系结构。

在这篇文章中,我只挑选了几幅图片来展示架构的性能。包含更多图像和更详细的评估可以在链接的Gizub存储库中找到,其中还包括Keras中的实现(https://github.com/pasztorb/Rotational_CAD)。

将手写的数字对齐到原始位置,自动编码器了解一下

相关推荐