Distilling the Knowledge in a Neural Network
이 포스팅은 이 블로그에서 정리한 논문을 내가 읽으면서 정리한 내용이다.
01. Intro
딥러닝 모델은 개발 단계와 적용 단계에서 다음과 같은 차이가 있다.
- 개발 단계 : 최대한 학습 모델의 정확도와 일반화 성능을 높여야 한다. 이를 위해 대규모 데이터 및 리소스의 사용이 비교적 자유롭다.
- 적용 단계 : 모델의 정확도와 일반화 성능을 유지하면서, 실시간 데이터 처리와 제한적인 리소스 사용을 고려해야 한다.
이와 관련해서 구글링해볼 수 있는 키워드는 ‘Knowledge Distillation’이며, Geoffrey Hinton 교수님이 2015년에 발표한 논문이 자주 인용된다.
02. How to Distilling the Knowledge
검색되는 블로그들을 보다 보면 ‘teacher model’과 ‘student model’이라는 용어를 만난다. 논문에서는 ,
- ‘teacher model’은 개발 단계에서
(데이터로부터 배워서) 잘학습된 모델을 의미한다. - ‘student model’은 적용 단계에서
(선생님한테 배워서) 잘학습된 모델을 의미한다.
‘student model’의 학습은 ‘teacher model’로 부터 배운다는게 요지이다. 그러면 어떻게 선생님으로부터 학생을 가르칠 것인가? 논문에서는 Model Compression과 Specialist Networks 방법을 제안한다.
02-1. Model Compression
선생님이 배운 지식을 증류해서 학생에게 전달하기 위해서 사용되는 것이 Softmax function 이다.
softmax function은 ‘모델이 각각의 클래스에 대해 어떻게 예측하고 있는지’ 에 대한 정보를 가지고 있다. 보통 처음 CNN 분류 모델을 배우고 코딩하면 이 softmax function의 결과를 원-핫 인코딩처럼 최종 변환해서 클래스를 예측한다. 그래서 제일 “높은 값” == “모델이 예측한 값” 이라고 생각하는데, 이는 정확히 말하면 hard target 또는 hard prediction이다. 실제로 모델은 softmax function을 통해 그렇게 ‘딱딱한’ 결과를 내놓지 않고, [0.00001, 0.8999, 0.002] 이런 soft target 또는 soft prediction을 뱉어낸다. 여기서 이 ‘soft target’을 학생에게 지식을 전달하는데 효과적이라고 주장하는게 논문의 내용이다. soft target의 분포를 학생에게 학습시킨다는 것은 선생님이 학습한 지식을 그냥 말 그대로 하드하게 가르치는게 아니라 충분히 일반화해서 전달할 수 있다는 것이다.
이쯤 왔으면 힌튼교수가 loss function에서 말하는 T(temperatrue)를 이해할 수 있다. 이 T 파라미터를 softmax fuction에 다음과 같이 적용했다. (T=1이면 우리가 아는 그 softmax function이다.)
\[p_i=\frac{exp(\frac{z_i}{T})} {∑_jexp(\frac{z_j}{T})}\]T가 커지면 exponential의 값이 작아질 것이다. 따라서 큰 값(0.899 같은)이 더 커지는 현상이 줄어들고 작은 값은[0.00001] 기본 softmax function의 결과보다 커지게 되서 전체적인 결과의 분포가 좀 더 부드러워진다.
02-2. Specialist Networks
학생한테 가르치기 전에 선생님은 미리 배워야 한다. 근데 이 배움의 단계에서 아무리 병렬학습, 리소스 플랙스를 해도 많은 시간이 걸리는건 팩트… 논문에서는 specialist 모델을 만들어서 학습의 효율성을 높일 수 있다고 주장한다. 이 specialist 모델은 ‘혼동하기 쉬운 특별한 부분 집합’에 대해서만 학습한 모델이다. 즉, ‘버섯’, ‘스포츠카’ 같은 데이터만 보아서 학습한 모델이다. 근데 이렇게 학습하면 완전 오버피팅된 모델이지 않나? 그래서 학습할 때 일반적인 데이터셋과 ‘특정 클래스의 데이터셋’을 반반 섞어서 학습해야 한다. 학습한 후에는 ‘특정 클래스의 데이터셋’의 비율에 맞춰서 결과 값은 스케일링해준다. (근데 이 부분은 잘 이해가 되지 않는다.. 흠..)
03. 네트워크 구조와 손실함수 이해하기
먼저 네트워크 학습을 위한 loss function을 보자 .
\[(1-α)L_{CE}(σ(Z_s), \hat{y})\]왼쪽 항은 아래 그림의 student loss에 해당된다. 학생의 분류 성능을 개선하기 위한 손실함수 부분으로, 실제 값과 학생이 추정한 값의 결과를 Cross Entropy 손실함수로 계산한다. Hard label 사용
\[2αT^ 2L_{CE}(σ(\frac{Z_t}{T}), σ(\frac{Z_s}{T}), )\]오른쪽 항은 아래 그림의 distillation loss에 해당된다. 선생님과 학생의 분류 결과 차이를 손실함수에 포함시키겠다는 것이다. 이때 앞서 말했던 T파라미터를 사용해서 두 모델이 추정한 결과를 soften된 분포로 변환해서 Cross Entropy 손실함수를 계산한다. Soft label 사용
손실함수를 봤으니, 앞에서 구구절절 얘기했던 ‘지식 전달’을 다시 생각해보자. distillation loss를 보면 선생님과 학생 모델의 분류 결과를 비교하기 위해 soft label을 사용한다. 즉, Hard label일 경우에는 정보손실이 발생할 수 있지만, soft label을 사용함으로써 ‘지식전달’이 가능해진다는 것이 요지인 것 같다.
전체 구조는 요 블로그 가 잘 그려났다.
04. Knowledge Distillation 코드 구현 예시
상황 01 : 선생님한테서 학생이 배우는 상황
상황 02 : 클래스 3 을 학생한테 일부러 보여주지 않는 상황 (
알아서 배워 컨셉)