剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

一、引言

我发现网络上铺天盖地的人工智能相关信息中的绝大多数都可以分为两类:一是向外行人士解释进展情况,二是向其他研究者解释进展。我还没找到什么好资源能让有技术背景但对不了解更前沿的进展的人可以自己充电。我想要成为这中间的桥梁——通过为前沿研究提供(相对)简单易懂的详细解释来实现。首先,让我们从《Overcoming Catastrophic Forgetting in Neural Networks》这篇论文开始吧。

二、动机

实现通用人工智能的关键步骤是获得连续学习的能力,也就是说,一个代理(agent)必须能在不遗忘旧任务的执行方法的同时习得如何执行新任务。然而,这种看似简单的特性在历史上却一直未能实现。McCloskey 和 Cohen(1989)首先注意到了这种能力的缺失——他们首先训练一个神经网络学会了给一个数字加 1,然后又训练该神经网络学会了给数字加 2,但之后该网络就不会给数字加 1 了。他们将这个问题称为「灾难性遗忘(catastrophic forgetting)」,因为神经网络往往是通过快速覆写来学习新任务,而这样就会失去执行之前的任务所必需的参数。

克服灾难性遗忘方面的进展一直收效甚微。之前曾有两篇论文《Policy Distillation》和《Actor-Mimic: Deep Multitask and Transfer Reinforcement Learning》通过在训练过程中提供所有任务的数据而在混合任务上实现了很好的表现。但是,如果一个接一个地引入这些任务,那么这种多任务学习范式就必须维持一个用于记录和重放训练数据的情景记忆系统(episodic memory system)才能获得良好的表现。这种方法被称为系统级巩固(system-level consolidation),该方法受限于对记忆系统的需求,且该系统在规模上必须和被存储的总记忆量相当;而随着任务的增长,这种记忆的量也会增长。

然而,你可能也直觉地想到了带着一个巨大的记忆库来进行连续学习是错误的——毕竟,人类除了能学会走路,也能学会说话,而不需要维持关于学习如何走路的记忆。哺乳动物的大脑是如何实现这一能力的?Yang, Pan and Gan (2009) 说明学习是通过突触后树突棘(postsynaptic dendritic spines)随时间而进行的形成和消除而实现的。树突棘是指「神经元树突上的突起,通常从突触的单个轴突接收输入」,如下所示:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

具体而言,这些研究者检查了小鼠学习过针对特定新任务的移动策略之后的大脑。当这些小鼠学会了最优的移动方式之后,研究者观察到树突棘形成出现了显著增多。为了消除运动可能导致树突棘形成的额外解释,这些研究者还设置了一个进行运动的对照组——这个组没有观察到树突棘形成。然后这些研究者还注意到,尽管大多数新形成的树突棘后面都消失了,但仍有一小部分保留了下来。2015 年的另一项研究表明当特定的树突棘被擦除时,其对应的技能也会随之消失。

Kirkpatrick et. 在论文《Overcoming catastrophic forgetting in neural networks》中提到:「这些实验发现……说明在新大脑皮层中的连续学习依赖于特定于任务的突触巩固(synaptic consolidation),其中知识是通过使一部分突触更少塑性而获得持久编码的,所以能长时间保持稳定。」我们能使用类似的方法(根据每个神经元的重要性来改变单个神经元的可塑性)来克服灾难性遗忘吗?

这篇论文的剩余部分推导并演示了他们的初步答案:可以。

三、直觉

假设有两个任务 A 和 B,我们想要一个神经网络按顺序学习它们。当我们谈到一个学习任务的神经网络时,它实际上意味着让神经网络调整权重和偏置(统称参数/θ)以使得神经网络在该任务上实现更好的表现。之前的研究表明对于大型网络而言,许多不同的参数配置可以实现类似的表现。通常,这意味着网络被过参数化了,但是我们可以利用这一点:过参数化(overparameterization)能使得任务 B 的配置可能接近于任务 A 的配置。作者提供了有用的图形:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

在图中,θ∗A 代表在 A 任务中表现最好的 θ 的配置,还存在多种参数配置可以接近这个表现,灰色表示这一配置的集合;在这里使用椭圆来表示是因为有些参数的调整权重比其他参数更大。如果神经网络随后被设置为学习任务 B 而对记住任务 A 没有任何兴趣(即遵循任务 B 的误差梯度),则该网络将在蓝色箭头的方向上移动其参数。B 的最优解也具有类似的误差椭圆体,上面由白色椭圆表示。

然而,我们还想记住任务 A。如果我们只是简单使参数固化,就会按绿色箭头发展,则处理任务 A 和 B 的性能都将变得糟糕。最好的办法是根据参数对任务的重要程度来选择其固化的程度;如果这样的话,神经网络参数的变化方向将遵循红色箭头,它将试图找到同时能够很好执行任务 A 和 B 的配置。作者称这种算法「弹性权重巩固(EWC/Elastic Weight Consolidation)」。这个名称来自于突触巩固(synaptic consolidation),结合「弹性的」锚定参数(对先前解决方案的约束限制参数是二次的,因此是弹性的)。

四、数学解释

在这里存在两个问题。第一,为什么锚定函数是二次的?第二,如何判定哪个参数是「重要的」?

在回答这两个问题之前,我们先要明白从概率的角度来理解神经网络的训练意味着什么。假设我们有一些数据 D,我们希望找到最具可能性的参数,它被表示为 p(θ|D)。我们可以是用贝叶斯规则来计算这个条件概率。

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

如果我们应用对数变换,则方程可以被重写为:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

假设数据 D 由两个独立的(independent)部分构成,用于任务 A 的数据 DA 和用于任务 B 的数据 DB。这个逻辑适用于多于两个任务,但在这里不用详述。使用独立性(independence)的定义,我们可以重写这个方程:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

看看(3)右边的中间三个项。它们看起来很熟悉吗?它们应该。这三个项是方程(2)的右边,但是 D 被 DA 代替了。简单来说,这三个项等价于给定任务 A 数据的网络参数的条件概率的对数。这样,我们得到了下面这个方程:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

让我们先解释一下方程(4)。左侧仍然告诉我们如何计算整个数据集的 p(θ| D),但是当求解任务 A 时学习的所有信息都包含在条件概率 p(θ| DA)中。这个条件概率可以告诉我们哪些参数在解决任务 A 中很重要。

下一步是不明确的:「真实的后验概率是难以处理的,因此,根据 Mackay (19) 对拉普拉斯近似的研究,我们将该后验近似为一个高斯分布,其带有由参数θ∗A 给定的均值和一个由 Fisher 信息矩阵 F 的对角线给出的对角精度。」

让我们详细解释一下。首先,为什么真正的后验概率难以处理?论文并没有解释,答案是:贝叶斯规则告诉我们

p(θ|DA) 取决于 p(DA)=∫p(DA|θ′)p(θ′)dθ′,其中θ′是参数空间中的参数的可能配置。通常,该积分没有封闭形式的解,留下数值近似以作为替代。数值近似的时间复杂性相对于参数的数量呈指数级增长,因此对于具有数亿或更多参数的深度神经网络,数值近似是不实际的。

然后,Mackay 关于拉普拉斯近似的工作是什么,跟这里的研究有什么关系?我们使用θ*A 作为平均值,而非数值近似后验分布,将其建模为多变量正态分布。方差呢?我们将把每个变量的方差指定为方差的倒数的精度。为了计算精度,我们将使用 Fisher 信息矩阵 F。Fisher 信息是「一种测量可观察随机变量 X 携带的关于 X 所依赖的概率的未知参数θ的信息的量的方法。」在我们的例子中,我们感兴趣的是测量来自 DA 的每个数据所携带的关于θ的信息的量。Fisher 信息矩阵比数值近似计算更可行,这使得它成为一个有用的工具。

因此,我们可以为我们的网络在任务 A 上训练后在任务 B 上再定义一个新的损失函数。让 LB(θ)仅作为任务 B 的损失。如果我们用 i 索引我们的参数,并且选择标量λ来影响任务 A 对任务 B 的重要性,则在 EWC 中最小化的函数 L 是:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

作者声称,EWC 具有相对于网络参数的数量和训练示例的数量是线性的运行时间。

五、实验和结果

1. 随机模式

EWC 的第一个测试是简单地看它能否比梯度下降(GD)更长时间地记住简单的模式。这些研究者训练了一个可将随机二元模式和二元结果关联起来的神经网络。如果该网络看到了一个之前见过的二元模式,那么就通过观察其信噪比是否超过了一个阈值来评价其是否已经「记住」了该模式。使用这种简单测试的原因是其具有一个分析解决方案。随着模式数量的增加,EWC 和 GD 的表现都接近了它们的完美答案。但是 EWC 能够记忆比 GD 远远更多的模式,如下图所示:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

2. MNIST

这些研究者为 EWC 所进行的第二个测试是一个修改过的 MNIST 版本。其没有使用被给出的数据,而是生成了三个随机的排列(permutation),并将每个排列都应用到该数据集中的每张图像上。任务 A 是对被第一个排列转换过的 MNIST 图像中的数字进行分类,任务 B 是对被第二个排列转换过的图像中的数字进行分类,任务 C 类推。这些研究者构建了一个全连接的深度神经网络并在任务 A、B 和 C 上对该网络进行了训练,同时在任务 A(在 A 上的训练完成后)、B(在 B 上的训练完成后)和 C(在 C 上的训练完成后)上测试了该网络的表现。训练是分别使用随机梯度下降(SGD)、使用 L2 正则化的均匀参数刚度(uniform parameter-rigidity using L2 regularization)、EWC 独立完成的。下面是它们的结果:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

如预期的一样,SGD 出现了灾难性遗忘;在任务 B 上训练后在任务 A 上的表现出现了快速衰退,在任务 C 上训练后更是进一步衰退。使参数更刚性能维持在第一个任务上的表现,但却不能学习后续的任务。而 EWC 能在成功学习新任务的同时记住如何执行之前的任务。随着任务数量的增加,EWC 也能维持相对较好的表现,相对地,带有 dropout 正则化的 SGD 的表现会持续下降,如下所示:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

3. Atari 2600

DeepMind 曾在更早的一篇论文中表明:当一次训练和测试一个游戏时,深度 Q 网络(DQN)能够在多种 Atari 2600 游戏上实现超人类的表现。为了了解 EWC 可以如何在这种更具挑战性的强化学习环境中实现连续学习,这些研究者对该 DQN 代理进行了修改,使其能够使用 EWC。但是,他们还需要做出一项额外的修改:在哺乳动物的连续学习中,为了确定一个代理当前正在学习的任务,必需要一个高层面的系统,但该 DQN 代理完全不能做出这样的确定。为了解决这个问题,研究者为其添加了一个基于 forget-me-not(FMN)过程的在线聚类算法,使得该 DQN 代理能够为每一个推断任务维持各自独立的短期记忆缓存。

这就得到了一个能够跨两个时间尺度进行学习的 DQN 代理。在短期内,DQN 代理可以使用 SGD 等优化器(本案例中研究者使用了 RMSProp)来从经历重放(experience replay)机制中学习。在长期内,该 DQN 代理使用 EWC 来巩固其从各种任务上学习到的知识。研究者从 DQN 实现了人类水平表现的 Atari 游戏(共 19 个)中随机选出了 10 个,然后该代理在每个单独的游戏上训练了一段时间,如下所示:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

这些研究者对三个不同的 DQN 的代理进行了比较。蓝色代理没有使用 EWC,红色代理使用了 EWC 和 forget-me-not 任务标识符。褐色代理则使用 EWC,并被提供了真实任务标签。在一个任务上实现人类水平的表现被规范化为分数 1。如你所见,EWC 在这 10 个任务上都实现了接近人类水平的表现,而非 EWC 代理没能在一个以上的任务上做到这一点。该代理是否被给出了一个真实标签或是否必须对任务进行推导对结果的影响不大,但我认为这也表明了 FMN 过程的成果,而不仅仅是 EWC 的成功。

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

接下来的这个部分非常酷。如前所述,EWC 并没有在这 10 个任务上实现人类水平的表现。为什么会这样?一个可能的原因是 Fisher 信息矩阵可能对参数重要程度的估计不佳。为了实际验证这一点,这些研究者研究了在仅一个游戏上训练的代理的权重。不管这个游戏是什么游戏,它们都表现出了以下模式:如果权重受到了一个均匀随机扰动的影响,随着该扰动的增加,该代理的表现(规范化为 1)会下降;而如果权重受到的扰动得到了 Fisher 信息的对角线的逆的影响,那么该分数在面临更大的扰动时也能保持稳定。这说明 Fisher 信息在确定参数的真正重要性方面是很好的。

然后,研究者尝试在 null 空间中进行扰动。这本来应该是无效的,但实际上研究者观察到了与逆 Fisher 空间中的结果类似的结果。这说明使用 Fisher 信息矩阵会导致将一些重要参数标记为不重要的情况——「因此很有可能当前实现的主要限制是其低估了参数的不确定性。」

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

六、讨论

1. 贝叶斯说明

作者们对 EWC 给出了非常好的贝叶斯解读:「正式来说,当需要学习新的任务时,网络参数被先验所调节,也就是先前任务在给定数据参数上的后验分布。这能实现先验任务被约束的更快的参数学习率,并为重要的参数降低学习率。」

2. 重叠

在刚开始时,我提到神经网络的过参数化(overparameterization)让 EWC 能实现优异的表现。有一个很合理的问题就是:为了每个不同的任务将网络划分到特定部分,这些神经网络能否给出更好的表现?或者通过共享表征,这些网络是否能高效地使用其能力?为了解答这个问题,作者们测量了任务对在 Fisher 信息矩阵上的重叠情况(Fisher Overlap)。对高度类似的任务(例如只有一点不同的两个随机排列)而言,Fisher Overlap 相当高。即使不相似的任务,Fisher Overlap 也高于 0。随着网络深度的增加,Fisher Overlap 也会增加。下图演示了该结果:

剖析DeepMind神经网络记忆研究:模拟动物大脑实现连续学习

3. 突触可塑性的理论

研究者还讨论了 EWC 可能能为神经可塑性方面的研究提供信息。级联(Cascade)理论企图构建突触状态的模型来对可塑性和稳定性建模。尽管 EWC 不能随时间缓和参数,也因此不能遗忘先前的信息,但 EWC 和级联都能通过让突触更不可塑而延展记忆稳定性。最近的一项研究提出除了存储自身的实际权重之外,突触也存储当前权重的不确定性。EWC 是该思路的延展:在 EWC 中,每个突触存储三个值:权重、均值和方差。

相关推荐