教程|七个小贴士,顺利提升TensorFlow模型训练表现
选自deeplearningweekly
作者:Malte Baumann
机器之心编译
参与:吴攀、蒋思源、微胖
初来乍到的 TensorFlow 在 2015 年推出之后已经飞速成长为了 2016 年被用得最多的深度学习框架。我(Malte Baumann)在 TensorFlow 推出后几个月入了坑,在我努力完成硕士论文阶段开始了我的深度学习之旅。我用了一段时间才适应计算图(computation graph)和会话模型(session model)。
这篇短文并不是一篇 TensorFlow 介绍文章,而是介绍了一些有用的小提示,其中大多都是关于性能表现上的,这些提示揭示了一些常见的陷阱,能帮助你将你的模型和训练表现提升到新的层次。本文将从预处理和你的输入流程开始,然后介绍图构建,之后会谈到调试和性能优化。
预处理和输入流程:让预处理整洁精简
你是否疑惑过:为什么你的简单模型需要那么长的训练时间?一定是预处理做得不好!如果你的神经网络输入还需要做一些类似于转换数据这样的繁重的预处理工作,那么你的推理速度就会被大大减缓。我就遇到过这种情况,那是我正在创建了一个所谓的「距离地图(distance maps)」,其使用了一个自定义的 Python 函数而将「Deep Interactive Object Selection」所使用的灰度图像作为附加输入。我的训练速度顶多只达到了大约每秒 2.4 张图像——即使我已经换用了一个远远更加强大的 GTX 1080。然后我注意到了这个瓶颈,经过修复之后,我将训练速度提升到了每秒 50 张图像。
如果你也遇到了类似的瓶颈,通常你的第一直觉是应该优化一些代码。但减少训练流程的计算时间的一个更有效的方法是将预处理作为一个一次性的操作先完成——生成 TFRecord 文件。这样,你只需要执行一次繁重的预处理工作;有了预处理训练数据之后得到的 TFRecord 文件之后,你需要在训练阶段加载这些文件即可。即使你想引入某种形式的随机性来增强(augment)你的数据,你考虑的也应该是创建不同的变体,而不是用原始数据来填充你的训练流程。
注意你的队列
一种关注成本高昂的预处理流程的方法是使用 TensorBoard 中的队列图(queue graph)。如果你使用框架 QueueRunners,那么就会自动生成队列图并将总结存储在一个文件中。
该图能够显示你的机器是否能够保持队列充盈。如果你在该图中看到了负峰值,那就说明你的系统无法在你的机器想要处理一个批(batch)的时候生成新数据。这种情况的原因有很多,而根据我的经验,其中最常见的是 min_after_dequeue 的值太大。如果你的队列想要在内存中保存大量记录,那么这些记录会很快充满你的容量,这会导致 swapping 并显著减慢你的队列。其它原因还可能是硬件上的问题(比如磁盘速度太慢)或数据太大超过了系统处理能力。不管原因是什么,解决这些问题就能帮助你获得更高的训练速度。
图构建和训练:完成你的图
TensorFlow 独立的图构建和图计算模型在日常编程中十分少见,并且可能让初学者感到困惑。这种独立的构架可以应用于在第一次构建图时代码所出现的漏洞和错误信息,然后在实际评估时再一次运行,这个是和你过去代码只会评估一次这样的直觉相反的。
另一个问题是与训练回路(training loops)结合的图构建。因为这些回路一般是「标准」的 Python 循环,因此能改变图并向其添加新的操作。在连续不断地评估过程中改变图会造成重大的性能损失,但一开始却很难引起注意。幸好 TensorFlow 有一个简单的解决方案,仅仅在开始训练回路前调用 tf.getDefaultGraph().finalize() 完成你的图即可。这一段语句将会锁定你的图,并且任何想要添加新操作的尝试都将会报错。这正是我们想要达到的效果。
分析你的图的性能
TensorFlow 一个少有人了解的功能是性能分析(profiling)。这是一种记录图操作的运行时间和内存消耗的机制。如果你正在寻找系统瓶颈或想要弄清楚模型能不能在不 swapping 到硬件驱动的情况下训练,这种功能会十分有效。
为了生成性能分析数据,你需要在开启了 tracing 的情况下在你的图上执行一次运行。
之后,会有一个 timeline.json 文件会保存到当前文件夹,tracing 数据在 TensorBoard 中也就可用了。现在你很容易就看到每个操作用了多长时间和多少内存。仅仅只需要在 TensorBoard 中打开图视窗(graph view),并在左边选定你最后的运行,然后你就可以在右边看到性能的详细记录。一方面,你能根据这些记录调整你的模型,从而尽可能地利用你机器的运算资源。另一方面,这些记录可以帮助你在训练流程(pipeline)中找到瓶颈。如果你比较喜欢时间轴视窗,可以在谷歌 Chrome 的事件追踪性能分析工具(Trace Event Profiling Tool,https://www.chromium.org/developers/how-tos/trace-event-profiling-tool)中加载 timeline.json 文件。
另一个很赞的工具就是 tfprof(http://dwz.cn/5lRNeQ),将同样的功能用于内存和执行时间分析,但能提供更加便利的功能。额外的统计需要变换代码。
注意你的内存
就像前部分解释的,分析(profiling)能够让你跟踪特定运行的内存使用情况,不过,注意整个模型的内存消耗更重要些。需要始终确定没有超出机器内存,因为 swapping 肯定会让输入流程放慢,会让你的 GPU 开始坐等新数据。简单地 top,就像前文讲到的 TensorBoard 队列图就应当足够侦测到这样的行为。然后使用前文提过的 tracing,进行细节调查。
调试:print 会帮到你
我主要用 tf.Print 来调试诸如停滞损失(stagnating loss)或奇怪的输出等问题。由于神经网络天性的缘故,观察模型内部张量原始值(raw value)通常并没多大意义。没人能够解释清楚数以百万的浮点数并搞清楚哪儿有问题。不过,专门 print 出数据形状(shape)或均值就能发现重要见解。如果你正在试着实现一些既有模型,你就能比较自己模型值和论文或文章中的模型值,这有利于解决棘手问题或发现论文中的书写错误。
有了 TensorFlow 1.0,我们也有了新的调试工具(http://suo.im/4FtjRy)——这个看起来似乎还蛮有前途的。虽然我还没有用过,不过呢,肯定会在接下来的时间里尝试一下啦。
设定一个运算执行的超时时间
你已经实现了你的模型,载入了会话,但却没动静?这经常是有空列队(empty queues)造成的,但是如果你并不清哪个队列才是罪魁祸首,一个很简单的解决办法:在创造会话时,设定运行超时时间——当运行超过你设定的时限,脚本就会崩溃。
使用栈进行追踪,找出让你头疼的问题,解决问题然后继续训练。