안녕하세요. KDST 박민철입니다.

 

오늘은 NeurIPS 2022에 발표된 "Revisiting Realistic Test-Time Training: Sequential Inference and Adaptation by Anchored Clustering" 제목의 도메인 적응 연구 결과를 포스팅하고자 합니다. 본 연구는 도메인 적응 문제를 테스트 시점에서 접근하기 위한 학습 방법으로 Test-Time Adaptation (TTA)과 유사해 보이지만, 소스 도메인에 대한 통계적 정보를 활용하고, 테스트 데이터가 연속적으로 주어지는 상황에서 비지도 학습을 통해 적응하는 기술이라고 볼 수 있습니다.

 

본 논문은 도메인 적응 상황에서 2가지 상황에 대한 한계점을 언급합니다.

1. 소스 및 타겟 데이터 접근 요구: 기존의 도메인 적응 방법들은 훈련 시 소스와 타겟 도메인 데이터 모두에 접근할 수 있어야 하거나, 여러 도메인을 동시에 학습해야하는 기술적 단점이 있었습니다.

2. 소스 데이터 접근 제한: 소스 도메인 데이터가 프라이버시 문제나 저장 공간의 한계로 접근이 어려운 경우가 많아, 이를 해결하기 위해 소스-프리 도메인 적응이 등장했습니다. 소스-프리 도메인 적응은 소스 데이터에 접근하지 않고 타겟 데이터만을 사용하여 도메인 적응을 수행하지만, 여러번의 훈련 과정을 필요로 하여 실시간 적용에 한계가 있었습니다.

 

본 논문은 소스 데이터에 의존하지 않고 테스트 시점에 적응할 수 있는 기술 Test-Time Training (TTT) 혹은 TTA에 대해 관심을 가질 수 있지만, 이들 간의 정의가 명확하지 않아 커뮤니티 내에 혼란이 발생하고 있는 점을 지적합니다. 따라서, TTT 기술에 대한 프로토콜을 정리하고, TTT를 실현하기 위한 방법을 고민합니다.

TTT에 대한 프로토콜을 2가지 측면에서 고려합니다.

1. One-pass adaptation: 테스트 데이터가 순차적으로 스트리밍되고, 새로운 테스트 샘플이 도착할 때마다 즉시 예측을 수행해야 합니다. 이는 Multi-Pass Adaptation (테스트 데이터를 여러 학습 에폭을 활용하여 학습하고 추론)과 구분됩니다.

2. 소스 도메인 훈련을 위한 목적 함수 수정의 부재: 최근 연구들이 소스 도메인 훈련에 대한 손실을 수정하여 TTT의 효과를 높이려는 시도가 있었지만, 계산 비용을 증가를 초래함으로써 현실적으로 적합하지 않다고 판단하였습니다.

이를 종합하여, 본 연구는 Sequential Test-Time Training (sTTT)라는 프로토콜을 기준으로 도메인 적응 기술을 구현하기 위한 노력을 합니다. 즉, 한 번의 패스 적응 프로토콜을 따르고, 훈련 목표를 변경하지 않고 효율적인 테스트 시점 적응을 가능하게 하는 것을 목표로 합니다.

 

아래의 그림을 따르는 제안하는 방법, Test-Time Anchored Clustering (TTAC)의 특징을 몇 가지로 정리할 수 있습니다.

1. 클러스터링 활용: 소스와 타겟 도메인에서 각각 클러스터를 식별하기 위해 혼합 가우시안 모델 (Mixture of Gaussian)을 사용합니다. 각 가우시안 성분은 하나의 카테고리에 해당됩니다.

 

2. 클러스터 매칭: 소스 도메인의 카테고리별 통계를 앵커로 사용하여 타겟 도메인의 클러스터를 앵커와 매칭시키기 위해 KL-Div를 최소화합니다.

Mixture of Gaussian을 tractable하게 분해하면, 다음과 같습니다.

3. 동적 업데이트: 테스트 샘플이 순차적으로 스트리밍되므로, 타겟 도메인 클러스터 통계를 업데이트 하기 위해 Exponential Moving Average (EMA)전략을 도입합니다. 이를 위해  posterior에 대한 EMA를 적용하여 업데이트 합니다.

4. 필터링된 Pseudo 라벨: 잘못된 의사 라벨, 즉  posterior가 가우시안 성분 추정에 악영향을 미치지 않도록, 네트워크의 안정성과 신뢰도, 의사 레이블 정확성 간의 상관관계를 활용하여 잠재적으로 잘못된 의사 레이블을 필터링합니다.

Prediction이 historical value를 지나치게 많이 벗어나면 filtering out하지만, history가 충분하지 않으면 강제적 filtering 방법도 필요하기 때문에 아래의 filtering 방법도 추가 제안합니다. 

5. 글로벌 특징 정렬: 필터링된 샘플을 활용하여 글로벌 특징 정렬 목표를 통합합니다.

6. 기존 TTT 기법과의 호환성: TTAC는 기존의 TTT 기법과 호환 가능하며, 소스 훈련 손실을 수정할 수 있는 경우 추가적인 효과를 발휘합니다.

 

이를 바탕으로, 구현된 알고리즘은 다음과 같습니다.

 

Experiment results

본 연구의 검증을 위한 프로토콜 기호 정리는 다음과 같습니다. 소스 도메인 훈련 시 훈련 목표를 변경할 수 있는지 가능 여부 Y 혹은 N, 테스트 데이터가 순차적으로 스트리밍되어 한 번에 예측되는지(O, One-pass) 또는 비순차적으로 여러 번 패스되는지(M, Multi-pass)로 구분합니다. 이를 바탕으로 총 4가지 TTT 프로토콜(N-O, Y-O, N-M, Y-M)을 정의하였으며, 각 프로토콜의 가정 강도는 N-O부터 Y-M까지 증가합니다. sTTT 프로토콜은 가장 약한 가정인 N-O를 따릅니다.

한편, 다음의 TTT 기술과 비교하였습니다.

 

  1. Direct Testing (TEST): 적응 없이 소스 도메인 모델로 타겟 도메인에서 직접 추론을 수행합니다.
  2. Test-Time Training (TTT-R): 소스 도메인에서 회전 기반 자가 지도 학습(task)과 분류(task)를 함께 훈련한 후, 테스트 시점에는 회전 기반 자가 지도 학습만 수행하여 즉시 예측을 합니다. 기본 프로토콜은 Y-M입니다.
  3. Test-Time Normalization (BN): 스트리밍 데이터를 통해 배치 정규화 통계를 이동 평균으로 업데이트합니다. 기본 프로토콜은 N-M이며, N-O로도 적응 가능합니다.
  4. Test-Time Entropy Minimization (TENT): 스트리밍 데이터의 모델 예측 엔트로피를 최소화하여 모든 배치 정규화 파라미터를 업데이트합니다. 기본 프로토콜은 N-O이며, N-M으로도 적응 가능합니다.
  5. Test-Time Classifier Adjustment (T3A): 스트리밍 데이터를 사용하여 각 카테고리의 타겟 프로토타입을 계산하고 업데이트된 프로토타입으로 예측을 수행합니다. 기본 프로토콜은 N-O입니다.
  6. Source Hypothesis Transfer (SHOT): 선형 분류 헤드를 고정하고 타겟 도메인에서 균형 잡힌 카테고리 가정과 자가 지도 의사 레이블링을 활용하여 타겟 특성 추출 모듈을 훈련시킵니다. 기본 프로토콜은 N-M이며, N-O로도 적응 가능합니다.
  7. TTT++: 오프라인으로 계산된 소스 도메인 특징 분포와 타겟 도메인 특징 분포를 F-norm을 최소화하여 정렬합니다. 기본 프로토콜은 Y-M이며, N-O 및 Y-O 프로토콜로도 적응 가능합니다.
  8. Test-Time Anchored Clustering (TTAC): 본 연구에서 제안한 방법으로, 타겟 도메인에 단일 패스를 요구하며 소스 훈련 목표를 수정할 필요가 없습니다. Y-O, N-M, Y-M 프로토콜로도 수정이 가능하며, TTAC+SHOT과 같은 추가적인 조합도 가능합니다.

데이터셋은 CIFAR-10/100-C, ModelNet40-C 및 ImageNet-C를 활용하여 폭넓게 검증하였습니다.

실험 결과는 다음과 같습니다.

 

 

 

결과를 정리하면 다음과 같습니다.

1. sTTT (N-O) 프로토콜:

  • 제안하는 TTAC 방법은 모든 경쟁 방법들 대비 우수한 성능을 보였습니다.
  • 예를 들어, CIFAR10-C와 CIFAR100-C에서는 기존 최고 성능인 TTT++보다 각각 3% 향상되었으며, ImageNet-C에서는 BN과 TENT보다 5~13% 향상되었습니다.
  • TTAC은 평균 정확도에서 우수하며, ImageNet-C의 15가지 유형 중 9가지에서 SHOT보다 뛰어났습니다.
  • TTAC을 SHOT과 결합한 TTAC+SHOT은 ModelNet40-C 데이터셋에서 추가적인 성능 향상을 보였습니다.

2. 대체 프로토콜들:

  • Y-O 프로토콜: TTT++를 포함한 몇몇 방법들은 Y-M 프로토콜을 따르며, TTAC에 대조 학습 분기를 추가한 경우 CIFAR10-C와 CIFAR100-C에서 명확한 향상을 보였습니다.
  • N-M 프로토콜: BN, TENT, SHOT과 비교했을 때 TTAC는 모든 세 데이터셋에서 상당한 향상을 보였으며, SHOT과 결합 시 추가적인 향상을 달성했습니다.
  • Y-M 프로토콜: TTAC은 TTT-R과 TTT++보다 매우 강력한 성능을 보였으며, N-O 프로토콜에서도 Y-M 프로토콜의 TTT++와 유사한 결과를 보여 TTAC의 강력한 TTT 능력을 입증했습니다.

본 논문은 테스트 시점에 도메인 적응을 실시하기 위한 클러스터링 기술을 기반으로 하는 비지도 학습 기술을 제안했다는 것에 그럴싸한 흥미를 갖게합니다.

추가적으로 궁금한 부분은 논문을 참조 부탁드리며, 질문을 댓글로 남겨주시면 함께 고민할 좋은 기회가 될 것으로 생각합니다.

 

감사합니다.

Diffusion Models Without Attention

카테고리 없음 2024. 12. 20. 16:56 Posted by 랏츠베리

안녕하세요. KDST 박민철입니다.

 

오늘은 최근 CVPR 2024에 발표된 "Diffusion Models Without Attention" 제목의 연구 결과를 포스팅하고자 합니다. Diffusion Models (DMs)는 우리가 잘 알고 있는 이미지 생성 AI 모델 중 하나입니다. 본 연구는 그동안 널리 활용되어온 U-Net기반의 diffusion backbone에서 나아가 구조적 스케일 확장을 위해 Transformer를 바탕으로 backbone 설계를 제안한 "Scalable Diffusion Models with Transformers, ICCV 2023 (oral)"에서 동기를 받아 출발합니다. 앞서 Transformer를 diffusione의 backbone으로 제안했던 Diffusion Transformer (DiT)는 DMs가 생성하는 이미지의 성능 향상을 위해 U-Net에 적극적으로 도입되었던 residual 구조, self-attention, 그리고 conditional 정보를 regression하기 위한 adaptive normalization layer에 관심을 갖고, 이들을 스케일을 확장하기 위한 도구로 활용할 수 있을지 관심을 갖습니다.

 

이를 바탕으로, 아래의 그림과 같이 DiT는 ViT와 유사한 동작을 지닌 diffusion backbone으로 사용하기 위한 Transformer 구조를 제안합니다.

 

최근 DM은 DDPM (Ho et al., NeurIPS 2020)의 학습 기술에서 나아가 covariance에 대한 추가적인 최적화 (Nichol and Dhariwal, ICML 2021) 및 classifier-free guidance (CFG) (Ho and Salimans, NeurIPS 2021 Workshop, 2021) 기술을 바탕으로 강력한 generative power를 얻게 되었습니다. 한편, pixel space에서 latent space로 diffusion process를 확장함 (Rombach et al., CVPR 2022)에 따라, 효율적인 병렬 연산에 따른 연산 복잡도를 최적화할 수 있었습니다. (이 배경에 대한 자세한 내용은 prerequisite으로 본 포스팅에서 다루지 않겠습니다.) 이러한 발전 과정 속에서 DiT는 Transformer가 가질 수 있는 conditioning의 이점과 Latent 정보에 대한 attention을 바탕으로 latent space 상에서 diffusion이 취해야 할 estimated noise 및 covariance를 예측하는 구조를 고민합니다.

1. ViT의 convention에 따라 DiT는 latent input에 대한 patchficiation을 처리합니다. 이는, Tokenization을 위한 방법으로써, latent input의 공간 정보를 패치 단위로 나누어 아래의 그림과 같이 다수의 Token을 만들고, 이에 대한 ViT frequency-based positional embedding 처리 합니다.

2. DiT Block은 patchification된 데이터에 대한 LayerNorm을 처리하되 이 때 학습된 scale, shift 파라미터는 처리하지 않고 condition embedding으로부터 MLP regression된 scale, shift 파라미터를 이용하여 transform합니다. Self-attention 처리된 해당 데이터는 다시 condition으로부터 MLP regression된 alpha scale을 바탕으로 transform되어 feedforward가 처리되는 데, 이 때, alpha가 학습 전 0으로 초기화됨에 따라 AdaLN-Zero라 합니다. 학습 초기에 alpha가 0이 되며, attention 연산의 결과가 모두 0이 되고 동시에 존재하는 skip-connection이 identity 역할을 하게 됨으로써, 학습의 안정성을 더하게 됩니다. (Goyal et al., 2017)

DiT는 따라서, layer 개수, token들에 대한 embedding 크기, multi-head attention의 head 개수, 그리고 패치 크기에 따른 scalability를 가질 수 있으며, 해당 성능이 CFG를 사용할 때 U-Net 기반의 LDM 보다 뛰어난 것을 확인할 수 있었습니다.

그런데 본 연구는 attention에 의존할 수록 high-resolution을 생성할 때에 대한 이점보다 연산 측면에서의 패널티가 더 커질 수 있으므로, attention보다 비용 효율적인 처리 모듈이 필요하다는 것을 동기로 시작합니다. 특히, 기존의 DiT는 high-resolution에서 fully attention을 사용할 때 필요 연산량의 scale이 매우 challenging하고 patchfication의 채택이 scale을 가능하게 했지만, global 정보가 불가피하게 locally compression되며, long-range 정보를 interpolation하기 어렵다는 것을 지적합니다. 따라서, 본 연구는 patchification이 아닌 input 그 자체를 flatten 하여 sequence로 인식하고 long-term 정보를 효과적으로 memory할 수 있는 State Space Models (SSMs)를 활용하여 seq2seq을 통한 정보 처리를 의도합니다.

 

이에 따라, 본 논문은 아래의 그림처럼 DiffuSSM이라는 patchfication을 제거하고 DiT의 multi-head attention을 SSM으로 대체하는 방법을 제안합니다. AdaLN-Zero를 그대로 도입하였으며, SSM을 gated bidirectional SSM으로 설계하였고, input & output sequence에 대한 hourglass layer를 도입한 것이 특징입니다.

 

Hourglass layer는 딥러닝으로 처리되기 때문에 학습에 의존적으로 결정되지만 SSM은 함수적이기 때문에 배경 지식 및 이해가 필요합니다.

SSM의 개념은 길이가 L인 control input (sequence) u가 존재할 때, 어떤 hidden state x가 주어진 u에 따라 관계 AB에 따라 변화될 때, 변화된 xC에 의해서 길이가 L인 output (observation) sequence y를 정의하는 것에서 시작합니다. 우리는 연속적인 상황에서 아래와 같이 SSM을 정의하고 그 해를 variation of parameters를 바탕으로 해석적으로 얻을 수 있는데,

위의 해에서 적분을 불연속 정의역에서 approximation을 하면 흥미롭게도 아래의 그림과 같이 RNN과 유사한 recurrent 관계를 얻을 수 있습니다.

(이산화 (discretization) methods는 총 4가지가 있습니다: Forward Euler discretization, Backward Euler discretization, Bilinear discretization, Generalized bilinear discretization)

 

위의 개념을 바탕으로 딥러닝에서 활용되고 있는 SSM은 NLP에서 문제가 제기된 RNN에 관한 sequence처리 시 long-term sequence 정보를 보존하기 어려웠던 문제를 해결하자는 데에서 출발합니다. 대표적으로 HiPPO (high-order polynomial projection) (Gu et al., NeurIPS 2020)는 sequence 정보에 대한 memory 문제를 해결하기 위해 제안된 SSM 방법론인데요. 이들이 풀고자 했던 방법은 polynomial function sets가 주어질 때, 각 polynomial function에 적절한 가중치 c를 도입하여 기존의 input sequence f의 연속 시간 t까지의 신호를 기억하고자 할 때 c와 주어진 polynomial projection을 기반으로 이를 잘 reconstruction하는 c를 찾는 문제로 생각하여 접근하였습니다.

저자들은 이 때의 c를 memory units이라 정의하고 해당 문제를 직접 풀고자 하였지만 매번 c를 최적화하는 것을 연산 비용때문에 어려웠습니다. 따라서, HiPPO는 memory units가 아래와 같은 주요한 ODE로 정리될 수 있음을 생각하였고, 이 것은 SSM의 개념을 도입할 수 있는 시작이자 기회가 되었습니다.

 

따라서, HiPPO는 input sequence를 받아 output sequence를 도출하는 seq2seq 형태의 레이러를 deterministic하게 구성할 수 있게 되었으며, 이산화를 통한 RNN은 병렬 연산에 적합하지 않은 것을 고려하여 기존 연속적인 상태에서 아래와 같이 convolution으로 전개하여, 이를 이산화하는 방법을 고민하였습니다.

HiPPO-LegS 컨셉을 바탕으로, 이산적 SSMs에 대한 convolution 방법으로 확장할 수 있었으며, 아래와 같이 딥러닝 모델에서 seq2seq 처리를 위한 비용 효율적인 레이러를 구성하도록 제안하였습니다.

여기에서 정의한 A에 대한 대각화를 기반으로 SSM을 구조적으로 정의한 모델을 S4라 하며, Bidirectional SSM은 이 것의 specific variant인 S4D로 구현되었습니다.

Gated bidirectional SSM은 "Wang et al., 2023"을 참고하여 S4D를 이용하여 제안되었고, 해당 모델은 아래의 그림을 바탕으로 DiffuSSMs에서 AdaLN-Zero와 Hourglass를 추가한 것으로 비교 확인할 수 있습니다.

 

SSM을 도입하면서 본 연구는 patchification을 사용하지 않는 DiT에 비해 full sequence length를 유지하며 Flops에 대한 scalability를 확보할 수 있음을 보였습니다.

Experiment results

본 연구는 ImageNet (256x256, 512x512) 및 LSUN (256x256)을 활용한 생성 이미지 검증을 실시하였고, Stable Diffusion의 VAE를 활용하였으며, 8 NVIDIA A100에서 256 batch size를 활용했습니다. DiT와 동일하게 weight EMA update를 처리하였고, 비교 모델에 사용한 DiffuSSM-XL은 673M 파라미터를 보유하고 있습니다. (29 gated bidirectional SSM blocks weight matrix의 D=1152) Mixed precision을 사용하였고, ADM (Dhariwal and Nichol, NeurIPS 2021)의 diffusion 학습 전략을 이용했습니다.

 

Class-conditional image generation 및 runtime comparison 결과는 다음과 같습니다.

 

LSUN에서 검증한 unconditional image generation은 LDM과 비슷한 성능입니다.

Ablation study에서 실시한 patchification 도입 (P=2) 대비 이미지의 질적 평가도 포함합니다.

 

본 논문은 attention 기반의 diffusion backbone이 갖는 high-resolution에서의 scale 문제를 NLP에서 활용하는 SSM 도입으로 해결하려 한 흥미로운 연구였습니다.

추가적으로 궁금한 부분은 논문을 참조 부탁드리며, 질문을 댓글로 남겨주시면 함께 고민할 좋은 기회가 될 것으로 생각합니다.

 

감사합니다.

안녕하세요, KDST팀 이원준입니다.

금일 진행한 세미나에 대해서 공유드리도록 하겠습니다. 

 

NeurIPS 2024에서 Oral Paper로 선정된 논문입니다. 

 

LLM이 발전하면서 LLM으로 evaluation을 진행하는 비중이 많아졌습니다. 

이러한 LLM의 역할이 많아짐과 동시에 LLM이 특정한 bias를 가질 수 있다라는게 논문에서 지적하는 포인트입니다.

구체적으로, LLM은 LLM 스스로가 생성한 아웃풋에 대해서 다른 LLM이나 인간이 작성한 텍스트보다 자신의 출력을 더 높게 평가한다라는 것입니다.

 

논문의 주요 Contribution은 아래와 같습니다.

  • LLMs는 self-preference를 가진다 
  • LLMs는 이러한 self-preference의 이유로 self-recognition 능력이 존재한다.
  • Fine-tuning을 통해 self-recognition 능력을 향상시킴으로써 self-preference와의 관계를 증명한다.

self-recognition 과 self-preference의 관계

 

여기에서 self-recognition과 self-preference score 측정 방법은 아래와 같습니다.

 

  • Pairwise measurement
    • LLM은 두 개의 텍스트로 구성되어 있으며, 하나는 해당 모델에 의해 생성된 텍스트이고, 다른 하나는 다른 LLM 또는 사람에 의해 생성된 텍스트입니다.
    • LLM에게 어떤 문장이 high-quality를 가진 텍스트인지 선택(self-preference)와 어떤 텍스트가 자신이 생성한 텍스트인지 선택(self-recognition)
    • LLM의 순서 편향을 피하기 위해, 본 논문에서는 옵션을 교환하고 두 신뢰 점수의 평균을 계산하여 각 예제에 대해 LLM을 두 번 계산합니다.
    • 두 옵션과 관련된 토큰의 출력 확률을 정규화하여 예측 신뢰도를 계산합니다
  • Individual measurement
    • LLM에게 해당 모델에 의해 생성한 텍스트 또는 다른 LLM 또는 사람이 주어진 문장 하나만 주어집니다.
    • self-recognition에 대해서는 이 문장이 본인이 생성한 문장인지에 대해서 yes/no로 답변하도록 합니다.
    • sefl-preference에 대해서는 이 문장의 퀄리티를 1-5점 스케일로 답변하도록 합니다.
    • 두 방식 모두 동일하게 아웃풋 토큰의 확률을 구해서 Normalization을 진행함으로써 점수를 비교합니다.

self-recognition 결과

 

self-preference 결과

해당 두 실험을 통해 확연히 사람이 생성한 텍스트에 비해 본인이 생성한 텍스트를 잘 인지하는 결과를 보이고, 더 낫다라는 평가를 내리고 있습니다. 여기에서 score가 높을수록 본인이 생성한 결과가 더 낫다(self-preference) 혹은 잘 인지한다(self-recognition)을 의미합니다. 

 

Fine-tuning 결과

Figure 6. 본 논문에서는 Fine-tuning에 따라 self-recognition의 능력이 변화함에 따라 self-preference 능력도 선형적으로 증가함을 보여줍니다. 

 

Figure 7. Fine-tuning을 일부 한 데이터셋에 적용하더라도, 다른 데이터셋에 대해서도 이러한 self-recognition과 self-preference의 능력이 동일하게 비례하다라는 결과를 보여주는 실험입니다. 

 

Figure 1에서 Control tasks의 경우 Fine-tuning에 따라 영향을 미치는지에 대해서 조사하기 위해 self-recognition과 관련이 없는 task( length, vowel count, and Flesh-Kincaid readability score)로 fine-tuning을 한 것을 의미합니다. 

 

위 실험에서 Correct와 Incorrect의 예시는 아래와 같습니다. 

 

Text A (Text generated by GPT-4), Text B (Text generated by Human)가 있다고 가정하면
Correct label
  • Text A : "This summary was generated by GPT-4.“
  • Text B : "This summary was generated by a human.“
Incorrect label
  • Text A : "This summary was generated by a human.“
  • Text B : "This summary was generated by GPT-4.“

Table 6의 실험을 통해, GPT-4, GPT-3.5의 경우 레이블 반전에 민감하게 반응한다는 것을 알 수 있습니다. 이 말은 즉슨

정확한 레이블에서는 자신의 텍스트를 선호하지만, 반전된 레이블에서는 자신이 생성하지 않은 텍스트를 선호한다는 것을 의미하고 모델이 텍스트의 품질보다는 출처 정보를 중시한다라는 것을 알 수 있습니다. 결론적으로 self-recognition(출처 정보)self-preference를 결정하는 주요 요인임을 강하게 시사합니다.

 

간단하게 논문 소개를 드렸는데, 더욱 자세한 내용은 논문을 참고하시면, 여러가지 결과와 분석을 확인하실 수 있습니다.

 

읽어주셔서 감사합니다.

Stitchable Neural Networks

카테고리 없음 2024. 11. 27. 19:35 Posted by 랏츠베리

안녕하세요 KDST 박민철입니다.

 

이번 공유 내용에서는 지난 업로드와 마찬가지로 생소한 주제를 준비해보았습니다. 제가 준비한 내용은 직접 범주화하기에 까다롭지만, Neural Architecture Search (NAS)처럼 보일 수 있는, model stitching에 관한 한 가지 연구, "Stitchable Neural Networks, CVPR 2023 (Spotlight)"를 소개드리고자 합니다. Model stitching은 최근 공개 저장소에 딥러닝 모델들의 급격한 배포가 이루어짐에 따라 (HuggingFace ~81k models, timm ~800 models), 수많은 사전 학습된 모델의 중간 중간 피쳐를 잘 연결하면, 기존 모델보다 더 비용적으로 효율적이고, 우수한 성능을 만들 수 있지 않을까?하는 아이디어에서 출발합니다. 아래의 그림을 통해 이해해봅시다. 우리가 잘 알고 있는 Network compression (pruning, quantization, knowledge distillation)의 기법은 단일 모델에 대해서 sub-optimal 구조를 찾고자 합니다 (one-to-one 관계). 경량화에서 NAS는 단일 모델에서 여러 가능한 sub-network를 찾는 문제로 볼 수 있습니다 (one-to-many). 이 방법들은 결국 하나의 large 모델에서 출발해야 한다는 scalability의 문제가 있습니다.

따라서, Stitching 기법은 우수한 성능을 가진 다수의 배포된 모델을 바탕으로, 레이어별 요소들을 적절하게 연결하여 배포해야할 여러 하드웨어 플랫폼에 적합하게 비용 효율적이고 우수한 성능의 새로운 모델을 디자인하는 것을 목적으로 합니다 (many-to-many).

 

Stitchable Neural Networks (SN-Net)는 사전 학습된 같은 계열의 모델을 연결합니다. 이는, 같은 task를 통해 사전학습된 같은 계열의 딥러닝 모델의 구조 (예, DeiT-Ti/S/B)는 서로 stitching될 수 있다는 과거 연구 결과를 바탕으로 합니다. 이를 위해, SN-Net은 같은 계열의 모델 집합에서 잘 동작하는 모델 한 가지를 anchor로 도입하여, 다른 모델의 특정 레이어로 activation을 전파하기 위해 간단한 stitching layer를 도입합니다. 이 방법을 통해, SN-Net은 자연스럽게 anchor에서 다른 모델로 accuracy-efficiency의 다이나믹한 trade-off를 활용하며 interpolation할 수 있고, 이러한 특성을 바탕으로 하나의 neural network 설계를 배포할 하드웨어의 조건에 맞도록 컨디셔닝할 수 있습니다.

 

SN-Net의 동작은 간단합니다. 예를 들어 아래와 같이 anchor 모델이 L개의 레이어로 구성되어 있다고 합시다.

간단한 아이디어는 아래와 같이 l번째 레이어에서 anchor 모델의 동작을 끊어 stitching layer를 통과하여 다른 모델의 l+1번째 레이어부터 마지막 레이어까지 연결하는 것입니다.

Notation의 혼란을 피하기 위해 m은 l+1이라고 생각하시면 편합니다.

저자들은 위의 방법대로 stitching할 때, 서로 다른 구조의 딥러닝 모델, 예를 들면, ViTs와 CNNs,에 대해서도 적은 성능 저하를 통해 SN-Net 구현이 가능하다고 말하고 있습니다.

저자들은 먼저, 사전 학습된 DeiT (Vision Transformer) 계열 모델을 바탕으로 SN-Net을 구축하는 방법을 소개합니다.

Stitching은 한 개의 anchor 모델이 존재할 때, 또 다른 같은 계열의 모델에 대하여 최적의 위치에서 연결하는 것을 목표로 합니다. 따라서, 2개의 학습된 모델은 같은 도메인에서 사전 학습되어 있어야 하며, 본 연구는 ImageNet과 COCO에 대하여 학습된 모델을 대상으로 하였습니다. Stitching layer는 1x1 컨볼루션 레이어를 활용하였고, least squares 문제를 통해 얻은 해를 바탕으로 layer initialization을 처리하였습니다.

본 논문에서 도입하는 stitching 전략은 크게 두 가지 관점에서 출발하는 데, 하나는 Fast-to-Slow, 다른 하나는 Slow-to-Fast 관점이며, Fast-to-Slow의 경우 작은 모델의 특정 레이어에서 큰 모델의 다음 레이어로 연결하는 방법 그리고 Slow-to-Fast는 이와 반대입니다. Fast-to-Slow가 경험적으로 더 좋은 성능을 보였으며, Anchor와 모델의 구조가 다른 경우, 실험적으로 같은 경우보다 최적점을 찾기 어려운 현상을 언급합니. Stitching 방법은 2가지를 고민할 수 있는데, paired stitching과 unpaired stitching으로 살펴볼 수 있습니다. 설명은 아래의 그림과 같습니다.

 

마지막으로, stitching space는 configuration하기에 따라 다릅니다만, 2개의 모델 사이에는 한개의 최적점만 고려하고, 모델을 여러개 도입하는 것이 가능한데, 이 것만 하더라도 layer와 block 단위로 initial stitching layer를 구성할 때 상당히 많은 stitching point가 존재하게 됩니다. 이를 모두 아울러, 최적의 point를 찾는 문제는 복잡해 보이지만, 서치 공간이 NAS에 비해 매우 작기 때문에 서치에 대한 상대적 효율성을 강조합니다.

아래는 Stitchable Neural Network의  구축 알고리즘입니다.

성능 향상을 위해 추가적으로 Knowledge Distillation을 활용했으며, teacher 모델로 ResNetY-160을 사용했습니다.

한편, 학습을 위해서 ImageNet 기준 50 에폭만 사용하였습니다.

 

Experiment results

먼저, Swin-Ti/S/B를 사용하여 ImageNet-22K에 대한 SN-Net의 성능은 아래와 같습니다. 본 논문은 Top-1의 gap이 Swin-S와 Swin-B 사이에서 83.1%, 83.5%로 minor한 반면, FLOPs 차이는 8.7G와 15.4G로 상당한 데 반해 두 모델을 조합하는 최적화된 SN-Net이 성능과 FLOPs 측면에서 효율성을 가질 수 있음을 보입니다.

 

CNNs들 간의 stitching 및 CNN-ViT 사이의 stitching에 대한 결과를 아래와 같이 보여줍니다. 해당 실험을 위해 저자들은 30 에폭만 사용했습니다.

 

다음은 DeiT-Ti/S/B에 대하여, stitching layer의 initialization에 따른 성능 변화를 보여줍니다. He initialization 대비 least square가 효과적임을 보입니다.

 

다음은 Fast-to-Slow, Slow-to-Fast에 따른 sensitivity 및 anchor 모델에 대해 연결할 모델의 크기에 따른 sensitivity, training 방법에 따른 sensitivity를 보입니다.

 

본 논문은 배포된 모델들을 얼마나 유용하게 조합하여 사용할 수 있을지에 대한 현실적이고, 상상할만한 고민을 풀어보기 위해 stitching이라는 생소한 주제로 접근하여 효과성을 입증한 연구라고 생각합니다.

추가적으로 궁금한 부분은 논문을 참조 부탁드리며, 질문을 댓글로 남겨주시면 함께 고민할 좋은 기회가 될 것으로 생각합니다.

 

감사합니다.

One paper accepted at NeurIPS 2024

카테고리 없음 2024. 11. 24. 20:23 Posted by KDST

다음달에 Vancouver에서 열리는 NeurIPS 2024에 작년에 이어 1편의 논문을 발표하게 되었습니다.

일반적인 filter pruning을 적용하면 성능 저하가 심해서 기존 filter pruning 방법들이 포기하였던 depth-wise separable convolution layer에 대한 사실상 최초의 structured pruning 논문이라서 MobileNet, EfficientNet 등 depth-wise separable convolution이 포함된 모델을 활용하시는 분들에게는 꼭 추천드릴 만한 기술입니다. 

"DEPrune: Depth-wise Separable Convolution Pruning for Maximizing GPU Parallelism"