Self-Attention
셀프 어텐션
시퀀스 내 모든 요소 간 관계를 계산하는 메커니즘. Transformer의 핵심. Query, Key, Value로 문맥을 파악.
셀프 어텐션
시퀀스 내 모든 요소 간 관계를 계산하는 메커니즘. Transformer의 핵심. Query, Key, Value로 문맥을 파악.
Self-Attention(자기 주의)은 시퀀스의 각 요소가 같은 시퀀스 내 다른 모든 요소와의 관계를 계산하는 메커니즘입니다. "The cat sat on the mat because it was tired"에서 "it"이 "cat"을 가리킨다는 것을 파악하려면, 문장 내 모든 단어의 관계를 봐야 합니다. Self-Attention이 바로 이 역할을 합니다.
2017년 "Attention Is All You Need" 논문에서 Transformer와 함께 제안되었습니다. 기존 RNN은 순차적으로 처리해야 해서 병렬화가 어렵고 긴 의존성 학습이 힘들었습니다. Self-Attention은 모든 위치를 한 번에 계산하므로 병렬화가 가능하고, 거리에 관계없이 관계를 파악합니다.
핵심은 Query(Q), Key(K), Value(V) 변환입니다. 입력 X에 가중치 행렬을 곱해 Q, K, V를 만들고, Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) * V로 계산합니다. Q와 K의 내적이 높으면 관련성이 높아 해당 V에 더 많은 가중치가 부여됩니다.
GPT, BERT, T5 등 모든 LLM의 기반이며, Vision Transformer(ViT)로 이미지에도 적용됩니다. Multi-Head Attention은 여러 개의 Self-Attention을 병렬로 실행해 다양한 관계 패턴을 학습합니다. 시퀀스 길이 N에 대해 O(N^2) 복잡도가 단점이라 Sparse Attention, Flash Attention 등 최적화 기법이 연구됩니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
"""Scaled Dot-Product Self-Attention"""
def __init__(self, embed_dim, num_heads=1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Query, Key, Value 변환 행렬
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# Q, K, V 계산
Q = self.W_q(x) # (batch, seq, embed)
K = self.W_k(x)
V = self.W_v(x)
# Attention Score: QK^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# scores: (batch, seq, seq)
# 마스킹 (옵션: causal mask for decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax로 확률 변환
attention_weights = F.softmax(scores, dim=-1)
# Value에 가중치 적용
output = torch.matmul(attention_weights, V)
output = self.W_o(output)
return output, attention_weights
# 사용 예시
embed_dim = 512
seq_len = 10
batch_size = 2
x = torch.randn(batch_size, seq_len, embed_dim)
attention = SelfAttention(embed_dim)
output, weights = attention(x)
print(f"입력: {x.shape}") # (2, 10, 512)
print(f"출력: {output.shape}") # (2, 10, 512)
print(f"어텐션 가중치: {weights.shape}") # (2, 10, 10)
# Causal Mask (Decoder용 - 미래 토큰 차단)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
print(f"Causal Mask:\n{causal_mask}")
# 1 0 0 0 ...
# 1 1 0 0 ...
# 1 1 1 0 ...
output_masked, _ = attention(x, mask=causal_mask)
# === Multi-Head Attention 구현 ===
class MultiHeadAttention(nn.Module):
"""Multi-Head Self-Attention (여러 헤드 병렬 처리)"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# Q, K, V 계산 후 헤드별로 분리
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# (batch, heads, seq, head_dim)
# 각 헤드에서 독립적으로 Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, V)
# 헤드 결합
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
return self.W_o(output)
# === PyTorch 내장 사용 (권장) ===
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
output, attn_weights = mha(x, x, x) # Self-Attention: Q=K=V
# === Flash Attention (메모리 효율) ===
# pip install flash-attn
from flash_attn import flash_attn_func
# Flash Attention은 O(N) 메모리로 O(N^2) 연산을 수행
# 128K 컨텍스트도 가능하게 만든 핵심 기술
# === Attention 가중치 시각화 ===
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attention_weights, tokens):
"""어텐션 히트맵 시각화"""
plt.figure(figsize=(10, 10))
sns.heatmap(
attention_weights.detach().numpy(),
xticklabels=tokens,
yticklabels=tokens,
cmap='viridis'
)
plt.title("Self-Attention Weights")
plt.show()
# === Cross-Attention vs Self-Attention ===
# Self-Attention: Q, K, V 모두 같은 입력에서 유래
# Cross-Attention: Q는 한 입력, K/V는 다른 입력에서
# 예: Encoder-Decoder에서 Decoder가 Encoder 출력 참조
class CrossAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
def forward(self, query_input, key_value_input):
Q = self.W_q(query_input) # Decoder 상태
K = self.W_k(key_value_input) # Encoder 출력
V = self.W_v(key_value_input) # Encoder 출력
# 이하 동일한 attention 계산
"Transformer 레이어 하나에 Multi-Head Self-Attention이 들어가는데, 헤드 수가 8이면 8가지 다른 관계 패턴을 병렬로 학습해요. 어떤 헤드는 구문 관계, 어떤 헤드는 상호참조(coreference)를 배우더라고요."
"Self-Attention은 Q, K, V 세 가지 변환으로 이루어집니다. Q와 K의 내적으로 관련성 점수를 계산하고, softmax로 정규화한 뒤 V에 가중합을 적용합니다. sqrt(d_k)로 나누는 건 내적 값이 커져서 softmax가 극단적이 되는 걸 방지하기 위함입니다."
"Self-Attention이 O(N^2)이라 긴 컨텍스트에서 메모리 폭발이 일어나요. Flash Attention 쓰면 메모리를 90% 줄이면서 속도도 2-3배 빨라집니다. 128K 컨텍스트 모델에는 필수예요."
sqrt(d_k)로 나누지 않으면 d_k가 클 때 내적 값이 매우 커져서 softmax가 거의 one-hot이 됩니다. 그래디언트가 소실되어 학습이 안 됩니다.
언어 모델 학습 시 미래 토큰을 볼 수 있으면 정답을 미리 알게 됩니다. 반드시 하삼각 행렬 마스크로 미래를 차단하세요.
PyTorch의 nn.MultiheadAttention을 사용하면 스케일링, 마스킹이 자동 처리됩니다. 커스텀 구현 시 반드시 / sqrt(d_k)와 attn_mask 파라미터를 확인하세요.