본 포스팅은 NIPS2018에 발표된 Knowledge distllation 연구 결과를 리뷰하도록 하겠습니다. 포스팅에 앞서, 주제와 관련된 모든 연구 내용은 Knowledge Distillation by On-the-Fly Native Ensemble 참조했음을 먼저 밝힙니다.

 

Preliminary


Knowledge distillation은 기계 학습 모델의 인상적인 성능 향상을 달성한 방법 중 하나입니다. 2014년 Hinton 교수 (NIPS 2014 Workshop)가 재해석한 distillation에 의하면 representation capacity가 큰 teacher model을 smaller capacity를 가진 student model이 aligning 할 수 있도록 하면, student model이 vanila process보다 qualifed performance를 달성할 수 있다고 하였습니다. 이를 바탕으로, 매우 다양한 형태의 distillation 방법들이 소개되었는데요. 본 연구는 지금까지 소개된 distillation의 방법론이 3가지 측면에서 문제가 있다고 지적합니다. 첫 째, longer training process, 둘 째, 상당한 extra computational cost and memory usage, 셋 째, complex multi-phase training procedure. 따라서, 저자는 distillation procedure의 simplification이 필요하다고 말하며, 2018년 Hinton 교수 (ICLR 2018)가 언급한 online distillation을 motivation 삼아 one-phase distillation training을 제안하였습니다.

 

Summary


본 연구에서 언급하는 simplification의 주된 영감은 multi-branch neural network라고 말합니다. Multi-branch neural network는 아키텍쳐 관점에서 대표적으로 ResNet이 있으며 이 모델은 feedforward propagation에서 identity를 고려하는 대표적인 two-branch network인 반면, Convolution 연산 관점에서는 group convolution을 예로 생각할 수 있습니다. group convolution의 경우 depthwise convolution과 같이 convolution을 특정 chnannel만큼 분할하여 수행하는 연산으로 ShuffleNet이나 Xception 등에 활용되었습니다. 이러한 Branch 구조는 최근 deep neural network가 더욱 깊어짐에 따라, 비례적으로 증가하지 못한 capacity 확장성 문제를 극복하기 위해 사용되었고 매우 성공적인 연구 결과를 도출하였습니다.

  이에 따라 본 논문은 성공적으로 모델 성능 효과를 달성했던 knowledge distillation을 multi-branch 컨셉에 접목하여 resource efficient distillation 알고리즘을 제안하였습니다. 아래의 그림은 제안하는 알고리즘의 overview인데요. Teacher model을 사용하던 과거의 컨셉과는 달리, network의 last level block에서 동일한 block을 병렬적으로 추가하여 ensemble 구조를 구현하였습니다. 이를 논문에서는 peer 혹은 auxiliary branch라고 표현하였으며, 이들의 logits 결과를 적절히 combination을 하면, teacher의 representation을 approximation 할 수 있다고 말합니다. 더욱 자세한 직관을 위해 그림을 설명드리면, 각 branch block은 동일한 아키텍쳐의 low-level layer를 공유하고 branch block마다 별도의 cross-entropy를 수행하도록 하였으며, 이 때 활용된 각각의 logits을 하나의 gate layer로부터 가중치를 얻어 summation을 하면 ensemble logits을 확보할 수 있는데, 이를 teacher logits으로 estimation 하고 ensemble cross-entropy를 처리하였습니다.

Overview of online distillation by the proposed On-the-fly native Ensemble (ONE)

 

본 논문에서 제안하는 Total loss는 아래처럼 총 3가지 텀으로 정의할 수 있습니다. (1) 각각의 branch에서 발생하는 cross-entropy의 합 (2) teacher의 것으로 estimation된 cross-entropy (3) teacher의 logits과 각 branch의 logits과의 KLLoss (Hinton KD). 본 논문에서 Hinton KD를 사용하기 위해 활용된 T는 3으로 하였으며 branch 개수는 [1, 2, 3, 4, 5]에 대해 각각 결과를 보여주었습니다.

Overall loss function by ONE

 

아래의 그림은 전체 알고리즘 동작에 관한 설명이며, Eq (1), Eq (3), Eq (4), Eq (6), Eq (7)에 관한 loss term은 논문에서 직접 확인 가능합니다. 본 논문의 data scale은 CIFAR10, CIFAR100, SVHN, ImageNet에서 하였으며, 모델은 ResNet에서 하였습니다. 

ONE Algorithm


서술한 내용 외에 본 논문에 대한 궁금한 점은 댓글로 남겨주세요.

 

감사합니다.

Mincheol Park