이번 글에서는 CVPR 2021에 accept된 Pruning 논문 중 하나인 Manifold Regularized Dynamic Network Pruning 을 리뷰하도록 하겠습니다. 

 

먼저 Dynamic pruning에 대해 알아보겠습니다. 기존의 channel pruning 방식들은 channel을 static하게 제거하기 때문에 모든 sample에 같은 구조의 network를 사용합니다. 하지만 실제로는 filter/channel의 중요도는 input에 따라 매우 다릅니다. 아래 그림은 Dynamic pruning을 처음 제안한 Dynamic channel pruning: Feature boosting and suppression 에서 가져왔습니다.

Pretrained ResNet-18로 imagenet sample을 추론해 보면, 특정 채널(114, 181)의 아웃풋은 이미지에 따라 많은 차이를 보입니다. 또한 (c)에서 볼 수 있듯이 20개 channel의 maximum activation의 분포를 그려보면, 채널 간에 큰 편차가 있는 것을 알 수 있습니다. 

 

기존 Dynamic pruning 방법들은 개별 instance마다 channel을 독립적으로 pruning 했기 때문에 instance간의 관계를 고려하지 않고, 또한 instance의 복잡도(complexity)에 상관없이 모두 같은 sparsity constraint를 사용한다는 단점이 있습니다. 저자들은 이러한 문제점을 극복하기 위해 임의의 instance에 상응하는 network redundancy를 극대화할 수 있는 dynamic pruning 방법론인 ManiDP를 제안합니다. 

 

이 논문의 핵심 아이디어는 각 input image의 manifold 정보를 학습 과정에서 활용하여 image간의 관계가 각각에 대응하는 sub-network 간에도 유지되도록 하는 것입니다. 학습 과정에서 두 가지 관점, complexity와 similarity 측면에서 manifold 정보를 활용합니다.

 

Manifold Regularized Dynamic Pruning

Preliminaries

논문의 수식을 이해하기 위해 먼저 기존 channel pruning과 dynamic pruning의 formulation을 소개하겠습니다. 기존 channel pruning의 optimization objective는 다음과 같습니다.

첫 번째 term은 각 sample의 cross entropy loss이고 두 번째 term은 convolution filter간의 channel-sparsity를 만들기 위한 식입니다. $\lambda$값을 증가시킬수록 더 sparse한 network가 생성됩니다.

 

Dynamic pruning에서는 output channel의 중요도가 어떤 input이 들어오는지에 따라 달라지게 됩니다. 따라서 개별 input(sample/instance)마다 다른 sub-network를 가집니다. input $x_i$에 대해 $l$번째 layer의 상황을 가정하겠습니다. input feature를 $F^{l-1}(x_i)$라 하고 channel saliency를 control하기 위한 모듈은 $G^l$라고 notation합니다. 그러면 channel saliency $\pi(x_i, w)\in \mathbb{R}^{c^l}$을 $G^l(F^{l-1}(x_i))$라고 할 수 있습니다. 

따라서 dynamic channel pruning의 optimization objective는 다음과 같습니다.

Complexity

직관적으로 생각해보면 간단한 sample을 분류하기 위해서는 작은 network로 충분할 것이고, 복잡한 sample을 분류하기 위해서는 더욱 복잡한 network가 필요할 것입니다. 따라서 이 논문에서는 instance와 sub-network의 complexity를 각각 측정하고 이를 정렬하기 위한 adaptive objective function을 제안했습니다.

 

Instance의 complexity를 측정하기 위해서는 Task-specific (i.e., Cross entropy) loss를 사용합니다. CE loss가 클수록 그 instance를 잘 학습하지 못했다는 뜻이고, 따라서 복잡한 instance로 볼 수 있습니다.  Sub-network의 sparsity를 측정하기 위해서는 channel saliencies($\pi$)의 sparsity를 사용합니다. 즉 network의 channel 중요도를 나타내는 벡터가 sparse할수록 적은 channel이 존재하므로 간단한 network라고 볼 수 있습니다.

 

앞에서 설명했다시피 기존 연구들에서는 모든 instance에 대해 같은 weight coefficient, 즉 $\lambda$가 사용되었습니다. 하지만 여기서는 CE loss가 높으면 더 낮은 $\lambda$값을 사용하고, CE loss가 낮으면 더 높은 $\lambda$값을 사용함으로써 각 instance마다 adaptive하게 sparsity penalty를 적용합니다. 또한 $\beta = \{\beta\}_{i=1}^{n} \in \{0,1\}^N$이라는 binary learnable parameter를 사용하는데 이는 $i$번째 instance에 대응하는 sub-network가 pruning 되어야 하는지 여부를 의미합니다. 최종적으로 optimization objective는 다음과 같이 정의됩니다. $C$는 instance complexity threshold로서 사전에 정의됩니다. 실험적으로 이전 epoch에서 구한 전체 dataset에 대한 cross entropy 평균을 사용한다고 합니다.

따라서 sparsity loss의 coefficient $\lambda(x_i)$는 다음과 같이 정의됩니다.

$\lambda(x_i)$의 값은 항상 $[0,\lambda']$의 범위 안에 있고, 값이 커질수록 $x_i$가 간단한 instance라는 것을 의미합니다.

 

Similarity

ManiDP의 두번째 목표는 instance간의 similarity를 sub-network간에도 잘 유지하는 것입니다. 즉 두 instance가 비슷하면 상응하는 sub-network도 비슷해야 합니다. 이를 반영하는 loss term을 만들기 위해 먼저 instance간, sub-network간의 similarity를 측정합니다.

 

Instance similarity는 intermediate features $F^l(x_i)$간의 similarity로 정의됩니다. intermediate features가 input sample $x_i$의 특징을 잘 대표한다고 할 수 있기 때문입니다. $F^l(x_i)$를 average pooling으로 flatten한다음 cosine similarity를 측정하여 similarity matrix $R^l$을 만듭니다. 이 matrix는 $N \times N$의 shape을 가집니다.

한편 Sub-network similarity는 어떻게 측정할 수 있을까요? $\pi(x_i)$는 특정 sub-network의 channel saliency, 즉 sub-network의 architecture를 표현합니다. 따라서 $\pi(x_i)$간의 cosine similarity를 sub-network similarity로 사용합니다. 이 matrix는 instance similarity와 마찬가지로 $N \times N$의 shape을 가지고 식은 다음과 같습니다. 

최종적으로 추가되는 loss term은 instance similarity $R^l[i,j]$와 sub-network similarity$T^l[i,j]$간의 Frobenius norm으로 정의됩니다.

Final Objective

최종 학습 objective는 다음과 같습니다. 2번째는 instance와 sub-network간의 complexity를 정렬하기 위한 term이고 위에서 소개한 $\lambda(x_i)$에 의해 통제됩니다. 세번째는 similarity를 유지하기 위한 term이고 hyperparameter $\gamma$에 의해 통제됩니다. 

추론 시에는 $\pi^{l}(x_i)$의 각 element값이 threshold보다 큰 경우에만 channel을 남깁니다. 각 instance마다 다른 $\pi^{l}(x_i)$ 값을 학습하기 때문에 모두 다른 sub-network를 갖게 됩니다. 

 

Experiments

ImageNet으로 학습했을 때 ResNet과 MobileNet 모델 성능

ManiDP-A와 ManiDP-B는 각각 다른 pruning rates를 가지는 모델입니다. ManiDP는 다른 방법론들과 비슷한 수준으로 FLOPs를 줄이면서도 가장 좋은 성능을 보입니다. 

 

CIFAR-10으로 학습했을 때 ResNet 모델 성능

 

Effectiveness of Manifold Information

또한 complexity와 similarity를 고려하는 것이 얼마나 효과적인지를 알아보기 위해 ablation study를 수행했습니다. 두 경우에서 complexity와 similarity를 모두 고려하는 것이 가장 좋은 성능을 보였습니다.

 

Visualization

이 논문에서 의도한 대로 간단한 instance에 가벼운 sub-network가 대응되는지를 확인하기 위해 visualization을 수행했습니다. 

FLOPs가 적은 sub-network일수록 그에 대응되는 이미지가 간단한 것을 볼 수 있습니다. FLOPs가 높아질수록 object가 명확하지 않고 어려운 이미지를 판별하게 됩니다.

 

이상으로 Manifold Regularized Dynamic Network Pruning 논문 리뷰를 마치겠습니다. 궁금한 점은 댓글로 달아주세요. 감사합니다.

 

Woojeong Kim 

안녕하세요, KDST에서 인턴 연구원으로 근무했던 김우정입니다.

 

이번에 김수현 박사님, 박민철 인턴 연구원님, 전근석 학연님과 함께 제출했던 논문이 NeurIPS (NIPS) 2020에 poster presentation으로 accept 되었습니다. 논문 제목은 "Neuron Merging: Compensating for Pruned Neurons" 입니다.

 

기존 structured network pruning 논문들에서는 신경망에서 덜 중요한 뉴런이나 필터를 제거하고 accuracy를 회복하기 위하여 fine-tuning을 합니다. 하지만 이때 제거된 뉴런에 상응하는 다음 layer의 차원 또한 제거됩니다. 따라서 pruning 직후에 다음 layer의 output feature map이 원래 모델과 매우 달라지게 되어 accuracy loss가 발생합니다.

본 논문에서는 제거된 neuron의 효과를 보상할 수 있는 "Neuron Merging" 컨셉을 제안했습니다. 또한 Neuron Merging 컨셉 안에서의 한가지 방법론으로, 추가적인 training이나 data 없이 pruning 직후 accuracy를 향상시키는 알고리즘을 제안했습니다. 

 

더욱 자세한 내용은 Conference 후에 다른 게시글로 소개하겠습니다. 감사합니다.

 

"Network Slimming" Review

카테고리 없음 2019. 10. 24. 16:27 Posted by 랏츠베리

본 포스팅은 Deep Convolutional Neural Network의 running time 최적화를 위해 channel level pruning을 도입한 "Learning Efficient Convolutional Networks through Network Slimming" (ICCV 2017)를 리뷰하도록 하겠습니다. 포스팅에 앞서, 주제와 관련된 모든 연구 내용은 Learning Efficient Convolutional Networks through Network Slimming 참조했음을 먼저 밝힙니다.

History


  Convolutional Neural Networks (CNNs)가 다양한 Computer Vision Task 처리에 중요한 솔루션으로 도입이 된 이후, 실제 응용된 어플리케이션에서 활용하려다 보니 Practical Issue가 광범위하게 발생하였습니다. 이들 중 실제 Light-weight device에서는 CNNs의 Computation Issue가 핵심적으로 보고되었는데요. 이를 해결하기 위해 Depp CNN Architecture의 발전은 Reinforcement Learning을 활용하는 수준까지 도달하였지만, 실제 응용되기엔 한계가 있었습니다. 이를 위해 다양한 경량화 기법들 즉, Weight Pruning이나 Filter or Channel Pruning과 같은 Model Reduction 효과를 극대화시키는 기술들이 소개되었고, 본 포스팅은 이 범주에서 Channel Pruning을 활용한 Reduction 효과에 대한 부분에 해당합니다.

 

Filter or Channel Pruning Effect


  Pruning 기법은 모델 내부의 Node 간의 Weighted Sum으로 구성된 커넥션들 중에서 불필요한 커넥션을 제거하는 기술입니다. 이를 통해, 중요하고, 필요하다고 고려되는 커넥션의 구성으로 최소화 함으로써, 모델 축소를 극대화하는 효과를 볼 수 있습니다. 하지만, 커넥션을 구성하는 Weight Parameter들은 대체로 Matrix 형태의 정형화된 데이터 구조로 관리되는데, 상기 기술로 인해 메모리 최적화와 같은 부수적인 경량화 효과를 얻기 위해서는 Sparse Matrix를 변형하여 관리할 데이터 구조가 요구되었습니다. 이에 따라, Weight Pruning은 아직까지 일반적으로 사용하는 경량화 디바이스에서 효율적으로 Running하지 못한다는 지적이 있었고, 기존에 활요하는 데이터 구조를 충분히 유지하며 딥러닝 모델의 경량화를 극대화하는 방향의 연구가 논의되었습니다.

  Filter 혹은 Channel Pruning은 Deep CNNs 구조에서 Computation이 가장 Dominant한 Convolutional Layer의 경량화를 목표로 하여, 이 Layer에 존재하는 Intermediate Feature Neurons의 개수를 최대한 줄이는 거시적인 관점에서의 Pruning을 의미합니다. 이 Pruning의 효과로는 기존의 Structured Data Structure를 그대로 유지하며, 이들의 개수를 제거하는 형태이기 때문에, 일반적인 디바이스에서도 경량화 효과를 충분히 느낄 수 있습니다.

 

Network Slimming


  본 포스팅에서 언급하는 Network Slimming이라는 연구는 Intermediate Feature Neurons의 개수를 최적화하는 방법론을 제안한 논문이며, Architecture의 한계에 따라 사용 가능 여부가 결정되지만, 저의 주관으로는 최적화 파이프라인 구현을 위해 상당히 간단하고 Fancy한 접근이었다는 면에서 좋은 점수를 받지 않았나 생각합니다.

  Network Slimming은 1) Deep CNNs의 학습 방식을 Stochastic Gradient Descent (SGD)를 활용하고 있고 이를 통해 Batch 단위 학습을 하고 있다는 점, 2) Layer 마다 Batch의 정규화가 요구되며 Batch Normalization Layer (BN)을 대부분의 네트워크 모델이 채택하고 있다는 점, 두 가지 측면을 Motivation으로 삼아 BN의 Trainable Variable인 Scaling Factor를 활용하여 Layer의 결과물인 Output Channel에서 불필요한 Channel을 찾아 제거하는 아이디어를 제안하였습니다. 이는 Scaling Factor가 학습 중인 시점에 Sparse 하게 되도록 L1-norm으로 Regularization을 하였는데요. 이 기법으로 인해 충분히 Channel 개수를 최적화하여 모델을 "Slim"하게 만드는 데 성공하였습니다.

 

Network Slimming

 

L1 Regularization on Scaling Factor in BN

 

Network Slimming


  Network Slimming의 Pruning 파이프라인은 전체 모델 구조에 대하여 하나의 Pruning Ratio를 활용하여 Dynamic Layer-Wise Pruning으로 처리하였고, 이를 여러 횟수 반복하여 Fine Tuning하였습니다.

Slimming Pruning Pipeline

 


  서술한 연구 내용 외에 추가적인 연구 결과나 이해가 필요한 부분이 있으면 직접 링크된 논문을 읽어보시고 댓글을 남겨주시면 답변드리겠습니다. 제가 작성한 PPT 자료를 업로드 해드리니 필요하시면 참고해주세요.