안녕하세요. KDST 박민철입니다.

 

오늘은 NeurIPS 2022에 발표된 "Revisiting Realistic Test-Time Training: Sequential Inference and Adaptation by Anchored Clustering" 제목의 도메인 적응 연구 결과를 포스팅하고자 합니다. 본 연구는 도메인 적응 문제를 테스트 시점에서 접근하기 위한 학습 방법으로 Test-Time Adaptation (TTA)과 유사해 보이지만, 소스 도메인에 대한 통계적 정보를 활용하고, 테스트 데이터가 연속적으로 주어지는 상황에서 비지도 학습을 통해 적응하는 기술이라고 볼 수 있습니다.

 

본 논문은 도메인 적응 상황에서 2가지 상황에 대한 한계점을 언급합니다.

1. 소스 및 타겟 데이터 접근 요구: 기존의 도메인 적응 방법들은 훈련 시 소스와 타겟 도메인 데이터 모두에 접근할 수 있어야 하거나, 여러 도메인을 동시에 학습해야하는 기술적 단점이 있었습니다.

2. 소스 데이터 접근 제한: 소스 도메인 데이터가 프라이버시 문제나 저장 공간의 한계로 접근이 어려운 경우가 많아, 이를 해결하기 위해 소스-프리 도메인 적응이 등장했습니다. 소스-프리 도메인 적응은 소스 데이터에 접근하지 않고 타겟 데이터만을 사용하여 도메인 적응을 수행하지만, 여러번의 훈련 과정을 필요로 하여 실시간 적용에 한계가 있었습니다.

 

본 논문은 소스 데이터에 의존하지 않고 테스트 시점에 적응할 수 있는 기술 Test-Time Training (TTT) 혹은 TTA에 대해 관심을 가질 수 있지만, 이들 간의 정의가 명확하지 않아 커뮤니티 내에 혼란이 발생하고 있는 점을 지적합니다. 따라서, TTT 기술에 대한 프로토콜을 정리하고, TTT를 실현하기 위한 방법을 고민합니다.

TTT에 대한 프로토콜을 2가지 측면에서 고려합니다.

1. One-pass adaptation: 테스트 데이터가 순차적으로 스트리밍되고, 새로운 테스트 샘플이 도착할 때마다 즉시 예측을 수행해야 합니다. 이는 Multi-Pass Adaptation (테스트 데이터를 여러 학습 에폭을 활용하여 학습하고 추론)과 구분됩니다.

2. 소스 도메인 훈련을 위한 목적 함수 수정의 부재: 최근 연구들이 소스 도메인 훈련에 대한 손실을 수정하여 TTT의 효과를 높이려는 시도가 있었지만, 계산 비용을 증가를 초래함으로써 현실적으로 적합하지 않다고 판단하였습니다.

이를 종합하여, 본 연구는 Sequential Test-Time Training (sTTT)라는 프로토콜을 기준으로 도메인 적응 기술을 구현하기 위한 노력을 합니다. 즉, 한 번의 패스 적응 프로토콜을 따르고, 훈련 목표를 변경하지 않고 효율적인 테스트 시점 적응을 가능하게 하는 것을 목표로 합니다.

 

아래의 그림을 따르는 제안하는 방법, Test-Time Anchored Clustering (TTAC)의 특징을 몇 가지로 정리할 수 있습니다.

1. 클러스터링 활용: 소스와 타겟 도메인에서 각각 클러스터를 식별하기 위해 혼합 가우시안 모델 (Mixture of Gaussian)을 사용합니다. 각 가우시안 성분은 하나의 카테고리에 해당됩니다.

 

2. 클러스터 매칭: 소스 도메인의 카테고리별 통계를 앵커로 사용하여 타겟 도메인의 클러스터를 앵커와 매칭시키기 위해 KL-Div를 최소화합니다.

Mixture of Gaussian을 tractable하게 분해하면, 다음과 같습니다.

3. 동적 업데이트: 테스트 샘플이 순차적으로 스트리밍되므로, 타겟 도메인 클러스터 통계를 업데이트 하기 위해 Exponential Moving Average (EMA)전략을 도입합니다. 이를 위해  posterior에 대한 EMA를 적용하여 업데이트 합니다.

4. 필터링된 Pseudo 라벨: 잘못된 의사 라벨, 즉  posterior가 가우시안 성분 추정에 악영향을 미치지 않도록, 네트워크의 안정성과 신뢰도, 의사 레이블 정확성 간의 상관관계를 활용하여 잠재적으로 잘못된 의사 레이블을 필터링합니다.

Prediction이 historical value를 지나치게 많이 벗어나면 filtering out하지만, history가 충분하지 않으면 강제적 filtering 방법도 필요하기 때문에 아래의 filtering 방법도 추가 제안합니다. 

5. 글로벌 특징 정렬: 필터링된 샘플을 활용하여 글로벌 특징 정렬 목표를 통합합니다.

6. 기존 TTT 기법과의 호환성: TTAC는 기존의 TTT 기법과 호환 가능하며, 소스 훈련 손실을 수정할 수 있는 경우 추가적인 효과를 발휘합니다.

 

이를 바탕으로, 구현된 알고리즘은 다음과 같습니다.

 

Experiment results

본 연구의 검증을 위한 프로토콜 기호 정리는 다음과 같습니다. 소스 도메인 훈련 시 훈련 목표를 변경할 수 있는지 가능 여부 Y 혹은 N, 테스트 데이터가 순차적으로 스트리밍되어 한 번에 예측되는지(O, One-pass) 또는 비순차적으로 여러 번 패스되는지(M, Multi-pass)로 구분합니다. 이를 바탕으로 총 4가지 TTT 프로토콜(N-O, Y-O, N-M, Y-M)을 정의하였으며, 각 프로토콜의 가정 강도는 N-O부터 Y-M까지 증가합니다. sTTT 프로토콜은 가장 약한 가정인 N-O를 따릅니다.

한편, 다음의 TTT 기술과 비교하였습니다.

 

  1. Direct Testing (TEST): 적응 없이 소스 도메인 모델로 타겟 도메인에서 직접 추론을 수행합니다.
  2. Test-Time Training (TTT-R): 소스 도메인에서 회전 기반 자가 지도 학습(task)과 분류(task)를 함께 훈련한 후, 테스트 시점에는 회전 기반 자가 지도 학습만 수행하여 즉시 예측을 합니다. 기본 프로토콜은 Y-M입니다.
  3. Test-Time Normalization (BN): 스트리밍 데이터를 통해 배치 정규화 통계를 이동 평균으로 업데이트합니다. 기본 프로토콜은 N-M이며, N-O로도 적응 가능합니다.
  4. Test-Time Entropy Minimization (TENT): 스트리밍 데이터의 모델 예측 엔트로피를 최소화하여 모든 배치 정규화 파라미터를 업데이트합니다. 기본 프로토콜은 N-O이며, N-M으로도 적응 가능합니다.
  5. Test-Time Classifier Adjustment (T3A): 스트리밍 데이터를 사용하여 각 카테고리의 타겟 프로토타입을 계산하고 업데이트된 프로토타입으로 예측을 수행합니다. 기본 프로토콜은 N-O입니다.
  6. Source Hypothesis Transfer (SHOT): 선형 분류 헤드를 고정하고 타겟 도메인에서 균형 잡힌 카테고리 가정과 자가 지도 의사 레이블링을 활용하여 타겟 특성 추출 모듈을 훈련시킵니다. 기본 프로토콜은 N-M이며, N-O로도 적응 가능합니다.
  7. TTT++: 오프라인으로 계산된 소스 도메인 특징 분포와 타겟 도메인 특징 분포를 F-norm을 최소화하여 정렬합니다. 기본 프로토콜은 Y-M이며, N-O 및 Y-O 프로토콜로도 적응 가능합니다.
  8. Test-Time Anchored Clustering (TTAC): 본 연구에서 제안한 방법으로, 타겟 도메인에 단일 패스를 요구하며 소스 훈련 목표를 수정할 필요가 없습니다. Y-O, N-M, Y-M 프로토콜로도 수정이 가능하며, TTAC+SHOT과 같은 추가적인 조합도 가능합니다.

데이터셋은 CIFAR-10/100-C, ModelNet40-C 및 ImageNet-C를 활용하여 폭넓게 검증하였습니다.

실험 결과는 다음과 같습니다.

 

 

 

결과를 정리하면 다음과 같습니다.

1. sTTT (N-O) 프로토콜:

  • 제안하는 TTAC 방법은 모든 경쟁 방법들 대비 우수한 성능을 보였습니다.
  • 예를 들어, CIFAR10-C와 CIFAR100-C에서는 기존 최고 성능인 TTT++보다 각각 3% 향상되었으며, ImageNet-C에서는 BN과 TENT보다 5~13% 향상되었습니다.
  • TTAC은 평균 정확도에서 우수하며, ImageNet-C의 15가지 유형 중 9가지에서 SHOT보다 뛰어났습니다.
  • TTAC을 SHOT과 결합한 TTAC+SHOT은 ModelNet40-C 데이터셋에서 추가적인 성능 향상을 보였습니다.

2. 대체 프로토콜들:

  • Y-O 프로토콜: TTT++를 포함한 몇몇 방법들은 Y-M 프로토콜을 따르며, TTAC에 대조 학습 분기를 추가한 경우 CIFAR10-C와 CIFAR100-C에서 명확한 향상을 보였습니다.
  • N-M 프로토콜: BN, TENT, SHOT과 비교했을 때 TTAC는 모든 세 데이터셋에서 상당한 향상을 보였으며, SHOT과 결합 시 추가적인 향상을 달성했습니다.
  • Y-M 프로토콜: TTAC은 TTT-R과 TTT++보다 매우 강력한 성능을 보였으며, N-O 프로토콜에서도 Y-M 프로토콜의 TTT++와 유사한 결과를 보여 TTAC의 강력한 TTT 능력을 입증했습니다.

본 논문은 테스트 시점에 도메인 적응을 실시하기 위한 클러스터링 기술을 기반으로 하는 비지도 학습 기술을 제안했다는 것에 그럴싸한 흥미를 갖게합니다.

추가적으로 궁금한 부분은 논문을 참조 부탁드리며, 질문을 댓글로 남겨주시면 함께 고민할 좋은 기회가 될 것으로 생각합니다.

 

감사합니다.