안녕하세요. 오늘 소개드릴 논문은 AAAI'2021에 게재된 DegAug입니다. 논문 제목에서 알 수 있듯이 Feature representation을 나누고 semantic augmentation을 통해 Out-of-Distribution generalization을 달성한 논문입니다.


About Out-of-Distribution

  • 논문에 대한 내용에 앞서 IID와 OoD의 개념을 말씀드리겠습니다.
  • Independent, Identical Distribution (I.I.D)→예를 들면 CIFAR10의 train-set과 test-set이 나눠져 있지만 그 둘은 동일한 분포를 가지고 있습니다. 따라서 이 둘은 I.I.D입니다.
    • 어떤 랜덤 확률 변수 집합이 있을 때 각각의 랜덤 확률변수들은 독립적이면서 동일한 분포를 가지는 것을 의미합니다.
  • Out of Distribution (OoD)(ICML 2021, Stanford) WILDS: A Benchmark of in-the-Wild Distribution Shifts→ 당연하게도 실제 상황에선 OoD가 더 많이 존재합니다.
    • 예를 들어, Medical 응용의 경우 특정 병원에서 얻은 데이터로 학습한 뒤, 다른 병원에 배포되는 경우가 있습니다. 학습 때와 다른 분포를 테스트시 추론하게 되므로 이 경우 OoD입니다.
    • 학습 데이터의 분포를 따르지 않는 데이터를 따르지 않는 데이터를 이야기합니다.
  • 그렇다면 기존의 학습방법으로 OoD을 잘 추론할 수 있을까요? → 그렇지 않습니다!

    • 위 그림은 2019년 페이스북에서 나온 Invariant Risk Minimization 에 나온 예시입니다. 초록색 배경의 소🐮 와 모래색 배경의 낙타🐪 로 학습을 시킨 모델이 있다고 가정해봅시다.
    • 이 모델에 초록색 배경 낙타🐪를 추론시킬 경우 소🐮로 추론한다고 합니다.
    • 즉, 기존의 학습 방법은 데이터간의 가장 큰 공통 특징(배경색)을 가지고 학습 및 추론을 하기 때문에 학습 데이터와 다른 데이터(OoD)은 잘 추론하지 못합니다.
  • OoD는 언제 문제가 될까요?

(CVPR 2019, Facebook) Does Object Recognition Work for Everyone?

 

  • 위 그림은 페이스북에서 CVPR'2019에 발표한 논문에서 발췌한 그림입니다.
  • 왼쪽 사진은 상용 클라우드 플랫폼 (Azure, Clarifai, Google, Amazon)의 AI classification 모델에 서로 다른 국가에서 촬영한 비누를 추론시킨 결과입니다 네팔에서 촬영한 사진의 경우 모든 상용 솔루션이 틀리는 것을, 영국에서 촬영한 사진의 경우 대부분의 경우에서 맞는것을 알 수 있습니다.→ 상용 솔루션의 경우 대부분 미국에서 배포하고 있고, 미국을 포함한 수입이 높은 국가에서 수집한 데이터로 학습을 진행합니다. 따라서 데이터를 수집한 국가와 다를수록 즉 OoD의 경우 정확도가 낮아지는 것을 알 수 있습니다.
  • 오른쪽 사진은 한달 수입이 $x$ 인 국가에서 데이터를 수집한 뒤, 상용 솔루션 모델들에 추론시켰을 때 정확도를 나타낸 표입니다. 수입이 높은 국가에서 수집한 데이터인 경우 정확도가 높은 것을 알 수 있습니다.

DegAug


Category vs . Context

  • 이 논문에선 하나의 이미지가 category, context 2개의 정보를 담고 있다고 얘기하고 있습니다
    • context : 각 이미지가 가지고 있는 환경적인 특징을 의미합니다. 예를 들어 잔디밭 위에 양이 있다면 잔디밭이 context가 됩니다.
    • category : label을 의미합니다.
  • 따라서 하나의 이미지는 category-context의 쌍으로 이루어져 있습니다. 만약 train-set에 없는 category-context쌍을 기존의 학습 방법으로 학습한 모델의 추론시키면 잘 추론을 못하고, 이 연구는 이를 해결하고자 하는 논문입니다.

Proposed Method

Overall Architecture

  • 자 그럼 이 논문에서 OoD을 어떻게 해결했는지. 즉, OoD generalization을 어떻게 달성했는지 알아보도록 하겠습니다.

Overall Diagram of DegAug

  • 위 그림은 OoD 데이터를 학습하기 위한 이 논문의 구조를 도식화 한 그림입니다
    • input data를 backbone네트워크에 추론
    • 1에서 얻은 embedding vector를 category feature extractor와 context feature extractor에 각각 추론
    • context feature extractor에서 context label을 추론합니다.
    • category feature extractor에서 category label을 추론합니다.
    • 2에서 얻은 2개의 embedding vector를 합쳐(concat) 최종적으로 추론
  • 따라서 DegAug는 총 4개의 네트워크로 구성되어 있습니다. (backbone network, Category Feature Extractor, Context Feature Extractor, Final Classifier)


Method

  • 하나의 이미지에 대해 category label과 context label을 모두 가지고 있는 데이터셋이 있다고 가정해봅시다.
    • x는 이미지, y는 category(class) 레이블을 c는 context 레이블을 의미
    •  
  • D를 알고 있으므로 Cross Entropy Loss를 이용해 Context Feature Extractor를 학습할 수 있습니다. 
    • 이렇게 발생한 2개의 loss를 가지고 backbone 네트워크를 학습할 수 있습니다.

h_1 : category classifier h_2 : context classifier f_1 : category feature extractor f_2 : context feature extractor

 

  • category feature extractor와 context feature extract가 서로 다른 특성을 학습할 수 있도록 이 논문에선 𝑂𝑟𝑡ℎ𝑜𝑔𝑜𝑛𝑎𝑙 𝐿𝑜𝑠𝑠 를 제안합니다.
    • category loss의 gradient와 context loss의 gradien의 코사인 유사도를 loss로 쉬하고 있습니다.
    • gradient간의 코사인 유사도를 감소하는 방향으로 backbone 네트워크의 학습을 진행함으로써 backbone 네트워크가 category와 context를 모두 담은 embedding vector를 추출할 수 있도록 학습이 진행됩니다.

 

  • train-set에 없는 데이터를 잘 추론하기 위해 DecAug는 feature-level에서 augmentation을 진행합니다. (Semantic augmentation)

→ 위 수식을 보시면 context feature에 context loss gradient를 더해줌으로써 train-set에 없는 category-context 쌍의 학습을 진행할 수 있습니다.

 

  • 위의 방법을 통해 얻은 category feature와 context feature를 합쳐 최종적인 classification을 하게 됩니다.
  • 모든 loss function을 정리하면 다음과 같습니다.


Experiments

DegAug는 총 3가지 데이터셋에 대한 실험 결과를 보여주고 있습니다

  • 먼저 colored-MNIST에 대한 실험 결과 입니다.

  • ERM은 일반적인 학습 방법입니다. DegAug가 가장 높은 정확도를 보이는 것을 확인할 수 있습니다.

 

  • PACS 실험 결과입니다.
    • PACS는 Picture, Art, Cartoon, Scatch 4개의 도메인으로 이루어진 데이터셋으로 3개의 도메인을 train-set에, 나머지 하나의 도메인을 test-set으로 두고 정확도를 측정합니다.
  • NICO 실험 결과입니다.


Conclusion

Feature augmentation을 통해 OoD generalization을 달성한 DegAug에 대해 알아보았습니다. feature를 augmentation한다는 점에서 흥미로웠지만 OoD generalization을 달성하기 위해 context label이 필요하다는 점이 굉장히 크리티컬한 것 같아 아쉬움이 남습니다. self-supervised learning과 이 연구를 합치거나 StyleGAN과 같이 context를 변경해 DegAug를 진행한다면 더 재밌지 않을까.. 라는 생각을 논문을 읽으며 많이 했던 것 같습니다.