Knowledge Distillation
큰 모델의 지식을 작은 모델로 전이하는 기법
큰 모델의 지식을 작은 모델로 전이하는 기법
Knowledge Distillation(지식 증류)은 크고 복잡한 Teacher 모델의 지식을 작고 가벼운 Student 모델로 전이하는 모델 압축 기법입니다. 핵심은 Teacher 모델의 최종 예측이 아닌 "소프트 타겟(soft target)" - 즉 확률 분포 전체를 학습하는 것입니다.
2015년 Hinton이 제안한 이 기법은 모바일/엣지 디바이스 배포의 핵심 기술로 자리잡았습니다. GPT-4 급 모델을 직접 배포할 수 없는 상황에서, 증류된 작은 모델로 비슷한 성능을 내는 것이 실무 목표입니다.
작동 원리: Teacher 모델의 출력에 Temperature(T)를 적용해 확률 분포를 부드럽게 만들고, Student는 이 "dark knowledge"를 학습합니다. T=1이면 일반 softmax, T가 높을수록 분포가 평탄해져 클래스 간 관계 정보가 더 많이 전달됩니다.
실무에서 Knowledge Distillation은 BERT를 DistilBERT(40% 작음, 97% 성능 유지)로, GPT를 경량 버전으로 압축하는 데 사용됩니다. 추론 비용을 10배 이상 줄이면서 성능 저하를 5% 미만으로 유지하는 것이 일반적 목표입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""Knowledge Distillation Loss 구현"""
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.T = temperature
self.alpha = alpha # distillation loss 가중치
def forward(self, student_logits, teacher_logits, labels):
# 1. Hard Target Loss (실제 라벨과의 CE)
hard_loss = F.cross_entropy(student_logits, labels)
# 2. Soft Target Loss (Teacher 분포 학습)
soft_student = F.log_softmax(student_logits / self.T, dim=1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
soft_loss = soft_loss * (self.T ** 2) # gradient 스케일 보정
# 3. 최종 Loss = alpha * soft + (1-alpha) * hard
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
# 사용 예시
teacher = BigModel().eval() # 고정
student = SmallModel()
criterion = DistillationLoss(temperature=4.0, alpha=0.7)
for images, labels in dataloader:
with torch.no_grad():
teacher_logits = teacher(images)
student_logits = student(images)
loss = criterion(student_logits, teacher_logits, labels)
loss.backward()
# Student만 1.2GB, Teacher 45GB -> 배포 시 Student만 사용
"GPT-4 API 비용이 월 200만원 나오는데, Knowledge Distillation으로 7B 모델 만들어서 온프레미스에 배포하면 90% 절감 가능합니다. Teacher-Student 학습에 2주, 성능은 원본의 92% 목표로 잡겠습니다."
"DistilBERT 프로젝트에서 Temperature 4.0, alpha 0.7로 설정해 BERT-base를 40% 압축했습니다. Soft target으로 클래스 간 관계 정보까지 전달되어, 단순 fine-tuning 대비 F1이 3% 높았습니다."
"Response-based distillation만으로 한계가 있어서, Feature-based로 중간 레이어 representation도 맞추고 있어요. Attention map alignment까지 추가하면 Student가 Teacher의 reasoning 패턴도 학습합니다."
T가 낮으면 확률 분포가 sharp해져서 dark knowledge 전달이 안 됩니다. T=3~6 범위에서 실험하세요.
Teacher의 1/10 미만으로 줄이면 capacity 부족으로 지식 수용 불가. 1/3~1/5가 적정선입니다.
Teacher를 반드시 eval() 모드로 고정하고, Student만 학습. alpha 값은 0.5~0.9에서 validation으로 튜닝하세요.