Distilling the Knowledge in a Neural Network
Geoffrey Hinton
Reference Blog
๋ชฉ์ ์ ๋ฏธ๋ฆฌ ์ ํ์ต๋ ํฐ ๋คํธ์ํฌ์ ์ง์์ ์ค์ ๋ก ์ฌ์ฉํ๊ณ ์ ํ๋ ์์ ๋คํธ์ํฌ์๊ฒ ์ ๋ฌํ๋ ๊ฒ์ ๋๋ค.
Ground Truth ์ ์์ ๋คํธ์ํฌ์์ ๋ถ๋ฅ ์ฐจ์ด์ ๋ํ
ํฌ๋ก์ค ์ํธ๋กํผ ์์คํจ์
์ํฐ ๋คํธ์ํฌ์ ์์ ๋คํธ์ํฌ ๋ถ๋ฅ ๊ฒฐ๊ณผ์ฐจ์ด
๋ฅผ ์์คํจ์ ํฌํจ์ํต๋๋ค.
Background
์ค๋ฒํผํ ์ ํผํ๊ธฐ ์ํด์ ANN์์ ์์๋ธ ๊ธฐ๋ฒ์ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค.
์์๋ธ์ ๊ณ์ฐ์๊ฐ์ด ๋ง์ด ๊ฑธ๋ฆฐ๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
๋ฐ๋ผ์ ์์๋ธ๋งํผ์ ์ฑ๋ฅ๊ณผ ์ ์ ํ๋ผ๋ฏธํฐ ์๋ฅผ ๊ฐ์ง ๋ด๋ด๋ท ๋ชจ๋ธ์ด ํ์ํฉ๋๋ค.
Neural Net Distillation
Neural Net์ Class์ ํ๋ฅ ์ ์ํํธ๋งฅ์ค Output Layer์ ์ด์ฉํด์ ์์ธกํฉ๋๋ค. ์์์ ์๋์ ๊ฐ์๋ฐ q๋ ๊ฐ ํด๋์ค์ ํ๋ฅ , z๋ ์ layer์ weight sum, T๋ temperature๋ผ๋ ๊ฐ์ผ๋ก 1๋ก setting์ด ๋ฉ๋๋ค.
T๊ฐ 1์ด๋ฉด 0, 1์ binaryํ๋ ๊ฒฐ๊ณผ๊ฐ์ ์ป๋๋ฐ ์ด๋ ํ๋ฅ ๋ถํฌ๋ฅผ ์๊ธฐ ์ด๋ ต๊ฒ ๋ง๋ญ๋๋ค.
T๊ฐ ํด์๋ก ๋ softํ ํ๋ฅ ๋ถํฌ๊ฐ ํด๋์ค๋ง๋ค ๋ง๋ค์ด์ง๊ฒ ๋ฉ๋๋ค. softํ๋ค๋ ๊ฒ์ ๊ฒฐ๊ณผ๊ฐ์ด ์ฒ์ฒํ ์ฆ๊ฐํ๋ค๋ ๋ป์ ๋๋ค. ํ์ง๋ง T๊ฐ ๋๋ฌด ์ปค์ง๋ฉด ๋ชจ๋ ํด๋์ค์ ํ๋ฅ ์ด ๋น์ทํด์ง๋๋ค.

Matching logits is a special case of distillation - ์ฆ๋ช
๊ณผ์
ํฌ๋ก์ค์ํธ๋กํผ์ ๋ฐ๋ผ ๋ฏธ๋ถ์ ์๋ ์์ผ๋ก ์ ๋ฆฌ๊ฐ ๋ฉ๋๋ค. v๋ soft target ํ๋ฅ ์ ์์ฑํ๋ weight sum๊ฐ์ ๋๋ค.

์๋์ ๊ฐ์ด ์ ๋ณ๊ฒฝ์ด ๊ฐ๋ฅํ๋ค๊ณ ํฉ๋๋ค.

์๊ทธ๋ง z์ ์๊ทธ๋ง v๋ ํ๊ท ์ด 0์ด๊ธฐ ๋๋ฌธ์ด ์๋ ์๋ง ๋จ์ต๋๋ค.

์ฆ v์ z๊ฐ์ด ๋น์ทํด์ง๋ ๊ฒ๊ณผ distillation์ ๋์ผํฉ๋๋ค. z๋ soft max๋ฅผ ํ๊ธฐ ์ ์ weight sum, v๋ soft target ํ๋ฅ ์ ์์ฑํ๋ weight sum ๊ฐ์ ๋๋ค. ์ง๊ด์ ์ผ๋ก๋ student network์ ์ต์ข ์ถ๋ ฅ๊ฐ๊ณผ teacher network ์ต์ข ์ถ๋ ฅ๊ฐ์ด ๋น์ทํ๊ฒ ์ต์ ์ด๋ผ๊ณ ๋งํ๋ ๊ฒ์ฒ๋ผ ๋ณด์ ๋๋ค.
MNIST
Two hidden layer + Relu : 146 test errors
Two hidden layer + DropOut : 67 test errors
Two hidden layer + Relu + soft target : 74 test errors
3์ ์ ์ธํ๊ณ Train ์์ผ๋ ํ์ต ๊ฒฐ๊ณผ๊ฐ ์ข์์ต๋๋ค.
Experiments on speech recognition

JFT Dataset

๋นจ๋ผ์ง ํ์๊ฐ ์๋ค๊ณ ํฉ๋๋ค.

Code
์ฝ๋ ์ค๋ช : teacher_model.py ์์ teacher ๋ชจ๋ธ์ ํ์ตํ๊ณ , knowledge_distillation.py ์์ ํ์ต teacher ๋ชจ๋ธ๋ก๋ถํฐ student ๋ชจ๋ธ์ ํ์ต๋ค. baseline.py๋ knowledge distillation ์์ด student ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค.
ํต์ฌ๋ถ๋ถ
teacher ๋ชจ๋ธ๋ก๋ถํฐ temparature ๊ณ์๋ฅผ ํตํด logit_T์ prob_T๋ฅผ ๊ตฌํฉ๋๋ค.
teacher_logits = self.teacher_model.layers[-1].output
teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)
student ๋ชจ๋ธ๋ก๋ถํฐ ๋๊ฐ์ด logit_T์ prob_T๋ฅผ ๊ตฌํฉ๋๋ค.
logits = Dense(num_classes, activation=None, name='dense2')(x)
output_softmax = Activation('softmax', name='output_softmax')(logits)
logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
probabilities_T = Activation('softmax', name='probabilities')(logits_T)
student์ ํฌ๋ก์ค ์ํธ๋กํผ์ teacher์ prob_T์ student์ prob_T์ ํฌ๋ก์ค์ํธ๋กํผ๋ฅผ ๋ํด์ค๋๋ค.
def knowledge_distillation_loss(input_distillation):
y_pred, y_true, y_soft, y_pred_soft = input_distillation
return (1 - args.lambda_const) * logloss(y_true, y_pred) + \
args.lambda_const * args.temperature * args.temperature * logloss(y_soft, y_pred_soft)
Result
Teacher Network (acc 90%)
Total params: 3,844,938
Trainable params: 3,841,610
Non-trainable params: 3,328
Student Network(acc 85%)
Total params: 770,378
Trainable params: 769,162
Non-trainable params: 1,216
Distillation(acc 85%)
Student Network๊ณผ Distllation์ acc ๊ฒฐ๊ณผ๊ฐ ๋น์ทํ๋ค๋ ์ ์ด ์ค๋ง์ค๋ฌ์ ์ต๋๋ค. ํ์ง๋ง Distillation์ ํ์ต ์๋๋ Student ๋ง ํ ๋๋ณด๋ค ๋๋ฆฌ์ง๋ง ํ์ต์ ์์ ์ฑ์ด ํจ์ฌ ๋์ ๋ชจ์ต์ ๋ณด์์ต๋๋ค. ๋ง์ ๋จน๊ณ teacher student distllation ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ๋ฉด ํ์คํ distillation์ ๊ฒฐ๊ณผ๊ฐ ๋ ๋์ ๊ฒ ๊ฐ๊ธด ํฉ๋๋ค.
My mini - contribution + ๋๋์ .
๊ธฐ์ฌ : ๊ธฐ์กด ์ฝ๋๋ฅผ TF 2.0์ผ๋ก ๋ฐ๊ฟจ์ต๋๋ค.
๋๋์ : ์ค์ ์ด distillation์ ํ๋ผ๋ฏธํฐ์๋ฅผ ์ผ๋ง๋ ์ค์ผ ์ง ๋ชจ๋ฅผ ๋ ๋งค์ฐ ์ ์ฉํ ๊ฒ ๊ฐ์ต๋๋ค.
Last updated
Was this helpful?