Distilling the Knowledge in a Neural Network

Geoffrey Hinton

Reference Blog

  • ๋ชฉ์ ์€ ๋ฏธ๋ฆฌ ์ž˜ ํ•™์Šต๋œ ํฐ ๋„คํŠธ์›Œํฌ์˜ ์ง€์‹์„ ์‹ค์ œ๋กœ ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•˜๋Š” ์ž‘์€ ๋„คํŠธ์›Œํฌ์—๊ฒŒ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  • Ground Truth ์™€ ์ž‘์€ ๋„คํŠธ์›Œํฌ์™€์˜ ๋ถ„๋ฅ˜ ์ฐจ์ด์— ๋Œ€ํ•œ ํฌ๋กœ์Šค ์—”ํŠธ๋กœํ”ผ ์†์‹คํ•จ์ˆ˜์— ํฐ ๋„คํŠธ์›Œํฌ์™€ ์ž‘์€ ๋„คํŠธ์›Œํฌ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ์ฐจ์ด๋ฅผ ์†์‹คํ•จ์— ํฌํ•จ์‹œํ‚ต๋‹ˆ๋‹ค.

Background

  • ์˜ค๋ฒ„ํ”ผํŒ…์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด์„œ ANN์—์„œ ์•™์ƒ๋ธ” ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

  • ์•™์ƒ๋ธ”์€ ๊ณ„์‚ฐ์‹œ๊ฐ„์ด ๋งŽ์ด ๊ฑธ๋ฆฐ๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  • ๋”ฐ๋ผ์„œ ์•™์ƒ๋ธ”๋งŒํผ์˜ ์„ฑ๋Šฅ๊ณผ ์ ์€ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ๊ฐ€์ง„ ๋‰ด๋Ÿด๋„ท ๋ชจ๋ธ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

Neural Net Distillation

Neural Net์€ Class์˜ ํ™•๋ฅ ์„ ์†Œํ”„ํŠธ๋งฅ์Šค Output Layer์„ ์ด์šฉํ•ด์„œ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์ˆ˜์‹์€ ์•„๋ž˜์™€ ๊ฐ™์€๋ฐ q๋Š” ๊ฐ ํด๋ž˜์Šค์˜ ํ™•๋ฅ , z๋Š” ์ „ layer์˜ weight sum, T๋Š” temperature๋ผ๋Š” ๊ฐ’์œผ๋กœ 1๋กœ setting์ด ๋ฉ๋‹ˆ๋‹ค.

T๊ฐ€ 1์ด๋ฉด 0, 1์˜ binaryํ™”๋œ ๊ฒฐ๊ณผ๊ฐ’์„ ์–ป๋Š”๋ฐ ์ด๋Š” ํ™•๋ฅ ๋ถ„ํฌ๋ฅผ ์•Œ๊ธฐ ์–ด๋ ต๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

T๊ฐ€ ํด์ˆ˜๋ก ๋” softํ•œ ํ™•๋ฅ  ๋ถ„ํฌ๊ฐ€ ํด๋ž˜์Šค๋งˆ๋‹ค ๋งŒ๋“ค์–ด์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. softํ•˜๋‹ค๋Š” ๊ฒƒ์€ ๊ฒฐ๊ณผ๊ฐ’์ด ์ฒœ์ฒœํžˆ ์ฆ๊ฐ€ํ•œ๋‹ค๋Š” ๋œป์ž…๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ T๊ฐ€ ๋„ˆ๋ฌด ์ปค์ง€๋ฉด ๋ชจ๋“  ํด๋ž˜์Šค์˜ ํ™•๋ฅ ์ด ๋น„์Šทํ•ด์ง‘๋‹ˆ๋‹ค.

Matching logits is a special case of distillation - ์ฆ๋ช…๊ณผ์ •

ํฌ๋กœ์Šค์—”ํŠธ๋กœํ”ผ์— ๋”ฐ๋ผ ๋ฏธ๋ถ„์„ ์•„๋ž˜ ์‹์œผ๋กœ ์ •๋ฆฌ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค. v๋Š” soft target ํ™•๋ฅ ์„ ์ƒ์„ฑํ•˜๋Š” weight sum๊ฐ’์ž…๋‹ˆ๋‹ค.

์•„๋ž˜์™€ ๊ฐ™์ด ์‹ ๋ณ€๊ฒฝ์ด ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

์‹œ๊ทธ๋งˆ z์™€ ์‹œ๊ทธ๋งˆ v๋Š” ํ‰๊ท ์ด 0์ด๊ธฐ ๋•Œ๋ฌธ์ด ์•„๋ž˜ ์‹๋งŒ ๋‚จ์Šต๋‹ˆ๋‹ค.

์ฆ‰ v์™€ z๊ฐ’์ด ๋น„์Šทํ•ด์ง€๋Š” ๊ฒƒ๊ณผ distillation์€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. z๋Š” soft max๋ฅผ ํ•˜๊ธฐ ์ „์˜ weight sum, v๋Š” soft target ํ™•๋ฅ ์„ ์ƒ์„ฑํ•˜๋Š” weight sum ๊ฐ’์ž…๋‹ˆ๋‹ค. ์ง๊ด€์ ์œผ๋กœ๋Š” student network์˜ ์ตœ์ข… ์ถœ๋ ฅ๊ฐ’๊ณผ teacher network ์ตœ์ข… ์ถœ๋ ฅ๊ฐ’์ด ๋น„์Šทํ•œ๊ฒŒ ์ตœ์ ์ด๋ผ๊ณ  ๋งํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋ณด์ž…๋‹ˆ๋‹ค.

MNIST

  • Two hidden layer + Relu : 146 test errors

  • Two hidden layer + DropOut : 67 test errors

  • Two hidden layer + Relu + soft target : 74 test errors

3์„ ์ œ์™ธํ•˜๊ณ  Train ์‹œ์ผœ๋„ ํ•™์Šต ๊ฒฐ๊ณผ๊ฐ€ ์ข‹์•˜์Šต๋‹ˆ๋‹ค.

Experiments on speech recognition

JFT Dataset

๋นจ๋ผ์งˆ ํ•„์š”๊ฐ€ ์žˆ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

Code

์ฝ”๋“œ ์„ค๋ช… : teacher_model.py ์—์„œ teacher ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ณ , knowledge_distillation.py ์—์„œ ํ•™์Šต teacher ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ student ๋ชจ๋ธ์„ ํ•™์Šต๋‹ค. baseline.py๋Š” knowledge distillation ์—†์ด student ๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

ํ•ต์‹ฌ๋ถ€๋ถ„

  • teacher ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ temparature ๊ณ„์ˆ˜๋ฅผ ํ†ตํ•ด logit_T์™€ prob_T๋ฅผ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.

 teacher_logits = self.teacher_model.layers[-1].output
 teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
 teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)
  • student ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ๋˜‘๊ฐ™์ด logit_T์™€ prob_T๋ฅผ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.

logits = Dense(num_classes, activation=None, name='dense2')(x)
output_softmax = Activation('softmax', name='output_softmax')(logits)
logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
probabilities_T = Activation('softmax', name='probabilities')(logits_T)
  • student์˜ ํฌ๋กœ์Šค ์—”ํŠธ๋กœํ”ผ์™€ teacher์˜ prob_T์™€ student์˜ prob_T์˜ ํฌ๋กœ์Šค์—”ํŠธ๋กœํ”ผ๋ฅผ ๋”ํ•ด์ค๋‹ˆ๋‹ค.

def knowledge_distillation_loss(input_distillation):
    y_pred, y_true, y_soft, y_pred_soft = input_distillation
    return (1 - args.lambda_const) * logloss(y_true, y_pred) + \
           args.lambda_const * args.temperature * args.temperature * logloss(y_soft, y_pred_soft)

Result

  • Teacher Network (acc 90%)

    • Total params: 3,844,938

    • Trainable params: 3,841,610

    • Non-trainable params: 3,328

  • Student Network(acc 85%)

    • Total params: 770,378

    • Trainable params: 769,162

    • Non-trainable params: 1,216

  • Distillation(acc 85%)

Student Network๊ณผ Distllation์˜ acc ๊ฒฐ๊ณผ๊ฐ€ ๋น„์Šทํ–ˆ๋‹ค๋Š” ์ ์ด ์‹ค๋ง์Šค๋Ÿฌ์› ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ Distillation์€ ํ•™์Šต ์†๋„๋Š” Student ๋งŒ ํ• ๋•Œ๋ณด๋‹ค ๋А๋ฆฌ์ง€๋งŒ ํ•™์Šต์˜ ์•ˆ์ •์„ฑ์ด ํ›จ์”ฌ ๋†’์€ ๋ชจ์Šต์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค. ๋งˆ์Œ ๋จน๊ณ  teacher student distllation ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ตํ•˜๋ฉด ํ™•์‹คํžˆ distillation์˜ ๊ฒฐ๊ณผ๊ฐ€ ๋” ๋†’์„ ๊ฒƒ ๊ฐ™๊ธด ํ•ฉ๋‹ˆ๋‹ค.

My mini - contribution + ๋А๋‚€์ .

  • ๊ธฐ์—ฌ : ๊ธฐ์กด ์ฝ”๋“œ๋ฅผ TF 2.0์œผ๋กœ ๋ฐ”๊ฟจ์Šต๋‹ˆ๋‹ค.

  • ๋А๋‚€์  : ์‹ค์ œ ์ด distillation์€ ํŒŒ๋ผ๋ฏธํ„ฐ์ˆ˜๋ฅผ ์–ผ๋งˆ๋‚˜ ์ค„์ผ ์ง€ ๋ชจ๋ฅผ ๋•Œ ๋งค์šฐ ์œ ์šฉํ•  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

Last updated

Was this helpful?