PPO
Proximal Policy Optimization
안정적인 정책 업데이트를 위한 강화학습 알고리즘. RLHF에서 주로 사용.
Proximal Policy Optimization
안정적인 정책 업데이트를 위한 강화학습 알고리즘. RLHF에서 주로 사용.
PPO(Proximal Policy Optimization)는 강화학습(RL)에서 정책(policy)을 안정적으로 업데이트하는 알고리즘입니다. 특히 대형 언어 모델(LLM)의 RLHF(Reinforcement Learning from Human Feedback)에서 핵심 알고리즘으로 사용되며, ChatGPT, Claude 등의 학습에 활용됩니다.
PPO는 2017년 OpenAI의 John Schulman이 발표했습니다. 이전의 TRPO(Trust Region Policy Optimization)가 복잡한 2차 최적화를 필요로 했던 것과 달리, PPO는 간단한 클리핑(clipping) 기법으로 비슷한 성능을 달성하여 구현과 확장이 용이해졌습니다.
핵심 원리는 정책 업데이트 폭을 제한하는 것입니다. 새 정책과 이전 정책의 비율(ratio)이 특정 범위(예: 0.8~1.2)를 벗어나면 클리핑하여, 한 번에 너무 큰 변화가 일어나는 것을 방지합니다. 이로 인해 학습이 안정적으로 진행됩니다.
RLHF에서 PPO는 보상 모델(Reward Model)의 점수를 최대화하도록 LLM을 학습시킵니다. 인간 선호도를 반영한 보상 신호로 모델의 출력을 더 유용하고 안전하게 만드는 핵심 단계입니다.
# RLHF에서 PPO 학습 예제 (trl 라이브러리 사용)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer
import torch
# 1. 모델 및 토크나이저 로드
model_name = "gpt2" # 실제로는 더 큰 모델 사용
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# 2. PPO 설정
ppo_config = PPOConfig(
model_name=model_name,
learning_rate=1e-5,
batch_size=4,
mini_batch_size=2,
ppo_epochs=4, # PPO 업데이트 반복 횟수
clip_range=0.2, # 클리핑 범위 (핵심 하이퍼파라미터)
target_kl=0.01, # KL 발산 제한
)
# 3. PPO 트레이너 초기화
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
)
# 4. 학습 루프 (간소화)
def compute_reward(response_text):
"""보상 모델에서 점수 계산 (실제로는 별도 Reward Model 사용)"""
# 예시: 응답 길이에 따른 단순 보상
return len(response_text) * 0.01
# 5. PPO 학습 스텝
query_tensors = tokenizer.encode("AI의 장점은", return_tensors="pt")
response_tensors = model.generate(query_tensors, max_new_tokens=50)
# 보상 계산
response_text = tokenizer.decode(response_tensors[0])
reward = torch.tensor([compute_reward(response_text)])
# PPO 업데이트
stats = ppo_trainer.step([query_tensors[0]], [response_tensors[0]], [reward])
print(f"PPO Loss: {stats['ppo/loss/total']:.4f}")
"RLHF 파이프라인에서 SFT 후 PPO로 보상 모델 점수를 최적화하고 있습니다. clip_range=0.2로 설정하여 정책이 급격히 변하지 않도록 제어하고 있습니다."
"PPO는 정책 그래디언트 방법의 일종으로, 클리핑을 통해 정책 업데이트 폭을 제한합니다. TRPO와 달리 1차 최적화만 사용하여 구현이 간단하고 확장성이 좋습니다."
"최근에는 DPO(Direct Preference Optimization)가 PPO를 대체하는 경우도 있습니다. DPO는 보상 모델 없이 직접 선호도 데이터로 학습하여 RLHF 파이프라인을 단순화합니다."
KL 발산 제약 없이 PPO를 적용하면 모델이 보상 해킹(reward hacking)에 빠져 이상한 출력을 생성할 수 있습니다.
clip_range가 너무 크면(예: 0.5) 정책이 급격히 변해 학습이 불안정해집니다. 보통 0.1~0.3 범위를 사용합니다.
참조 모델(ref_model)을 유지하여 KL 발산을 모니터링하고, 보상 모델의 품질을 충분히 검증한 후 PPO 학습을 시작하세요.