Distilling the Knowledge in a Neural Network

Geoffrey Hinton

Reference Blog

  • λͺ©μ μ€ 미리 잘 ν•™μŠ΅λœ 큰 λ„€νŠΈμ›Œν¬μ˜ 지식을 μ‹€μ œλ‘œ μ‚¬μš©ν•˜κ³ μž ν•˜λŠ” μž‘μ€ λ„€νŠΈμ›Œν¬μ—κ²Œ μ „λ‹¬ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€.

  • Ground Truth 와 μž‘μ€ λ„€νŠΈμ›Œν¬μ™€μ˜ λΆ„λ₯˜ 차이에 λŒ€ν•œ 크둜슀 μ—”νŠΈλ‘œν”Ό μ†μ‹€ν•¨μˆ˜μ— 큰 λ„€νŠΈμ›Œν¬μ™€ μž‘μ€ λ„€νŠΈμ›Œν¬ λΆ„λ₯˜ 결과차이λ₯Ό 손싀함에 ν¬ν•¨μ‹œν‚΅λ‹ˆλ‹€.

Background

  • μ˜€λ²„ν”ΌνŒ…μ„ ν”Όν•˜κΈ° μœ„ν•΄μ„œ ANNμ—μ„œ 앙상블 기법을 μ‚¬μš©ν•˜κ²Œ λ©λ‹ˆλ‹€.

  • 앙상블은 κ³„μ‚°μ‹œκ°„μ΄ 많이 κ±Έλ¦°λ‹€λŠ” 단점이 μžˆμŠ΅λ‹ˆλ‹€.

  • λ”°λΌμ„œ μ•™μƒλΈ”λ§ŒνΌμ˜ μ„±λŠ₯κ³Ό 적은 νŒŒλΌλ―Έν„° 수λ₯Ό κ°€μ§„ λ‰΄λŸ΄λ„· λͺ¨λΈμ΄ ν•„μš”ν•©λ‹ˆλ‹€.

Neural Net Distillation

Neural Net은 Class의 ν™•λ₯ μ„ μ†Œν”„νŠΈλ§₯슀 Output Layer을 μ΄μš©ν•΄μ„œ μ˜ˆμΈ‘ν•©λ‹ˆλ‹€. μˆ˜μ‹μ€ μ•„λž˜μ™€ 같은데 qλŠ” 각 클래슀의 ν™•λ₯ , zλŠ” μ „ layer의 weight sum, TλŠ” temperatureλΌλŠ” κ°’μœΌλ‘œ 1둜 setting이 λ©λ‹ˆλ‹€.

Tκ°€ 1이면 0, 1의 binaryν™”λœ 결과값을 μ–»λŠ”λ° μ΄λŠ” ν™•λ₯ λΆ„포λ₯Ό μ•ŒκΈ° μ–΄λ ΅κ²Œ λ§Œλ“­λ‹ˆλ‹€.

Tκ°€ 클수둝 더 softν•œ ν™•λ₯  뢄포가 ν΄λž˜μŠ€λ§ˆλ‹€ λ§Œλ“€μ–΄μ§€κ²Œ λ©λ‹ˆλ‹€. softν•˜λ‹€λŠ” 것은 결과값이 천천히 μ¦κ°€ν•œλ‹€λŠ” λœ»μž…λ‹ˆλ‹€. ν•˜μ§€λ§Œ Tκ°€ λ„ˆλ¬΄ 컀지면 λͺ¨λ“  클래슀의 ν™•λ₯ μ΄ λΉ„μŠ·ν•΄μ§‘λ‹ˆλ‹€.

Matching logits is a special case of distillation - 증λͺ…κ³Όμ •

ν¬λ‘œμŠ€μ—”νŠΈλ‘œν”Όμ— 따라 미뢄을 μ•„λž˜ μ‹μœΌλ‘œ 정리가 λ©λ‹ˆλ‹€. vλŠ” soft target ν™•λ₯ μ„ μƒμ„±ν•˜λŠ” weight sumκ°’μž…λ‹ˆλ‹€.

μ•„λž˜μ™€ 같이 식 변경이 κ°€λŠ₯ν•˜λ‹€κ³  ν•©λ‹ˆλ‹€.

μ‹œκ·Έλ§ˆ z와 μ‹œκ·Έλ§ˆ vλŠ” 평균이 0이기 λ•Œλ¬Έμ΄ μ•„λž˜ μ‹λ§Œ λ‚¨μŠ΅λ‹ˆλ‹€.

즉 v와 z값이 λΉ„μŠ·ν•΄μ§€λŠ” 것과 distillation은 λ™μΌν•©λ‹ˆλ‹€. zλŠ” soft maxλ₯Ό ν•˜κΈ° μ „μ˜ weight sum, vλŠ” soft target ν™•λ₯ μ„ μƒμ„±ν•˜λŠ” weight sum κ°’μž…λ‹ˆλ‹€. μ§κ΄€μ μœΌλ‘œλŠ” student network의 μ΅œμ’… 좜λ ₯κ°’κ³Ό teacher network μ΅œμ’… 좜λ ₯값이 λΉ„μŠ·ν•œκ²Œ 졜적이라고 λ§ν•˜λŠ” κ²ƒμ²˜λŸΌ λ³΄μž…λ‹ˆλ‹€.

MNIST

  • Two hidden layer + Relu : 146 test errors

  • Two hidden layer + DropOut : 67 test errors

  • Two hidden layer + Relu + soft target : 74 test errors

3을 μ œμ™Έν•˜κ³  Train μ‹œμΌœλ„ ν•™μŠ΅ κ²°κ³Όκ°€ μ’‹μ•˜μŠ΅λ‹ˆλ‹€.

Experiments on speech recognition

JFT Dataset

빨라질 ν•„μš”κ°€ μžˆλ‹€κ³  ν•©λ‹ˆλ‹€.

Code

μ½”λ“œ μ„€λͺ… : teacher_model.py μ—μ„œ teacher λͺ¨λΈμ„ ν•™μŠ΅ν•˜κ³ , knowledge_distillation.py μ—μ„œ ν•™μŠ΅ teacher λͺ¨λΈλ‘œλΆ€ν„° student λͺ¨λΈμ„ ν•™μŠ΅λ‹€. baseline.pyλŠ” knowledge distillation 없이 student λͺ¨λΈμ„ ν•™μŠ΅ν•©λ‹ˆλ‹€.

핡심뢀뢄

  • teacher λͺ¨λΈλ‘œλΆ€ν„° temparature κ³„μˆ˜λ₯Ό 톡해 logit_T와 prob_Tλ₯Ό κ΅¬ν•©λ‹ˆλ‹€.

 teacher_logits = self.teacher_model.layers[-1].output
 teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
 teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)
  • student λͺ¨λΈλ‘œλΆ€ν„° λ˜‘κ°™μ΄ logit_T와 prob_Tλ₯Ό κ΅¬ν•©λ‹ˆλ‹€.

logits = Dense(num_classes, activation=None, name='dense2')(x)
output_softmax = Activation('softmax', name='output_softmax')(logits)
logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
probabilities_T = Activation('softmax', name='probabilities')(logits_T)
  • student의 크둜슀 μ—”νŠΈλ‘œν”Όμ™€ teacher의 prob_T와 student의 prob_T의 ν¬λ‘œμŠ€μ—”νŠΈλ‘œν”Όλ₯Ό λ”ν•΄μ€λ‹ˆλ‹€.

def knowledge_distillation_loss(input_distillation):
    y_pred, y_true, y_soft, y_pred_soft = input_distillation
    return (1 - args.lambda_const) * logloss(y_true, y_pred) + \
           args.lambda_const * args.temperature * args.temperature * logloss(y_soft, y_pred_soft)

Result

  • Teacher Network (acc 90%)

    • Total params: 3,844,938

    • Trainable params: 3,841,610

    • Non-trainable params: 3,328

  • Student Network(acc 85%)

    • Total params: 770,378

    • Trainable params: 769,162

    • Non-trainable params: 1,216

  • Distillation(acc 85%)

Student Networkκ³Ό Distllation의 acc κ²°κ³Όκ°€ λΉ„μŠ·ν–ˆλ‹€λŠ” 점이 μ‹€λ§μŠ€λŸ¬μ› μŠ΅λ‹ˆλ‹€. ν•˜μ§€λ§Œ Distillation은 ν•™μŠ΅ μ†λ„λŠ” Student 만 ν• λ•Œλ³΄λ‹€ λŠλ¦¬μ§€λ§Œ ν•™μŠ΅μ˜ μ•ˆμ •μ„±μ΄ 훨씬 높은 λͺ¨μŠ΅μ„ λ³΄μ˜€μŠ΅λ‹ˆλ‹€. 마음 λ¨Ήκ³  teacher student distllation κ²°κ³Όλ₯Ό λΉ„κ΅ν•˜λ©΄ ν™•μ‹€νžˆ distillation의 κ²°κ³Όκ°€ 더 높을 것 κ°™κΈ΄ ν•©λ‹ˆλ‹€.

My mini - contribution + λŠλ‚€μ .

  • κΈ°μ—¬ : κΈ°μ‘΄ μ½”λ“œλ₯Ό TF 2.0으둜 λ°”κΏ¨μŠ΅λ‹ˆλ‹€.

  • λŠλ‚€μ  : μ‹€μ œ 이 distillation은 νŒŒλΌλ―Έν„°μˆ˜λ₯Ό μ–Όλ§ˆλ‚˜ 쀄일 μ§€ λͺ¨λ₯Ό λ•Œ 맀우 μœ μš©ν•  것 κ°™μŠ΅λ‹ˆλ‹€.

Last updated

Was this helpful?