🤖 AI/ML

VAE

Variational Autoencoder (변분 오토인코더)

확률적 잠재 공간을 학습하는 생성 모델. 이미지 생성, 표현 학습, 데이터 압축에 활용되는 딥러닝 아키텍처.

📖 상세 설명

VAE(Variational Autoencoder, 변분 오토인코더)는 데이터의 잠재 표현(latent representation)을 확률적으로 학습하는 생성 모델입니다. 2013년 Kingma와 Welling이 제안한 이 모델은 기존 오토인코더에 확률론적 접근을 결합하여, 단순한 압축을 넘어 새로운 데이터를 생성할 수 있는 능력을 갖추었습니다. VAE는 Encoder가 입력을 잠재 공간의 확률 분포(평균과 분산)로 인코딩하고, Decoder가 이 분포에서 샘플링하여 데이터를 재구성합니다.

VAE의 핵심은 손실 함수에 있습니다. 재구성 손실(Reconstruction Loss)과 KL 발산(Kullback-Leibler Divergence) 두 항으로 구성됩니다. 재구성 손실은 입력과 출력의 차이를 최소화하여 정보 보존을 보장하고, KL 발산은 학습된 잠재 분포가 표준 정규분포에 가까워지도록 정규화합니다. 이 정규화 덕분에 잠재 공간이 연속적이고 의미 있는 구조를 갖게 되어, 임의의 점에서 샘플링해도 유효한 데이터를 생성할 수 있습니다.

실무에서 VAE는 이미지 생성, 이상 탐지, 데이터 증강, 특성 학습 등에 활용됩니다. 특히 Stable Diffusion과 같은 최신 이미지 생성 모델의 핵심 구성 요소로 사용되며, 고해상도 이미지를 저차원 잠재 공간에서 효율적으로 처리할 수 있게 합니다. 의료 분야에서는 정상 데이터로 학습한 VAE의 재구성 오류를 활용하여 이상(질병) 패턴을 탐지합니다.

VAE의 변형으로는 Beta-VAE(해석 가능한 잠재 변수 학습), Conditional VAE(조건부 생성), VQ-VAE(이산 잠재 공간) 등이 있습니다. GAN(Generative Adversarial Network)과 비교하면 VAE는 학습이 안정적이고 잠재 공간의 해석이 용이하지만, 생성 이미지가 다소 흐릿한 경향이 있습니다. 최근에는 VAE와 Diffusion Model을 결합한 Latent Diffusion이 이미지 생성의 표준으로 자리잡았습니다.

💻 코드 예제

PyTorch로 구현한 VAE 모델과 학습 코드:

Python
# VAE (Variational Autoencoder) PyTorch 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

class VAE(nn.Module):
    """Variational Autoencoder"""

    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()

        self.latent_dim = latent_dim

        # Encoder: 입력 -> 잠재 분포의 평균과 분산
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        # 잠재 분포 파라미터
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)      # 평균
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)  # 로그 분산

        # Decoder: 잠재 벡터 -> 재구성된 입력
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # 0-1 범위로 출력
        )

    def encode(self, x):
        """입력을 잠재 분포의 파라미터로 인코딩"""
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization Trick
        z = mu + std * epsilon 형태로 샘플링하여 역전파 가능하게 함
        """
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)  # 표준 정규분포에서 샘플링
        z = mu + std * epsilon
        return z

    def decode(self, z):
        """잠재 벡터를 원본 공간으로 디코딩"""
        return self.decoder(z)

    def forward(self, x):
        """순전파: 인코딩 -> 샘플링 -> 디코딩"""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

    def generate(self, num_samples, device='cpu'):
        """새로운 샘플 생성"""
        # 표준 정규분포에서 샘플링
        z = torch.randn(num_samples, self.latent_dim).to(device)
        samples = self.decode(z)
        return samples


def vae_loss(reconstructed, original, mu, logvar, beta=1.0):
    """
    VAE 손실 함수

    Args:
        reconstructed: 재구성된 출력
        original: 원본 입력
        mu: 잠재 분포의 평균
        logvar: 잠재 분포의 로그 분산
        beta: KL 발산 가중치 (Beta-VAE용)

    Returns:
        total_loss, recon_loss, kl_loss
    """
    # 1. 재구성 손실 (Binary Cross Entropy)
    recon_loss = F.binary_cross_entropy(reconstructed, original, reduction='sum')

    # 2. KL Divergence: 잠재 분포와 표준 정규분포 간의 거리
    # KL(q(z|x) || p(z)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    total_loss = recon_loss + beta * kl_loss
    return total_loss, recon_loss, kl_loss


def train_vae(model, train_loader, optimizer, device, epoch):
    """VAE 학습 함수"""
    model.train()
    total_loss = 0
    total_recon = 0
    total_kl = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)  # Flatten

        optimizer.zero_grad()
        reconstructed, mu, logvar = model(data)

        loss, recon_loss, kl_loss = vae_loss(reconstructed, data, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()

    avg_loss = total_loss / len(train_loader.dataset)
    print(f'Epoch {epoch}: Loss={avg_loss:.4f}, '
          f'Recon={total_recon/len(train_loader.dataset):.4f}, '
          f'KL={total_kl/len(train_loader.dataset):.4f}')

    return avg_loss


def visualize_latent_space(model, test_loader, device):
    """잠재 공간 시각화 (2D 잠재 공간인 경우)"""
    model.eval()
    z_list = []
    label_list = []

    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(-1, 784).to(device)
            mu, _ = model.encode(data)
            z_list.append(mu.cpu())
            label_list.append(labels)

    z = torch.cat(z_list, dim=0).numpy()
    labels = torch.cat(label_list, dim=0).numpy()

    if z.shape[1] == 2:
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(z[:, 0], z[:, 1], c=labels, cmap='tab10', alpha=0.6)
        plt.colorbar(scatter)
        plt.xlabel('Latent Dimension 1')
        plt.ylabel('Latent Dimension 2')
        plt.title('VAE Latent Space Visualization')
        plt.savefig('vae_latent_space.png')
        print("Latent space visualization saved to 'vae_latent_space.png'")


# 사용 예제
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # MNIST 데이터셋 로드
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # VAE 모델 초기화
    vae = VAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

    # 학습
    print("Training VAE on MNIST...")
    for epoch in range(1, 11):
        train_vae(vae, train_loader, optimizer, device, epoch)

    # 새로운 샘플 생성
    print("\nGenerating new samples...")
    samples = vae.generate(16, device)
    print(f"Generated samples shape: {samples.shape}")

🗣️ 실무에서 이렇게 말해요

생성 모델 설계 회의에서
"이미지 생성에 VAE 쓰면 잠재 공간 해석이 쉬운데, 출력이 좀 블러리해요. 고품질 원하면 Diffusion이랑 같이 써야 할 것 같아요."
이상 탐지 시스템 구축에서
"정상 데이터로 VAE 학습시키고, 재구성 오류가 크면 이상으로 판단하는 방식이에요. 라벨 없이도 되니까 실용적이죠."
모델 아키텍처 논의에서
"Beta-VAE로 beta 값 높이면 disentanglement가 좋아져서 잠재 변수 각각이 의미 있는 특성을 담게 돼요."
Stable Diffusion 설명에서
"Stable Diffusion은 픽셀 공간이 아니라 VAE가 만든 잠재 공간에서 디퓨전을 돌려요. 그래서 빠르고 메모리 효율적인 거죠."

⚠️ 흔한 실수 & 주의사항

KL Collapse (KL 붕괴)
학습 초기에 KL 발산이 0으로 수렴하여 잠재 공간이 무의미해지는 현상입니다. KL annealing(초기에 beta를 작게 시작해 점진적으로 증가)이나 Free Bits 기법으로 해결할 수 있습니다.
재구성 손실 선택 오류
이진 데이터(0-1)에는 BCE, 연속 데이터에는 MSE를 사용해야 합니다. 이미지에 Sigmoid 출력과 BCE를 쓸 때 입력도 0-1로 정규화해야 하며, 그렇지 않으면 학습이 불안정합니다.
잠재 차원 설정
잠재 차원이 너무 크면 정규화 효과가 약해지고, 너무 작으면 정보 손실이 커집니다. 데이터 복잡도에 따라 적절한 차원을 실험적으로 결정해야 합니다. MNIST는 20-50, 복잡한 이미지는 더 큰 차원이 필요합니다.
평가 지표 단일 사용
재구성 손실만 보면 안 됩니다. 생성 품질(FID), 잠재 공간 구조(interpolation 품질), 다양성 등을 함께 평가해야 VAE가 제대로 학습되었는지 판단할 수 있습니다.

🔗 관련 용어

📚 더 배우기