Paper Info



Field Content
Title QKFormer: Hierarchical Spiking Transformer using Q-K Attention
Authors Chenlin Zhou et al.
Venue NeurIPS 2024
Year 2024
Link arxiv



Summary



QK-Attention을 Spikeformer에 적용한 논문



Problem Statement



Space complexity of spikformer

 SNN의 성능을 극대화하기 위해선 Transformer 아키택처가 필수적이다. 그러나 공간 복잡도 측면에서 두 가지 어려움이 존재한다.

  1. Spiking Self Attention(SSA)에서의 계산 복잡도는 토큰 수에 대해서 제곱으로 증가한다.
  2. SSN은 시간 영역에 걸처 데이터를 처리하기에 높은 수준의 계산 및 메모리 리소스가 요구된다.

 이러한 한계점을 보완하기 위해 본 논문에서는 Q-K attention, hierarchical architecture, novel patch embedding with deformed shortcut을 적용한 QKFormer를 제안한다. 이 설계는 에너지 소비와 공간 요구 사항을 낮추며 모델 성능도 준수하다.



Key Idea



Q-K Attention

VSA, SSA

Vanilla Self Attention

 Transformer의 기존 셀프 어텐션은 학습 가능한 선형 행렬과 입력 $X$에 의해 계산되는 $Q$, $K$, $V$의 행렬곱 즉 부동 소수점 연산으로 계산된다.

\[QF, KF, VF = X(W_Q, W_K, W_V) \\ VSA(QF, KF, VF) = \mathrm{Softmax}\left(Q_FK_F^T\over\sqrt{d_k}\right)V_F\]

이때 $F$는 부동 소수점 형태를 나타내는데, 이처럼 부동 소수점을 사용하는 행렬 곱셈, 지수 계산과 나눗셈을 포함하는 소프트멕스 연산 등은 모두 SNN의 속성과 일치하지 않는다.


Spiking Self Attention

 반면 SSA는 SNN에 맞게 VSA를 변형한 형태이다.

\[I = \text{SN}_I(\text{BN}_I(X(W_I))),\quad I\in \{Q,K,V\} \\ \text{SSA}'(Q, K, V) = \text{SN}(QK^T*s)V\]

여기서 $Q,K,V\in\mathbb{R}^{{T\times N\times D}}$이고 $Q$, $K$, $V$는 학습 가능한 선형 레이어를 통해 계산된 스파이크 형태의 텐서이다. $s$는 스케일링 펙터로 VAS에서의 $1/\sqrt{d}$를 대신한다. SSA는 소프트맥스 연산과 부동소수점 행렬 곱셈을 피해서 SNN의 속성을 만족시킨다.


Q-K Attention


 VSA와 SSA는 $Q$, $K$, $V$라는 세 가지 핵심 구성 요소를 모두 연산에 사용하며, 이로 인해 $O(N^2D)$ 혹은 $O(ND^2)$의 계산 복잡도를 가진다. 반면 Q-K attention은 선형 복잡도를 가지며, $Q$와 $K$만을 사용한다.

\[Q=\text{SN}_Q(\text{BN}(XW_Q)),\quad K=\text{SN}_K(\text{BN}(XW_K))\]

 이때 $X$는 입력 스파이킹 맵이다. Q-K Attention은 Tocken Q-K Attention(QKTA)과 Channel Attention(QKCA)으로 나눌 수 있다.


Q-K Token Attention

 수학적 설명을 위해 T=1, 단일 헤드 attention을 가정하면 QKTA는 다음과 같이 공식화될 수 있다.

\[{A}_t = \operatorname{SN}\left(\sum_{i=0}^{D} {Q}_{i,j}\right), \quad {X}' = {A}_t \otimes {K},\]

 여기서 $A_t$는 다양한 토큰의 중요도를 모델링하는 N*1의 토큰 attention 벡터이다. $\otimes$는 스파이크 텐서 간의 요소별 곱(hadamard product)으로 특정 값만 통과시키는 마스크 연산과 동일하다.


Q-K Channel Attention

 QKCA의 계산과정은 QKTA와 유사하며, 다음과 같이 공식화될 수 있다.

\[{A}_c = \operatorname{SN}\left(\sum_{j=0}^{N} {Q}_{i,j}\right), \quad {X}' = {A}_c \otimes {K},\]

 이때 $A_c$는 다양한 채널의 중요도를 모델링하는 1*D의 채널 attention 벡터이다. 역시 $\otimes$는 마스크 연산과 동일한 역할을 수행한다.

\[{X}'' = \operatorname{SN}(\operatorname{BN}(\operatorname{Linear}({X}'))).\]

 위 식에서 보다시피 $X’$은 최종적으로 선형 계층을 지나 $ {X}’’ $으로 변환된다. 추가적으로 multi-head A-K Attention에서는 채널 차원이 $D/h$이며, time step $T$는 스파이킹 뉴런 계층에 대한 독립적인 차원이다. 본 연구에선 스파이킹 뉴런으로 LIF 모델을 사용하며, QKTA을 활용한다.


Structural efficiency & numerical stability

No scaling factor

 VSA에서 $Q$와 $K$를 평균 0, 분산 1을 따르는 독립적인 확률 변수라고 했을 때 둘의 행렬곱의 각 요소는 평균 0, 분산 $d$를 가진다. 이 경우 softmax 함수를 통과하면서 기울기 소실 문제로 이어질 수 있으므로 정규화를 위해 $1/\sqrt{d}$를 곱해주게 된다. SNN의 경우 softmax 함수를 사용하지는 않지만, 분산이 커지면 성능에 영향을 주거나 수렴이 잘 되지 않기에 스케일링 펙터 $s$를 곱해주어서 문제를 해결한다.  반면 Q-K attention의 분산은 SSA에 비해 훨씬 작다. Q-K attention의 최대 이론적 분산은 SSA의 약 1/200 정도에 불과하기 때문이다. (이유에 대한 이해 및 내용 보충)


Linear computational conplixity of Q-K Attention

 표에서 볼 수 있듯이 Q-K attention의 시간 및 공간 복잡도는 구현 방식에 따라 달라진다. 브로드캐스트 방식으로 요소별 곱 $\otimes$을 활용할 경우 시간 복잡도는 최대 $O(ND)$에 달할 수 있다. 이때 마스크 연산을 적용하게 되면 시간 복잡도는 최대 $O(D)$ 혹은 $O(N)$까지 떨어질 수 있다. 이 경우 공간복잡도는 $A_t$ 혹은 $A_c$ 벡터를 저장할 공간인 1*D 또는 N*1 정도만 요구된다.


Hight Energe Efficiency

 Q-K attention은 스파이크 기반 attention 모듈이므로 선형 곱셈은 희소 덧셈으로 변환된다. 마스크 연산은 뉴로모픽 칩에서는 주소 지정 알고리즘 혹은 AND 논리 연산을 통해 전력 소비를 줄일 수 있다. 같은 스파이크 기반 attention인 SSA와 비교했을 때 Q-K attention은 다음과 같은 이유로 훨씬 더 에너지 효율적이다.

  • $V$가 없이 두 개읜 구성 요소만 채택하므로 시냅스 연산량이 적음
  • 시간, 공간 복잡도가 $O(N)$ 혹은 $O(D)$의 선형 복잡도로 행렬 연산이 적음
  • SSA의 스케일 연산이 없기에 전력 소비를 아낌


Hierarchical Architecture

 기존 SSA를 사용한 spikformer의 경우 토큰 또는 채널에 대해 제곱으로 복잡도가 증가하므로 계층적 구조를 적용했을 때 메모리 폭발이 쉽게 발생하는 문제가 있었다. Q-K attention을 기반으로 한 QKFormer의 경우, 토큰 또는 채널에 대해 선형 계산 복잡도를 가지기에 이런 문제가 없으며, 이로 인해 계층적 특징 맵을 구성할 수 있다.
 이에 QKFormer에선 블록을 지나면서 토큰의 수는 줄어들고, 채널의 수는 늘어나는 계층적 구조를 적용했다. 깊은 층으로 갈수록 더 다양한 표현이 가능해져 성능이 증가했다.


Spiking Patch Embedding with Deformed Shortcut(SPEDS)

 SNN에서 잔차 연결은 항등 사상(identity mapping)을 구현할 수 있으며, 이는 정보 손실을 줄여 네트워크의 깊어도 잘 작동할 수 있도록 한다. 기존 Spikformer에선 잔차 연결을 사용해 항등 사상을 달성했지만, 다운샘플링 블록을 가로지르는 patch embedding에서는 사용하지 않았다. QKFormer에서는 잔차 연결에 경랑 선형 투영 $W_d$를 수행해 채널과 토큰 수를 일치시며 다운샘플링 블록을 가로지르는 항등 사상을 실현할 수 있다.
 본 논문에서 논의한 잔차 연결 방식은 두 가지 유형이다. 하나는 활성화 후 덧셈 방식이고, 다른 하나는 사전 활성화 방식이다.

Activation Befoer Addition

 활성화 후 덧셈 방식(ABA)은 스파이킹 레이어를 통과한 뒤 이를 잔차 연결으로 더하는 방식이다.

\[Y = F(X, \{W_i\})+\text{SN}(W_dX)\]

 여기서 사용된 선형 투영 $W_d$는 1$\times$1커널과 stride > 1을 갖는 경량 Conv 레이어로, patch embedding의 채널 및 토큰 수를 충족시킨다. 본 연구서는 함수 $F$를 다음으로 설정했으며, 다른 변형도 가능하다.

  • Conv2D - BN - MaxPooling - SN - Conv2D - BN - SN 또는
  • Conv2D - BN - SN Conv2D - BN - MaxPooling - SN


Pre-Activation

 반면 사전 활성화 방식(PA)은 다음과 같이 공식으로 나타낼 수 있다.

\[Y = \text{SN}(G(X, \{W_j\})+W_dX)\]

여기서 함수 $G$는 다음으로 설정했다.

  • Conv2D - BN - MaxPooling - SN - Conv2D - BN 또는
  • Conv2D - BN - SN - Conv2D - BN - MaxPooling

 논문에서는 활성화 후 덧셈 방식을 기준점으로 잡아서 실험을 진행하였다.



Methods



Overall Architecture

 위에서 언급한 QK-Attention, Hierarchical Architecture, SPEDS 등을 적용한 최종적인 구조는 다음과 같이 시각화할 수 있다.


 그림에서 보이는 입력 형태인 $T_0\times H\times W\times n$은 이미지 데이터셋, 뉴로모픽 데이터셋을 일반화한 것이다. 입력 시 패치 크기를 4$\times$4로 하기에 SPEDS-1을 거친 뒤에는 $T_0\times {H\over 4}\times {W\over 4}\times C$가 되고 이후엔 다운샘플링을 할 때마다 토큰은 절반으로, 채널은 두 배로 증가한다. 모든 블록을 통과한 뒤에는 FC layer가 분류기 역할을 한다.

 특징적인 것은 마지막 블록에서 Q-K Attention이 아니라 SSA가 사용된다는 점이다. 계층적 구조를 사용하기 때문에 후반 블록은 채널 수가 많고 토큰 수가 적은데 이때 QKTA를 사용하게 되면 이 풍부한 채널 간의 상호작용을 파악하기 힘들다. 이에 후반 블록은 SSA나 QKCA를 사용하는 것이 성능에 유리하며, 본 논문에선 기준값으로 SSA를 사용했다.


Key Equations

 QKFormer의 블록은 기존 transformer의 블록과 유사한 형태로, Q-K Attention 모듈과 Spiking MLP(SMLP)로 구성되며, 다음과 같이 공식화할 수 있다.

\[X_l'=QKTA(X_{l-1})+X_{l-1} \\ X_l=SMLP(X_l')+X_l'\]



Results





Discussion





Insights





References



-