Flash Attention
메모리 효율적인 어텐션 연산 알고리즘
메모리 효율적인 어텐션 연산 알고리즘
Flash Attention은 Transformer의 self-attention 연산을 메모리 효율적으로 수행하는 알고리즘입니다. 기존 attention은 시퀀스 길이의 제곱(O(n^2))에 비례하는 메모리를 사용하지만, Flash Attention은 tiling과 recomputation 기법으로 이를 선형(O(n))으로 줄입니다.
2022년 Stanford의 Tri Dao 등이 발표한 Flash Attention은 GPU 메모리 계층(HBM vs SRAM)을 고려한 IO-aware 알고리즘입니다. GPU의 고속 SRAM에서 계산을 수행하고, 느린 HBM 접근을 최소화하여 속도와 메모리 효율을 동시에 개선했습니다.
핵심 아이디어는 attention matrix 전체를 메모리에 저장하지 않고, 작은 블록(tile) 단위로 나누어 계산하는 것입니다. softmax의 분자와 분모를 온라인으로 누적 계산하고, 역전파 시 필요한 값은 재계산(recomputation)합니다. 수학적으로 정확한 결과를 보장하면서도 2~4배 빠릅니다.
현재 GPT-4, Llama, Claude 등 거의 모든 대형 LLM이 Flash Attention을 사용합니다. 시퀀스 길이 2K에서 10배, 4K에서 20배의 메모리 절약 효과가 있어, 128K 이상의 긴 context를 처리하는 핵심 기술이 되었습니다. Flash Attention 3은 H100 GPU에서 1.5~2배 추가 속도 향상을 달성했습니다.
import torch
from torch.nn.functional import scaled_dot_product_attention
# PyTorch 2.0+에서 Flash Attention 자동 사용
# CUDA 11.6+, Ampere GPU (A100, RTX 30xx) 이상 필요
# 입력 텐서 준비 (batch, heads, seq_len, head_dim)
batch_size = 2
num_heads = 8
seq_len = 4096 # 긴 시퀀스
head_dim = 64
query = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='cuda', dtype=torch.float16)
key = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='cuda', dtype=torch.float16)
value = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='cuda', dtype=torch.float16)
# scaled_dot_product_attention은 자동으로 Flash Attention 사용
# (조건 충족 시: CUDA, fp16/bf16, 적절한 텐서 shape)
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # Flash Attention 활성화
enable_math=False, # 기본 수학 연산 비활성화
enable_mem_efficient=False # Memory Efficient Attention 비활성화
):
output = scaled_dot_product_attention(query, key, value)
print(f"Input shape: {query.shape}")
print(f"Output shape: {output.shape}")
# HuggingFace Transformers에서 Flash Attention 2 사용
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2", # Flash Attention 2 명시
device_map="auto"
)
# 메모리 사용량 비교
print(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
"32K context를 처리하려면 Flash Attention이 필수입니다. 일반 attention으로는 A100 80GB에서도 OOM이 나는데, Flash Attention 2를 쓰면 배치 사이즈도 늘릴 수 있어요."
"Flash Attention은 attention matrix를 tile 단위로 나눠서 SRAM에서 계산하고, HBM 접근을 최소화하는 IO-aware 알고리즘입니다. softmax를 online으로 계산해서 메모리를 O(n)으로 줄입니다."
"HuggingFace에서 attn_implementation='flash_attention_2' 옵션 하나로 적용 가능해요. 다만 causal mask가 기본이라 encoder-decoder 모델에선 주의가 필요하고, attention weight 시각화가 안 되는 단점이 있습니다."
Flash Attention은 NVIDIA Ampere(A100, RTX 30xx) 이상에서만 지원됩니다. V100이나 RTX 20xx에서는 에러가 발생하거나 fallback됩니다.
Flash Attention은 FP16 또는 BF16에서만 동작합니다. FP32 텐서를 넣으면 자동으로 느린 기본 attention으로 fallback됩니다.
PyTorch 2.0+ 사용, torch.float16 또는 torch.bfloat16으로 캐스팅, Ampere+ GPU 확인. HuggingFace에서는 attn_implementation="flash_attention_2" 명시적 설정을 권장합니다.