Distilling the Knowledge in a Neural Network

Geoffrey Hinton

Reference Blog

  • 목적은 미리 잘 학습된 큰 네트워크의 지식을 실제로 사용하고자 하는 작은 네트워크에게 전달하는 것입니다.

  • Ground Truth 와 작은 네트워크와의 분류 차이에 대한 크로스 엔트로피 손실함수큰 네트워크와 작은 네트워크 분류 결과차이를 손실함에 포함시킵니다.

Background

  • 오버피팅을 피하기 위해서 ANN에서 앙상블 기법을 사용하게 됩니다.

  • 앙상블은 계산시간이 많이 걸린다는 단점이 있습니다.

  • 따라서 앙상블만큼의 성능과 적은 파라미터 수를 가진 뉴럴넷 모델이 필요합니다.

Neural Net Distillation

Neural Net은 Class의 확률을 소프트맥스 Output Layer을 이용해서 예측합니다. 수식은 아래와 같은데 q는 각 클래스의 확률, z는 전 layer의 weight sum, T는 temperature라는 값으로 1로 setting이 됩니다.

T가 1이면 0, 1의 binary화된 결과값을 얻는데 이는 확률분포를 알기 어렵게 만듭니다.

T가 클수록 더 soft한 확률 분포가 클래스마다 만들어지게 됩니다. soft하다는 것은 결과값이 천천히 증가한다는 뜻입니다. 하지만 T가 너무 커지면 모든 클래스의 확률이 비슷해집니다.

Matching logits is a special case of distillation - 증명과정

크로스엔트로피에 따라 미분을 아래 식으로 정리가 됩니다. v는 soft target 확률을 생성하는 weight sum값입니다.

아래와 같이 식 변경이 가능하다고 합니다.

시그마 z와 시그마 v는 평균이 0이기 때문이 아래 식만 남습니다.

즉 v와 z값이 비슷해지는 것과 distillation은 동일합니다. z는 soft max를 하기 전의 weight sum, v는 soft target 확률을 생성하는 weight sum 값입니다. 직관적으로는 student network의 최종 출력값과 teacher network 최종 출력값이 비슷한게 최적이라고 말하는 것처럼 보입니다.

MNIST

  • Two hidden layer + Relu : 146 test errors

  • Two hidden layer + DropOut : 67 test errors

  • Two hidden layer + Relu + soft target : 74 test errors

3을 제외하고 Train 시켜도 학습 결과가 좋았습니다.

Experiments on speech recognition

JFT Dataset

빨라질 필요가 있다고 합니다.

Code

코드 설명 : teacher_model.py 에서 teacher 모델을 학습하고, knowledge_distillation.py 에서 학습 teacher 모델로부터 student 모델을 학습다. baseline.py는 knowledge distillation 없이 student 모델을 학습합니다.

핵심부분

  • teacher 모델로부터 temparature 계수를 통해 logit_T와 prob_T를 구합니다.

 teacher_logits = self.teacher_model.layers[-1].output
 teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
 teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)
  • student 모델로부터 똑같이 logit_T와 prob_T를 구합니다.

logits = Dense(num_classes, activation=None, name='dense2')(x)
output_softmax = Activation('softmax', name='output_softmax')(logits)
logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
probabilities_T = Activation('softmax', name='probabilities')(logits_T)
  • student의 크로스 엔트로피와 teacher의 prob_T와 student의 prob_T의 크로스엔트로피를 더해줍니다.

def knowledge_distillation_loss(input_distillation):
    y_pred, y_true, y_soft, y_pred_soft = input_distillation
    return (1 - args.lambda_const) * logloss(y_true, y_pred) + \
           args.lambda_const * args.temperature * args.temperature * logloss(y_soft, y_pred_soft)

Result

  • Teacher Network (acc 90%)

    • Total params: 3,844,938

    • Trainable params: 3,841,610

    • Non-trainable params: 3,328

  • Student Network(acc 85%)

    • Total params: 770,378

    • Trainable params: 769,162

    • Non-trainable params: 1,216

  • Distillation(acc 85%)

Student Network과 Distllation의 acc 결과가 비슷했다는 점이 실망스러웠습니다. 하지만 Distillation은 학습 속도는 Student 만 할때보다 느리지만 학습의 안정성이 훨씬 높은 모습을 보였습니다. 마음 먹고 teacher student distllation 결과를 비교하면 확실히 distillation의 결과가 더 높을 것 같긴 합니다.

My mini - contribution + 느낀점.

  • 기여 : 기존 코드를 TF 2.0으로 바꿨습니다.

  • 느낀점 : 실제 이 distillation은 파라미터수를 얼마나 줄일 지 모를 때 매우 유용할 것 같습니다.

Last updated