멀티헤드 어텐션
Multi-Head Attention
여러 어텐션 헤드를 병렬로 사용하여 다양한 관계를 학습. Transformer 구성요소.
Multi-Head Attention
여러 어텐션 헤드를 병렬로 사용하여 다양한 관계를 학습. Transformer 구성요소.
멀티헤드 어텐션(Multi-Head Attention)은 여러 개의 어텐션 연산을 병렬로 수행하여 입력 시퀀스의 다양한 표현 부분공간을 동시에 학습하는 메커니즘입니다. 단일 어텐션이 하나의 관점에서만 관계를 파악하는 것과 달리, 멀티헤드 어텐션은 문법, 의미, 위치 관계 등 다양한 패턴을 동시에 포착합니다.
2017년 "Attention Is All You Need" 논문에서 Transformer 아키텍처와 함께 제안되었습니다. 기존 RNN/LSTM의 순차 처리 한계를 극복하고, 완전 병렬화가 가능한 어텐션 기반 구조를 제시했습니다. 원 논문에서는 8개의 헤드를 사용했으며, 이후 GPT, BERT 등 모든 주요 언어 모델의 핵심 구성요소가 되었습니다.
멀티헤드 어텐션의 작동 원리는 다음과 같습니다. 입력을 Query(Q), Key(K), Value(V)로 변환한 후, 각 헤드별로 다른 가중치 행렬로 투영합니다. 각 헤드에서 Scaled Dot-Product Attention을 수행하고, 모든 헤드의 출력을 연결(concatenate)한 뒤 최종 선형 변환을 적용합니다. 수식은 MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O 입니다.
2024-2025년에는 효율적인 어텐션 변형이 활발히 연구됩니다. Flash Attention 2는 GPU 메모리 효율을 극대화하고, GQA(Grouped Query Attention)는 KV 캐시를 줄입니다. Llama 2/3, Mistral 등은 GQA를 채택하여 긴 컨텍스트에서 메모리를 절약합니다. MHA는 여전히 모델 품질의 핵심이지만, 추론 비용 최적화가 주요 과제입니다.
# 멀티헤드 어텐션 구현 (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 헤드당 차원
# Q, K, V 투영 가중치
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 출력 투영
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""Scaled Dot-Product Attention"""
# Q @ K^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
return torch.matmul(attention_weights, V), attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 선형 투영 후 헤드 분할
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. 어텐션 계산
attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# 3. 헤드 연결 및 출력 투영
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(attn_output)
return output, attn_weights
# 사용 예시
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # (batch, seq_len, d_model)
output, weights = mha(x, x, x)
print(f"출력 shape: {output.shape}") # [2, 10, 512]
print(f"어텐션 shape: {weights.shape}") # [2, 8, 10, 10]
"모델 추론 속도가 느린 건 KV 캐시 메모리 때문입니다. MHA를 GQA로 바꾸면 헤드 수는 유지하면서 KV 캐시를 1/8로 줄일 수 있어요. Llama 3도 이 방식 씁니다."
"멀티헤드 어텐션의 핵심은 다양한 표현 부분공간 학습입니다. 8개 헤드면 512차원을 64차원씩 8개로 나눠 각각 다른 관계를 학습합니다. 'sqrt(d_k)'로 나누는 건 softmax 기울기 소실을 방지하기 위함이고요."
"Flash Attention 2 적용하면 같은 MHA도 메모리 O(N^2)에서 O(N)으로 줄어요. 긴 컨텍스트 학습할 때 필수입니다. PyTorch 2.0부터 scaled_dot_product_attention이 Flash Attention을 자동 사용해요."
d_model=512, num_heads=12면 오류 발생. d_k = d_model // num_heads가 정수여야 합니다. 512에는 8, 16, 32 등을 사용하세요.
Causal mask(미래 토큰 마스킹)와 padding mask를 혼동하면 학습이 제대로 안 됩니다. 디코더는 상삼각 마스크, 패딩은 True/False 마스크입니다.
PyTorch의 nn.MultiheadAttention 또는 F.scaled_dot_product_attention 사용을 권장합니다. 자동으로 Flash Attention을 적용하고 마스크 처리도 검증되어 있습니다.