본문으로 건너뛰기

Recurrent Neural Network

RNN(Recurrent Neural Network)

input=Dim(batch_size,timesteps,input_size)output=Dim(batch_size,timesteps,hidden_size)input = Dim(batch\_size, timesteps, input\_size) \\ output = Dim(batch\_size, timesteps, hidden\_size)
ht=tanh(Whxxt+Whhht1+bh)h_t = \tanh(W_{hx} x_t + W_{hh} h_{t-1} + b_h)

Vanishing/exploding gradient

긴 시퀀스를 처리하는 RNN은 깊은 네트워크가 되면서(WW가 반복적으로 곱해지면서) Vanishing/exploding gradient 문제가 발생하기 쉽습니다.

  • Relu와 같이 수렴하지 않는 activation을 사용하면 불안정해질 수 있습니다.
  • Exploding gradient 문제가 발견되면, gradient clipping을 사용하여 값을 제한 해볼 수 있습니다.

LSTM(Long Short-Term Memory)

긴 시퀀스를 처리하는 RNN은 순환이 반복되면서 상대적으로 앞쪽 값의 영향이 줄어들 수 있습니다. 이를 Long-Term Dependency라고 합니다. LSTM 셀을 사용하면, Long-Term Dependency 문제가 완화되며 훈련 시 빠르게 수렴합니다.

ft=σ(Wxfxt+bxf+Whfht1+bhf)it=σ(Wxixt+bxi+Whiht1+bhi)ct~=tanh(Wxc~xt+bxc~+Whc~ht1+bhc~)ot=σ(Wxoxt+bxo+Whoht1+bho)ct=ftct1+itct~ht=ottanh(ct)\begin{aligned} f_t &= \sigma(W_{xf} x_t + b_{xf} + W_{hf} h_{t-1} + b_{hf}) \\ i_t &= \sigma(W_{xi} x_t + b_{xi} + W_{hi} h_{t-1} + b_{hi}) \\ \widetilde{c_t} &= \tanh(W_{x\tilde{c}} x_t + b_{x\tilde{c}} + W_{h\tilde{c}} h_{t-1} + b_{h\tilde{c}}) \\ o_t &= \sigma(W_{xo} x_t + b_{xo} + W_{ho} h_{t-1} + b_{ho}) \\ c_t &= f_t \odot c_{t-1} + i_t \odot \widetilde{c_t} \\ h_t &= o_t \odot \tanh(c_t) \end{aligned}
  • ctc_tcell state입니다.
  • hth_toutput입니다.
  • ftf_tforget gate로 이전 cell state를 얼마나 잊어버릴 지 결정합니다.
  • iti_tinput gate로 입력 정보를 얼마나 cell state에 저장할 것일지 결정합니다.
  • oto_toutput gate로 업데이트 된 cell state에서 어떤 정보를 내보낼지 결정합니다.

GRU(Gated Recurrent Unit)

LSTM variants 중 하나입니다.

rt=σ(Wxrxt+bxr+Whrht1+bhr)zt=σ(Wxzxt+bxz+Whzht1+bhz)ht~=tanh(Wxh~xt+bxh~+rt(Whh~ht1+bhh~))ht=ztht1+(1zt)ht~\begin{aligned} r_t &= \sigma(W_{xr} x_t + b_{xr} + W_{hr} h_{t-1} + b_{hr}) \\ z_t &= \sigma(W_{xz} x_t + b_{xz} + W_{hz} h_{t-1} + b_{hz}) \\ \widetilde{h_t} &= \tanh(W_{x\tilde{h}} x_t + b_{x\tilde{h}} + r_t \odot (W_{h\tilde{h}} h_{t-1} + b_{h\tilde{h}})) \\ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \widetilde{h_t}\\ \end{aligned}
  • hth_toutpu입니다.
  • rtr_treset gate로 이전 state를 얼마나 output에 포함시킬지 결정합니다.
  • ztz_tupdate gate로 값이 1에 가까울 수록 이전 state가 저장되고, 0에 가까울 수록 새로운 state가 저장됩니다.