Model Pruning
모델 프루닝 / 가지치기
모델에서 불필요한 연결(가중치)을 제거하여 크기와 연산량을 줄이는 경량화 기법. 모바일/엣지 배포와 추론 속도 향상의 핵심 기술.
모델 프루닝 / 가지치기
모델에서 불필요한 연결(가중치)을 제거하여 크기와 연산량을 줄이는 경량화 기법. 모바일/엣지 배포와 추론 속도 향상의 핵심 기술.
Model Pruning(모델 프루닝)은 신경망에서 중요도가 낮은 가중치나 뉴런을 제거하여 모델을 경량화하는 기법입니다. 나무의 불필요한 가지를 쳐내듯, 성능에 큰 영향을 주지 않는 연결을 제거하여 모델 크기를 50-90% 줄이면서 정확도는 1-2% 이내로 유지할 수 있습니다.
프루닝 연구는 1990년대 "Optimal Brain Damage" 논문에서 시작되었으며, 딥러닝 시대에 들어 대규모 모델의 효율적 배포가 중요해지면서 다시 주목받고 있습니다. 특히 LLM과 Vision Transformer의 등장으로 모델 크기가 급증하면서 프루닝의 필요성이 더욱 커졌습니다.
프루닝 방식은 크게 Unstructured(비구조적)와 Structured(구조적)로 나뉩니다. Unstructured는 개별 가중치를 0으로 만들어 희소 행렬을 생성하고, Structured는 전체 필터, 채널, 어텐션 헤드 등 구조 단위로 제거합니다. Structured가 실제 속도 향상에 더 효과적이지만, Unstructured가 더 높은 압축률을 달성합니다.
실무에서 프루닝은 모바일/엣지 디바이스 배포, 추론 비용 절감, 실시간 응답 요구 서비스에 필수입니다. PyTorch의 torch.nn.utils.prune, TensorFlow Model Optimization Toolkit, NVIDIA의 ASP(Automatic Sparsity) 등이 프루닝을 지원하며, 최근에는 SparseGPT, Wanda 등 LLM 전용 프루닝 기법도 활발히 연구되고 있습니다.
| 프루닝 방식 | 압축률 | 속도 향상 | 장단점 |
|---|---|---|---|
| Unstructured | 90-95% | 제한적 (하드웨어 의존) | 높은 압축, 희소 행렬 연산 필요 |
| Structured (필터/채널) | 50-70% | 2-4배 | 범용 하드웨어에서 실제 속도 향상 |
| Semi-structured (2:4) | 50% | 2배 (A100+) | NVIDIA Ampere+ GPU 최적화 |
| LLM Pruning (SparseGPT) | 50-60% | 1.5-2배 | 원샷 프루닝, 재학습 불필요 |
# PyTorch 모델 프루닝 예제
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# 1. 간단한 모델 정의
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.fc = nn.Linear(128 * 8 * 8, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
return self.fc(x)
model = SimpleCNN()
# 2. Unstructured Pruning (L1 크기 기준)
# Conv 레이어의 30% 가중치를 0으로
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)
prune.l1_unstructured(model.conv2, name='weight', amount=0.3)
# 프루닝 상태 확인
print(f"conv1 희소성: {100 * float(torch.sum(model.conv1.weight == 0)) / model.conv1.weight.nelement():.1f}%")
# 3. Structured Pruning (전체 필터 제거)
# 중요도 낮은 필터 20% 제거
prune.ln_structured(model.conv2, name='weight', amount=0.2, n=2, dim=0)
# 4. 글로벌 프루닝 (모델 전체에서 가장 작은 가중치 제거)
parameters_to_prune = [
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc, 'weight'),
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # 전체 가중치의 40% 제거
)
# 5. 프루닝 영구 적용 (마스크를 실제 가중치에 반영)
for module, name in parameters_to_prune:
prune.remove(module, name)
# 6. 모델 크기 비교
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def count_nonzero(model):
return sum((p != 0).sum().item() for p in model.parameters())
total = count_parameters(model)
nonzero = count_nonzero(model)
print(f"총 파라미터: {total:,}")
print(f"0이 아닌 파라미터: {nonzero:,}")
print(f"희소성: {100 * (1 - nonzero/total):.1f}%")
# 7. Fine-tuning (프루닝 후 성능 복구)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 간단한 학습 루프 (실제로는 데이터로더 사용)
# for epoch in range(5):
# for x, y in dataloader:
# optimizer.zero_grad()
# loss = criterion(model(x), y)
# loss.backward()
# optimizer.step()
# 8. 희소 모델 저장
torch.save(model.state_dict(), "pruned_model.pt")
print("프루닝된 모델 저장 완료!")
# ===== NVIDIA 2:4 Semi-structured Sparsity =====
# A100 GPU 이상에서 2배 속도 향상
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
# 2:4 패턴: 4개 원소 중 2개가 0
# NVIDIA Ampere+ GPU에서 하드웨어 가속 지원
# sparse_weight = to_sparse_semi_structured(dense_weight)
"모바일 배포를 위해 ResNet50을 프루닝해봤는데, Structured Pruning으로 필터 50% 제거하니 속도가 2배 빨라지고 정확도는 1.5%만 떨어졌어요. Fine-tuning 5 에포크면 원래 성능의 99%까지 복구됩니다."
"Pruning은 중요도가 낮은 가중치를 제거하는 경량화 기법입니다. Unstructured는 개별 가중치를 0으로 만들어 높은 압축률을 달성하지만 희소 행렬 연산이 필요하고, Structured는 전체 필터/채널을 제거해서 범용 하드웨어에서도 실제 속도 향상이 가능합니다. 프루닝 후 Fine-tuning으로 성능을 복구합니다."
"LLM 프루닝은 SparseGPT나 Wanda가 좋아요. 재학습 없이 원샷으로 50% 프루닝해도 perplexity 손실이 적고, 양자화와 결합하면 모델 크기를 10배 이상 줄일 수 있습니다. A100 쓰면 2:4 sparsity로 하드웨어 가속도 받을 수 있고요."
가중치를 0으로 만드는 것만으로는 실제 속도가 향상되지 않습니다. 희소 행렬 연산을 지원하는 하드웨어/라이브러리가 필요합니다.
50% 이상 프루닝하면 성능이 크게 떨어집니다. 반드시 Fine-tuning으로 성능을 복구하고, 점진적으로 프루닝 비율을 높이세요.
프루닝(50%) + 양자화(INT8/INT4)를 결합하면 모델 크기를 10-20배 줄일 수 있습니다. Iterative pruning으로 점진적 프루닝 → Fine-tuning을 반복하면 높은 압축률에서도 성능 유지가 가능합니다.