基于元学习的自动化神经网络通道剪枝网络

基于元学习的自动化神经网络通道剪枝网络

转载自:学术头条(ID:SciTouTiao)

作者:Liyang

本文6427字,建议阅读17分钟

本文介绍一篇在神经网络压缩领域的通道剪枝方面有所创新的论文。

基于元学习的自动化神经网络通道剪枝网络

MetaPruning:Meta Learning for Automatic Neural Network Channel Pruning

论文作者:Zechun Liu , Haoyuan Mu ,Xiangyu Zhang ,Zichao Guo ,Xin Yang ,Tim Kwang-Ting Cheng , Jian Sun(香港科技大学,清华大学,旷视科技等)

论文地址:https://arxiv.org/pdf/1903.10258.pdf

开源地址:https://github.com/liuzechun/MetaPruning

前言

本文将对ICCV2019会议论文《MetaPruning:Meta Learning for Automatic Neural Network Channel Pruning》进行解读,这篇论文在神经网络压缩领域的通道剪枝方面有所创新。作者基于元学习(meta learning)方法,首先训练了一个PruningNet网络,可为给定目标网络的任何裁剪结构生成权重系数,然后采用进化过程通过不同约束条件来搜索性能好的剪枝网络,并且由于权重由PruningNet生成,不需要在搜索过程中进行任何微调。与当前最新的剪枝算法比较,作者所采用的算法在MobileNet V1 / V2和ResNet上有非常好的表现。

相关工作

通道剪枝(Channel Pruning)

通道剪枝是一种有效的神经网络压缩/加速方法,在业界已广泛使用剪枝方法包括三个阶段:训练大型的超参数化网络,修剪不太重要的权重或通道,微调或重新训练修剪的网络。其中,第二阶段是关键,实现逐层的迭代修剪和快速微调或权重重构,以保持精度【Jose M Alvarez and Mathieu Salzmann. Learning the number of neurons in deep networks. In Advances in Neural Information Processing Systems, pages 2270-2278, 2016】。

在权重修剪时, Han等修剪单个权重以压缩模型尺寸【Song Han, Huizi Mao, and William J Dally. Deep compression:Compressing deep neural networks with pruning】。但是,权重修剪会导致非结构化的稀疏滤波器,通用硬件几乎无法加速这种结构。近年来,Hu和Li等聚焦于CNN中的通道修剪,而不是单个权重,该方法删除了整个权重过滤器【Hengyuan Hu, Rui Peng, Yu-Wing Tai, and Chi-Keung Tang. Network trimming:A data-driven neuron pruning approach towards efficient deep architectures. arXiv preprint arXiv:1607.03250, 2016】。

传统的通道剪枝方法主要依靠数据驱动的稀疏度约束【Zehao Huang and Naiyan Wang. Data-driven sparse structure selection for deep neural networks. In Proceedings of the European Conference on Computer Vision (ECCV), pages 304-320, 2018】或人为设计的策略【Yihui He, Xiangyu Zhang, and Jian Sun. Channel pruning for accelerating very deep neural networks. In Proceedings of the IEEE International Conference on Computer Vision, pages 1389-1397, 2017】。在大多数传统的通道剪枝中,需要人为设置每一层的压缩率,比较耗时,且容易陷入次优解。

AutoML(Automated Machine Learning)

AutoML方法基于反馈循环【Jiahui Yu, Linjie Yang, Ning Xu, Jianchao Yang, and Thomas Huang. Slimmable neural networks. arXiv preprint arXiv:1812.08928, 2018】或强化学习【Yihui He, Ji Lin, Zhijian Liu, Hanrui Wang, Li-Jia Li, and Song Han. Amc:Automl for model compression and acceleration on mobile devices. In Proceedings of the European Conference on Computer Vision (ECCV), pages 784-800, 2018】,以迭代模式自动修剪通道。与传统的通道剪枝方法相比,AutoML方法节省人力,有助于减轻通道剪枝中的超参数所需的手动工作,并且可以直接优化如硬件latency(推理时延)之类的指标。

元学习(Meta Learning)

元学习是指通过观察不同的机器学习方法如何执行各种学习任务来学习。权重预测是指神经网络的权重是由另一个神经网络预测的,而非直接学习【David Ha, Andrew Dai, and Quoc V Le. Hypernetworks. arXiv preprint arXiv:1609.09106, 2016】。元学习可用于零样本学习(few/zero-shot learning)【Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning. 2016】和转移学习【Yu-Xiong Wang and Martial Hebert. Learning to learn:Model regression networks for easy small sample learning. In European Conference on Computer Vision, pages 616-634. Springer, 2016】。

神经网络架构搜索(Neural Architecture Search)

神经网络架构搜索通过强化学习【Barret Zoph and Quoc V Le. Neural architecture search with reinforcement learning. arXiv preprint arXiv:1611.01578,2016】,遗传算法【Lingxi Xie and Alan Yuille. Genetic cnn. In Proceedings of the IEEE International Conference on Computer Vision, pages 1379-1388, 2017.】或基于梯度【Bichen Wu, Xiaoliang Dai, Peizhao Zhang, Yanghan Wang, Fei Sun, YimingWu, Yuandong Tian, Peter Vajda, Yangqing Jia, and Kurt Keutzer. Fbnet:Hardware-aware efficient convnet design via differentiable neural architecture search. arXiv preprint arXiv:1812.03443, 2018.】的方法找到最佳的网络结构和超参数。作者将训练PruningNet通过权重预测来做连续的通道修剪。

整体架构

如下图所示,MetaPruning有两个阶段:(1)训练一个PruningNet:在每次迭代过程中,随机地生成网络编码向量(即每层中的信道数),剪枝网络也相应地构建。PruningNet将上述网络编码向量作为输入,来生成剪枝网络的权重。(2)搜索最佳剪枝网络:通过改变网络编码向量来构造许多剪枝网络,并利用PruningNet预测的权重对验证集的优劣进行了评估,搜索时无需微调或重新训练。

基于元学习的自动化神经网络通道剪枝网络

MetaPruning具有如下优点

  • 与传统剪枝方法相比,MetaPruning无需人工调超参数,并可直接给定优化指标。
  • 与其他AutoML方法相比,MetaPruning能方便搜索到所需结构,而无需手动调整强化学习超参数。
  • 采用元学习修剪类似于ResNet结构的“快捷通道”(shortcuts)。

方法

通道剪枝问题描述

通道剪枝问题用公式表示为:

基于元学习的自动化神经网络通道剪枝网络

在此公式中,A是剪枝前的网络,需尝试找到剪枝网络的第1层到第L层的通道宽度(c_1, c_2,····, c_l),使训练后的权重有最小的损失,同时C满足相应的约束(FLOPs或者latency)。为此,作者构建一个PruningNet元网络,可通过验证集来快速评估所有可能的剪枝网络结构。 然后采用搜索方法(作者采用进化算法)来搜索最佳的剪枝网络。

PruningNet训练

以前的剪枝方法将通道修剪问题分解为子问题,即逐层修剪不重要的通道或添加稀疏正则化。考虑整个被修剪的网络结构来执行通道修剪任务,有利于找到用于剪枝的最佳解决方案,并且可以解决“shortcuts”通道的修剪问题。而且研究表明剪枝后的权重与剪枝后的网络结构相比并不重要【Zhuang Liu, Mingjie Sun, Tinghui Zhou, Gao Huang, and Trevor Darrell. Rethinking the value of network pruning. arXiv preprint arXiv:1810.05270, 2018】。从这个角度,作者认为可以直接预测最佳剪枝网络,而无需迭代确定权重过滤器。PruningNet是一个元网络,将一个网络编码向量(c_1, c_2,····, c_l)作为输入,输出剪枝网络的权重,如下:

基于元学习的自动化神经网络通道剪枝网络

如下图所示,PruningNet块由两个全连接层组成,在前向传递中,PruningNet将网络编码向量作为输入,并生成权重矩阵。同时,构建剪枝网络,并使每层中的输出通道宽度等于网络编码向量中的元素。修剪生成的权重矩阵以匹配修剪网络中输入和输出通道的数量。如给定一批输入图像,就可以使用上述生成的权重来计算修剪网络的损失。

在后向传递过程中,并不更新剪枝网络的权重,而是计算PruningNet权重的梯度。由于PruningNet中全连接层的输出与PrunedNet中先前卷积层的输出之间的reshape以及convolution也是可微的,因此可以通过链式规则计算PruningNet中权重的梯度。

基于元学习的自动化神经网络通道剪枝网络

在每次迭代中,作者将网络编码向量随机化。PruningNet通过将向量作为输入来生成权重。在每次迭代中再调整网络编码向量,PruningNet可学会为各种剪枝的网络生成不同的权重。

基于元学习的自动化神经网络通道剪枝网络

如上图(a)所示为与Pruned Network连接的PruningNet的网络结构。 PruningNet和Pruned Network通过网络编码向量的输入以及图像的小批量进行联合训练。(b)对PruningNet块生成的权重矩阵进行reshape和crop操作。

剪枝网络搜索

在PruningNet训练后,可以将输入网络编码到PruningNet中,生成相应的权重,并在验证集上进行评估,来获取每个可能剪枝网络的准确度。由于网络编码向量数量巨大,为了找到约束条件下较高准确度的剪枝网络,作者使用进化搜索,同时兼顾任何软约束或硬约束。

每个剪枝网络被编码成一个包含每层通道数量的向量,称为“剪枝网络的基因”(genes of pruned networks)。在严格约束条件下,首先随机挑选多个基因,通过评估获得相应剪枝网络的准确度。然后选择准确度top k个基因,以产生具有变异(mutation)和交叉(crossover)的新基因。变异通过随机改变基因中元素的比例来实现,交叉通过随机重组两个亲本基因中的基因以产生后代。经过多次迭代上述过程,可获得满足约束条件的基因,同时获得最高的准确度。

实验

MobileNets、ResNet网络的处理

1.MobileNet V1

MobileNet V1没有shortcut结构。为此,作者构造了等于MobileNet v1中卷积层数的PruningNet块,并且每个PruningNet块都由两个串联的全连接层组成。通过修剪原始权重矩阵的左上部分,来匹配输入和输出通道,在不同的迭代中,生成了不同通道宽度的编码向量。

2.MobileNet V2

每个阶段都从匹配两个阶段之间的维度bottleneck block开始。如含有shortcut结构,则将输入特征图与输出特征图相加,为了修剪这种结构,会生成两个网络编码向量,一个编码整个阶段的输出通道以shortcut中的通道,另一个编码每个block的中间通道。在PruningNet中,先将网络编码向量解码为每个block的输入,输出和中间通道压缩比,然后生成block中相应的权重矩阵。

3.ResNet

ResNet与MobileNet v2结构相似,只是中间层的卷积类型、下采样块和每个阶段的块数不同,因此,处理过程相同。

FLOPs约束下的剪枝

如下表所示为MetaPruning的效果。MobileNet V1上,在0.25x时,MetaPruning提升了6.6%的准确度。MobileNet V2上,当约束到43M FLOPs时,MetaPruning提升了3.7%的准确度。

基于元学习的自动化神经网络通道剪枝网络

基于元学习的自动化神经网络通道剪枝网络

在ResNet-50上,MetaPruning也明显优于统一基准和其他传统修剪方法。

基于元学习的自动化神经网络通道剪枝网络

作者也与AutoML剪枝方法做了对比,结果如下表所示。与AMC【Yihui He, Ji Lin, Zhijian Liu, Hanrui Wang, Li-Jia Li, and Song Han. Amc:Automl for model compression and acceleration on mobile devices. In Proceedings of the European Conference on Computer Vision (ECCV), pages 784-800, 2018】相比,MetaPruning在MobileNet V1、MobileNet V2分别有0.1%、0.4%的提升,甚至FLOPs更低。此外,MetaPruning摆脱了手动调整强化学习超参数的麻烦。

基于元学习的自动化神经网络通道剪枝网络

latency约束下的剪枝

首先估计出在目标设备上执行不同输入和输出通道宽度的卷积层的latency,构建一个look-up表。然后从这个look-up表中计算得到构建网络的latency。如下表所示,在MobileNet V1和MobileNet V2上, MetaPruning剪枝网络在具有相同latency下具有更高的准确度。

基于元学习的自动化神经网络通道剪枝网络

基于元学习的自动化神经网络通道剪枝网络

剪枝网络的可视化

作者发现了如下图所示现象:(1)当向下采样方式为stride=2的深度卷积,需用更多通道来补偿特征图分辨率的下降,而MetaPruning自动学会了在下采样时保存更多的通道。(2)MetaPruning方法学会在最后阶段中修剪较少的shortcut通道数,如145M网络在最后阶段保留的通道数与300M网络的相似,而前阶段修剪了更多通道数。

基于元学习的自动化神经网络通道剪枝网络

基于元学习的自动化神经网络通道剪枝网络

消融实验

元学习中的权重预测机制可以有效地使不同剪枝结构的权重去相关,从而为PruningNet获得更高的准确度。

基于元学习的自动化神经网络通道剪枝网络

总结

作者介绍了用于通道剪枝的MetaPruning,它具有以下优点:1)与统一剪枝基准以及其他的通道剪枝方法(包括传统的和最新的AutoML)相比,其具有更好的性能;2)可以针对不同的约束条件灵活优化,而无需引入额外的超参数;3)可以有效处理类似ResNet的结构;4)整个过程非常高效

— 完 —

关注清华-青岛数据科学研究院官方微信公众平台“THU数据派”及姊妹号“数据派THU”获取更多讲座福利及优质内容。

相关推荐