안녕하세요. KDST 박민철입니다.
오늘은 최근 CVPR 2024에 발표된 "Diffusion Models Without Attention" 제목의 연구 결과를 포스팅하고자 합니다. Diffusion Models (DMs)는 우리가 잘 알고 있는 이미지 생성 AI 모델 중 하나입니다. 본 연구는 그동안 널리 활용되어온 U-Net기반의 diffusion backbone에서 나아가 구조적 스케일 확장을 위해 Transformer를 바탕으로 backbone 설계를 제안한 "Scalable Diffusion Models with Transformers, ICCV 2023 (oral)"에서 동기를 받아 출발합니다. 앞서 Transformer를 diffusione의 backbone으로 제안했던 Diffusion Transformer (DiT)는 DMs가 생성하는 이미지의 성능 향상을 위해 U-Net에 적극적으로 도입되었던 residual 구조, self-attention, 그리고 conditional 정보를 regression하기 위한 adaptive normalization layer에 관심을 갖고, 이들을 스케일을 확장하기 위한 도구로 활용할 수 있을지 관심을 갖습니다.
이를 바탕으로, 아래의 그림과 같이 DiT는 ViT와 유사한 동작을 지닌 diffusion backbone으로 사용하기 위한 Transformer 구조를 제안합니다.
최근 DM은 DDPM (Ho et al., NeurIPS 2020)의 학습 기술에서 나아가 covariance에 대한 추가적인 최적화 (Nichol and Dhariwal, ICML 2021) 및 classifier-free guidance (CFG) (Ho and Salimans, NeurIPS 2021 Workshop, 2021) 기술을 바탕으로 강력한 generative power를 얻게 되었습니다. 한편, pixel space에서 latent space로 diffusion process를 확장함 (Rombach et al., CVPR 2022)에 따라, 효율적인 병렬 연산에 따른 연산 복잡도를 최적화할 수 있었습니다. (이 배경에 대한 자세한 내용은 prerequisite으로 본 포스팅에서 다루지 않겠습니다.) 이러한 발전 과정 속에서 DiT는 Transformer가 가질 수 있는 conditioning의 이점과 Latent 정보에 대한 attention을 바탕으로 latent space 상에서 diffusion이 취해야 할 estimated noise 및 covariance를 예측하는 구조를 고민합니다.
1. ViT의 convention에 따라 DiT는 latent input에 대한 patchficiation을 처리합니다. 이는, Tokenization을 위한 방법으로써, latent input의 공간 정보를 패치 단위로 나누어 아래의 그림과 같이 다수의 Token을 만들고, 이에 대한 ViT frequency-based positional embedding 처리 합니다.
2. DiT Block은 patchification된 데이터에 대한 LayerNorm을 처리하되 이 때 학습된 scale, shift 파라미터는 처리하지 않고 condition embedding으로부터 MLP regression된 scale, shift 파라미터를 이용하여 transform합니다. Self-attention 처리된 해당 데이터는 다시 condition으로부터 MLP regression된 alpha scale을 바탕으로 transform되어 feedforward가 처리되는 데, 이 때, alpha가 학습 전 0으로 초기화됨에 따라 AdaLN-Zero라 합니다. 학습 초기에 alpha가 0이 되며, attention 연산의 결과가 모두 0이 되고 동시에 존재하는 skip-connection이 identity 역할을 하게 됨으로써, 학습의 안정성을 더하게 됩니다. (Goyal et al., 2017)
DiT는 따라서, layer 개수, token들에 대한 embedding 크기, multi-head attention의 head 개수, 그리고 패치 크기에 따른 scalability를 가질 수 있으며, 해당 성능이 CFG를 사용할 때 U-Net 기반의 LDM 보다 뛰어난 것을 확인할 수 있었습니다.
그런데 본 연구는 attention에 의존할 수록 high-resolution을 생성할 때에 대한 이점보다 연산 측면에서의 패널티가 더 커질 수 있으므로, attention보다 비용 효율적인 처리 모듈이 필요하다는 것을 동기로 시작합니다. 특히, 기존의 DiT는 high-resolution에서 fully attention을 사용할 때 필요 연산량의 scale이 매우 challenging하고 patchfication의 채택이 scale을 가능하게 했지만, global 정보가 불가피하게 locally compression되며, long-range 정보를 interpolation하기 어렵다는 것을 지적합니다. 따라서, 본 연구는 patchification이 아닌 input 그 자체를 flatten 하여 sequence로 인식하고 long-term 정보를 효과적으로 memory할 수 있는 State Space Models (SSMs)를 활용하여 seq2seq을 통한 정보 처리를 의도합니다.
이에 따라, 본 논문은 아래의 그림처럼 DiffuSSM이라는 patchfication을 제거하고 DiT의 multi-head attention을 SSM으로 대체하는 방법을 제안합니다. AdaLN-Zero를 그대로 도입하였으며, SSM을 gated bidirectional SSM으로 설계하였고, input & output sequence에 대한 hourglass layer를 도입한 것이 특징입니다.
Hourglass layer는 딥러닝으로 처리되기 때문에 학습에 의존적으로 결정되지만 SSM은 함수적이기 때문에 배경 지식 및 이해가 필요합니다.
SSM의 개념은 길이가 L인 control input (sequence) u가 존재할 때, 어떤 hidden state x가 주어진 u에 따라 관계 A와 B에 따라 변화될 때, 변화된 x가 C에 의해서 길이가 L인 output (observation) sequence y를 정의하는 것에서 시작합니다. 우리는 연속적인 상황에서 아래와 같이 정의된 SSM을 variation of parameters를 기반으로 해석적으로 얻을 수 있는데,
위의 해에서 적분을 불연속 정의역에서 approximation을 하면 흥미롭게도 아래의 그림과 같이 RNN과 유사한 recurrent 관계를 얻을 수 있습니다.
(이산화 (discretization) methods는 총 4가지가 있습니다: Forward Euler discretization, Backward Euler discretization, Bilinear discretization, Generalized bilinear discretization)
위의 개념을 바탕으로 딥러닝에서 활용되고 있는 SSM은 NLP에서 문제가 제기된 RNN에 관한 sequence처리 시 long-term sequence 정보를 보존하기 어려웠던 문제를 해결하자는 데에서 출발합니다. 대표적으로 HiPPO (high-order polynomial projection) (Gu et al., NeurIPS 2020)는 sequence 정보에 대한 memory 문제를 해결하기 위해 제안된 SSM 방법론인데요. 이들이 풀고자 했던 방법은 polynomial function sets가 주어질 때, 각 polynomial function에 적절한 가중치 c를 도입하여 기존의 input sequence f의 연속 시간 t까지의 신호를 기억하고자 할 때 c와 주어진 polynomial projection을 기반으로 이를 잘 reconstruction하는 c를 찾는 문제로 생각하여 접근하였습니다.
저자들은 이 때의 c를 memory units이라 정의하고 해당 문제를 직접 풀고자 하였지만 매번 c를 최적화하는 것을 연산 비용때문에 어려웠습니다. 따라서, HiPPO는 memory units가 아래와 같은 주요한 ODE로 정리될 수 있음을 생각하였고, 이 것은 SSM의 개념을 도입할 수 있는 시작이자 기회가 되었습니다.
따라서, HiPPO는 input sequence를 받아 output sequence를 도출하는 seq2seq 형태의 레이러를 deterministic하게 구성할 수 있게 되었으며, 이산화를 통한 RNN은 병렬 연산에 적합하지 않은 것을 고려하여 기존 연속적인 상태에서 아래와 같이 convolution으로 전개하여, 이를 이산화하는 방법을 고민하였습니다.
HiPPO-LegS 컨셉을 바탕으로, 이산적 SSMs에 대한 convolution 방법으로 확장할 수 있었으며, 아래와 같이 딥러닝 모델에서 seq2seq 처리를 위한 비용 효율적인 레이러를 구성하도록 제안하였습니다.
여기에서 정의한 A에 대한 대각화를 기반으로 SSM을 구조적으로 정의한 모델을 S4라 하며, Bidirectional SSM은 이 것의 specific variant인 S4D로 구현되었습니다.
Gated bidirectional SSM은 "Wang et al., 2023"을 참고하여 S4D를 이용하여 제안되었고, 해당 모델은 아래의 그림을 바탕으로 DiffuSSMs에 AdaLN-Zero와 Hourglass를 추가한 것으로 비교 확인할 수 있습니다.
SSM을 도입하면서 본 연구는 patchification을 사용하지 않는 DiT에 비해 full sequence length를 유지하며 Flops에 대한 scalability를 확보할 수 있음을 보였습니다.
Experiment results
본 연구는 ImageNet (256x256, 512x512) 및 LSUN (256x256)을 활용한 생성 이미지 검증을 실시하였고, Stable Diffusion의 VAE를 활용하였으며, 8 NVIDIA A100에서 256 batch size를 활용했습니다. DiT와 동일하게 weight EMA update를 처리하였고, 비교 모델에 사용한 DiffuSSM-XL은 673M 파라미터를 보유하고 있습니다. (29 gated bidirectional SSM blocks weight matrix의 D=1152) Mixed precision을 사용하였고, ADM (Dhariwal and Nichol, NeurIPS 2021)의 diffusion 학습 전략을 이용했습니다.
Class-conditional image generation 및 runtime comparison 결과는 다음과 같습니다.
LSUN에서 검증한 unconditional image generation은 LDM과 비슷한 성능입니다.
Ablation study에서 실시한 patchification 도입 (P=2) 대비 이미지의 질적 평가도 포함합니다.
본 논문은 attention 기반의 diffusion backbone이 갖는 high-resolution에서의 scale 문제를 NLP에서 활용하는 SSM 도입으로 해결하려 한 흥미로운 연구였습니다.
추가적으로 궁금한 부분은 논문을 참조 부탁드리며, 질문을 댓글로 남겨주시면 함께 고민할 좋은 기회가 될 것으로 생각합니다.
감사합니다.