리뷰할 논문은 Learning Loss for Test-Time Augmentation 입니다. 딥러닝을 모델링을 해오면서 학습 시에만 Augmentation을 적용해 왔는데, 테스트 할 때 Augmentation을 적용할 수도 있다는 내용을 접하게 되어 관련 논문을 리딩해보았고, 본 논문의 방법론을 완벽하게 이해하기 보다는 TTA에 대한 전반적인 이해를 하고자 했습니다.
Test-Time Augmentation 개념
- Augmentation은 robust한 뉴럴 네트워크를 만들기 위해서 연구 되어옴
- 대부분의 Augmentation은 학습단계에 적용하는 것
- 그러나 TTA는 테스트 단계에서 Augmentation을 적용하는 것
- TTA를 적용하면서 모델의 불확실성을 줄일 수 있음
Test Time Augmentation (TTA) and how to perform it with Keras
Data Augmentation
towardsdatascience.com
Background & Proposal
- 학습 시 Augmentation을 적용하여 robust 한 네트워크를 만들었음에도, 테스트 시 input image의 약간의 transformation 에도 성능이 저하되는 것을 관찰
- 결국 테스트 단계에서의 성능개선의 여지가 있음
- TTA 적용 시 기존에는 간단한 geometric transformation만 적용되어 왔음(horizontal/vertical flip, rotation)
- 그러나 그것조차 성능을 검증하기 위해서는 데이터 수를 상당히 늘려야 하기 때문에 추론 시간에 대한 비용이 증가
- 그래서 instance-aware test-time augmentation algorithm을 제안함
- pre-trained 네트워크를 이용해서, input image에 따라서 동적으로 최적의 transformation 선택하는 것을 목표로 함
- 이 방식은 효율적임. 왜냐하면 Loss 예측 모델은 컴팩트하고, 앙상블을 위한 Augmentation 수를 줄일 수가 있음
Concept
- (a)는 기존의 TTA 방식, (b) 제안된 방식
- 기존의 방식은 transformation이 정해져 있다. 그러나 제안된 방식은 loss를 예측하고, loss가 낮은 transformation을 선택하기 때문에 input image에 따라서 서로 다른 transformation이 적용됨. 몇 개의 transformation을 선택할 지는 선택 가능하다.
Method
- Figure 2. 는 loss predictor의 학습과정을 나타냄
- upper part는 relative loss 값을 얻는 과정이고, lower part는 loss 예측 네트워크임
- input 이미지는 가능한 모든 transformation에 대해서 real loss 값을 얻기 위해서 target 네트워크로 평가되고 loss 값을 얻음
- real loss 값을 softmax를 통해서 normalize 함
- loss predictor는 relative loss값을 예측하기 위해서 multi-level feature를 집계한다.
- loss predictor는 ranking loss로 학습 됨
Test-Time Augmenation Space
- transformation 후보를 다양화하는 것은 학습에 대한 비용을 증가시킴. 왜냐하면 각 transformation에 대해서 ground truth loss 값을 계산해야 하기 때문
- 반대로 transformation 후보를 제한하면, augmentation 효과가 줄어듬
- 그래서 적당한 다양성을 제시함
- geometric transforations(rotation, zooming), tow color transformation(color, contrast), image filter(sharpness)
Loss predictor Module: Architecture
- predictor와 target 네트워크를 구분함. target 네트워크의 loss 값을 최종 ouput의 보조적인 수단으로써 활용할 수 있지만, 이 방식은 추가적인 inference 을 발생시키기 때문에 효율성이 떨어진다고 함. 그래서 target 네트워크를 작은 네트워크로 선택
- predictorh로 예측한 output vector는 transformation 의 개수만큼의 size를 가짐
- 논문작성 시점의 SOTA 모델인 Efficientnet-B0를 사용함
- multi-level feature를 활용할 수 있도록 모델을 수정함. 왜냐하면 네트워크가 high-level의 feature만 학습하는 것이 아니라 low-level의 feature도 학습하길 원함
Loss Predictor Module: Training Method
- target 네트워크는 train data 로 학습되고 validation data 로 평가 됨. (target 네트워크의 weight는 train data로 형성됨)
- 그리고 target 네트워크를 freeze 하고, train data를 2가지로 나눔 (train-loss data 와 valid-loss data)
- predictor(prediction module) 를 학습하기 위해서 각각 train, valid 로 활용됨
- train-loss data는 loss prediction module에 사용 됨, valid-loss data는 평가용으로 사용됨
- 이러한 방식이 근데 target 네트워크의 학습 데이터를 줄 일 수 있다고 하는데 이해가 잘 가지 않음
- loss prediction module의 output을 생성하기 위해, 여러 transformation에 대해서 target 네트워크의 추론으로 얻은 ground truth loss 값을 모아 softmax 함수로 normalize 하여 relative loss 값을 얻는다. predicted loss 값도 마찬가지
- 그리고 relative loss 간의 spaerman correlation을 직접 최적화 함[10]
- 구체적으로는 relative loss 와 predicted loss 사이의 [공식3] 을 활용하여 상관관계를 근사화하는 recurrent 네트워크를 학습함
- ranking loss와 relative loss를 활용하여 학습하는 것이 정확한 loss를 활용하는 것보다 학습이 더욱 안정적인 것을 관측함
'논문리뷰' 카테고리의 다른 글
Feature Pyramid Networks for Object Detection 리뷰 (0) | 2023.01.08 |
---|---|
Mask R-CNN 논문 리뷰 (1) (1) | 2023.01.01 |