본문 바로가기

논문리뷰

Test-Time Augmentation

리뷰할 논문은 Learning Loss for Test-Time Augmentation 입니다. 딥러닝을 모델링을 해오면서 학습 시에만 Augmentation을 적용해 왔는데, 테스트 할 때 Augmentation을 적용할 수도 있다는 내용을 접하게 되어 관련 논문을 리딩해보았고, 본 논문의 방법론을 완벽하게 이해하기 보다는 TTA에 대한 전반적인 이해를 하고자 했습니다.

 


Test-Time Augmentation 개념


  • Augmentation은 robust한 뉴럴 네트워크를 만들기 위해서 연구 되어옴
  • 대부분의 Augmentation은 학습단계에 적용하는 것
  • 그러나 TTA는 테스트 단계에서 Augmentation을 적용하는 것
  • TTA를 적용하면서 모델의 불확실성을 줄일 수 있음

https://towardsdatascience.com/test-time-augmentation-tta-and-how-to-perform-it-with-keras-4ac19b67fb4d

 

Test Time Augmentation (TTA) and how to perform it with Keras

Data Augmentation

towardsdatascience.com

 

 


Background & Proposal


  • 학습 시 Augmentation을 적용하여 robust 한 네트워크를 만들었음에도, 테스트 시 input image의 약간의 transformation 에도 성능이 저하되는 것을 관찰
  • 결국 테스트 단계에서의 성능개선의 여지가 있음
  • TTA 적용 시 기존에는 간단한 geometric transformation만 적용되어 왔음(horizontal/vertical flip, rotation)
  • 그러나 그것조차 성능을 검증하기 위해서는 데이터 수를 상당히 늘려야 하기 때문에 추론 시간에 대한 비용이 증가
  • 그래서 instance-aware test-time augmentation algorithm을 제안함
  • pre-trained 네트워크를 이용해서, input image에 따라서 동적으로 최적의 transformation 선택하는 것을 목표로 함
  • 이 방식은 효율적임. 왜냐하면 Loss 예측 모델은 컴팩트하고, 앙상블을 위한 Augmentation 수를 줄일 수가 있음

 


Concept


 

Figure 1.

 

  • (a)는 기존의 TTA 방식, (b) 제안된 방식
  • 기존의 방식은 transformation이 정해져 있다. 그러나 제안된 방식은 loss를 예측하고, loss가 낮은 transformation을 선택하기 때문에 input image에 따라서 서로 다른 transformation이 적용됨. 몇 개의 transformation을 선택할 지는 선택 가능하다.

 


 

Method


 

Figure 2.

  • Figure 2. 는 loss predictor의 학습과정을 나타냄
  • upper part는 relative loss 값을 얻는 과정이고, lower part는 loss 예측 네트워크임
  • input 이미지는 가능한 모든 transformation에 대해서 real loss 값을 얻기 위해서 target 네트워크로 평가되고 loss 값을 얻음
  • real loss 값을 softmax를 통해서 normalize 함
  • loss predictor는 relative loss값을 예측하기 위해서 multi-level feature를 집계한다.
  • loss predictor는 ranking loss로 학습 됨

 

Test-Time Augmenation Space

  • transformation 후보를 다양화하는 것은 학습에 대한 비용을 증가시킴. 왜냐하면 각 transformation에 대해서 ground truth loss 값을 계산해야 하기 때문
  • 반대로 transformation 후보를 제한하면, augmentation 효과가 줄어듬
  • 그래서 적당한 다양성을 제시함
  • geometric transforations(rotation, zooming), tow color transformation(color, contrast), image filter(sharpness)

 

Loss predictor Module: Architecture

  • predictor와 target 네트워크를 구분함. target 네트워크의 loss 값을 최종 ouput의 보조적인 수단으로써 활용할 수 있지만, 이 방식은 추가적인 inference 을 발생시키기 때문에 효율성이 떨어진다고 함. 그래서 target 네트워크를 작은 네트워크로 선택
  • predictorh로 예측한 output vector는 transformation 의 개수만큼의 size를 가짐
  • 논문작성 시점의 SOTA 모델인 Efficientnet-B0를 사용함
  • multi-level feature를 활용할 수 있도록 모델을 수정함. 왜냐하면 네트워크가 high-level의 feature만 학습하는 것이 아니라 low-level의 feature도 학습하길 원함

 

Loss Predictor Module: Training Method

  • target 네트워크는 train data 로 학습되고 validation data 로 평가 됨. (target 네트워크의 weight는 train data로 형성됨)
  • 그리고 target 네트워크를 freeze 하고, train data를 2가지로 나눔 (train-loss data 와 valid-loss data)
  • predictor(prediction module) 를 학습하기 위해서 각각 train, valid 로 활용됨
  • train-loss data는 loss prediction module에 사용 됨, valid-loss data는 평가용으로 사용됨
  • 이러한 방식이 근데 target 네트워크의 학습 데이터를 줄 일 수 있다고 하는데 이해가 잘 가지 않음
  • loss prediction module의 output을 생성하기 위해, 여러 transformation에 대해서 target 네트워크의 추론으로 얻은 ground truth loss 값을 모아 softmax 함수로 normalize 하여 relative loss 값을 얻는다. predicted loss 값도 마찬가지
  • 그리고 relative loss 간의 spaerman correlation을 직접 최적화 함[10]
  • 구체적으로는 relative loss 와 predicted loss 사이의 [공식3] 을 활용하여 상관관계를 근사화하는 recurrent 네트워크를 학습함
  • ranking loss와 relative loss를 활용하여 학습하는 것이 정확한 loss를 활용하는 것보다 학습이 더욱 안정적인 것을 관측함

 

'논문리뷰' 카테고리의 다른 글

Feature Pyramid Networks for Object Detection 리뷰  (0) 2023.01.08
Mask R-CNN 논문 리뷰 (1)  (1) 2023.01.01