SFT
Supervised Fine-Tuning
라벨링된 데이터로 모델을 지도학습 방식으로 파인튜닝. RLHF 전 단계. ChatGPT 같은 챗봇의 핵심 학습 방법.
Supervised Fine-Tuning
라벨링된 데이터로 모델을 지도학습 방식으로 파인튜닝. RLHF 전 단계. ChatGPT 같은 챗봇의 핵심 학습 방법.
SFT(Supervised Fine-Tuning)는 사전 학습된 LLM을 고품질의 (입력, 출력) 쌍 데이터로 지도학습하는 과정입니다. "사용자가 이렇게 질문하면 이렇게 답해야 한다"는 것을 직접 보여주는 방식으로, Instruction Tuning이라고도 불립니다.
GPT-3.5/4, Claude, Llama 2 Chat 등 모든 대화형 LLM은 SFT를 거칩니다. Base 모델(GPT-3)은 단순히 다음 토큰을 예측하지만, SFT를 거친 모델(ChatGPT)은 사용자 질문에 적절히 응답합니다. OpenAI는 이 과정에 수천 명의 인간 레이블러를 동원했습니다.
일반적인 LLM 정렬 파이프라인은 Pretraining -> SFT -> RLHF 순서입니다. SFT로 기본적인 대화 능력과 지시 따르기를 학습하고, RLHF로 더 섬세한 선호도를 반영합니다. 최근에는 DPO(Direct Preference Optimization)로 RLHF를 대체하는 추세입니다.
Alpaca, Vicuna 등 오픈소스 모델들이 GPT-3.5 출력을 SFT 데이터로 활용해 성공을 거두면서, 합성 데이터 기반 SFT가 주목받고 있습니다. 1천~1만 개의 고품질 데이터만으로도 상당한 성능 향상이 가능합니다.
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import load_dataset
# 모델과 토크나이저 로드
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# SFT 데이터셋 로드 (Instruction-Output 형식)
dataset = load_dataset("tatsu-lab/alpaca", split="train")
# 데이터 포맷팅 함수
def format_instruction(example):
"""Alpaca 형식의 프롬프트 생성"""
if example.get("input"):
return f"""### Instruction:
{example['instruction']}
### Input:
{example['input']}
### Response:
{example['output']}"""
else:
return f"""### Instruction:
{example['instruction']}
### Response:
{example['output']}"""
# 학습 설정
training_args = TrainingArguments(
output_dir="./llama-2-7b-sft",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
warmup_ratio=0.03,
logging_steps=10,
save_strategy="epoch",
bf16=True,
)
# SFT Trainer 사용
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
formatting_func=format_instruction,
max_seq_length=2048,
tokenizer=tokenizer,
)
# 학습 실행
trainer.train()
trainer.save_model("./llama-2-7b-sft-final")
# A100 80GB 기준: 7B 모델 SFT ~4시간, 52K 샘플
# === LoRA + SFT (메모리 효율적) ===
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16, # LoRA rank
lora_alpha=32, # 스케일링 팩터
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model_lora = get_peft_model(model, lora_config)
model_lora.print_trainable_parameters()
# trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0622
# LoRA SFT는 RTX 4090 24GB로도 7B 모델 학습 가능!
# === ChatML 형식 데이터 ===
def format_chatml(example):
"""ChatML 형식 (OpenAI, Mistral 호환)"""
return f"""<|im_start|>system
You are a helpful AI assistant.<|im_end|>
<|im_start|>user
{example['instruction']}<|im_end|>
<|im_start|>assistant
{example['output']}<|im_end|>"""
# === 고품질 데이터 생성 (GPT-4 활용) ===
from openai import OpenAI
import json
client = OpenAI()
def generate_sft_data(topic: str, num_samples: int = 10):
"""GPT-4로 SFT 데이터 합성"""
samples = []
for _ in range(num_samples):
response = client.chat.completions.create(
model="gpt-4-turbo",
messages=[{
"role": "user",
"content": f"""다음 주제에 대한 Q&A 데이터를 생성하세요.
주제: {topic}
JSON 형식:
{{"instruction": "사용자 질문", "output": "상세하고 정확한 답변"}}"""
}],
response_format={"type": "json_object"}
)
samples.append(json.loads(response.choices[0].message.content))
return samples
# === 평가 메트릭 ===
def evaluate_sft_model(model, tokenizer, test_prompts):
"""SFT 모델 품질 평가"""
results = []
for prompt in test_prompts:
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=200)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
results.append({"prompt": prompt, "response": response})
return results
# GPT-4로 응답 품질 평가 (LLM-as-Judge)
def evaluate_with_gpt4(responses):
"""GPT-4로 응답 품질 점수 매기기"""
scores = []
for r in responses:
response = client.chat.completions.create(
model="gpt-4-turbo",
messages=[{
"role": "user",
"content": f"""다음 응답의 품질을 1-10점으로 평가하세요.
질문: {r['prompt']}
응답: {r['response']}
점수 (숫자만):"""
}]
)
scores.append(int(response.choices[0].message.content.strip()))
return sum(scores) / len(scores)
"Base Llama 모델은 대화를 못 해요. 우리 도메인에 맞는 Q&A 데이터 1만 개 정도 만들어서 SFT 돌리면 됩니다. LoRA 쓰면 A100 1대로 하루면 충분해요. RLHF까지 갈 필요 없이 SFT만으로도 충분히 쓸만합니다."
"SFT는 LLM이 지시를 따르도록 학습하는 과정입니다. Pretraining된 모델은 단순 텍스트 완성만 하지만, SFT를 거치면 '요약해줘', '번역해줘' 같은 명령을 이해합니다. ChatGPT는 SFT + RLHF로 만들어졌고, 최근엔 DPO가 더 효율적인 대안으로 떠오르고 있어요."
"SFT 성공의 80%는 데이터 품질입니다. LIMA 논문 보면 1,000개의 고품질 데이터가 52,000개 Alpaca 데이터보다 성능이 좋아요. 양보다 질이에요. GPT-4로 합성 데이터 만들고 사람이 검수하는 게 가성비 최고입니다."
노이즈가 많은 데이터로 SFT하면 모델 성능이 오히려 하락합니다. 1만 개 고품질 > 100만 개 저품질입니다.
특정 도메인에만 SFT하면 일반 능력이 사라집니다. 다양한 태스크를 혼합하거나 LoRA로 원본 지식을 보존하세요.
프롬프트 형식을 일관되게 유지하고, 도메인 데이터와 일반 데이터를 8:2로 혼합하세요. Learning rate은 base보다 10배 낮게 시작합니다.