-
Pseudo Labeling, TTA(Test Time Augmentation) 기법CV 2020. 12. 13. 21:17
안녕하세요. 요즘 이미지 관련 대회를 참여하면서 굉장히 많은 난관을 겪고 있습니다...
A를 공부하면 A'을 알아야하고, A'을 배우기 위해서는 A''을 알아야하고 무한의 굴레인 것 같습니다. 하지만 이와 같은 과정을 계속 반복하다보면 언젠가는 제 것이 될 것이라고 생각합니다 (화이팅!)
이번 글은 image, Tabular Data, voice, text 등의 다양한 도메인 데이터에서 활용할 수 있는 예측 성능을 높이기 위한 방법론인 Pseudo labeling, TTA 기법에 대해서 다뤄보겠습니다.
해당 방법은 보통은 성능을 쥐어짜내는? 부분에서 굉장한 효과를 나타내기도 하지만 특정 대회에서는 Data leakage로 인해서 사용을 아예 못할 수도 있습니다. 즉, 해당 방법론은 Test dataset을 활용하기 때문에 Data leakage문제가 발생하게 되는건데 명확하게 쓰면 안된다라는 정의가 좀 다를수도 있어 애매하기도 합니다.
1. Pseudo labeling
저희가 알고 있는 학습방법은 크게 아래의 세가지가 있습니다
- 지도학습(Supervised learning): 정답레이블을 가지고 진행하는 학습
- 비지도학습(Unsupervised learning): 정답레이블을 없는 데이터로 진행하는 학습
- 준지도학습(semi-supervised learning): 정답레이블을 알고 있는 데이터로(크기가 작음) 1차 학습 진행 후 정답레이블이 없는 즉, 여기서는 테스트 데이터로 2차 학습을 진행하는 학습
이 세가지 학습 방법중에 저희는 준지도학습인 Pseudo labeling를 알아보겠습니다.
먼저 큰 개념으로는 기존에 정답라벨이 있는 데이터에 정답라벨이 없는 데이터를 학습시켜 도출된 결과를 접목시켜서 기존 가지고 있는 정답 Label 데이터에 기반하여 확률적인(대략적인) 정답 라벨을 부여하는 기법입니다.
순서는 다음과 같은데요
- 정답라벨이 있는 데이터로 모델 학습 진행
- 1번에서 학습한 모델을 확용하여 라벨이 없는 데이터(Test data)를 예측하고 그 결과를 라벨로 사용하는 Pseudo Labeled data 생성
- 2번에서 만든 Pseudo-labeled data와 기존에 정답라벨이 있는 데이터를 모두 사용하여 다시 모델을 학습
여기까지 확인했을 때 굉장히 무언가 비슷한게 있는데 Stacking과 방법론적으로 유사한 것을 확인할 수 있습니다. 하지만 Stacking은 특징적으로 예측한 값을 변수로 그대로 이용하는 반면에 Pseudo labeling은 정답라벨이 없는 데이터로 예측하여 도출한 결과와 정답라벨이 없는 데이터를 가지고 새롭게 Pseudo labeling dataset을 만들어서 기존의 모델에 결합하여 이용하는 부분에서 다른 것을 확인할 수 있습니다.
위에서 글로 설명한 Pseudo labeling의 전체적인 flow를 확인해보겠습니다.
해당 과정을 거치면 성능이 무조건 좋아질 것만 같지만 만약, 정답이 없는 라벨로 예측한 값의 예측확률이 떨어지게 된다면 오히려 정상적인 라벨 데이터의 학습을 저해시켜 성능을 저조하게 도추할 수도 있습니다. 그렇기 때문에 어느정도 예측확률이 뒷받침 되었을 때 (Cut-off 설정을 잘해야 함) 즉, 정답라벨이 있는 데이터로 성능을 꽤 확보했을 때의 정답라벨이 없는 데이터의 fine-tuning할 때 사용해야합니다.
여기서 Pseudo labeling을 활용하여 학습을 진행할 때 Loss함수를 이용하는데 수식적인 부분은 생략하겠습니다. Loss함수를 사용함으로써 국소최적화에 빠지는 문제를 피할 수 있게 만들며, 이 과정을 최적화 시킬 수 있습니다.
마지막으로 Pseudo labeling가 어떻게 성능을 높일 수 있는지에 대해서 간략하게 알아보고 마무리하겠습니다.
첫번째로 Low-Density Separation between Classes입니다.
모델이 특정 클래스들을 분류할 때 결정경계(Decision Boundary)를 그리게 되는데 그 결정경계 중심으로 데이터들이 Low-Density를 띄고 있으면 모델이 변수들간의 미세한 차이점을 구별하여 클래스르 예측할 수 있는 성능을 높일 수 있게됩니다. 공부중이기는한데, 왜 Pseudo labeling이 해당 특성을 가지고 있는지는 명확하게 파악하지 못했습니다 ㅠㅠ
두번째로 Entropy Regularization입니다.
해당 부분은 정답라벨이 없는 데이터가 가지는 우리가 예측했던 클래스별 확률에 대해서 entropy를 최소화시킴으로써(Loss) 바로 위에서 언급한 결정경계 중심의 밀도가 낮도록 만드는 것입니다. 즉, 테스트 데이터(정답라벨이 없는 데이터)를 예측하여 Pseudo labeled data를 구축했을 때 실제로 예측해야 하는 것은 Labeled Data + pseudo labeled data이므로 결정경계 중심으로 데이터 밀도가 상대적으로 작아지게 되는 것입니다.
위의 두가지 특성을 flow상으로 이해해보면 entropy를 최소화 시키면서 클래스 결정경계 밀도를 낮추는 작업을 통해서 성능을 올릴 수 있는 것입니다.
이를 시각자료로 간단하게 확인해보면 다음과 같습니다
2. TTA(Test Time Augmentation)
TTA는 Data Augmentation 기법중 하나로써 부족한 데이터셋을 보완하고, 성능을 끌어올릴 수 있는 방법론입니다.
먼저 Augmentation을 살펴보면 CV에서 image detection, image classification을 수행할 때 풀어야하는 문제에 대비해서 데이터가 현저히 적을 경우가 많습니다. 또한 데이터가 특정한 모습을 가진 즉, 편향된 이미지만이 존재할 수도 있는데(예를 들어서 이미지의 조도가 어두운 것 밖에 없거나 반대로 밝은 것 밖에 없는 등) 그래서 Augmentation은 기존의 이러한 문제점들을 해결하고자 다양한 데이터 변환 및 생성 과정을 거치는 것을 말합니다. pytorch에서는 transform함수를 만들어서 Augmentation을 관리하며, 요즘은 주로 albumentations를 활용하여 data augmentation을 진행합니다.(다양한 augmentation기법도 위의 링크에서 참고해주세요)
다시 본론으로 돌아와서 위처럼 Train데이터 이미지에 다양한 Aug기법들을 적용하여 data create를 하는 것과 달리 테스트 이미지 데이터에 Augmentation을 적용하는 것이 Test Time Augmentation입니다.
아래의 그림을 보시겠습니다.
위의 그림을 단계별로 설명하자면
- 기존의 Test데이터가 가장 왼쪽에 있는 한장의 이미지였다면 해당 이미지에 대해서 Augmentation을 진행
- 1번을 진행한 뒤 각각의 이미지에 대해서 예측
- 각 클래스에 대한 예측 확률이 도출되면 그 결과값에 대해서 앙상블을 진행
- 최종 결과 리턴
예를들어서 위의 그림에서 Aug포함 7장의 사진을 대상으로 예측을 진행했을 때 예측확률을 리스트해보면(지금은 bird라는 클래스를 안다고 가정하겠습니다.)
- 1번: 0.68
- 2번: 0.33
- 3번: 0.51
- 4번: 0.44
- 5번: 0.73
- 6번: 0.39
- 7번: 0.78
위의 결과들을 해석해보면 1번, 3번, 5번, 7번을 제외하고 나머지 이미지에 대해서는 cut-off를 0.5로 두었을 때 새가 아니라는 예측결과가 나왔습니다. 만약에 단일 이미지로 결과값을 예측했다고 가정했을 때 단순하게 새로 예측하는 확률은 4/7(0.57)이 됩니다. 하지만 TTA를 통해서 이미지 1장으로 예측했을 때 발생하는 예측 오차율을 줄일 수 있습니다. 즉, 결과적으로 산술평균으로 일정한 가중치를 두어 앙상블을 진행했을 때 0.55로 새로 예측하는 것을 알 수 있습니다.
해당 부분을 공부했을 때 중심극한정리가 생각났습니다. 예측하려는 데이터 즉, 샘플이 커지면 커질수록 데이터의 분포는(예측 결과) 정규분포 곡선을 따른다는 것인데, 예측하려는 이미지 자체에 대해서 Augmentation를 거쳐 데이터 증강을 했을 때 생성되는 이미지가 N으로 갈수록 예측 결과값은 정답에 수렴할 수 있을 것입니다.(모델을 잘 만든다는 가정하에.. ㅎㅎ)
따라서 TTA는 불확실성이 높은 이미지 데이터에 대해서 오차율을 최소화하여 예측 성능을 개선하는데 효과적입니다. 하지만 TTA를 쓰면 좋지만, 이미 정답을 확실하게 분류할 수 있는 문제이거나, Augmentation을 진행했을 때 성능이 떨어지는 문제, 해결려는 문제의 도메인적으로 Augmentation을 하지 못하는 등의 여러 케이스가 존재하기 때문에 무조건적으로 정답은 아닙니다. 그렇기 때문에 적절한 도메인의 적절한 문제에서 도입하여 적용하는 것이 중요합니다.