长短时记忆网络
从前的RNN容易产生梯度消失的问题。长短时记忆网络可以解决这一问题。
原始的RNN的隐藏层只有一个状态h,它对于短期的输入非常敏感。假设再增加一个状态c,让其保存长期的状态,如下图:
新增的状态c称为单元状态(cell state)。将上图按时间维度展开:
对于时刻t,LSTM的输入有三个:当前时刻网络的输入值xt,上一时刻LSTM的输出值ht−1,上一时刻的单元状态ct−1。LSTM的输出有两个:当前时刻LSTM的输出值ht,当前时刻的单元状态ct。其中x,h,c都是向量。
LSTM的关键即为如何控制长期状态c。在这里,LSTM的思路是使⽤三个控制开关。第⼀个开关,负责控制继续保存⻓期状态c;第⼆个开关,负责控制将即时状态输⼊到⻓期状态c;第三个开关,负责控制是否把⻓期状态c作为当前的LSTM的输出。三个开关的作⽤如下图所示:
长短时记忆网络的前向计算
以上思路在算法中的实现需要门的概念。门实际上是一层全连接层,它的输入是一个向量,输出是一个0到1之间的实数向量。假设W是门的权重向量,b是偏置项,则门的表示为:
g(x)=σ(Wx+b)
在LSTM中,门是一种用于控制信息流动的机制,它通过一个全连接层加上Sigmoid函数实现,输出是一个取值在0到1之间的向量。门的作用是将这个输出与需要控制的信息向量按元素相乘,从而决定信息的保留程度:如果门输出为0,信息被完全屏蔽;如果输出为1,信息被完全保留;如果介于0和1,则表示信息部分通过。因此,门的状态可以看作是“半开半闭”的。
LSTM用两种门来控制单元状态c的内容,一个是遗忘门,它决定了上一时刻的单元状态ct−1有多少保留到当前时刻ct;另一个是输入门,它决定了当前时刻网络的输入xt有多少保存到单元状态ct。
LSTM用输出门来控制单元状态ct有多少输出到LSTM的当前输出值ht。
遗忘门
ft=σ(Wf⋅[ht−1,xt]+bf)
其中, Wf是遗忘门的权重矩阵,[ht−1,xt]表示将两个向量相连,bf是遗忘门的偏置项,σ是Sigmoid函数。如果输入维度是dx,隐藏层维度是dh,单元状态维度是dc(通常dc=dh),则遗忘门的权重矩阵Wf的维度是dc×(dh+dx)。事实上,权重矩阵Wf是由两个矩阵拼接而成的:一个是wfh,对应输入项ht−1,其维度是dc×dh;另一个是Wfx,对应输入项xt,其维度是dc×dx。即Wf=Wfhht−1+Wfxxt.
下面是神经网络与反向传播算法一章中Sigmoid函数的回顾:
Sigmoid函数定义如下:
Sigmoid(x)=1+e−x1
那么对于输出y:
y=1+e−ωT⋅x1
令y=Sigmoid(x),则y′=y(1−y)
Sigmoid函数的输出值就是在0和1之间的,起到控制信息保留程度的作用。
下图展示了遗忘门的计算。
输入门
it=σ(Wi⋅[ht−1,xt]+bi)
上式中,Wi是输入门的权重矩阵,bi是输入门的偏置项。
下图表示了输入门的计算。
下面计算用于描述当前输入的单元状态ct′, 它是根据上一次的输出和本次输入来计算的:
ct′=tanh(Wc⋅[ht−1,xt]+bc)
下图展示了ct′的计算。
现在计算当前时刻的单元状态ct。它是由上一次的单元状态ct−1按元素乘以遗忘门ft,再用当前输入的单元状态ct′按元素乘以输入门it,再将两个积加和产生的:
ct=ft∘ct−1+it∘ct′
下图展示了ct的计算。
下面是四个变量的对比表格:
符号 |
名称 |
含义 |
作用 |
计算公式 |
ft |
遗忘门(forget gate) |
控制旧记忆 ct−1 保留多少 |
保留多少旧信息 |
σ(Wf⋅[ht−1,xt]+bf) |
it |
输入门(input gate) |
控制当前输入写入记忆的比例 |
控制新信息写入量 |
σ(Wi⋅[ht−1,xt]+bi) |
ct′ |
候选单元状态(candidate cell state) |
由当前输入和上一个状态生成的新记忆提案 |
候选要写进记忆的值 |
tanh(Wc⋅[ht−1,xt]+bc) |
ct |
当前单元状态(cell state) |
当前时刻 LSTM 的最终记忆 |
真实记忆内容(保留旧的+写入新的) |
ft∘ct−1+it∘ct′ |
这样就把LSTM关于当前的记忆ct′和长期的记忆ct−1组合在一起,形成了新的单元状态ct。遗忘门的控制可以保留很久以前的信息,输入门的控制可以避免当前无关紧要的内容进入记忆。
输出门
ot=σ(Wo⋅[ht−1,xt]+bo)
下图展示了输出门的计算。
LSTM的最终输出是由输出门和单元状态共同决定的:
ht=ot∘tanh(ct)
下图展示了LSTM最终输出的计算。
至此为前向计算的所有步骤。
下面是两个变量的对比表格:ht 是每一步 LSTM 的输出,而 ct 是“控制”或“生成”这个输出的重要内部变量。
符号 |
名称 |
含义 |
特点 |
ct |
单元状态(cell state) |
长期记忆 |
可以跨很多时间步延续,信息“走得远” |
ht |
隐藏状态(hidden state) |
当前输出,也可理解为短期记忆 |
每个时刻都要输出(用于下游任务) |
长短时记忆网络的训练
LSTM训练算法框架
LSTM的训练算法仍然是反向传播算法。主要有三步:
- 前向计算每个神经元的输出值,即ft,it,ct,ot,ht五个向量的值
- 反向计算每个神经元的误差项δ。与普通RNN一样,LSTM误差项的反向传播也是包括两个方向:一个是沿时间的反向传播,即从当前时刻t开始,计算每个时刻的误差项;另一个是将误差项向上一层传播
- 根据相应的误差项,计算每个权重的梯度
公式和符号
在接下来的推导中,设定门的激活函数为Sigmoid函数,输出的激活函数为tanh函数。它们的导数分别为:
σ(z)σ′(z)tanh(z)tanh′(z)=y=1+e−z1=y(1−y)=y=ez+e−zez−e−z=1−y2
LSTM需要学习的参数共有8组,分别是遗忘门的权重矩阵Wf和偏置项bf,输入门的权重矩阵Wi和偏置项bi,输出门的权重矩阵Wo和偏置项bo,以及计算单元状态的权重矩阵Wc和偏置项bc。权重矩阵的两部分在反向传播中使用不同的公式,因此每个权重矩阵将以h,x角标区分。
在t时刻,LSTM的输出值为ht。定义t时刻的误差项δt为
δt=∂ht∂E
注意,和以往不同,这里假设误差项是损失函数对输出值的偏导数,而不是对加权输入nettl的导数
netf,tneti,tnetc′,tneto,tδf,tδi,tδc′,tδo,t=Wf[ht−1,xt]+bf=Wfhht−1+Wfxxt+bf=Wi[ht−1,xt]+bi=Wihht−1+Wixxt+bi=Wc[ht−1,xt]+bc=Wchht−1+Wcxxt+bc=Wo[ht−1,xt]+bo=Wohht−1+Woxxt+bo=∂netf,t∂E=∂neti,t∂E=∂netc′,t∂E=∂neto,t∂E
误差项沿时间的反向传递
沿时间反向传递误差项,就是要计算出t−1时刻的误差项δt−1。
δt−1T=∂ht−1∂E=∂ht∂E∂ht−1∂ht=δtT∂ht−1∂ht
∂ht−1∂ht是一个雅可比矩阵。如果隐藏层h的维度是N的话,那么它就是一个N×N矩阵。已知
htct=ot∘tanh(ct)=ft∘ct−1+it∘ct′
其中,ot,ft,it,ct′都是ht−1的函数。因此,
δtT∂ht−1∂ht=δo,tT∂ht−1∂neto,t+δf,tT∂ht−1∂netf,t+δi,tT∂ht−1∂neti,t+δc′,tT∂ht−1∂netc′,t
其中,
∂ot∂ht∂ct∂ht∂ft∂ct∂it∂ct∂c′t∂ct=diag[tanh(ct)]=diag[ot∘(1−tanh(ct)2)]=diag[ct−1]=diag[c′t]=diag[it]
因为
otneto,tftnetf,titneti,tct′netct′=σ(neto,t)=Wohht−1+Woxxt+bo=σ(netf,t)=Wfhht−1+Wfxxt+bf=σ(neti,t)=Wihht−1+Wixxt+bi=tanh(netct′)=Wchht−1+Wcxxt+bc
所以
∂neto,t∂ot∂ht−1∂neto,t∂netf,t∂ft∂ht−1∂netf,t∂neti,t∂it∂ht−1∂neti,t∂netct′∂ct′∂ht−1∂netct′=diag[ot∘(1−ot)]=Woh=diag[ft∘(1−ft)]=Wfh=diag[it∘(1−it)]=Wih=diag[1−ct′2]=Wch
将以上全部代入得:
δt−1=δo,tT∂ht−1∂neto,t+δf,tT∂ht−1∂netf,t+δi,tT∂ht−1∂neti,t+δct′T∂ht−1∂netct′=δo,tTWoh+δf,tTWfh+δi,tTWih+δct′TWch
根据δo,t,δf,t,δi,t,δc′,t的定义可知
δo,tTδf,tTδi,tTδct′T=δtT∘tanh(ct)∘ot∘(1−ot)=δtT∘ot∘(1−tanh(ct)2)∘ct−1∘ft∘(1−ft)=δtT∘ot∘(1−tanh(ct)2)∘ct′∘it∘(1−it)=δtT∘ot∘(1−tanh(ct)2)∘it∘(1−ct′2)
以δo,tT为例进行推导:
定义输出门的误差项为:
δo,t=∂neto,t∂E
Step 1:LSTM 输出定义
ht=ot∘tanh(ct)
Step 2:链式法则求导
设:
δt=∂ht∂E
根据链式法则:
∂ot∂E=∂ht∂E∘∂ot∂ht=δt∘tanh(ct)
Step 3:激活函数导数(sigmoid)
ot=σ(neto,t)⇒∂neto,t∂ot=ot∘(1−ot)
Step 4:合并得到最终结果
δo,t=δt∘tanh(ct)∘ot∘(1−ot)
上四式就是将误差沿时间反向传播一个时刻的公式。可以写出将误差项向前传递到任意时间k时刻的公式:
δkT=j=k∏t−1(δo,jTWoh+δf,jTWfh+δi,jTWih+δcj′TWch)
将误差项传递到上一层
假设当前层为l,定义l−1层的误差项是误差函数对l−1层的加权输入的导数,即:
δtl−1=∂nettl−1∂E
本次输入xt由下面公式计算:
xtl=fl−1(nettl−1)
其中,fl−1表示第l−1层的激活函数。
因为 netf,tl,neti,tl,netc′,tl,neto,tl 都是 xt 的函数,而 xt又是 nettl−1 的函数,因此要求出 E 对 nettl−1的导数,就需要使用全导数公式:
∂nettl−1∂E=∂netf,tl∂E⋅∂xtl∂netf,tl⋅∂nettl−1∂xtl+∂neti,tl∂E⋅∂xtl∂neti,tl⋅∂nettl−1∂xtl+∂netc′,tl∂E⋅∂xtl∂netc′,tl⋅∂nettl−1∂xtl+∂neto,tl∂E⋅∂xtl∂neto,tl⋅∂nettl−1∂xtl=δf,tTWfx∘f′(nettl−1)+δi,tTWix∘f′(nettl−1)+δc′,tTWcx∘f′(nettl−1)+δo,tTWox∘f′(nettl−1)=(δf,tTWfx+δi,tTWix+δc′,tTWcx+δo,tTWox)∘f′(nettl−1)
权重梯度的计算
Wfh,Wih,Wch,Woh的权重梯度是各个时刻梯度之和,首先求出它们在t时刻的梯度。
已经求得了误差项δo,t,δf,t,δi,t,δc′,t,则可以求出:
∂Woh,t∂E ∂Wfh,t∂E ∂Wih,t∂E ∂Wch,t∂E =∂neto,t∂E⋅∂Woh,t∂neto,t =δo,tht−1T=∂netf,t∂E⋅∂Wfh,t∂netf,t =δf,tht−1T=∂neti,t∂E⋅∂Wih,t∂neti,t =δi,tht−1T=∂netc′,t∂E⋅∂Wch,t∂netc′,t =δc′,tht−1T
那么将各个时刻的梯度加在一起,就可以得到最终的梯度:
∂Woh∂E∂Wfh∂E∂Wih∂E∂Wch∂E=j=1∑tδo,jhj−1T=j=1∑tδf,jhj−1T=j=1∑tδi,jhj−1T=j=1∑tδcj′hj−1T
偏置项bf,bi,bc,bo的梯度,也是将各个时刻的梯度加在一起:
∂bo∂E∂bi∂E∂bf∂E∂bc∂E=j=1∑tδo,j=j=1∑tδi,j=j=1∑tδf,j=j=1∑tδcj′
对于Wfx,Wix,Wcx,Wox的权重梯度:
∂Wox∂E∂Wfx∂E∂Wix∂E∂Wcx∂E=δo,txtT=δf,txtT=δi,txtT=δc′,txtT
以上。