🤖AI/ML

SFT

Supervised Fine-Tuning

라벨링된 데이터로 모델을 지도학습 방식으로 파인튜닝. RLHF 전 단계. ChatGPT 같은 챗봇의 핵심 학습 방법.

📖 상세 설명

SFT(Supervised Fine-Tuning)는 사전 학습된 LLM을 고품질의 (입력, 출력) 쌍 데이터로 지도학습하는 과정입니다. "사용자가 이렇게 질문하면 이렇게 답해야 한다"는 것을 직접 보여주는 방식으로, Instruction Tuning이라고도 불립니다.

GPT-3.5/4, Claude, Llama 2 Chat 등 모든 대화형 LLM은 SFT를 거칩니다. Base 모델(GPT-3)은 단순히 다음 토큰을 예측하지만, SFT를 거친 모델(ChatGPT)은 사용자 질문에 적절히 응답합니다. OpenAI는 이 과정에 수천 명의 인간 레이블러를 동원했습니다.

일반적인 LLM 정렬 파이프라인은 Pretraining -> SFT -> RLHF 순서입니다. SFT로 기본적인 대화 능력과 지시 따르기를 학습하고, RLHF로 더 섬세한 선호도를 반영합니다. 최근에는 DPO(Direct Preference Optimization)로 RLHF를 대체하는 추세입니다.

Alpaca, Vicuna 등 오픈소스 모델들이 GPT-3.5 출력을 SFT 데이터로 활용해 성공을 거두면서, 합성 데이터 기반 SFT가 주목받고 있습니다. 1천~1만 개의 고품질 데이터만으로도 상당한 성능 향상이 가능합니다.

💻 코드 예제

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import load_dataset

# 모델과 토크나이저 로드
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# SFT 데이터셋 로드 (Instruction-Output 형식)
dataset = load_dataset("tatsu-lab/alpaca", split="train")

# 데이터 포맷팅 함수
def format_instruction(example):
    """Alpaca 형식의 프롬프트 생성"""
    if example.get("input"):
        return f"""### Instruction:
{example['instruction']}

### Input:
{example['input']}

### Response:
{example['output']}"""
    else:
        return f"""### Instruction:
{example['instruction']}

### Response:
{example['output']}"""

# 학습 설정
training_args = TrainingArguments(
    output_dir="./llama-2-7b-sft",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    warmup_ratio=0.03,
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
)

# SFT Trainer 사용
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    formatting_func=format_instruction,
    max_seq_length=2048,
    tokenizer=tokenizer,
)

# 학습 실행
trainer.train()
trainer.save_model("./llama-2-7b-sft-final")

# A100 80GB 기준: 7B 모델 SFT ~4시간, 52K 샘플

# === LoRA + SFT (메모리 효율적) ===
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,                          # LoRA rank
    lora_alpha=32,                 # 스케일링 팩터
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model_lora = get_peft_model(model, lora_config)
model_lora.print_trainable_parameters()
# trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622

# LoRA SFT는 RTX 4090 24GB로도 7B 모델 학습 가능!

# === ChatML 형식 데이터 ===
def format_chatml(example):
    """ChatML 형식 (OpenAI, Mistral 호환)"""
    return f"""<|im_start|>system
You are a helpful AI assistant.<|im_end|>
<|im_start|>user
{example['instruction']}<|im_end|>
<|im_start|>assistant
{example['output']}<|im_end|>"""

# === 고품질 데이터 생성 (GPT-4 활용) ===
from openai import OpenAI
import json

client = OpenAI()

def generate_sft_data(topic: str, num_samples: int = 10):
    """GPT-4로 SFT 데이터 합성"""
    samples = []
    for _ in range(num_samples):
        response = client.chat.completions.create(
            model="gpt-4-turbo",
            messages=[{
                "role": "user",
                "content": f"""다음 주제에 대한 Q&A 데이터를 생성하세요.
주제: {topic}

JSON 형식:
{{"instruction": "사용자 질문", "output": "상세하고 정확한 답변"}}"""
            }],
            response_format={"type": "json_object"}
        )
        samples.append(json.loads(response.choices[0].message.content))
    return samples

# === 평가 메트릭 ===
def evaluate_sft_model(model, tokenizer, test_prompts):
    """SFT 모델 품질 평가"""
    results = []
    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_new_tokens=200)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        results.append({"prompt": prompt, "response": response})
    return results

# GPT-4로 응답 품질 평가 (LLM-as-Judge)
def evaluate_with_gpt4(responses):
    """GPT-4로 응답 품질 점수 매기기"""
    scores = []
    for r in responses:
        response = client.chat.completions.create(
            model="gpt-4-turbo",
            messages=[{
                "role": "user",
                "content": f"""다음 응답의 품질을 1-10점으로 평가하세요.
질문: {r['prompt']}
응답: {r['response']}
점수 (숫자만):"""
            }]
        )
        scores.append(int(response.choices[0].message.content.strip()))
    return sum(scores) / len(scores)

🗣️ 실무에서 이렇게 말하세요

💬 LLM 커스터마이징 미팅에서
"Base Llama 모델은 대화를 못 해요. 우리 도메인에 맞는 Q&A 데이터 1만 개 정도 만들어서 SFT 돌리면 됩니다. LoRA 쓰면 A100 1대로 하루면 충분해요. RLHF까지 갈 필요 없이 SFT만으로도 충분히 쓸만합니다."
💬 면접에서
"SFT는 LLM이 지시를 따르도록 학습하는 과정입니다. Pretraining된 모델은 단순 텍스트 완성만 하지만, SFT를 거치면 '요약해줘', '번역해줘' 같은 명령을 이해합니다. ChatGPT는 SFT + RLHF로 만들어졌고, 최근엔 DPO가 더 효율적인 대안으로 떠오르고 있어요."
💬 데이터 품질 논의에서
"SFT 성공의 80%는 데이터 품질입니다. LIMA 논문 보면 1,000개의 고품질 데이터가 52,000개 Alpaca 데이터보다 성능이 좋아요. 양보다 질이에요. GPT-4로 합성 데이터 만들고 사람이 검수하는 게 가성비 최고입니다."

⚠️ 흔한 실수 & 주의사항

저품질 데이터로 대량 SFT

노이즈가 많은 데이터로 SFT하면 모델 성능이 오히려 하락합니다. 1만 개 고품질 > 100만 개 저품질입니다.

Catastrophic Forgetting

특정 도메인에만 SFT하면 일반 능력이 사라집니다. 다양한 태스크를 혼합하거나 LoRA로 원본 지식을 보존하세요.

올바른 방법

프롬프트 형식을 일관되게 유지하고, 도메인 데이터와 일반 데이터를 8:2로 혼합하세요. Learning rate은 base보다 10배 낮게 시작합니다.

🔗 관련 용어

📚 더 배우기