🤖AI/ML

Sparse Attention

스파스 어텐션

어텐션 연산을 일부에만 적용하여 효율을 높이는 기법. O(N^2)를 O(N log N)으로 줄여 긴 컨텍스트 처리 가능.

📖 상세 설명

Sparse Attention은 표준 Self-Attention의 모든 위치 쌍(N^2)이 아닌, 선택된 위치 쌍에만 어텐션을 계산하는 효율적인 기법입니다. 4096 토큰 시퀀스에서 Full Attention은 약 1,670만 번(4096^2) 연산하지만, Sparse Attention은 윈도우 크기 256 기준 약 100만 번으로 줄여 긴 문서를 처리할 수 있게 합니다. 어텐션 행렬의 대부분이 0(sparse)인 것에서 이름이 유래했습니다.

2019년 OpenAI의 "Generating Long Sequences with Sparse Transformers" 논문에서 본격적으로 연구되었습니다. Full Attention의 O(N^2) 복잡도는 긴 시퀀스에서 메모리와 연산 비용이 폭발합니다. N=8192일 때 어텐션 행렬만 256GB 메모리가 필요합니다(FP32 기준). GPT-2의 1024 토큰, BERT의 512 토큰 제한도 이 때문이었습니다. Sparse Attention은 O(N log N) 또는 O(N sqrt(N))으로 복잡도를 낮춥니다.

대표적 패턴으로 Local Attention(주변 k개 토큰만 attend), Strided Attention(일정 간격으로 attend), Global Attention(특정 토큰이 전체와 연결)이 있습니다. Longformer는 Local+Global 조합으로 4096 토큰까지 지원하고, BigBird는 Local+Global+Random을 조합해 이론적 표현력을 유지합니다. Mistral 7B의 Sliding Window Attention(4096)도 효율적인 Sparse의 일종입니다. 최근에는 Ring Attention으로 여러 GPU에 시퀀스를 분산 처리하는 기법도 등장했습니다.

긴 문서 요약, 법률 계약서 분석, 의료 기록 처리, 전체 코드베이스 이해 등 긴 컨텍스트가 필요한 작업에 필수입니다. 단, 커스텀 CUDA 커널이 필요해 구현이 복잡하고 GPU 최적화가 어렵습니다. Flash Attention은 알고리즘적 Sparsity 대신 메모리 접근 패턴 최적화(IO-aware)로 효율을 높이는 보완적 접근이며, 두 기법을 함께 사용하면 시너지가 납니다.

💻 코드 예제

import torch
import torch.nn.functional as F

def create_local_attention_mask(seq_len, window_size):
    """Local Attention: 주변 window_size 토큰에만 attend"""
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = 1
    return mask

def create_strided_attention_mask(seq_len, stride):
    """Strided Attention: stride 간격으로 attend"""
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        # 같은 stride 블록 내 위치들
        block_start = (i // stride) * stride
        block_end = min(block_start + stride, seq_len)
        mask[i, block_start:block_end] = 1
        # 모든 stride 위치
        mask[i, ::stride] = 1
    return mask

def create_bigbird_mask(seq_len, window_size, num_global, num_random):
    """BigBird: Local + Global + Random"""
    mask = create_local_attention_mask(seq_len, window_size)
    # Global tokens (처음 몇 개 토큰이 모든 위치와 연결)
    mask[:num_global, :] = 1
    mask[:, :num_global] = 1
    # Random connections
    for i in range(seq_len):
        random_indices = torch.randperm(seq_len)[:num_random]
        mask[i, random_indices] = 1
    return mask

# 사용 예시
seq_len = 1024
window_size = 64
local_mask = create_local_attention_mask(seq_len, window_size)
print(f"Local Attention 연결 수: {local_mask.sum().item():,.0f}")  # ~65K
print(f"Full Attention 연결 수: {seq_len * seq_len:,}")  # 1,048,576

# Hugging Face Longformer 사용
from transformers import LongformerModel, LongformerTokenizer

tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')

# 최대 4096 토큰까지 처리 가능 (BERT는 512)
text = "This is a very long document... " * 500
inputs = tokenizer(text, return_tensors='pt', max_length=4096, truncation=True)

# global_attention_mask: 1인 위치는 모든 토큰과 attend
global_attention_mask = torch.zeros_like(inputs['input_ids'])
global_attention_mask[:, 0] = 1  # [CLS] 토큰만 global

outputs = model(**inputs, global_attention_mask=global_attention_mask)

# === Sliding Window Attention (Mistral 스타일) ===
def create_sliding_window_mask(seq_len, window_size):
    """Sliding Window: 각 토큰이 앞의 window_size 토큰만 attend"""
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size + 1)
        mask[i, start:i+1] = 1  # causal + window
    return mask

# Mistral 7B: window_size=4096
# 128K 토큰도 4096 윈도우로 효율적 처리
sliding_mask = create_sliding_window_mask(8192, 4096)
print(f"Sliding Window 연결 수: {sliding_mask.sum().item():,.0f}")  # ~33M
print(f"Full Attention 연결 수: {8192 * 8192:,}")  # 67M

# === 복잡도 비교 ===
def attention_complexity_comparison(seq_len, window_size=256, stride=256):
    """각 Sparse 패턴의 연산 횟수 비교"""
    full = seq_len * seq_len
    local = seq_len * window_size
    strided = seq_len * (window_size + seq_len // stride)
    sliding = seq_len * min(window_size, seq_len)

    print(f"시퀀스 길이: {seq_len:,}")
    print(f"Full Attention: {full:,} ({full/1e6:.1f}M)")
    print(f"Local Attention: {local:,} ({local/1e6:.1f}M) - {local/full*100:.1f}%")
    print(f"Strided Attention: {strided:,} ({strided/1e6:.1f}M) - {strided/full*100:.1f}%")
    print(f"Sliding Window: {sliding:,} ({sliding/1e6:.1f}M) - {sliding/full*100:.1f}%")

attention_complexity_comparison(16384, window_size=512)

# === BigBird 패턴 (Local + Global + Random) ===
def create_bigbird_mask(seq_len, window_size=64, num_global=64, num_random=64):
    """BigBird: 이론적 표현력을 유지하면서 효율성 확보"""
    mask = torch.zeros(seq_len, seq_len)

    for i in range(seq_len):
        # Local attention (주변 토큰)
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2)
        mask[i, start:end] = 1

        # Global attention (처음 num_global 토큰이 전체와 연결)
        if i < num_global:
            mask[i, :] = 1
            mask[:, i] = 1

        # Random attention (무작위 연결로 이론적 표현력 유지)
        random_indices = torch.randint(0, seq_len, (num_random,))
        mask[i, random_indices] = 1

    return mask

# === Ring Attention (분산 처리) ===
# 여러 GPU에 시퀀스를 분산하여 처리
# 각 GPU가 자신의 청크를 처리하고 결과를 전달
# 1M 토큰 이상의 초장문 처리 가능

# === 실제 모델에서의 적용 ===
# GPT-4: 128K 컨텍스트 (내부 최적화)
# Claude: 200K 컨텍스트 (Ring Attention 추정)
# Mistral 7B: 32K (Sliding Window 4096)
# Longformer: 4096 (Local + Global)
# BigBird: 4096 (Local + Global + Random)

# === Attention Sink 현상 ===
# 첫 번째 토큰에 어텐션이 집중되는 현상
# StreamingLLM에서 발견: 첫 토큰을 유지하면 무한 생성 가능

# === 성능 벤치마크 ===
# 시퀀스 길이 16K에서:
# - Full Attention: 메모리 OOM (A100 40GB)
# - Sparse Attention (window=512): 8GB
# - Flash Attention: 12GB (더 정확, 더 빠름)
# - Sparse + Flash: 6GB (최적)

def memory_usage_comparison(seq_lengths):
    """다양한 어텐션 방식 메모리 사용량 비교"""
    for n in seq_lengths:
        full = n * n * 4 / 1e9  # FP32, GB
        sparse_512 = n * 512 * 4 / 1e9
        print(f"N={n}: Full={full:.1f}GB, Sparse={sparse_512:.3f}GB")

memory_usage_comparison([4096, 8192, 16384, 32768])

# === 구현 도전 과제 ===
# 1. 커스텀 CUDA 커널 필요 (PyTorch 기본 X)
# 2. Sparse 패턴이 GPU에서 비효율적일 수 있음
# 3. 컴파일러 최적화 어려움
# Flash Attention이 대안으로 부상한 이유!

# === 최신 트렌드 ===
# - Linear Attention: softmax 대신 커널 근사
# - Performer: FAVOR+ 메커니즘
# - Linformer: 저랭크 근사
# 모두 O(n^2) -> O(n) 목표

# === Dilated Attention (LongNet) ===
# 지수적으로 증가하는 dilation으로 1B 토큰 처리
# 가까운 토큰은 dense, 먼 토큰은 sparse하게 처리

# === 실전 체크리스트 ===
# 1. 시퀀스 길이가 2K 이하면 Full Attention으로 충분
# 2. 2K-16K: Sliding Window 또는 Flash Attention
# 3. 16K-100K: BigBird, Longformer + Flash Attention
# 4. 100K+: Ring Attention, LongNet 고려
# 5. GPU 메모리와 latency 요구사항에 맞춰 선택

# === 최신 모델별 Attention 전략 ===
# | 모델 | 컨텍스트 | 전략 |
# |-----|---------|-----|
# | GPT-4 Turbo | 128K | 내부 최적화 |
# | Claude 3 | 200K | Ring Attention |
# | Gemini 1.5 | 1M | 멀티모달 최적화 |

🗣️ 실무에서 이렇게 말하세요

💬 긴 문서 처리 논의에서
"법률 문서가 보통 3만 토큰이라 BERT로는 안 됩니다. Longformer 쓰면 4096 토큰까지 한 번에 처리 가능하고, 여러 청크로 나눠 처리하는 것보다 문맥 파악이 훨씬 좋아요. Local + Global Attention 덕분이죠."
💬 면접에서
"Sparse Attention은 Full Attention의 O(N^2) 문제를 해결합니다. 모든 쌍이 아니라 Local(주변), Global(핵심 토큰), Random 연결만 계산해서 O(N)이나 O(N log N)으로 줄이죠. BigBird, Longformer가 대표적이고, 최근 Mistral의 Sliding Window도 같은 아이디어입니다."
💬 모델 선택 회의에서
"128K 컨텍스트가 필요하면 Sparse Attention 모델이나 Flash Attention이 적용된 모델을 써야 해요. Claude나 GPT-4 Turbo는 내부적으로 최적화가 되어 있고, 오픈소스면 Mistral이 Sliding Window로 효율적입니다."

⚠️ 흔한 실수 & 주의사항

Window size를 너무 작게 설정

Local Attention 윈도우가 너무 작으면 멀리 떨어진 관계를 포착하지 못합니다. 문서 내 상호참조(예: "앞서 언급한 것처럼")가 깨질 수 있습니다.

Global token을 적절히 설정하지 않음

Longformer에서 [CLS]나 질문 토큰을 global로 설정하지 않으면 요약/QA 성능이 크게 떨어집니다. 태스크에 맞는 global 위치 설정이 중요합니다.

올바른 방법

Local window는 최소 256~512, 문서 분류는 [CLS]를 global로, QA는 질문 토큰 전체를 global로 설정하세요. 필요시 Flash Attention과 조합해 효율을 극대화합니다.

🔗 관련 용어

📚 더 배우기