🤖AI/ML

Self-Attention

셀프 어텐션

시퀀스 내 모든 요소 간 관계를 계산하는 메커니즘. Transformer의 핵심. Query, Key, Value로 문맥을 파악.

📖 상세 설명

Self-Attention(자기 주의)은 시퀀스의 각 요소가 같은 시퀀스 내 다른 모든 요소와의 관계를 계산하는 메커니즘입니다. "The cat sat on the mat because it was tired"에서 "it"이 "cat"을 가리킨다는 것을 파악하려면, 문장 내 모든 단어의 관계를 봐야 합니다. Self-Attention이 바로 이 역할을 합니다.

2017년 "Attention Is All You Need" 논문에서 Transformer와 함께 제안되었습니다. 기존 RNN은 순차적으로 처리해야 해서 병렬화가 어렵고 긴 의존성 학습이 힘들었습니다. Self-Attention은 모든 위치를 한 번에 계산하므로 병렬화가 가능하고, 거리에 관계없이 관계를 파악합니다.

핵심은 Query(Q), Key(K), Value(V) 변환입니다. 입력 X에 가중치 행렬을 곱해 Q, K, V를 만들고, Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) * V로 계산합니다. Q와 K의 내적이 높으면 관련성이 높아 해당 V에 더 많은 가중치가 부여됩니다.

GPT, BERT, T5 등 모든 LLM의 기반이며, Vision Transformer(ViT)로 이미지에도 적용됩니다. Multi-Head Attention은 여러 개의 Self-Attention을 병렬로 실행해 다양한 관계 패턴을 학습합니다. 시퀀스 길이 N에 대해 O(N^2) 복잡도가 단점이라 Sparse Attention, Flash Attention 등 최적화 기법이 연구됩니다.

💻 코드 예제

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    """Scaled Dot-Product Self-Attention"""
    def __init__(self, embed_dim, num_heads=1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Query, Key, Value 변환 행렬
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Q, K, V 계산
        Q = self.W_q(x)  # (batch, seq, embed)
        K = self.W_k(x)
        V = self.W_v(x)

        # Attention Score: QK^T / sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # scores: (batch, seq, seq)

        # 마스킹 (옵션: causal mask for decoder)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax로 확률 변환
        attention_weights = F.softmax(scores, dim=-1)

        # Value에 가중치 적용
        output = torch.matmul(attention_weights, V)
        output = self.W_o(output)

        return output, attention_weights

# 사용 예시
embed_dim = 512
seq_len = 10
batch_size = 2

x = torch.randn(batch_size, seq_len, embed_dim)
attention = SelfAttention(embed_dim)
output, weights = attention(x)

print(f"입력: {x.shape}")          # (2, 10, 512)
print(f"출력: {output.shape}")     # (2, 10, 512)
print(f"어텐션 가중치: {weights.shape}")  # (2, 10, 10)

# Causal Mask (Decoder용 - 미래 토큰 차단)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
print(f"Causal Mask:\n{causal_mask}")
# 1 0 0 0 ...
# 1 1 0 0 ...
# 1 1 1 0 ...

output_masked, _ = attention(x, mask=causal_mask)

# === Multi-Head Attention 구현 ===
class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention (여러 헤드 병렬 처리)"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Q, K, V 계산 후 헤드별로 분리
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (batch, heads, seq, head_dim)

        # 각 헤드에서 독립적으로 Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)

        # 헤드 결합
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.W_o(output)

# === PyTorch 내장 사용 (권장) ===
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
output, attn_weights = mha(x, x, x)  # Self-Attention: Q=K=V

# === Flash Attention (메모리 효율) ===
# pip install flash-attn
from flash_attn import flash_attn_func

# Flash Attention은 O(N) 메모리로 O(N^2) 연산을 수행
# 128K 컨텍스트도 가능하게 만든 핵심 기술

# === Attention 가중치 시각화 ===
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens):
    """어텐션 히트맵 시각화"""
    plt.figure(figsize=(10, 10))
    sns.heatmap(
        attention_weights.detach().numpy(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='viridis'
    )
    plt.title("Self-Attention Weights")
    plt.show()

# === Cross-Attention vs Self-Attention ===
# Self-Attention: Q, K, V 모두 같은 입력에서 유래
# Cross-Attention: Q는 한 입력, K/V는 다른 입력에서
# 예: Encoder-Decoder에서 Decoder가 Encoder 출력 참조

class CrossAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

    def forward(self, query_input, key_value_input):
        Q = self.W_q(query_input)      # Decoder 상태
        K = self.W_k(key_value_input)  # Encoder 출력
        V = self.W_v(key_value_input)  # Encoder 출력
        # 이하 동일한 attention 계산

🗣️ 실무에서 이렇게 말하세요

💬 모델 아키텍처 논의에서
"Transformer 레이어 하나에 Multi-Head Self-Attention이 들어가는데, 헤드 수가 8이면 8가지 다른 관계 패턴을 병렬로 학습해요. 어떤 헤드는 구문 관계, 어떤 헤드는 상호참조(coreference)를 배우더라고요."
💬 면접에서
"Self-Attention은 Q, K, V 세 가지 변환으로 이루어집니다. Q와 K의 내적으로 관련성 점수를 계산하고, softmax로 정규화한 뒤 V에 가중합을 적용합니다. sqrt(d_k)로 나누는 건 내적 값이 커져서 softmax가 극단적이 되는 걸 방지하기 위함입니다."
💬 성능 최적화 논의에서
"Self-Attention이 O(N^2)이라 긴 컨텍스트에서 메모리 폭발이 일어나요. Flash Attention 쓰면 메모리를 90% 줄이면서 속도도 2-3배 빨라집니다. 128K 컨텍스트 모델에는 필수예요."

⚠️ 흔한 실수 & 주의사항

스케일링 없이 내적 계산

sqrt(d_k)로 나누지 않으면 d_k가 클 때 내적 값이 매우 커져서 softmax가 거의 one-hot이 됩니다. 그래디언트가 소실되어 학습이 안 됩니다.

Decoder에서 causal mask 누락

언어 모델 학습 시 미래 토큰을 볼 수 있으면 정답을 미리 알게 됩니다. 반드시 하삼각 행렬 마스크로 미래를 차단하세요.

올바른 방법

PyTorch의 nn.MultiheadAttention을 사용하면 스케일링, 마스킹이 자동 처리됩니다. 커스텀 구현 시 반드시 / sqrt(d_k)와 attn_mask 파라미터를 확인하세요.

🔗 관련 용어

📚 더 배우기