🤖 AI/ML

DiT

Diffusion Transformer

Transformer 기반 디퓨전 모델. Sora의 기반 아키텍처.

📖 상세 설명

DiT(Diffusion Transformer)는 이미지 생성을 위한 디퓨전 모델에서 기존 U-Net 백본을 Transformer 아키텍처로 대체한 혁신적인 모델입니다. 2022년 Meta AI의 William Peebles와 Saining Xie가 발표했으며, Transformer의 확장성(scalability)을 디퓨전 모델에 적용하여 모델 크기 증가에 따른 성능 향상을 입증했습니다.

DiT는 OpenAI Sora의 핵심 아키텍처로 채택되면서 주목받았습니다. Sora는 DiT를 기반으로 시공간(spatiotemporal) 패치를 처리하여 최대 1분 길이의 고품질 비디오를 생성합니다. 이는 기존 U-Net 기반 비디오 모델 대비 월등한 시간적 일관성과 화질을 보여줍니다.

핵심 원리는 이미지를 패치(patch) 단위로 분할하고, 각 패치를 토큰으로 취급하여 Transformer 블록에 통과시키는 것입니다. 노이즈 스텝과 클래스 레이블은 adaLN(Adaptive Layer Normalization)을 통해 조건부로 주입되며, 이를 통해 조건부 이미지 생성이 가능합니다. ImageNet 256x256 벤치마크에서 DiT-XL/2 모델은 FID 2.27을 달성하여 당시 SOTA를 기록했습니다.

실무에서 DiT는 대규모 이미지/비디오 생성 서비스의 기반 모델로 활용됩니다. Transformer 특성상 병렬화가 용이하여 대규모 GPU 클러스터에서 효율적으로 학습 가능하며, LoRA 등 파인튜닝 기법 적용이 직관적입니다. Stable Diffusion 3.0도 DiT 기반 아키텍처(MM-DiT)를 채택하여 텍스트 이해력과 이미지 품질을 크게 개선했습니다.

💻 코드 예제

# DiT 모델 구조 간소화 예제 (PyTorch)
import torch
import torch.nn as nn

class DiTBlock(nn.Module):
    """DiT의 기본 Transformer 블록 (adaLN 포함)"""
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        # adaLN 모듈레이션 파라미터 (6개: scale1, shift1, gate1, scale2, shift2, gate2)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size)
        )

    def forward(self, x, c):
        # c: 조건 임베딩 (timestep + class label)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(c).chunk(6, dim=-1)

        # Modulated self-attention
        h = self.norm1(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.attn(h, h, h)[0]
        x = x + gate_msa.unsqueeze(1) * h

        # Modulated MLP
        h = self.norm2(x) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        x = x + gate_mlp.unsqueeze(1) * self.mlp(h)
        return x

# 사용 예제
block = DiTBlock(hidden_size=768, num_heads=12)
x = torch.randn(2, 256, 768)  # (batch, num_patches, hidden)
c = torch.randn(2, 768)        # condition embedding
output = block(x, c)
print(f"Output shape: {output.shape}")  # (2, 256, 768)

🗣️ 실무에서 이렇게 말하세요

AI 연구팀 아키텍처 논의에서

"이미지 생성 모델 업그레이드할 때 DiT 아키텍처로 전환하는 걸 검토해봤습니다. U-Net 대비 Transformer가 스케일링 법칙을 더 잘 따라서, 모델 크기 키우면 성능이 예측 가능하게 올라갑니다. 다만 패치 크기와 hidden dimension 조합을 실험해봐야 할 것 같습니다."

기술 면접에서

"DiT의 핵심 혁신은 adaLN(Adaptive Layer Normalization)입니다. 기존 cross-attention으로 조건을 주입하던 방식 대신, LayerNorm의 scale/shift 파라미터를 조건에 따라 동적으로 조절합니다. 이게 연산량도 줄이고 학습 안정성도 높여줍니다."

제품 로드맵 회의에서

"Sora가 DiT 기반인 것처럼, 우리 비디오 생성 파이프라인도 DiT로 마이그레이션하면 프레임 간 일관성이 크게 개선될 겁니다. Meta가 공개한 DiT 체크포인트를 베이스로 도메인 특화 파인튜닝 하는 방향으로 진행하면 어떨까요?"

⚠️ 흔한 실수 & 주의사항

⚠️
높은 VRAM 요구량

DiT-XL/2 모델 학습에는 A100 80GB 급 GPU가 필요합니다. 추론 시에도 고해상도 이미지 생성을 위해선 최소 24GB VRAM을 확보해야 합니다. 패치 크기를 키우면 메모리는 줄지만 디테일이 손실됩니다.

⚠️
U-Net 대비 느린 추론 속도

Transformer 특성상 self-attention의 O(n^2) 복잡도로 인해 고해상도에서 추론 속도가 저하됩니다. Flash Attention, xFormers 등 최적화 라이브러리 적용이 필수입니다.

⚠️
패치 크기 선택의 중요성

패치 크기(2, 4, 8 등)에 따라 시퀀스 길이와 디테일 표현력이 트레이드오프 됩니다. 패치 크기 2는 최고 품질이지만 연산량이 16배 증가하므로 용도에 맞게 선택해야 합니다.

🔗 관련 용어

📚 더 배우기