GAN
Generative Adversarial Network
생성자와 판별자가 경쟁하며 학습하는 생성 모델. 이미지 생성, 스타일 변환 등에 활용.
Generative Adversarial Network
생성자와 판별자가 경쟁하며 학습하는 생성 모델. 이미지 생성, 스타일 변환 등에 활용.
GAN(Generative Adversarial Network)은 2014년 Ian Goodfellow가 제안한 생성 모델로, 두 개의 신경망이 적대적으로 경쟁하며 학습하는 구조입니다. 생성자(Generator)는 노이즈로부터 실제 같은 데이터를 생성하고, 판별자(Discriminator)는 생성된 데이터와 실제 데이터를 구분하려 합니다.
GAN의 핵심 아이디어는 게임 이론에서 영감을 받았습니다. 생성자와 판별자가 min-max 게임을 진행하며, 생성자는 판별자를 속이려 하고 판별자는 속지 않으려 합니다. 이 경쟁 과정에서 생성자는 점점 더 현실적인 데이터를 만들어냅니다.
학습 과정에서 판별자의 손실 함수는 이진 크로스 엔트로피를 사용하며, 생성자는 판별자가 생성된 이미지를 진짜로 판단하도록 학습됩니다. Nash 균형에 도달하면 생성자는 실제 데이터 분포와 동일한 분포를 생성하게 됩니다.
실무에서 GAN은 이미지 생성(StyleGAN), 이미지-이미지 변환(Pix2Pix, CycleGAN), 초해상도(ESRGAN), 데이터 증강, 얼굴 생성 등에 활용됩니다. NVIDIA의 StyleGAN3는 1024x1024 해상도의 포토리얼리스틱 얼굴을 생성하며, 학습에 8개 V100 GPU로 약 25일이 소요됩니다.
import torch
import torch.nn as nn
# 생성자 네트워크
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
super().__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img.view(img.size(0), *self.img_shape)
# 판별자 네트워크
class Discriminator(nn.Module):
def __init__(self, img_shape=(1, 28, 28)):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
# 학습 루프 예시
def train_gan(generator, discriminator, dataloader, epochs=100):
criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
for epoch in range(epochs):
for real_imgs, _ in dataloader:
batch_size = real_imgs.size(0)
# 판별자 학습
z = torch.randn(batch_size, 100)
fake_imgs = generator(z)
d_loss = criterion(discriminator(real_imgs), torch.ones(batch_size, 1)) + \
criterion(discriminator(fake_imgs.detach()), torch.zeros(batch_size, 1))
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 생성자 학습
g_loss = criterion(discriminator(fake_imgs), torch.ones(batch_size, 1))
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print(f"Epoch {epoch}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
"GAN의 mode collapse 문제를 해결하기 위해 어떤 방법을 사용하셨나요?" - "저희 프로젝트에서는 Wasserstein GAN을 적용해서 학습 안정성을 높였고, minibatch discrimination 기법을 추가해서 다양한 출력을 생성하도록 개선했습니다. 그 결과 FID 점수가 45에서 28로 향상되었습니다."
"이번 데이터 증강에 GAN을 도입하려고 하는데, 기존 augmentation 대비 장단점이 뭔가요?" - "GAN 기반 증강은 도메인 특화 변형을 학습할 수 있어서 의료 영상처럼 특수한 데이터에 효과적입니다. 다만 학습에 GPU 리소스가 많이 들고, 생성 품질 검증을 위해 FID나 IS 같은 메트릭을 함께 모니터링해야 합니다."
"생성자에 BatchNorm 대신 InstanceNorm을 쓴 이유가 있나요?" - "네, 이미지 스타일 변환 태스크에서는 InstanceNorm이 각 샘플의 스타일을 독립적으로 정규화해서 더 일관된 스타일 전이가 가능합니다. StyleGAN에서도 AdaIN을 사용하는 것과 같은 맥락입니다."
생성자가 다양성을 잃고 비슷한 이미지만 생성하는 현상입니다. Wasserstein loss, spectral normalization, progressive growing 등의 기법으로 완화할 수 있습니다.
판별자가 너무 강해지면 생성자가 학습하지 못하고, 반대면 생성 품질이 낮아집니다. 생성자와 판별자의 학습 비율 조절(예: 1:5)과 learning rate 튜닝이 필수입니다.
FID(Frechet Inception Distance)가 낮다고 항상 좋은 것은 아닙니다. 생성 다양성과 품질을 함께 평가해야 하며, 도메인별로 적합한 평가 방법이 다릅니다.