Sparse Attention
스파스 어텐션
어텐션 연산을 일부에만 적용하여 효율을 높이는 기법. O(N^2)를 O(N log N)으로 줄여 긴 컨텍스트 처리 가능.
스파스 어텐션
어텐션 연산을 일부에만 적용하여 효율을 높이는 기법. O(N^2)를 O(N log N)으로 줄여 긴 컨텍스트 처리 가능.
Sparse Attention은 표준 Self-Attention의 모든 위치 쌍(N^2)이 아닌, 선택된 위치 쌍에만 어텐션을 계산하는 효율적인 기법입니다. 4096 토큰 시퀀스에서 Full Attention은 약 1,670만 번(4096^2) 연산하지만, Sparse Attention은 윈도우 크기 256 기준 약 100만 번으로 줄여 긴 문서를 처리할 수 있게 합니다. 어텐션 행렬의 대부분이 0(sparse)인 것에서 이름이 유래했습니다.
2019년 OpenAI의 "Generating Long Sequences with Sparse Transformers" 논문에서 본격적으로 연구되었습니다. Full Attention의 O(N^2) 복잡도는 긴 시퀀스에서 메모리와 연산 비용이 폭발합니다. N=8192일 때 어텐션 행렬만 256GB 메모리가 필요합니다(FP32 기준). GPT-2의 1024 토큰, BERT의 512 토큰 제한도 이 때문이었습니다. Sparse Attention은 O(N log N) 또는 O(N sqrt(N))으로 복잡도를 낮춥니다.
대표적 패턴으로 Local Attention(주변 k개 토큰만 attend), Strided Attention(일정 간격으로 attend), Global Attention(특정 토큰이 전체와 연결)이 있습니다. Longformer는 Local+Global 조합으로 4096 토큰까지 지원하고, BigBird는 Local+Global+Random을 조합해 이론적 표현력을 유지합니다. Mistral 7B의 Sliding Window Attention(4096)도 효율적인 Sparse의 일종입니다. 최근에는 Ring Attention으로 여러 GPU에 시퀀스를 분산 처리하는 기법도 등장했습니다.
긴 문서 요약, 법률 계약서 분석, 의료 기록 처리, 전체 코드베이스 이해 등 긴 컨텍스트가 필요한 작업에 필수입니다. 단, 커스텀 CUDA 커널이 필요해 구현이 복잡하고 GPU 최적화가 어렵습니다. Flash Attention은 알고리즘적 Sparsity 대신 메모리 접근 패턴 최적화(IO-aware)로 효율을 높이는 보완적 접근이며, 두 기법을 함께 사용하면 시너지가 납니다.
import torch
import torch.nn.functional as F
def create_local_attention_mask(seq_len, window_size):
"""Local Attention: 주변 window_size 토큰에만 attend"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = 1
return mask
def create_strided_attention_mask(seq_len, stride):
"""Strided Attention: stride 간격으로 attend"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
# 같은 stride 블록 내 위치들
block_start = (i // stride) * stride
block_end = min(block_start + stride, seq_len)
mask[i, block_start:block_end] = 1
# 모든 stride 위치
mask[i, ::stride] = 1
return mask
def create_bigbird_mask(seq_len, window_size, num_global, num_random):
"""BigBird: Local + Global + Random"""
mask = create_local_attention_mask(seq_len, window_size)
# Global tokens (처음 몇 개 토큰이 모든 위치와 연결)
mask[:num_global, :] = 1
mask[:, :num_global] = 1
# Random connections
for i in range(seq_len):
random_indices = torch.randperm(seq_len)[:num_random]
mask[i, random_indices] = 1
return mask
# 사용 예시
seq_len = 1024
window_size = 64
local_mask = create_local_attention_mask(seq_len, window_size)
print(f"Local Attention 연결 수: {local_mask.sum().item():,.0f}") # ~65K
print(f"Full Attention 연결 수: {seq_len * seq_len:,}") # 1,048,576
# Hugging Face Longformer 사용
from transformers import LongformerModel, LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')
# 최대 4096 토큰까지 처리 가능 (BERT는 512)
text = "This is a very long document... " * 500
inputs = tokenizer(text, return_tensors='pt', max_length=4096, truncation=True)
# global_attention_mask: 1인 위치는 모든 토큰과 attend
global_attention_mask = torch.zeros_like(inputs['input_ids'])
global_attention_mask[:, 0] = 1 # [CLS] 토큰만 global
outputs = model(**inputs, global_attention_mask=global_attention_mask)
# === Sliding Window Attention (Mistral 스타일) ===
def create_sliding_window_mask(seq_len, window_size):
"""Sliding Window: 각 토큰이 앞의 window_size 토큰만 attend"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size + 1)
mask[i, start:i+1] = 1 # causal + window
return mask
# Mistral 7B: window_size=4096
# 128K 토큰도 4096 윈도우로 효율적 처리
sliding_mask = create_sliding_window_mask(8192, 4096)
print(f"Sliding Window 연결 수: {sliding_mask.sum().item():,.0f}") # ~33M
print(f"Full Attention 연결 수: {8192 * 8192:,}") # 67M
# === 복잡도 비교 ===
def attention_complexity_comparison(seq_len, window_size=256, stride=256):
"""각 Sparse 패턴의 연산 횟수 비교"""
full = seq_len * seq_len
local = seq_len * window_size
strided = seq_len * (window_size + seq_len // stride)
sliding = seq_len * min(window_size, seq_len)
print(f"시퀀스 길이: {seq_len:,}")
print(f"Full Attention: {full:,} ({full/1e6:.1f}M)")
print(f"Local Attention: {local:,} ({local/1e6:.1f}M) - {local/full*100:.1f}%")
print(f"Strided Attention: {strided:,} ({strided/1e6:.1f}M) - {strided/full*100:.1f}%")
print(f"Sliding Window: {sliding:,} ({sliding/1e6:.1f}M) - {sliding/full*100:.1f}%")
attention_complexity_comparison(16384, window_size=512)
# === BigBird 패턴 (Local + Global + Random) ===
def create_bigbird_mask(seq_len, window_size=64, num_global=64, num_random=64):
"""BigBird: 이론적 표현력을 유지하면서 효율성 확보"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
# Local attention (주변 토큰)
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2)
mask[i, start:end] = 1
# Global attention (처음 num_global 토큰이 전체와 연결)
if i < num_global:
mask[i, :] = 1
mask[:, i] = 1
# Random attention (무작위 연결로 이론적 표현력 유지)
random_indices = torch.randint(0, seq_len, (num_random,))
mask[i, random_indices] = 1
return mask
# === Ring Attention (분산 처리) ===
# 여러 GPU에 시퀀스를 분산하여 처리
# 각 GPU가 자신의 청크를 처리하고 결과를 전달
# 1M 토큰 이상의 초장문 처리 가능
# === 실제 모델에서의 적용 ===
# GPT-4: 128K 컨텍스트 (내부 최적화)
# Claude: 200K 컨텍스트 (Ring Attention 추정)
# Mistral 7B: 32K (Sliding Window 4096)
# Longformer: 4096 (Local + Global)
# BigBird: 4096 (Local + Global + Random)
# === Attention Sink 현상 ===
# 첫 번째 토큰에 어텐션이 집중되는 현상
# StreamingLLM에서 발견: 첫 토큰을 유지하면 무한 생성 가능
# === 성능 벤치마크 ===
# 시퀀스 길이 16K에서:
# - Full Attention: 메모리 OOM (A100 40GB)
# - Sparse Attention (window=512): 8GB
# - Flash Attention: 12GB (더 정확, 더 빠름)
# - Sparse + Flash: 6GB (최적)
def memory_usage_comparison(seq_lengths):
"""다양한 어텐션 방식 메모리 사용량 비교"""
for n in seq_lengths:
full = n * n * 4 / 1e9 # FP32, GB
sparse_512 = n * 512 * 4 / 1e9
print(f"N={n}: Full={full:.1f}GB, Sparse={sparse_512:.3f}GB")
memory_usage_comparison([4096, 8192, 16384, 32768])
# === 구현 도전 과제 ===
# 1. 커스텀 CUDA 커널 필요 (PyTorch 기본 X)
# 2. Sparse 패턴이 GPU에서 비효율적일 수 있음
# 3. 컴파일러 최적화 어려움
# Flash Attention이 대안으로 부상한 이유!
# === 최신 트렌드 ===
# - Linear Attention: softmax 대신 커널 근사
# - Performer: FAVOR+ 메커니즘
# - Linformer: 저랭크 근사
# 모두 O(n^2) -> O(n) 목표
# === Dilated Attention (LongNet) ===
# 지수적으로 증가하는 dilation으로 1B 토큰 처리
# 가까운 토큰은 dense, 먼 토큰은 sparse하게 처리
# === 실전 체크리스트 ===
# 1. 시퀀스 길이가 2K 이하면 Full Attention으로 충분
# 2. 2K-16K: Sliding Window 또는 Flash Attention
# 3. 16K-100K: BigBird, Longformer + Flash Attention
# 4. 100K+: Ring Attention, LongNet 고려
# 5. GPU 메모리와 latency 요구사항에 맞춰 선택
# === 최신 모델별 Attention 전략 ===
# | 모델 | 컨텍스트 | 전략 |
# |-----|---------|-----|
# | GPT-4 Turbo | 128K | 내부 최적화 |
# | Claude 3 | 200K | Ring Attention |
# | Gemini 1.5 | 1M | 멀티모달 최적화 |
"법률 문서가 보통 3만 토큰이라 BERT로는 안 됩니다. Longformer 쓰면 4096 토큰까지 한 번에 처리 가능하고, 여러 청크로 나눠 처리하는 것보다 문맥 파악이 훨씬 좋아요. Local + Global Attention 덕분이죠."
"Sparse Attention은 Full Attention의 O(N^2) 문제를 해결합니다. 모든 쌍이 아니라 Local(주변), Global(핵심 토큰), Random 연결만 계산해서 O(N)이나 O(N log N)으로 줄이죠. BigBird, Longformer가 대표적이고, 최근 Mistral의 Sliding Window도 같은 아이디어입니다."
"128K 컨텍스트가 필요하면 Sparse Attention 모델이나 Flash Attention이 적용된 모델을 써야 해요. Claude나 GPT-4 Turbo는 내부적으로 최적화가 되어 있고, 오픈소스면 Mistral이 Sliding Window로 효율적입니다."
Local Attention 윈도우가 너무 작으면 멀리 떨어진 관계를 포착하지 못합니다. 문서 내 상호참조(예: "앞서 언급한 것처럼")가 깨질 수 있습니다.
Longformer에서 [CLS]나 질문 토큰을 global로 설정하지 않으면 요약/QA 성능이 크게 떨어집니다. 태스크에 맞는 global 위치 설정이 중요합니다.
Local window는 최소 256~512, 문서 분류는 [CLS]를 global로, QA는 질문 토큰 전체를 global로 설정하세요. 필요시 Flash Attention과 조합해 효율을 극대화합니다.