본 게시글에서는 얼마전 진행되었던 2020 CVPR에서 인상깊었던 논문인 "Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion"을 소개해드리고자 합니다. Training dataset에 대한 사전 정보 없이 network inversion 방식을 통해 noise 이미지에서 해당 클래스에 속하는 이미지를 생성하는 방법론에 대한 논문입니다. 논문 링크는 이곳을 클릭해주세요.
논문 포스팅에 앞서 Neural network inversion 방법론에 대해 간략히 소개합니다.
기존에 모델을 학습시켰던 training dataset에 대한 정보가 없을때, training dataset 이미지의 input space를 복원하는 방법론 중에는 Neural network inversion 방식이 있습니다. Neural network inversion은 기존에 pre-trained 된 모델의 weight를 고정시킨채로 noise한 이미지를 모델에 포워딩시키고, back propagation을 수행하며 weight가 아닌 node의 output 값을 update하는 방식입니다.
이러한 Neural network inversion방식은 크게 2가지 카테고리로 분류할 수 있습니다.
Optimization based inversion
loss 값을 back propagation하며 gradient를 기반으로 optimization기법을 수행하여 실제 이미지와 근사한 input space를 예측하는 방식
Training based inversion
image를 생성하는 sub network가 존재하며 sub network를 학습시키며 input space를 예측하는 방식
ex ) Generative Adversarial Network와 같이 image generator를 활용하는 경우
해당 논문은 첫번째 카테고리인 Optimization based inversion 방식으로, sub network를 활용하지 않고 optimization 기법을 통해 synthesized image를 생성해 내는 방식입니다. 본 논문에서는 기존의 DeepDream 방법론에서 사용하였던 loss function에 Regularization term을 추가하여 복원된 이미지를 실제 이미지와 비슷한 분포로 생성해 냈습니다.
DeepDream은 Optimization based inversion 기법중 하나로, 아래와 같은 loss function을 통해 noise에서 synthesized image로 optimize하는 방식입니다.
loss function의 첫번째 term은 실제 이미지의 target label인 y와 inversion 과정에서 생성된 synthesized image인 x hat과 의 classification loss 입니다. 두번째 term은 image의 regularization term으로 생성된 이미지의 total variance (R_tv)와 l2 norm (R_l2)로 구성되어있습니다. 이러한 regularization term은 실제 이미지에 생성된 이미지가 수렴할수 있도록 도움을 주는 term입니다.
본 논문에서 소개하는 DeepInversion은 DeepDream에서 사용하였던 loss function을 확장하여 synthesized image가 실제 이미지와 더 유사해지도록 만들었습니다.
DeepInversion은 실제 training dataset 이미지와 synthesized image를 CNN에 포워딩 시켰을 때 각 레이어의 feature map 거리를 최소화 시켜 원래 이미지의 데이터 분포를 따라가게 하는 term을 추가 했습니다. feature 통계정보로서 feature map에서의 평균과 분산을 활용하였고 이러한 정보는 생성된 이미지가 실제 이미지의 분포를 따를 수 있도록 정규화 합니다. 따라서 synthesized image가 실제 training dataset batch단위의 평균과 분산을 따라가게 하는 R_feature term은 아래와 같습니다.
실제 training dataset의 배치마다의 평균과 분산은 CNN 아키텍처에서 널리 사용되는 BatchNorm (BN) layer에 저장되어있습니다. 따라서 실제 training dataset 배치 x 평균의 기대값과 분산의 기대값은 아래와 같이 BN layer에 저장된 정보로 치환가능합니다.
따라서 DeepInversion의 regularization term은 R_prior와 R_feature를 추가함으로써 아래와 같이 정의할 수 있습니다.
추가로, 논문에서는 DeepInversion의 성능을 개선시키기 위해 Adaptive DeepInversion (ADI) 방법론을 소개하고 있습니다. DeepInversion 방식으로 생성된 이미지의 다양성을 위해 정규화 term을 확장시켰습니다. 본 논문에서 주장하는 이미지의 다양성이란, 동일 클래스에 속하는 이미지더라도 다양성을 추구하는 이미지를 생성하는 것을 의미합니다.
논문에서는 Jensen-Shannon divergence를 활용한 정규화 term을 추가하여 synthesized image가 student 모델과 teacher 모델의 output의 분포의 불일치를 유도합니다. Jensen-Shannon divergence는 KL divergence의 평균으로, 두 output간의 분포를 거리로 치환하여 거리를 최소화 시킴으로써 두 분포를 일치시킵니다. 하지만 본 논문에서는 이미지의 다양성을 추구해야 하므로 두 분포의 거리를 최대화 시키기 위해 Jensen-Shannon divergence 값을 1에서 뺍니다.
따라서 Adaptive DeepInversion의 regularization term은 아래와 같이 정의할 수 있습니다.
아래의 그림은 위에서 설명한 loss function을 활용하여 random noise를 신뢰도 있는 synthesized image로 최적화하는 플로우 입니다. loss 값을 backpropagation 하며 원래 training dataset에 대한 정보 없이 input 이미지를 업데이트하며 synthesized image를 생성하게 됩니다. 향후 생성된 이미지로 Knowledge distillation 기법을 통해 student network를 학습시키며 높은 accuracy를 달성합니다. (논문의 4.4 section 참고)
본 논문에서는 synthesized image가 얼마나 실제 training dataset의 분포와 비슷한지를 보여주고 있으며 다양한 application에 Data-free 기법으로 접근하여 높은 성능에 달성한 결과를 보여주고 있습니다. Pruning, Knowledge distillation, Continual learning 등에서 state-of-the-art 성능을 달성하며 자세한 실험 결과는 논문을 참조해주세요.
<그림2: Microsoft Azure IoT platform heavy vs light edge concept from azure.microsoft.com>
아파치스트리밍플랫폼들
끝으로, 위에서 언급한 아파치 플랫폼 들에 대해 간략히 정리하겠습니다. 먼저, Hadoop 및Spark에서의Map-reduce 모델은다량의데이터를노드에나누어적용할오퍼레이션을맵핑한다음, 처리된데이터를다시key, value 쌍으로스토리지에저장하는모델입니다. 네트워크및I/O 오버헤드를고려하면, 데이터가일정이상크기여야오버헤드를상쇄하고이득을볼수있습니다. 따라서, 데이터의batch를가지고처리합니다.
하지만, 실시간처리에서는단일이벤트처리관점에서의지연시간을봤을때적합하지않을수있습니다. 데이터Batch를가지고최적화된모델은throughput, 즉, 총처리량에최적화된모델이고, 실시간처리에서는이벤트당처리시간latency에최적화가이루어져야합니다.
실시간 처리에 최적화되어 가장 먼저 주목을 받은 것은 Apache Storm입니다. 하지만, 기존의 배치 프로세싱과 너무 다른 프로그래밍 인터페이스, 그리고 빅데이터의 가장 중요한 분야인 배치 프로세싱이 처음에 지원되지 않았습니다. 따라서, 다양한 경쟁 제품 들이 나오게 됩니다. Apache Flink는 기존의 Map-Reduce프레임 워크 들과 유사한 프로그래밍 인터페이스, 그리고 batch job을 통해 어느 정도 배치 프로세싱도 지원합니다.
실시간 스트리밍 처리의 인기에 힘입어, 가장 인기있는 빅데이터 플랫폼 중에 하나인 Apache Spark에도 스트리밍 프로세싱 기능이 추가되게 됩니다. 개념적으로 batch를 쪼갠 mini batch를 만들어서 이벤트당 지연시간을 최적화하는 개념으로 구현되었습니다. 그리고, Apache Edgent처럼 엣지 디바이스에 최적화된 스트리밍 플랫폼도 등장합니다. 가볍게 만들기위해, 단일 노드안에서 실행하는 경우에 필요치 않은 여러가지 부가 매커니즘들을 없애서 가볍게 만들고, 원격으로 재설정 및 기기 생존을 체크하는 heart beat 등의 기능과 함께 구현되어 있습니다.
Knowledge distillation은 기계 학습 모델의 인상적인 성능 향상을 달성한 방법 중 하나입니다. 2014년 Hinton 교수 (NIPS 2014 Workshop)가 재해석한 distillation에 의하면 representation capacity가 큰 teacher model을 smaller capacity를 가진 student model이 aligning 할 수 있도록 하면, student model이 vanila process보다 qualifed performance를 달성할 수 있다고 하였습니다. 이를 바탕으로, 매우 다양한 형태의 distillation 방법들이 소개되었는데요. 본 연구는 지금까지 소개된 distillation의 방법론이 3가지 측면에서 문제가 있다고 지적합니다. 첫 째, longer training process, 둘 째, 상당한 extra computational cost and memory usage, 셋 째, complex multi-phase training procedure. 따라서, 저자는 distillation procedure의 simplification이 필요하다고 말하며, 2018년 Hinton 교수 (ICLR 2018)가 언급한 online distillation을 motivation 삼아 one-phase distillation training을 제안하였습니다.
Summary
본 연구에서 언급하는 simplification의 주된 영감은 multi-branch neural network라고 말합니다. Multi-branch neural network는 아키텍쳐 관점에서 대표적으로 ResNet이 있으며 이 모델은 feedforward propagation에서 identity를 고려하는 대표적인 two-branch network인 반면, Convolution 연산 관점에서는 group convolution을 예로 생각할 수 있습니다. group convolution의 경우 depthwise convolution과 같이 convolution을 특정 chnannel만큼 분할하여 수행하는 연산으로 ShuffleNet이나 Xception 등에 활용되었습니다. 이러한 Branch 구조는 최근 deep neural network가 더욱 깊어짐에 따라, 비례적으로 증가하지 못한 capacity 확장성 문제를 극복하기 위해 사용되었고 매우 성공적인 연구 결과를 도출하였습니다.
이에 따라 본 논문은 성공적으로 모델 성능 효과를 달성했던 knowledge distillation을 multi-branch 컨셉에 접목하여 resource efficient distillation 알고리즘을 제안하였습니다. 아래의 그림은 제안하는 알고리즘의 overview인데요. Teacher model을 사용하던 과거의 컨셉과는 달리, network의 last level block에서 동일한 block을 병렬적으로 추가하여 ensemble 구조를 구현하였습니다. 이를 논문에서는 peer 혹은 auxiliary branch라고 표현하였으며, 이들의 logits 결과를 적절히 combination을 하면, teacher의 representation을 approximation 할 수 있다고 말합니다. 더욱 자세한 직관을 위해 그림을 설명드리면, 각 branch block은 동일한 아키텍쳐의 low-level layer를 공유하고 branch block마다 별도의 cross-entropy를 수행하도록 하였으며, 이 때 활용된 각각의 logits을 하나의 gate layer로부터 가중치를 얻어 summation을 하면 ensemble logits을 확보할 수 있는데, 이를 teacher logits으로 estimation 하고 ensemble cross-entropy를 처리하였습니다.
본 논문에서 제안하는 Total loss는 아래처럼 총 3가지 텀으로 정의할 수 있습니다. (1) 각각의 branch에서 발생하는 cross-entropy의 합 (2) teacher의 것으로 estimation된 cross-entropy (3) teacher의 logits과 각 branch의 logits과의 KLLoss (Hinton KD). 본 논문에서 Hinton KD를 사용하기 위해 활용된 T는 3으로 하였으며 branch 개수는 [1, 2, 3, 4, 5]에 대해 각각 결과를 보여주었습니다.
아래의 그림은 전체 알고리즘 동작에 관한 설명이며, Eq (1), Eq (3), Eq (4), Eq (6), Eq (7)에 관한 loss term은 논문에서 직접 확인 가능합니다. 본 논문의 data scale은 CIFAR10, CIFAR100, SVHN, ImageNet에서 하였으며, 모델은 ResNet에서 하였습니다.