본문 바로가기
Paper Review/Medical Imaging

[논문 리뷰] ALDM: Adaptive Latent Diffusion Model for 3D Medical Image to Image Translation: Multi-modal Magnetic Resonance Imaging Study

by kongshin 2025. 4. 14.

[논문]

  • Adaptive Latent Diffusion Model for 3D Medical Image to Image Translation: Multi-modal Magnetic Resonance Imaging Study
  • WACA 2024
  • Citations: 40
  • https://arxiv.org/abs/2311.00265

Summary

[문제점]

  • 3D MRI image analysis - 다양한 modality (T1, T2, FLAIR 등)이 필요
    • but, 비용, 시간, 환자 안전성 등으로 모든 modality를 촬영하기 어려움

[ALDM]

  • Patch cropping 없이 3D image-to-image translation 수행
    • Patch cropping
      • 큰 3D medical imaging을 작게 잘라서 batch 구성 후 training 진행
      • sliding window를 통해 전체 structure를 inference
      • ⇒ global structure의 정보 손실, 경계 artifact 발생, inference 속도 느림
  • MS-SPADE block 도입 → 하나의 model에서 다양한 modality 간 변환 (one-to-many) 가능
    • MS-SPADE: Multiple switchable spatially adaptive nomalization
  • 3D LDM + conditiong
    • condition: target modality
  • MS-SPACE - source latent를 target style latent로 변형

[결과]

  • 다양한 modality 간 translation 성능 뛰어남
    • Datasets: BraTS2021, IXI
  • 기존 one-to-one model들보다 좋은 성능

Related work

[Generative adversarial networks]

  • Medical image-to-image translation 연구들은 대부분 GAN 기반
    • Pix2Pix: paired data 기반 pixel-to-pixel 변환
    • CycleGAN: unpaired data에서도 학습 가능하게 하는 구조
    • NICEGAN: generator-discriminator 구조를 encoder 공유로 간결화
    • RegGAN: anatomical 구조 보존을 위한 registration 포함
    • Ea-GAN: 3D 기반, edge-aware 모듈로 구조 정보 보존
    • ResViT: Vision Transformer 기반으로 context-aware translation 가능

⇒ 대부분 2D 기반 → slice 간 정보 불연속

⇒ one-to-one 방식 → modality 수가 많아질수록 model 수 증가

 

[Conditional normalization layers]

  • style transfer 기존 방법들
    • AdaIN: 스타일만 반영 (평균+표준편차로 정규화 후 style 값으로 재스케일)
      • Apative instance normalization
    • SPADE: semantic map 기반 공간적 style 주입 가능
      • Spatially adaptive normalization
      • ALDM - SPADE를 확장한 MS-SPADE (Multi-Switchable SPADE) 사용
      • target modality에 따라 동적으로 normalization

[Diffusion probabilistic model]

  • Image generation 분야의 SOTA model, GAN 대안
    • DDPM: Gaussian noise를 target distribution으로 바꾸는 과정 학습
    • Palette: source + target image를 concat하여 조건부 diffusion 수행
    • LDM: latent 공간에서 diffusion 수행해 연산 효율 up
      • 2단계 구조: (1) autoencoder로 latent 압축 → (2) latent에서 diffusion
      • 3D patch cropping 없이 전체 영상 사용 가능

 

Method

  • Image-to-image translation 2단계
    1. AutoEncoder - 3D medical image를 latent로 압축
      •   VQ-VAE/VQGAN 기반 autoencoder 사용
      •   latent - perceptual representation 보존
    2. Latent Diffusion Model (LDM) - source latent를 target latent로 잘 mapping되게 학습
      •   source latent(condition latent)를 MS-SPADE를 통해 style만 변경 ⇒ target-like latent
      •   target image latent에 foward diffusion process 진행 ⇒ noisy target latent
      •   2개를 concat하여 diffusion model (UNet)에 input
        •   ⇒ Palette style 구조로 noise 예측하도록 학습

 

1. Image compression to compute latents

  • VQ-VAE / VQGAN 기반 구조 사용
    • Encoder $E(x)=z_{src}$: image를 latent로 압축
    • Decoder $D(z)=\hat x$: latent에서 원래 image로 복원
    • Latent - vector quantization을 거쳐 discrete latent space를 형성

 

[MS-SPADE Block]

  • source latent $h$를 target modality style로 transfer하는 block
  • $h\in\mathbb{R}^{N\times C\times H\times W\times D}$: SPADE block의 input
    • N:batch & C: channel & H: height & W: width & D: depth

 

  • $\mu_c,\sigma_c$: channel-wise mean & std
  • $\gamma^{tar},\beta^{tar}$: target modality condition에 따라 학습하는 modulation parameters

 

[Loss]

  • Reconstruction Loss $\mathcal{L}_{recon}$
    • Goal: target style로 변환된 결과 $\hat I_{tar}$가 실제 target image $I_{tar}$과 유사하도록
    • $\mathcal{L}_{recon}=||I_{tar}-\hat I_{tar}||$

 

  • Quantization Loss $\mathcal{L}_{quant}$
    • Goal: VQ encoder에서 나온 continuous latent $z$가 discrete codebook과 일치하도록
    • $\mathcal{L}_{quant}=||sg[z]-e||^2_2+\beta||z-sg[e]||^2_2$
      • e: codebook entry, sg: stop-gradient (역전파 차단)
      • 첫 번째 term - codebook만 update ⇒ codebook vector가 encoder 출력에 비슷하도록
      • 두 번째 term - encoder만 update ⇒ encoder 출력이 codebook vector에 비슷하도록
      • $\beta$: weight

 

  • Perceptual Loss $\mathcal{L}_{percept}$
    • Goal: 시각적으로 더 유사한 image 생성하기 위해 pretrained network의 feature space에서 비교
    • $\mathcal{L}_{percept}=\Sigma_{l}||\pi_l(\hat I_{tar})-\phi_l(I_{tar})||$
      • $\phi_l$: pretrained network(VGG)의 $l^{th}$ layer features
    • VQGAN의 recon. loss
    • VQ-VAE와 다르게 recon. loss를 $L_2$가 아닌 VGG의 feature map 차이를 확인하는 Perceptual loss 사용

 

  • Adversarial Loss $\mathcal{L}_{adv}$
    • Goal: Generator가 진짜 같은 $\hat I_{tar}$을 생성하도록 유도

 

  • Cycle Consistency Loss $\mathcal{L}_{cyc}$
    • Goal: $I_{src}$→$\hat I_{tar}$→$\hat I_{src}$ 순환 경로에서 원래 input 복원 가능성 확보$L_{cyc}=||I_{src}-\hat I_{src}||$

⇒ 기존 4개의 loss와 다르게 저자가 추가한 loss

 


[Reference - VQGAN]

 

  • VQ-VAE에서 patch GAN의 Discriminator 추가
  • 1 stage
    • VQloss
      •  
      • VQ-VAEVQ loss & Commitment loss
      • $\beta$가 1보다 작으면 encoder가 codebook보다 빠르게 update
    • Discriminator loss
      • Vallina discriminator loss 와 동일
    • 최종 VQGAN loss
      • $\lambda$ - addaptive weight ⇒ 두 loss balance 조절
        • $G_L$: decoder의 마지막 layer & $\delta$: $10^{-6}$

⇒ CNN Encoder, Decoder를 통해 codebook 학습

⇒ Goal: reconstruction

 

  • 2 stage
    • Transformer 사용하여 이전 patch들을 통해 auto-regressive 하게 다음 patch 예측
    • Goal: 생성 model

2. Diffusion Model

  • LDM 구조를 따라 latent space에서 diffusion을 수행함
  • source - $y$ condition, target - $x$
  • LDM input - 2가지
    1. $z^{tar}_{src}$: MS-SPADE로 변환된 target-style source latent
    2. $z_t^{tar}$: forward diffusion을 거친 noisy target latent at time step $t$
  • Model process
    1. $z_{src}=E(I_{src})$
    2. $z^{tar}_{src}=SPADE(z_{src})$ → target-style
    3. $z_{tar}=E(I_{tar})$ → forward noise
    4. Diffusion model input: $\epsilon_\theta(z^{tar}_{src}, z_t^{tar},t)$
  • Loss

 

3. Modality Conditioning

  • image-to-image translation - 어떤 modality로 변경할지 model에 알려줘야함
  • one-hot modality vector $y$를 condition으로 사용
  • UNet의 attention - modality vector $y$추가
    • $\phi_i(z_t)\in \mathbb{R}^{N\times d_e}$
      • UNet 내부 $i^{th}$ layer의 중간 representation
    • $y\in \mathbb{R}^{d_y}$
  • Cross-attention
    • 기존 UNet은 self-attention이였음
  • Loss - $y$ 추가

 

Implemental Details

[Model]

  • Image compression: VQGAN
  • Diffusion model: LDM
  • Library: Pytorch + MONAI

[Hyperparameters]

  • Optimizer: AdamW
  • Learning rate: $2\times 10^{-6}$
  • Timestpes: 1000
  • Noise schedule: scaled linear 0.0015 → 0.0195
  • Batch size: 1 per GPU
  • GPU: A100 80GB x 4

[Appendix - Hyperparameters]

  • VQ-GAN
    • Input size: $192\times 192\times 144$ → 3D MRI image size
    • dim $z$: 8192 → latent space size (Quantized vector 개수 총합)
    • Channels: $[252,512,512]$ → Encoder/Decoder의 channels 수
    • Embedding size: 3
    • Batch size: 1
    • Epochs: 500
    • Model size: 749M
    • Param size: 237M
  • LDM
    • Input size: $48\times 48\times 36\times 3$ → latent shape
    • Condi: $[128, 256, 512]$ → condition input dim
    • Batch size: 1
    • Model size: 722M
    • Param size: 658M
    • Epochs: 800

 

Evaluation Metrics

  • PSNR
  • NMSE
  • SSIM
  • 모든 metric은 3D volume 단위로 계산
  • train과 분리된 test set에서 진행
  • 2D sliced model → 3D로 쌓은 뒤 평가 진행