이번 글에서는 CVPR 2021에 accept된 Pruning 논문 중 하나인 Manifold Regularized Dynamic Network Pruning 을 리뷰하도록 하겠습니다. 

 

먼저 Dynamic pruning에 대해 알아보겠습니다. 기존의 channel pruning 방식들은 channel을 static하게 제거하기 때문에 모든 sample에 같은 구조의 network를 사용합니다. 하지만 실제로는 filter/channel의 중요도는 input에 따라 매우 다릅니다. 아래 그림은 Dynamic pruning을 처음 제안한 Dynamic channel pruning: Feature boosting and suppression 에서 가져왔습니다.

Pretrained ResNet-18로 imagenet sample을 추론해 보면, 특정 채널(114, 181)의 아웃풋은 이미지에 따라 많은 차이를 보입니다. 또한 (c)에서 볼 수 있듯이 20개 channel의 maximum activation의 분포를 그려보면, 채널 간에 큰 편차가 있는 것을 알 수 있습니다. 

 

기존 Dynamic pruning 방법들은 개별 instance마다 channel을 독립적으로 pruning 했기 때문에 instance간의 관계를 고려하지 않고, 또한 instance의 복잡도(complexity)에 상관없이 모두 같은 sparsity constraint를 사용한다는 단점이 있습니다. 저자들은 이러한 문제점을 극복하기 위해 임의의 instance에 상응하는 network redundancy를 극대화할 수 있는 dynamic pruning 방법론인 ManiDP를 제안합니다. 

 

이 논문의 핵심 아이디어는 각 input image의 manifold 정보를 학습 과정에서 활용하여 image간의 관계가 각각에 대응하는 sub-network 간에도 유지되도록 하는 것입니다. 학습 과정에서 두 가지 관점, complexity와 similarity 측면에서 manifold 정보를 활용합니다.

 

Manifold Regularized Dynamic Pruning

Preliminaries

논문의 수식을 이해하기 위해 먼저 기존 channel pruning과 dynamic pruning의 formulation을 소개하겠습니다. 기존 channel pruning의 optimization objective는 다음과 같습니다.

첫 번째 term은 각 sample의 cross entropy loss이고 두 번째 term은 convolution filter간의 channel-sparsity를 만들기 위한 식입니다. $\lambda$값을 증가시킬수록 더 sparse한 network가 생성됩니다.

 

Dynamic pruning에서는 output channel의 중요도가 어떤 input이 들어오는지에 따라 달라지게 됩니다. 따라서 개별 input(sample/instance)마다 다른 sub-network를 가집니다. input $x_i$에 대해 $l$번째 layer의 상황을 가정하겠습니다. input feature를 $F^{l-1}(x_i)$라 하고 channel saliency를 control하기 위한 모듈은 $G^l$라고 notation합니다. 그러면 channel saliency $\pi(x_i, w)\in \mathbb{R}^{c^l}$을 $G^l(F^{l-1}(x_i))$라고 할 수 있습니다. 

따라서 dynamic channel pruning의 optimization objective는 다음과 같습니다.

Complexity

직관적으로 생각해보면 간단한 sample을 분류하기 위해서는 작은 network로 충분할 것이고, 복잡한 sample을 분류하기 위해서는 더욱 복잡한 network가 필요할 것입니다. 따라서 이 논문에서는 instance와 sub-network의 complexity를 각각 측정하고 이를 정렬하기 위한 adaptive objective function을 제안했습니다.

 

Instance의 complexity를 측정하기 위해서는 Task-specific (i.e., Cross entropy) loss를 사용합니다. CE loss가 클수록 그 instance를 잘 학습하지 못했다는 뜻이고, 따라서 복잡한 instance로 볼 수 있습니다.  Sub-network의 sparsity를 측정하기 위해서는 channel saliencies($\pi$)의 sparsity를 사용합니다. 즉 network의 channel 중요도를 나타내는 벡터가 sparse할수록 적은 channel이 존재하므로 간단한 network라고 볼 수 있습니다.

 

앞에서 설명했다시피 기존 연구들에서는 모든 instance에 대해 같은 weight coefficient, 즉 $\lambda$가 사용되었습니다. 하지만 여기서는 CE loss가 높으면 더 낮은 $\lambda$값을 사용하고, CE loss가 낮으면 더 높은 $\lambda$값을 사용함으로써 각 instance마다 adaptive하게 sparsity penalty를 적용합니다. 또한 $\beta = \{\beta\}_{i=1}^{n} \in \{0,1\}^N$이라는 binary learnable parameter를 사용하는데 이는 $i$번째 instance에 대응하는 sub-network가 pruning 되어야 하는지 여부를 의미합니다. 최종적으로 optimization objective는 다음과 같이 정의됩니다. $C$는 instance complexity threshold로서 사전에 정의됩니다. 실험적으로 이전 epoch에서 구한 전체 dataset에 대한 cross entropy 평균을 사용한다고 합니다.

따라서 sparsity loss의 coefficient $\lambda(x_i)$는 다음과 같이 정의됩니다.

$\lambda(x_i)$의 값은 항상 $[0,\lambda']$의 범위 안에 있고, 값이 커질수록 $x_i$가 간단한 instance라는 것을 의미합니다.

 

Similarity

ManiDP의 두번째 목표는 instance간의 similarity를 sub-network간에도 잘 유지하는 것입니다. 즉 두 instance가 비슷하면 상응하는 sub-network도 비슷해야 합니다. 이를 반영하는 loss term을 만들기 위해 먼저 instance간, sub-network간의 similarity를 측정합니다.

 

Instance similarity는 intermediate features $F^l(x_i)$간의 similarity로 정의됩니다. intermediate features가 input sample $x_i$의 특징을 잘 대표한다고 할 수 있기 때문입니다. $F^l(x_i)$를 average pooling으로 flatten한다음 cosine similarity를 측정하여 similarity matrix $R^l$을 만듭니다. 이 matrix는 $N \times N$의 shape을 가집니다.

한편 Sub-network similarity는 어떻게 측정할 수 있을까요? $\pi(x_i)$는 특정 sub-network의 channel saliency, 즉 sub-network의 architecture를 표현합니다. 따라서 $\pi(x_i)$간의 cosine similarity를 sub-network similarity로 사용합니다. 이 matrix는 instance similarity와 마찬가지로 $N \times N$의 shape을 가지고 식은 다음과 같습니다. 

최종적으로 추가되는 loss term은 instance similarity $R^l[i,j]$와 sub-network similarity$T^l[i,j]$간의 Frobenius norm으로 정의됩니다.

Final Objective

최종 학습 objective는 다음과 같습니다. 2번째는 instance와 sub-network간의 complexity를 정렬하기 위한 term이고 위에서 소개한 $\lambda(x_i)$에 의해 통제됩니다. 세번째는 similarity를 유지하기 위한 term이고 hyperparameter $\gamma$에 의해 통제됩니다. 

추론 시에는 $\pi^{l}(x_i)$의 각 element값이 threshold보다 큰 경우에만 channel을 남깁니다. 각 instance마다 다른 $\pi^{l}(x_i)$ 값을 학습하기 때문에 모두 다른 sub-network를 갖게 됩니다. 

 

Experiments

ImageNet으로 학습했을 때 ResNet과 MobileNet 모델 성능

ManiDP-A와 ManiDP-B는 각각 다른 pruning rates를 가지는 모델입니다. ManiDP는 다른 방법론들과 비슷한 수준으로 FLOPs를 줄이면서도 가장 좋은 성능을 보입니다. 

 

CIFAR-10으로 학습했을 때 ResNet 모델 성능

 

Effectiveness of Manifold Information

또한 complexity와 similarity를 고려하는 것이 얼마나 효과적인지를 알아보기 위해 ablation study를 수행했습니다. 두 경우에서 complexity와 similarity를 모두 고려하는 것이 가장 좋은 성능을 보였습니다.

 

Visualization

이 논문에서 의도한 대로 간단한 instance에 가벼운 sub-network가 대응되는지를 확인하기 위해 visualization을 수행했습니다. 

FLOPs가 적은 sub-network일수록 그에 대응되는 이미지가 간단한 것을 볼 수 있습니다. FLOPs가 높아질수록 object가 명확하지 않고 어려운 이미지를 판별하게 됩니다.

 

이상으로 Manifold Regularized Dynamic Network Pruning 논문 리뷰를 마치겠습니다. 궁금한 점은 댓글로 달아주세요. 감사합니다.

 

Woojeong Kim