🤖 AI/ML

Knowledge Distillation

큰 모델의 지식을 작은 모델로 전이하는 기법

📖 상세 설명

Knowledge Distillation(지식 증류)은 크고 복잡한 Teacher 모델의 지식을 작고 가벼운 Student 모델로 전이하는 모델 압축 기법입니다. 핵심은 Teacher 모델의 최종 예측이 아닌 "소프트 타겟(soft target)" - 즉 확률 분포 전체를 학습하는 것입니다.

2015년 Hinton이 제안한 이 기법은 모바일/엣지 디바이스 배포의 핵심 기술로 자리잡았습니다. GPT-4 급 모델을 직접 배포할 수 없는 상황에서, 증류된 작은 모델로 비슷한 성능을 내는 것이 실무 목표입니다.

작동 원리: Teacher 모델의 출력에 Temperature(T)를 적용해 확률 분포를 부드럽게 만들고, Student는 이 "dark knowledge"를 학습합니다. T=1이면 일반 softmax, T가 높을수록 분포가 평탄해져 클래스 간 관계 정보가 더 많이 전달됩니다.

실무에서 Knowledge Distillation은 BERT를 DistilBERT(40% 작음, 97% 성능 유지)로, GPT를 경량 버전으로 압축하는 데 사용됩니다. 추론 비용을 10배 이상 줄이면서 성능 저하를 5% 미만으로 유지하는 것이 일반적 목표입니다.

💻 코드 예제

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    """Knowledge Distillation Loss 구현"""
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()
        self.T = temperature
        self.alpha = alpha  # distillation loss 가중치

    def forward(self, student_logits, teacher_logits, labels):
        # 1. Hard Target Loss (실제 라벨과의 CE)
        hard_loss = F.cross_entropy(student_logits, labels)

        # 2. Soft Target Loss (Teacher 분포 학습)
        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        soft_loss = soft_loss * (self.T ** 2)  # gradient 스케일 보정

        # 3. 최종 Loss = alpha * soft + (1-alpha) * hard
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# 사용 예시
teacher = BigModel().eval()  # 고정
student = SmallModel()
criterion = DistillationLoss(temperature=4.0, alpha=0.7)

for images, labels in dataloader:
    with torch.no_grad():
        teacher_logits = teacher(images)
    student_logits = student(images)
    loss = criterion(student_logits, teacher_logits, labels)
    loss.backward()
    # Student만 1.2GB, Teacher 45GB -> 배포 시 Student만 사용

🗣️ 실무에서 이렇게 말하세요

💬 회의에서
"GPT-4 API 비용이 월 200만원 나오는데, Knowledge Distillation으로 7B 모델 만들어서 온프레미스에 배포하면 90% 절감 가능합니다. Teacher-Student 학습에 2주, 성능은 원본의 92% 목표로 잡겠습니다."
💬 면접에서
"DistilBERT 프로젝트에서 Temperature 4.0, alpha 0.7로 설정해 BERT-base를 40% 압축했습니다. Soft target으로 클래스 간 관계 정보까지 전달되어, 단순 fine-tuning 대비 F1이 3% 높았습니다."
💬 기술 토론에서
"Response-based distillation만으로 한계가 있어서, Feature-based로 중간 레이어 representation도 맞추고 있어요. Attention map alignment까지 추가하면 Student가 Teacher의 reasoning 패턴도 학습합니다."

⚠️ 흔한 실수 & 주의사항

Temperature를 너무 낮게 설정 (T=1)

T가 낮으면 확률 분포가 sharp해져서 dark knowledge 전달이 안 됩니다. T=3~6 범위에서 실험하세요.

Student 용량을 너무 작게 설정

Teacher의 1/10 미만으로 줄이면 capacity 부족으로 지식 수용 불가. 1/3~1/5가 적정선입니다.

올바른 방법

Teacher를 반드시 eval() 모드로 고정하고, Student만 학습. alpha 값은 0.5~0.9에서 validation으로 튜닝하세요.

🔗 관련 용어

📚 더 배우기