Attention

2024. 9. 4. 00:11Machine Learning/[TIL] Naver Boost Camp

[딥러닝을 이용한 자연어 처리 입문 15-01 어텐션 메커니즘]을 정리한 내용입니다.

Attention 메커니즘

Attention(Q,K,V) = Attentionvalue

어텐션 함수는 주어진 Query에 대해 모든 Key의 유사도를 각각 구한다. 그리고 이 유사도를 key와 매핑되어있는 각각의 값(value)에 반영해준다. 유사도가 반영된 값(value)을 모두 더해서 리턴하고, attention value를 반환한다.

Dot-Product Attention

Seq2Seq에 Attention 기법을 적용한 예시인 바다나우 어텐션의 기본 형태.

 

Attention value a_t를 구하는 방법은 다음과 같다.Attention Score(e_t)를 구한다.


encoder의 시점(time step)을 각각 $1,2, … ,N$이라고 했을 때, encoder의 은닉 상태(hiddent state)를 각각 $h_1,h_2, … , h_N$이라고 하자.

encoder hidden state의 차원 = dencoder hidden state의 차원 (그림: 4차원)

 

(Attention 메커니즘을 사용하지 않는 경우) 시점 t에서 출력 단어를 예측하기 위해서 디코더의 셀은 두 개의 입력값을 필요로 하는데, 바로 이전 시점인 t-1의 은닉 상태와 이전 시점 t-1에 나온 출력 단어이다.

 

Attention 메커니즘에서는 Attention Value 값을 필요로 한다. t번째 단어를 예측하기 위한 어텐션 값을 a_t라고 정의한다.

 

Attention 메커니즘에서는 가장 먼저 Attention score를 구해야한다. Attention score란, 현재 decoder의 시점 t에서 단어를 예측하기 위해, encoder의 모든 hidden state 각각이 decoder의 현 시점의 hidden state s_t와 얼마나 유사한지를 판단하는 스코어이다.

 

dot-product attention은 이 스코어 값을 구하기 위해 s_t를 전치하고, 각 hidden state와 내적을 수행한다. 즉, 모든 어텐션 스코어 값을 스칼라이다.

 

$score(s_t,h_i) = s_t^{T}h_i$

$s_t$와 encoder의 모든 hidden statedml attention score의 모음값을 $e^t$라고 정의하면, 수식은 다음과 같다.

$e^t=[s_t^{T}h_1, ..., s_T^{T}h_N]$

softmax를 활용해 Attention Distribution을 구한다.

 

e_t에 소프트맥스 함수를 적용하여, 모든 값을 합하면 1이 되는 확률 분포를 얻을 수 있다. 이를 Attention Distribution라고 하며, 각각의 값은 Attention Weight(어텐션 가중치)라고 한다.

 

디코더의 시점 t에서의 어텐션 가중치의 모음값인 어텐션 분포를 α_t이라고 할 때, α_t을 식으로 정의하면 다음과 같다.

 

$ α^t = softmax(e^t) $

 

각 인코더의 어텐션 가중치와 은닉 상태를 가중합하여 어텐션 값(Attention Value)을 구한다.

 

 

attention의 최종 결과 값을 얻기 위해서 각 encoder의 hidden state와 attention weight(가중치 값)들을 곱하고, 최종적으로 모두 더한다.

어텐션 함수의 출력값인 Attention Value a_t에 대한 식은 다음과 같다.

$ a_t = \sum_{i=1}^{N} \alpha_i^t h_i $

 

어텐션 값과 디코더의 t 시점의 은닉 상태를 연결한다.(Concatenate)

어텐션 함수의 최종값인 어텐션 값 a_t을 구했다. 어텐션 값이 구해지면 어텐션 메커니즘은 a_ts_t와 결합(concatenate)하여 하나의 벡터로 만드는 작업을 수행한다. 이를 v_t라고 정의하면, v_t를 \hat{y} 예측 연산의 입력으로 사용하므로서 인코더로부터 얻은 정보를 활용하여 \hat{y}를 좀 더 잘 예측할 수 있게 된다.