Distilling the Knowledge in a Neural Network
Geoffrey Hinton
Reference Blog
λͺ©μ μ 미리 μ νμ΅λ ν° λ€νΈμν¬μ μ§μμ μ€μ λ‘ μ¬μ©νκ³ μ νλ μμ λ€νΈμν¬μκ² μ λ¬νλ κ²μ λλ€.
Ground Truth μ μμ λ€νΈμν¬μμ λΆλ₯ μ°¨μ΄μ λν
ν¬λ‘μ€ μνΈλ‘νΌ μμ€ν¨μ
μν° λ€νΈμν¬μ μμ λ€νΈμν¬ λΆλ₯ κ²°κ³Όμ°¨μ΄
λ₯Ό μμ€ν¨μ ν¬ν¨μν΅λλ€.
Background
μ€λ²νΌν μ νΌνκΈ° μν΄μ ANNμμ μμλΈ κΈ°λ²μ μ¬μ©νκ² λ©λλ€.
μμλΈμ κ³μ°μκ°μ΄ λ§μ΄ κ±Έλ¦°λ€λ λ¨μ μ΄ μμ΅λλ€.
λ°λΌμ μμλΈλ§νΌμ μ±λ₯κ³Ό μ μ νλΌλ―Έν° μλ₯Ό κ°μ§ λ΄λ΄λ· λͺ¨λΈμ΄ νμν©λλ€.
Neural Net Distillation
Neural Netμ Classμ νλ₯ μ μννΈλ§₯μ€ Output Layerμ μ΄μ©ν΄μ μμΈ‘ν©λλ€. μμμ μλμ κ°μλ° qλ κ° ν΄λμ€μ νλ₯ , zλ μ layerμ weight sum, Tλ temperatureλΌλ κ°μΌλ‘ 1λ‘ settingμ΄ λ©λλ€.
Tκ° 1μ΄λ©΄ 0, 1μ binaryνλ κ²°κ³Όκ°μ μ»λλ° μ΄λ νλ₯ λΆν¬λ₯Ό μκΈ° μ΄λ ΅κ² λ§λλλ€.
Tκ° ν΄μλ‘ λ softν νλ₯ λΆν¬κ° ν΄λμ€λ§λ€ λ§λ€μ΄μ§κ² λ©λλ€. softνλ€λ κ²μ κ²°κ³Όκ°μ΄ μ²μ²ν μ¦κ°νλ€λ λ»μ λλ€. νμ§λ§ Tκ° λ무 컀μ§λ©΄ λͺ¨λ ν΄λμ€μ νλ₯ μ΄ λΉμ·ν΄μ§λλ€.

Matching logits is a special case of distillation - μ¦λͺ
κ³Όμ
ν¬λ‘μ€μνΈλ‘νΌμ λ°λΌ λ―ΈλΆμ μλ μμΌλ‘ μ λ¦¬κ° λ©λλ€. vλ soft target νλ₯ μ μμ±νλ weight sumκ°μ λλ€.

μλμ κ°μ΄ μ λ³κ²½μ΄ κ°λ₯νλ€κ³ ν©λλ€.

μκ·Έλ§ zμ μκ·Έλ§ vλ νκ· μ΄ 0μ΄κΈ° λλ¬Έμ΄ μλ μλ§ λ¨μ΅λλ€.

μ¦ vμ zκ°μ΄ λΉμ·ν΄μ§λ κ²κ³Ό distillationμ λμΌν©λλ€. zλ soft maxλ₯Ό νκΈ° μ μ weight sum, vλ soft target νλ₯ μ μμ±νλ weight sum κ°μ λλ€. μ§κ΄μ μΌλ‘λ student networkμ μ΅μ’ μΆλ ₯κ°κ³Ό teacher network μ΅μ’ μΆλ ₯κ°μ΄ λΉμ·νκ² μ΅μ μ΄λΌκ³ λ§νλ κ²μ²λΌ 보μ λλ€.
MNIST
Two hidden layer + Relu : 146 test errors
Two hidden layer + DropOut : 67 test errors
Two hidden layer + Relu + soft target : 74 test errors
3μ μ μΈνκ³ Train μμΌλ νμ΅ κ²°κ³Όκ° μ’μμ΅λλ€.
Experiments on speech recognition

JFT Dataset

λΉ¨λΌμ§ νμκ° μλ€κ³ ν©λλ€.

Code
μ½λ μ€λͺ : teacher_model.py μμ teacher λͺ¨λΈμ νμ΅νκ³ , knowledge_distillation.py μμ νμ΅ teacher λͺ¨λΈλ‘λΆν° student λͺ¨λΈμ νμ΅λ€. baseline.pyλ knowledge distillation μμ΄ student λͺ¨λΈμ νμ΅ν©λλ€.
ν΅μ¬λΆλΆ
teacher λͺ¨λΈλ‘λΆν° temparature κ³μλ₯Ό ν΅ν΄ logit_Tμ prob_Tλ₯Ό ꡬν©λλ€.
teacher_logits = self.teacher_model.layers[-1].output
teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)
student λͺ¨λΈλ‘λΆν° λκ°μ΄ logit_Tμ prob_Tλ₯Ό ꡬν©λλ€.
logits = Dense(num_classes, activation=None, name='dense2')(x)
output_softmax = Activation('softmax', name='output_softmax')(logits)
logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
probabilities_T = Activation('softmax', name='probabilities')(logits_T)
studentμ ν¬λ‘μ€ μνΈλ‘νΌμ teacherμ prob_Tμ studentμ prob_Tμ ν¬λ‘μ€μνΈλ‘νΌλ₯Ό λν΄μ€λλ€.
def knowledge_distillation_loss(input_distillation):
y_pred, y_true, y_soft, y_pred_soft = input_distillation
return (1 - args.lambda_const) * logloss(y_true, y_pred) + \
args.lambda_const * args.temperature * args.temperature * logloss(y_soft, y_pred_soft)
Result
Teacher Network (acc 90%)
Total params: 3,844,938
Trainable params: 3,841,610
Non-trainable params: 3,328
Student Network(acc 85%)
Total params: 770,378
Trainable params: 769,162
Non-trainable params: 1,216
Distillation(acc 85%)
Student Networkκ³Ό Distllationμ acc κ²°κ³Όκ° λΉμ·νλ€λ μ μ΄ μ€λ§μ€λ¬μ μ΅λλ€. νμ§λ§ Distillationμ νμ΅ μλλ Student λ§ ν λλ³΄λ€ λ리μ§λ§ νμ΅μ μμ μ±μ΄ ν¨μ¬ λμ λͺ¨μ΅μ 보μμ΅λλ€. λ§μ λ¨Ήκ³ teacher student distllation κ²°κ³Όλ₯Ό λΉκ΅νλ©΄ νμ€ν distillationμ κ²°κ³Όκ° λ λμ κ² κ°κΈ΄ ν©λλ€.
My mini - contribution + λλμ .
κΈ°μ¬ : κΈ°μ‘΄ μ½λλ₯Ό TF 2.0μΌλ‘ λ°κΏ¨μ΅λλ€.
λλμ : μ€μ μ΄ distillationμ νλΌλ―Έν°μλ₯Ό μΌλ§λ μ€μΌ μ§ λͺ¨λ₯Ό λ λ§€μ° μ μ©ν κ² κ°μ΅λλ€.
Last updated
Was this helpful?