深度学习self-attention流程详解(qkv)

深度学习self-attention流程详解(qkv)

一.从InputEmbedding和PositionalEnocding说起
1.将原文的所有单词汇总统计频率,删除低频词汇(比如出现次数小于20次的统一
定义为’<UNK>’);此时总共选出了假设10000个单词,则用数字编号为0~9999,一一对应,定义该对应表为word2num;然后用xaviers方法生成随机矩阵Matrix :10000行N列(10000行是确定的,对应10000个单词,N列自定义,常用N= 512,但训练会非常耗资源,亲测128足够了),我们定义为矩阵matX
深度学习self-attention流程详解(qkv)
2.这样,我们针对InputEmbedding,每句话就是一个对应的矩阵,该矩阵指定长度,例如‘中国人有中国梦’,对应矩阵!(这里定义矩阵行数为10,100可以理解为结束符,不足的在后面补0)图片描述
3.PositionEncoding
这里的PositionEncoding主要是为了保留句子的位置信息。其矩阵shape和Inputembedding一样。对于矩阵matPosition的每一行,第0,2,4,6,...等偶数列上的值用sin()函数激 活,第1,3,5,。。。等奇数列的值用cos()函数激活,将此矩阵定义为mapX。
深度学习self-attention流程详解(qkv)
4.这里,将两个矩阵相加,得到matEnc=matP+matX。然后matEnc进入模型编码部分的循环,即Figure1中左边红色框内部分,每个循环单元又分为4个小部分:multi-head attention, add&norm, feedForward, add&norm;
二.Encoder
深度学习self-attention流程详解(qkv)

1.Multi-head attention
(1)由三个输入,分别为V,K,Q,此处V=K=Q=matEnc(后面会经过变化变的不一样)
(2)首先分别对V,K,Q三者分别进行线性变换,即将三者分别输入到三个单层神经网络层,激活函数选择relu,输出新的V,K,Q(三者shape都和原来shape相同,即经过线性变换时输出维度和输入维度相同);
(3)然后将Q在最后一维上进行切分为num_heads(假设为8,必须可以被matENC整除)段,然后对切分完的矩阵在axis=0维上进行concat链接起来;对V和K都进行和Q一样的操作;操作后的矩阵记为Q_,K_,V_;如图深度学习self-attention流程详解(qkv)
(4)之后将Q_,K_.T进行想乘和Scale,得到的output为[8.10,10],执行output = softmax(output),然后将更新后的output想乘V_,得到再次更新后的output矩阵[8,10,64],然后将得到的output在0维上切分为8段,在2维上合并为[10,512]原始shape样式。
2.add&norm
add实际上是为了避免梯度消失,也就是曾经的残差网络解决办法:output=output+Q;
norm是标准化矫正一次,在output对最后一维计算均值和方差,用output减去均值除以方差+spsilon得值更新为output,然后变量gamma*output+变量beta

3.feed forward
(1)对output进行两次卷积,第一次卷积荷11,数目为词对应向量的维度。第二次卷积也是11,数目为N。
(2)两次卷积后得到的output和matEnc 的shape相同,更新matEnc = output,进行上述循环,循环自定义次数,进入解码部分。
三.decoder
1.InputEmbedding和Positionembedding相同。
2.进入解码循环,这里的Masked multi-head attention: 和编码部分的multi-head attention类似,但是多了一 次masked,因为在解码部分,解码的时候时从左到右依次解码的,当解出第一个字的时候,第一个字只能与第一个字计算相关性,当解出第二个字的时候,只能计算出第二个字与第一个字和第二个字的相关性,。。。;所以需要linalg.LinearOperatorLowerTriangular进行一次mask。
深度学习self-attention流程详解(qkv)
3.在解码中,add&norm,Feed forward和编码相同,其中multi-head attention:同编码部分,但是Q和K,V不再相同,Q=outputs,K=V=matEnc。
4.多次更新
5.Linear: 将最新的outputs,输入到单层神经网络中,输出层维度为“译文”有效单词总数;更新outputs

备注:借鉴出处https://zhuanlan.zhihu.com/p/...

相关推荐