[논문]
- Adaptive Latent Diffusion Model for 3D Medical Image to Image Translation: Multi-modal Magnetic Resonance Imaging Study
- WACA 2024
- Citations: 40
- https://arxiv.org/abs/2311.00265
Summary
[문제점]
- 3D MRI image analysis - 다양한 modality (T1, T2, FLAIR 등)이 필요
- but, 비용, 시간, 환자 안전성 등으로 모든 modality를 촬영하기 어려움
[ALDM]
- Patch cropping 없이 3D image-to-image translation 수행
- Patch cropping
- 큰 3D medical imaging을 작게 잘라서 batch 구성 후 training 진행
- sliding window를 통해 전체 structure를 inference
- ⇒ global structure의 정보 손실, 경계 artifact 발생, inference 속도 느림
- Patch cropping
- MS-SPADE block 도입 → 하나의 model에서 다양한 modality 간 변환 (one-to-many) 가능
- MS-SPADE: Multiple switchable spatially adaptive nomalization
- 3D LDM + conditiong
- condition: target modality
- MS-SPACE - source latent를 target style latent로 변형
[결과]
- 다양한 modality 간 translation 성능 뛰어남
- Datasets: BraTS2021, IXI
- 기존 one-to-one model들보다 좋은 성능
Related work
[Generative adversarial networks]
- Medical image-to-image translation 연구들은 대부분 GAN 기반
- Pix2Pix: paired data 기반 pixel-to-pixel 변환
- CycleGAN: unpaired data에서도 학습 가능하게 하는 구조
- NICEGAN: generator-discriminator 구조를 encoder 공유로 간결화
- RegGAN: anatomical 구조 보존을 위한 registration 포함
- Ea-GAN: 3D 기반, edge-aware 모듈로 구조 정보 보존
- ResViT: Vision Transformer 기반으로 context-aware translation 가능
⇒ 대부분 2D 기반 → slice 간 정보 불연속
⇒ one-to-one 방식 → modality 수가 많아질수록 model 수 증가
[Conditional normalization layers]
- style transfer 기존 방법들
- AdaIN: 스타일만 반영 (평균+표준편차로 정규화 후 style 값으로 재스케일)
- Apative instance normalization
- SPADE: semantic map 기반 공간적 style 주입 가능
- Spatially adaptive normalization
- ALDM - SPADE를 확장한 MS-SPADE (Multi-Switchable SPADE) 사용
- target modality에 따라 동적으로 normalization
- AdaIN: 스타일만 반영 (평균+표준편차로 정규화 후 style 값으로 재스케일)
[Diffusion probabilistic model]
- Image generation 분야의 SOTA model, GAN 대안
- DDPM: Gaussian noise를 target distribution으로 바꾸는 과정 학습
- Palette: source + target image를 concat하여 조건부 diffusion 수행
- LDM: latent 공간에서 diffusion 수행해 연산 효율 up
- 2단계 구조: (1) autoencoder로 latent 압축 → (2) latent에서 diffusion
- 3D patch cropping 없이 전체 영상 사용 가능
Method
- Image-to-image translation 2단계
- AutoEncoder - 3D medical image를 latent로 압축
- VQ-VAE/VQGAN 기반 autoencoder 사용
- latent - perceptual representation 보존
- Latent Diffusion Model (LDM) - source latent를 target latent로 잘 mapping되게 학습
- source latent(condition latent)를 MS-SPADE를 통해 style만 변경 ⇒ target-like latent
- target image latent에 foward diffusion process 진행 ⇒ noisy target latent
- 2개를 concat하여 diffusion model (UNet)에 input
- ⇒ Palette style 구조로 noise 예측하도록 학습
- AutoEncoder - 3D medical image를 latent로 압축
1. Image compression to compute latents
- VQ-VAE / VQGAN 기반 구조 사용
- Encoder $E(x)=z_{src}$: image를 latent로 압축
- Decoder $D(z)=\hat x$: latent에서 원래 image로 복원
- Latent - vector quantization을 거쳐 discrete latent space를 형성
[MS-SPADE Block]
- source latent $h$를 target modality style로 transfer하는 block
- $h\in\mathbb{R}^{N\times C\times H\times W\times D}$: SPADE block의 input
- N:batch & C: channel & H: height & W: width & D: depth
- $\mu_c,\sigma_c$: channel-wise mean & std
- $\gamma^{tar},\beta^{tar}$: target modality condition에 따라 학습하는 modulation parameters
[Loss]
- Reconstruction Loss $\mathcal{L}_{recon}$
- Goal: target style로 변환된 결과 $\hat I_{tar}$가 실제 target image $I_{tar}$과 유사하도록
- $\mathcal{L}_{recon}=||I_{tar}-\hat I_{tar}||$
- Quantization Loss $\mathcal{L}_{quant}$
- Goal: VQ encoder에서 나온 continuous latent $z$가 discrete codebook과 일치하도록
- $\mathcal{L}_{quant}=||sg[z]-e||^2_2+\beta||z-sg[e]||^2_2$
- e: codebook entry, sg: stop-gradient (역전파 차단)
- 첫 번째 term - codebook만 update ⇒ codebook vector가 encoder 출력에 비슷하도록
- 두 번째 term - encoder만 update ⇒ encoder 출력이 codebook vector에 비슷하도록
- $\beta$: weight
- Perceptual Loss $\mathcal{L}_{percept}$
- Goal: 시각적으로 더 유사한 image 생성하기 위해 pretrained network의 feature space에서 비교
- $\mathcal{L}_{percept}=\Sigma_{l}||\pi_l(\hat I_{tar})-\phi_l(I_{tar})||$
- $\phi_l$: pretrained network(VGG)의 $l^{th}$ layer features
- ⇒ VQGAN의 recon. loss
- ⇒ VQ-VAE와 다르게 recon. loss를 $L_2$가 아닌 VGG의 feature map 차이를 확인하는 Perceptual loss 사용
- Adversarial Loss $\mathcal{L}_{adv}$
- Goal: Generator가 진짜 같은 $\hat I_{tar}$을 생성하도록 유도
- Cycle Consistency Loss $\mathcal{L}_{cyc}$
- Goal: $I_{src}$→$\hat I_{tar}$→$\hat I_{src}$ 순환 경로에서 원래 input 복원 가능성 확보$L_{cyc}=||I_{src}-\hat I_{src}||$
⇒ 기존 4개의 loss와 다르게 저자가 추가한 loss
[Reference - VQGAN]
- VQ-VAE에서 patch GAN의 Discriminator 추가
- 1 stage
- VQloss
-
- ⇒ VQ-VAE의 VQ loss & Commitment loss
- $\beta$가 1보다 작으면 encoder가 codebook보다 빠르게 update
-
- Discriminator loss
- Vallina discriminator loss 와 동일
- 최종 VQGAN loss
- $\lambda$ - addaptive weight ⇒ 두 loss balance 조절
- $G_L$: decoder의 마지막 layer & $\delta$: $10^{-6}$
- VQloss
⇒ CNN Encoder, Decoder를 통해 codebook 학습
⇒ Goal: reconstruction
- 2 stage
- Transformer 사용하여 이전 patch들을 통해 auto-regressive 하게 다음 patch 예측
- Goal: 생성 model
2. Diffusion Model
- LDM 구조를 따라 latent space에서 diffusion을 수행함
- source - $y$ condition, target - $x$
- LDM input - 2가지
- $z^{tar}_{src}$: MS-SPADE로 변환된 target-style source latent
- $z_t^{tar}$: forward diffusion을 거친 noisy target latent at time step $t$
- Model process
- $z_{src}=E(I_{src})$
- $z^{tar}_{src}=SPADE(z_{src})$ → target-style
- $z_{tar}=E(I_{tar})$ → forward noise
- Diffusion model input: $\epsilon_\theta(z^{tar}_{src}, z_t^{tar},t)$
- Loss
3. Modality Conditioning
- image-to-image translation - 어떤 modality로 변경할지 model에 알려줘야함
- one-hot modality vector $y$를 condition으로 사용
- UNet의 attention - modality vector $y$추가
- $\phi_i(z_t)\in \mathbb{R}^{N\times d_e}$
- UNet 내부 $i^{th}$ layer의 중간 representation
- $y\in \mathbb{R}^{d_y}$
- $\phi_i(z_t)\in \mathbb{R}^{N\times d_e}$
- Cross-attention
- 기존 UNet은 self-attention이였음
- Loss - $y$ 추가
Implemental Details
[Model]
- Image compression: VQGAN
- Diffusion model: LDM
- Library: Pytorch + MONAI
[Hyperparameters]
- Optimizer: AdamW
- Learning rate: $2\times 10^{-6}$
- Timestpes: 1000
- Noise schedule: scaled linear 0.0015 → 0.0195
- Batch size: 1 per GPU
- GPU: A100 80GB x 4
[Appendix - Hyperparameters]
- VQ-GAN
- Input size: $192\times 192\times 144$ → 3D MRI image size
- dim $z$: 8192 → latent space size (Quantized vector 개수 총합)
- Channels: $[252,512,512]$ → Encoder/Decoder의 channels 수
- Embedding size: 3
- Batch size: 1
- Epochs: 500
- Model size: 749M
- Param size: 237M
- LDM
- Input size: $48\times 48\times 36\times 3$ → latent shape
- Condi: $[128, 256, 512]$ → condition input dim
- Batch size: 1
- Model size: 722M
- Param size: 658M
- Epochs: 800
Evaluation Metrics
- PSNR
- NMSE
- SSIM
- 모든 metric은 3D volume 단위로 계산
- train과 분리된 test set에서 진행
- 2D sliced model → 3D로 쌓은 뒤 평가 진행