이번 글에서는 CVPR 2021에 oral로 선정된 Representative Batch Normalization with Feature Calibration”이란 논문을 소개합니다. 본 논문은 기존의 Batch normalization의 문제를 분석하고 이를 보완하는 Representative Batch normalization을 제안합니다. 

 

Introduction

Batch Normalization의 효과를 감소시키는 요인

 Batch Normalization(BN)은 convolution 연산의 결과로 나온 feature들을 mini-batch의 통계 정보를 사용하여 normalize 된 분포로 제한하는 과정을 수행합니다. 이 과정은 학습의 어려움을 낮춰 CNN 모델의 성능을 가져왔습니다. 하지만, BN이 가정하는 "서로 다른 인스턴스 feature의 채널별 분포가 모두 같다"는 2가지 문제점이 존재합니다.

  1. Training 단계에서의 mini-batch statistics와 testing시 사용하는 running statistics 사이의 불일치.
    • training 단계에서 mini-batch 단위로 누적 계산되는 statistic은 추론 시에 사용할 running statistics에 완벽하게 적합하지 않습니다.
  2. testing dataset의 인스턴스가 항상 훈련 세트의 분포에 속하지는 않을 수 있음.

이러한 두 가지 불일치 문제는 BN의 효과를 감소하게 합니다.

 

기존의 연구들

몇몇 연구들에서는 이를 해결하기 위해 normalize 단계에서 mini-batch 단위의 statistic 정보 대신 instance 단위의 statistic 정보를 활용하였습니다. 하지만 이는 batch 정보의 부족으로 학습의 불안정을 발생시켜 많은 경우에서 기존의 BN보다 안 좋은 성능을 보였습니다. 또 다른 연구에서는 여러 개의 normalize 기법을 조합하거나 attention 메커니즘의 도입으로 해결하였으나 많은 overhead를 도입하여 실용성을 떨어트리게 됩니다. 따라서 본 논문은 적은 비용으로 feature의 normalize 과정에서 mini-batch 이점과 instance 별 representation을 향상시키는 방법을 제안합니다.

 

Statistic의 불일치로 인한 문제 

Batch Normalization(BN)은 feature standardization과 affine transformation 동작으로 이루어져 있습니다. 본 논문은 centering과 scaling 동작으로 이루어진 standardization에 집중하였으며, centering과 scaling 과정은 각각 mini-batch의 statistics(mean, variance)를 기반으로 feature들이 zero-mean property(평균 값이 0)와 unit-variance(분산 값이 1)를 가지도록 합니다. 이때 testing instance와 training 시에 계산된 running mean, variance의 불일치는 몇 가지 문제를 발생시킵니다.

  1. running mean의 불일치
    • running mean > instance mean: 유익한 representation의 손실.
    • running mean < instance mean: feature에 noise의 증가.
  2. running variance의 불일치
    • 특정 채널의 값의 scale을 너무 크게 혹은 작게 만들어 feature의 분포를 불안정하게 만듦.
    • (예를 들어, 특정 channel에서의 running variance의 불일치는 해당 채널의 scale을 다른 channel에 비해 크게 만들어 다음 convolution에서 연산 결과가 그 채널에 지배적(dominant)이게 만듭니다.)

제안하는 Representative Batch Normalization은 feature의 채널별로 3개의 weight를 추가하여 centering, scaling calibration 과정을 수행합니다. 이는 기존의 BN 보다 training loss와 test accuracy를 작게 만들며 기존 BN을 손쉽게 대체 할 수 있습니다.

BN과 RBN의 training loss/test error 그래프

Proposed Method

Revisiting Batch Normalization

제안하는 Representative Batch Normalization(RBN)에 앞서 기존의 Batch Normalization(BN)의 standardization, affine transformation 과정을 수식화하면 아래와 같습니다.

Centering : $ X_m =X-E(X) $

Scaling : $ X_s= \frac{X_m}{\sqrt{Var(X)+\epsilon}} $

Affine : $ Y = X_s\gamma + \beta $

 

$X\in \mathbb{R^{N\times C\times H\times W}}$는 input feature이며 $E(X), Var(X)$는 각각 running mean과 variance를 의미합니다. training 과정에서 running mean과 variance는 dataset으로 부터 mean과 variance를 누적하여 계산됩니다. 이 과정은 아래와 같습니다.

 

$\mu _B=\frac{1}{NHW}\sum_{n=1}^{N}\sum_{h=1}^{H}\sum_{w=1}^{W}X_{(n,c,h,w)}$

$\sigma^{2}_B=\frac{1}{NHW}\sum_{n=1}^{N}\sum_{h=1}^{H}\sum_{w=1}^{W}(X_{(n,c,h,w)}-\mu_B)^2$

 

$E(X)\Leftarrow mE(X)+(1-m)\mu_B$

$Var(X)\Leftarrow mVar(X)+(1-m)\sigma^2_B$

 

Representative Batch Normalization

  • Statistics for Calibration

제안하는 RBN에서 Batch Normalization(BN)의 running statistic을 보정하기 위해 사용한 instance-specific feature의 statistic은 channel 별 정보를 사용하였으며 이는 다음과 같습니다.

$\mu_c = \frac{1}{HW}\sum_{h=1}^{H} \sum_{w=1}^{W} X_{(n,c,h,w)}$

$\sigma ^2 _c = \frac{1}{HW} \sum_{h=1}^{H} \sum_{w=1}^{W}(X_{(n,c,h,w)}-\mu _c)^2$

  • Centering Calibration

BN의 Centering 과정에서  feature 값 중 running mean보다 큰 부분은 유지되어 다음 레이어로 넘어가게 됩니다. 이때 잘못된 running mean은 feature의 유용한 정보를 지우거나 추가적인 noise를 추가하게 됩니다. 이를 해결하기 위해 centering calibration은 BN의 centering 과정 전에 feature의 추가적인 statistic($K_m$)을 채널 별 learnable weight($w_m$)을 통해 feature를 calibration 합니다. 수식으로 표현하면 다음과 같습니다. (논문의 실험에서 $K_m$은 $\mu_c$를 기본값으로 사용하였습니다.)

$X_{cm(n,c,h,w)}=X_{(n,c,h,w)} + w_m\odot K_m$ 

$\odot$: $w_m$과 $K_m$을 같은 shape로 만들고 dot product 수행.

 

Centering calibration이 적용된 X와 적용되지 않은 X의 centering operation 결과 차이를 수식화하면 아래와 같습니다. $w_m$ > 0 그리고 $K_m > E(X)$와 같은 조건에서 centering calibration은 feature의 유용한 정보를 살려주며 $w_m$ < 0 그리고 $K_m > E(X)$와 같은 조건에서는 feauture의 noise를 줄여주는 방식으로 동작하게 됩니다.

$X_{cal}-X_{no} = w_m\cdot (K_m-E(X))$

  • Scaling Calibration

유지할 특징을 조절하는 centering과 다르게 scaling 과정에서는 feature의 명암을 조절합니다. 이때 부정확한 running variance는 특정 channel의 강도를 크게 만들게 됩니다. 이를 해결하기 위해 scaling calibration은 scaling operation 과정 후에 instance statistic($K_s$)을 채널별 learnable weight($w_v,w_b$)를 사용하여 feature를 calibration합니다. 

$X_{cs(n,c,h,w)}=X_{s(n,c,h,w)}\cdot R(w_v\odot K_s+w_b)$

$R()$: out-of-distribution feature 억제용 restricted function($0<R()<1$)

 

Scaling calibration 후 feature의 variance는 아래와 같이 표현됩니다.

$Var(X_{cs})=Var(X\cdot R(w_v\cdot K_s+w_b))$

 

Scaling calibration은 feature의 variance를 작게 만들고 이를 통해 채널 간의 안정적인 distribution을 유지하게 합니다. $w_v$가 작을수록 $Var(X_{cs})$는 작아지며 $w_b$를 통해 restricted function의 제한 위치를 조정하게 됩니다.

 

Centering/Scaling Calibration을 통한 Feature 변화

Experiments

다른 방법들과의 비교 / scaling calibration 전후의 채널별 표준편차

제안하는 RBN은 분류에서의 다른 방식들에 비해 더 나은 성능을 보였으며 scalling calibration이 적용되었을 때 채널간의 분포를 보다 안정적이게 만들어주는 것을 확인 할수 있습니다.

Centering/Scaling Calibration 효과 시각화

Conclusion

Representation Batch Normalization(RBN)은 기존 BN의 장점인 Batch statistic 정보를 유지하며 적은 추가적인 파라미터를 통해 효과적으로 instance의 statistic을 적용합니다. centering calibration은 feature의 유용한 representation을 강화하고 noise를 감소시킵니다. Scaling calibration은 feature의 강도를 제한함으로써 채널간의 안정적인 feature 분포를 형성하게 합니다. RBN은 BN의 centering/scaling operation을 사용하는 기존의 방식들에도 쉽게 적용 가능하다는 장점 또한 가집니다.

 

이상으로 논문 리뷰를 마칩니다.

자세한 내용이나 궁금하신점은 full paper를 참고하시고나 댓글로 남겨주세요.

 

감사합니다.

Dongjin Kim