生产级深度学习的开发经验分享:数据集的构建和提升是关键
深度学习的研究和生产之间存在较大差异,在学术研究中,人们一般更重视模型架构的设计,并使用较小规模的数据集。本文从生产层面强调了深度学习项目开发中需要更加重视数据集的构建,并以作者本人的亲身开发经验为例子,分享了几个简单实用的建议,涉及了数据集特性、迁移学习、指标以及可视化分析等层面。无论是对于研究者还是开发者,这些建议都有一定的参考价值。
本文还得到了 Andrej Karpathy 的转发:
作者简介:Pete Warden 是 Jetpac Inc 的 CTO,著有《The Public Data Handbook》和《The Big Data Glossary》两本 O'Reilly 出版的书,并参与建立了多个开源项目,例如 OpenHeatMap 和 Data Science Toolkit 等。
图片来源:Lisha Li
Andrej Karpathy 在 Train AI(https://www.figure-eight.com/train-ai/)进行演讲时展示了这张幻灯片,我非常喜欢它!它完美地展现了深度学习的研究与实际的生产之间的差异。学术论文大多仅仅使用公开数据中的一小部分作为数据集而关注创造和改进模型。然而据我所知,当人们开始在实际的应用中使用机器学习时,对于训练数据的担忧占据了他们的大部分时间。
有很多很好的理由可以用来解释为什么研究人员如此关注于模型的架构,但这也确实意味着,对那些专注于将机器学习应用于生产环境中的人员来说,他们可以获取到的相关资源是很少的。为了解决这个问题,我在会议上进行了关于「the unreasonable effectiveness of training data」的演讲,而在这篇博客中,我想进一步阐述为什么数据如此重要以及改进它的一些实用技巧。
作为我工作的一部分,我与很多研究人员还有产品团队之间进行了密切的合作。我看到当他们专注于模型构建这一角度时可以获得很好的效果,而这也让我笃信于改进数据的威力。将深度学习应用到大多数应用中的最大障碍是如何在现实世界中获得足够高的准确率,而据我所知,提高准确度的最快途径就是改进训练集。即使你在其他限制(如延迟或存储空间)上遇到了阻碍,在特定的模型上提高准确率也可以帮助你通过使用规模较小的架构来对这些性能指标做出权衡。
语音数据集
我无法将我对于生产性系统的大部分观察分享给大家,但我有一个开源的例子可以用来阐释相同的模式。去年,我为 TensorFlow 创建了一个简单的语音识别示例(https://www.tensorflow.org/tutorials/audio_recognition),结果表明在现有的数据集中,没有哪一个是可以很容易地被用作训练数据的。多亏了由 AIY 团队帮助我开发的开放式语音记录站点(https://aiyprojects.withgoogle.com/open_speech_recording)我才得以在很多志愿者的慷慨帮助下,收集到了 6 万个记录了人们说短单词的一秒钟音频片段。在这一数据训练下的模型虽然可以使用,但仍然没有达到我想要的准确度。为了了解我设计模型时可能存在的局限性,我用相同的数据集发起了一个 Kaggle 比赛(https://www.kaggle.com/c/tensorflow-speech-recognition-challenge/)。参赛者的表现比我的简单模型要好得多,但即使有很多不同的方法,多个团队的精确度最终都仅仅达到了 91%左右。对我而言,这意味着数据本身存在着根本性的问题,而实际上参赛者们也的确发现了很多问题,比如不正确的标签或被截断过的音频。这些都激励着我去解决他们发现的问题并且增加这个数据集的样本数量。
我查看了错误度量标准,以了解模型最常遇到的问题,结果发现「其他」类别(当语音被识别出来,但这些单词不在模型有限的词汇表内时)更容易发生错误。为了解决这个问题,我增加了我们捕获的不同单词的数量,以提供更加多样的训练数据。
由于 Kaggle 参赛者报告了标签错误,我通过众包的形式增加了一个额外的验证过程:要求人们倾听每个片段并确保其与预期标签相符。由于 Kaggle 竞赛中还发现了一些几乎无声或被截断的文件,我还编写了一个实用的程序来进行一些简单的音频分析(https://github.com/petewarden/extract_loudest_section),并自动清除特别糟糕的样本。最后,尽管删除了错误的文件,但由于更多志愿者和一些付费的众包服务人员的努力,我们最终获得了超过 10 万的发言样本。
为了帮助他人使用数据集(并从我的错误中吸取教训!)我将所有相关内容以及最新的结果写入了一篇 arXiv 论文(https://arxiv.org/abs/1804.03209)。其中最重要的结论是,在不改变模型或测试数据的情况下,(通过改进数据)我们可以将 top-1 准确率从 85.4% 提高到 89.7%。这是一个巨大的提升(超过了 4%),并且当人们在安卓或树莓派的样例程序中使用该模型时,获得了更好的效果。尽管目前我使用的远非最优的模型,但我确信如果我将这些时间花费在调整模型上,我将无法获得这样的性能提升。
在生产的配置过程中,我多次见证了上述这样的性能提升。当你想要做同样的事情的时候,可能很难知道应该从哪里开始。你可以从我处理语音数据的技巧中得到一些灵感,但在接下来的内容中,我将为你介绍一些我认为有用的具体的方法。
首先,观察你的数据
这看起来显而易见,但你首先最应该做的是随机浏览你将要使用的训练数据。将一些文件复制到本地计算机上,然后花几个小时来预览它们。如果您正在处理图片,使用类似于 MacOS 的取景器的功能滚动浏览缩略图,将可以让你快速地浏览数千个图片。对于音频,你可以使用取景器播放预览,或者将文本随机片段转储到终端。正因为我没有花费足够的时间来对第一版语音命令进行上述处理,Kaggle 参赛者们才会在开始处理数据时发现了很多问题。
我总是觉得这个过程有点愚蠢,但我从未后悔过。每当我完成这些工作时,我都可以发现一些对数据来说非常重要的事情,比如不同类别之间样本数量的失衡、数据乱码(例如扩展名标识为 JPG 的 PNG 文件)、错误的标签,或者仅仅是令人惊讶的组合。Tom White 在对 ImageNet 的检查中获得了许多惊人的发现,比如:标签「太阳镜」,实际上是指一种古老的用来放大阳光的设备。Andrej 对 ImageNet 进行手动分类的工作(http://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/)同样教会了我很多与这个数据集相关的知识,包括如何分辨所有不同的犬种,甚至是人。
你将要采取的行动取决于你的发现,但是在你做任何其他数据清理工作之前,你都应该先进行这种检查,因为对数据集内容的直观了解有助于你在其余步骤中做出更好的决定。
快速地选择一个模型
不要在选择模型上花费太多时间。如果你正在进行图像分类任务,请查看 AutoML,或查看 TensorFlow 的模型存储库(https://github.com/tensorflow/models/)或 Fast.AI 收集的样例(http://www.fast.ai/)来找到你产品中面对的类似问题的模型。重要的是尽可能快地开始迭代,这样你就可以尽早且经常性地让实际用户来试用你的模型。你随时都可以上线改进的模型,并且可能会看到更好的结果,但你必须首先对数据进行合适的处理。深度学习仍然遵循「输入决定输出」的基本计算规律,所以即使是最好的模型也会受到训练集中数据缺陷的限制。通过选择模型并对其进行测试,你将能够理解这些缺陷从而开始改进数据。
为了进一步加快模型的迭代速度,你可以尝试从一个已经在大型现有数据集上预训练过的模型开始,使用迁移学习来利用你收集到的(可能小得多的)一组数据对它进行微调。这通常比仅在较小的数据集上进行训练的结果要好得多,而且速度更快,这样一来你就可以快速地了解到应该如何调整数据收集策略。最重要的是,你可以根据结果中的反馈调整数据收集(和处理)流程,以便适应你的学习策略,而不是仅仅在训练之前将数据收集作为单独的阶段进行。
在做到之前先假装做到(人工标注数据)
建立研究和生产模型最大的不同之处在于,研究通常在开始时就有了明确的问题定义,而实际应用的需求潜藏在用户的头脑中,并且只能随着时间的推移而逐渐获知。例如,对于 Jetpac,我们希望找到好的照片并展示在城市的自动旅行指南中。刚开始我们要求评分者给他们认为好的照片打上标签,但我们最终却得到了很多张笑脸,因为这就是他们对这个问题的理解。我们将这些内容放入产品的展示模型中,来测试用户的反应,结果发现这并没有给他们留下什么深刻的印象。为了解决这个问题,我们将问题修改为「这张照片是否让你想要前往它所展示的地方?」。这很大程度上提高了我们结果的质量,然而事实表明,来自东南亚的工作人员,更倾向于认为充满了在大型酒店中穿西装的人和酒杯的会议照片看起来令人惊叹。这种不匹配是对我们生活的泡沫的一个提醒,但它同时也是一个实际问题,因为我们产品的目标受众是美国人,他们看到会议照片会感到压抑和沮丧。最终,我们六个 Jetpac 团队的成员自己为超过 200 万张照片进行了评分,因为我们比任何可以被训练去做这件事的人都更清楚标准。
这是一个极端的例子,但它表明标注过程在很大程度上依赖于应用程序的要求。对于大多数生产用例来说,找出模型正确问题的正确答案需要花费很长的一段时间,而这对于正确地解决问题至关重要。如果你正在试图让模型回答错误的问题,那么将永远无法在这个不可靠的基础上建立可靠的用户体验。
图片来自 Thomas Hawk
我发现能够判断你所问的问题是否正确的唯一方法是对你的应用程序进行模拟,而不是使用有人参与迭代的机器学习模型。因为在背后有人类的参与,这种方法有时被称为「Wizard-of-Oz-ing」。在 Jetpac 的案例中,我们让人们为一些旅行指南样例手动选择照片,而不是训练一个通过测试用户的反馈来调整挑选图片的标准的模型。一旦我们可以很可靠地从测试中获得正面反馈,我们接下来就可以将我们设计的照片选择规则转化为标注指导手册,以便用这样的方法获得数百万个图像用作训练集。然后,我们使用这些数据训练出了能够预测数十亿张照片质量的模型,但它的 DNA 来自我们设计的原始的人工规则。
在真实数据上进行训练
在 Jetpac 案例中,我们用于训练模型的图像和我们希望应用模型的图像来源相同(主要是 Facebook 和 Instagram),但是我发现的一个常见问题是,训练数据集与模型最终输入数据的一些关键差异最终会体现在生产中。例如,我经常会看到基于 ImageNet 训练的模型在被尝试应用到无人机或机器人中时会遇到问题。这是因为 ImageNet 大多为人们拍摄的照片,而这些照片存在着很多共性,比如:用手机或照相机拍摄,使用中性镜头,大致在头部高度,在白天或人造光线下拍摄,标记的物体居中并位于前景中等等。而机器人和无人机使用视频摄像机,通常配有高视野镜头,拍摄位置要么是在地面上要么是在高空中,同时缺乏光照条件,并且由于没有对于物体轮廓的智能判定,通常只能进行裁剪。这些差异意味着,如果你只是在 ImageNet 上训练模型并将其部署到某一台设备上,那么将无法获得较好的准确率。
训练数据和最终模型输入数据的差异还可能体现在很多细微的地方。想象一下,你正在使用世界各地的动物数据集来训练一个识别野生动物的相机。如果你只打算将它部署在婆罗洲的丛林中,那么企鹅标签被选中的概率会特别低。如果训练数据中包含有南极的照片,那么模型将会有很大的机会将其他动物误认为是企鹅,因而模型整体的准确率会远比你不使用这部分训练数据时低。
有许多方法可以根据已知的先验知识(例如,在丛林环境中大幅度降低企鹅的概率)来校准结果,但使用能够反映产品真实场景的训练集会更加方便和有效。我发现最好的方法是始终使用从实际应用程序中直接捕获到的数据,这与我上面提到的「Wizard of Oz」方法之间存在很好的联系。这样一来,在训练过程中使用人来进行反馈的部分可以被数据的预先标注所替代,即使收集到的标签数量非常少,它们也可以反映真实的使用情况,并且也基本足够被用于进行迁移学习的一些初始实验了。
混淆矩阵
当我研究语音指令的例子时,我看到的最常见的报告之一是训练期间的混淆矩阵。这是一个显示在控制台中的例子:
[[258 0 0 0 0 0 0 0 0 0 0 0] [ 7 6 26 94 7 49 1 15 40 2 0 11] [ 10 1 107 80 13 22 0 13 10 1 0 4] [ 1 3 16 163 6 48 0 5 10 1 0 17] [ 15 1 17 114 55 13 0 9 22 5 0 9] [ 1 1 6 97 3 87 1 12 46 0 0 10] [ 8 6 86 84 13 24 1 9 9 1 0 6] [ 9 3 32 112 9 26 1 36 19 0 0 9] [ 8 2 12 94 9 52 0 6 72 0 0 2] [ 16 1 39 74 29 42 0 6 37 9 0 3] [ 15 6 17 71 50 37 0 6 32 2 1 9] [ 11 1 6 151 5 42 0 8 16 0 0 20]]
这可能看起来很吓人,但它实际上只是一个表格,显示网络出错的详细信息。这里有一个更加美观的带标签版本:
表中的每一行代表一组与真实标签相同的样本,每列显示标签预测结果的数量。例如,高亮显示的行表示所有无声的音频样本,如果你从左至右阅读,则可以发现标签预测的结果是正确的,因为每个标签都落在」Silence」一栏中。这表明,该模型可以很好地识无声的音频片段,不存在任何一个误判的情况。从列的角度来看,第一列显示有多少音频片段被预测为无声,我们可以看到一些实际上是单词的音频片段被误认为是无声的,这其中有很多误判。这些知识对我来说非常有用,因为它让我更加仔细地观察那些被误认为是无声的音频片段,而这些片段事实上并不总是安静的。这帮助我通过删除音量较低的音频片段来提高数据的质量,而如果没有混淆矩阵的线索,我将无从下手。
几乎所有对结果的总结都可能是有用的,但是我发现混淆矩阵是一个很好的折衷方案,它提供的信息比单个的准确率更多,同时也不会涵盖太多我无法处理的细节。在训练过程中观察数字变化也很有用,因为它可以告诉你模型正在努力学习什么类别,并可以让你在清理和扩充数据集时专注于某些方面。
可视化模型
可视化聚类是我最喜欢的用来理解我的网络如何解读训练数据的方式之一。TensorBoard 为这种探索提供了很好的支持,尽管它经常被用于查看词嵌入,但我发现它几乎适用于与任何嵌入有类似的工作方式的网络层。例如,图像分类网络在最后的全连接或 softmax 单元之前通常具有的倒数第二层,可以被用作嵌入(这就是简单的迁移学习示例的工作原理,如 TensorFlow for Poets(https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0))。严格意义上来说,这些并不是嵌入,因为我们并没有在训练过程中努力确保在真正的嵌入具有希望的空间属性,但对它们的向量进行聚类确实会产生一些有趣的结果。
举例来说,之前一个同我合作过的团队对图像分类模型中某些动物的高错误率感到困惑。他们使用聚类可视化来查看他们的训练数据是如何分布到各种类别的,当他们看到「捷豹」时,他们清楚地发现数据被分成两个彼此之间存在一定间隔的不同的组。
图片来源: djblock99 和 Dave Adams
他们所看到的的图表如上所示。一旦我们将每个聚类的图片展示出来,结果就变得显而易见:很多捷豹品牌车辆被错误地标记为捷豹(动物)。一旦他们知道了这一问题,他们就能够检查标注过程,并意识到工作人员的指导和用户界面是令人困惑的。有了这些信息,他们就能够改善标注者的训练过程并修复工具中存在的问题,从而将所有汽车图像从捷豹类别中删除,进而让模型在该类别上获得更好的效果。
通过深入了解训练集中的内容,聚类提供了与仅仅观察数据相同的好处,但网络实际上是通过根据自己的学习理解将输入分组来指导你的探索。作为人类,我们非常善于从视觉上发现异常情况,所以将我们的直觉和计算机处理大量输入的能力相结合为追踪数据集质量问题提供了一个高可扩展的解决方案。关于如何使用 TensorBoard 来完成这样的工作的完整教程超出了本文的范围,但如果你真的想要提高结果,我强烈建议你熟悉这个工具。
持续收集数据
我从来没有见过收集了更多的数据,但最终没有提高模型准确性的情况,事实证明,有很多研究都支持我的这一经验。
该图来自「重新审视数据的不合理的有效性(https://ai.googleblog.com/2017/07/revisiting-unreasonable-effectiveness.html)」,并展示了在训练集的规模增长到数以亿计的情况下图像分类的模型准确率如何持续增加。Facebook 最近进行了更进一步的探索,使用数十亿有标注 Instagram 图像在 ImageNet 图像分类任务上取得了最优的准确率(https://www.theverge.com/2018/5/2/17311808/facebook-instagram-ai-training-hashtag-images)。这表明,即使对于已有大型、高质量数据集的任务来说,增加训练集的大小仍然可以提高模型效果。
这意味着,只要任何用户可以从更高的模型准确率中受益,你就需要一个可以持续改进数据集的策略。如果可以的话,找到创造性的方法利用微弱的信号来获取更大的数据集(是一个可以尝试的方向)。Facebook 使用 Instagram 标签就是一个很好的例子。另一种方法是提高标注过程的智能化程度,例如通过将模型的初始版本的标签预测结果提供给标注人员,以便他们可以做出更快的决策。这种方法的风险是可能在标注早期造成某种程度的偏见,但在实践中,所获得的好处往往超过这种风险。此外,通过聘请更多的人来标记新的训练数据来解决这个问题,通常也是一项物有所值的投资,但是对这类支出没有预算传统的组织可能会遇到阻碍。如果你是非营利性组织,让你的支持者通过某种公共工具更方便地自愿提供数据,这可能是在不增加开支的情况下增加数据集大小的可取方法。
当然,对于任何组织来说,最优的解决方案都是应该有一种产品,它可以在使用时自然生成更多的有标注数据。虽然我不会太在意这个想法,它在很多真实的场景中都不适用,因为人们只是想尽快得到答案,而不希望参与到复杂的标注过程中来。而对于初创公司来说,这是一个很好的投资热点,因为它就像是一个改进模型的永动机,当然,在清理或增强数据时总是无法避免产生一些单位成本,所以经济学家最终经常会选择一种比真正免费的方案看起来更加便宜一点的版本。
潜在的风险
几乎所有的模型错误对应用程序用户造成的影响都远大于损失函数可以捕获的影响。你应该提前考虑可能的最糟糕的结果,并尝试设计模型的后盾以避免它们发生。这可能只是一个因为误报的成本太高而不想让模型去预测的类别的黑名单,或者你可能有一套简单的算法规则,以确保所采取的行动不会超过某些已经设定好的边界参数。例如,你可能会维持一个你不希望文本生成器输出的脏话词表,即便它们存在于训练集中。因为它们对于你的产品来说是很不合适的。
究竟会发生什么样的不好结果在事前总是不那么明显的,所以从现实世界中的错误中吸取教训至关重要。最简单的方法之一就是在一旦你有一个半成品的时候使用错误报告。当人们使用你的应用程序时,你需要让他们可以很容易地报告不满意的结果。要尽可能获得模型的完整输入,但当它们是敏感数据时,仅仅知道不良输出是什么同样有助于指导你的调查。这些类别可被用于选择收集更多数据的来源,以及你应该去了解其哪些类别的当前标签质量。一旦对模型进行了新的调整,除了正常的测试集之外,还应该对之前产生不良结果的输入进行单独的测试。考虑到单个指标永远无法完全捕捉到人们关心的所有内容,这个错例图片库有点像回归测试,并且为你提供了一种可以用来跟踪你改进用户体验程度的方式。通过查看一小部分在过去引发强烈反应的例子,你可以得到一些独立的证据来表明你实际上正在为你的用户提供更好的服务。如果因为过于敏感而无法获取模型的输入数据,请使用内部测试或内部实验来确定哪些输入可以产生这些错误,然后替换回归数据集中的那些输入。
在这篇文章中,我希望设法说服你在数据上花费更多时间,并给你提供一些关于如何改进它的想法。目前这个领域还没有得到足够的关注,我甚至觉得我在这里的建议是在抛砖引玉,所以我感谢每一个与我分享他们的策略的人,并且我希望未来我可以从更多的人那里了解到更多有成效的方法。我认为会有越来越多的组织分配工程师团队专门用于数据集的改进,而不是让机器学习研究人员来推动这一方向的进展。我期待看到整个领域能够得益于在数据改进上的工作。我总是为即使在训练数据存在严重缺陷的情况下模型也可以良好运作而感到惊叹,所以我迫不及待地想看到在改进数据以后我们可以取得的效果!