๋ชฉ์ ์ ๋ฏธ๋ฆฌ ์ ํ์ต๋ ํฐ ๋คํธ์ํฌ์ ์ง์์ ์ค์ ๋ก ์ฌ์ฉํ๊ณ ์ ํ๋ ์์ ๋คํธ์ํฌ์๊ฒ ์ ๋ฌํ๋ ๊ฒ์
๋๋ค.
Ground Truth ์ ์์ ๋คํธ์ํฌ์์ ๋ถ๋ฅ ์ฐจ์ด์ ๋ํ ํฌ๋ก์ค ์ํธ๋กํผ ์์คํจ์์ ํฐ ๋คํธ์ํฌ์ ์์ ๋คํธ์ํฌ ๋ถ๋ฅ ๊ฒฐ๊ณผ์ฐจ์ด๋ฅผ ์์คํจ์ ํฌํจ์ํต๋๋ค.
์ค๋ฒํผํ
์ ํผํ๊ธฐ ์ํด์ ANN์์ ์์๋ธ ๊ธฐ๋ฒ์ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค.
์์๋ธ์ ๊ณ์ฐ์๊ฐ์ด ๋ง์ด ๊ฑธ๋ฆฐ๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
๋ฐ๋ผ์ ์์๋ธ๋งํผ์ ์ฑ๋ฅ๊ณผ ์ ์ ํ๋ผ๋ฏธํฐ ์๋ฅผ ๊ฐ์ง ๋ด๋ด๋ท ๋ชจ๋ธ์ด ํ์ํฉ๋๋ค.
Neural Net Distillation
Neural Net์ Class์ ํ๋ฅ ์ ์ํํธ๋งฅ์ค Output Layer์ ์ด์ฉํด์ ์์ธกํฉ๋๋ค. ์์์ ์๋์ ๊ฐ์๋ฐ q๋ ๊ฐ ํด๋์ค์ ํ๋ฅ , z๋ ์ layer์ weight sum, T๋ temperature๋ผ๋ ๊ฐ์ผ๋ก 1๋ก setting์ด ๋ฉ๋๋ค.
T๊ฐ 1์ด๋ฉด 0, 1์ binaryํ๋ ๊ฒฐ๊ณผ๊ฐ์ ์ป๋๋ฐ ์ด๋ ํ๋ฅ ๋ถํฌ๋ฅผ ์๊ธฐ ์ด๋ ต๊ฒ ๋ง๋ญ๋๋ค.
T๊ฐ ํด์๋ก ๋ softํ ํ๋ฅ ๋ถํฌ๊ฐ ํด๋์ค๋ง๋ค ๋ง๋ค์ด์ง๊ฒ ๋ฉ๋๋ค. softํ๋ค๋ ๊ฒ์ ๊ฒฐ๊ณผ๊ฐ์ด ์ฒ์ฒํ ์ฆ๊ฐํ๋ค๋ ๋ป์
๋๋ค. ํ์ง๋ง T๊ฐ ๋๋ฌด ์ปค์ง๋ฉด ๋ชจ๋ ํด๋์ค์ ํ๋ฅ ์ด ๋น์ทํด์ง๋๋ค.
Matching logits is a special case of distillation - ์ฆ๋ช
๊ณผ์
ํฌ๋ก์ค์ํธ๋กํผ์ ๋ฐ๋ผ ๋ฏธ๋ถ์ ์๋ ์์ผ๋ก ์ ๋ฆฌ๊ฐ ๋ฉ๋๋ค. v๋ soft target ํ๋ฅ ์ ์์ฑํ๋ weight sum๊ฐ์
๋๋ค.
์๋์ ๊ฐ์ด ์ ๋ณ๊ฒฝ์ด ๊ฐ๋ฅํ๋ค๊ณ ํฉ๋๋ค.
์๊ทธ๋ง z์ ์๊ทธ๋ง v๋ ํ๊ท ์ด 0์ด๊ธฐ ๋๋ฌธ์ด ์๋ ์๋ง ๋จ์ต๋๋ค.
์ฆ v์ z๊ฐ์ด ๋น์ทํด์ง๋ ๊ฒ๊ณผ distillation์ ๋์ผํฉ๋๋ค. z๋ soft max๋ฅผ ํ๊ธฐ ์ ์ weight sum, v๋ soft target ํ๋ฅ ์ ์์ฑํ๋ weight sum ๊ฐ์
๋๋ค. ์ง๊ด์ ์ผ๋ก๋ student network์ ์ต์ข
์ถ๋ ฅ๊ฐ๊ณผ teacher network ์ต์ข
์ถ๋ ฅ๊ฐ์ด ๋น์ทํ๊ฒ ์ต์ ์ด๋ผ๊ณ ๋งํ๋ ๊ฒ์ฒ๋ผ ๋ณด์
๋๋ค.
Two hidden layer + Relu : 146 test errors
Two hidden layer + DropOut : 67 test errors
Two hidden layer + Relu + soft target : 74 test errors
3์ ์ ์ธํ๊ณ Train ์์ผ๋ ํ์ต ๊ฒฐ๊ณผ๊ฐ ์ข์์ต๋๋ค.
Experiments on speech recognition
๋นจ๋ผ์ง ํ์๊ฐ ์๋ค๊ณ ํฉ๋๋ค.
์ฝ๋ ์ค๋ช
: teacher_model.py ์์ teacher ๋ชจ๋ธ์ ํ์ตํ๊ณ , knowledge_distillation.py ์์ ํ์ต teacher ๋ชจ๋ธ๋ก๋ถํฐ student ๋ชจ๋ธ์ ํ์ต๋ค. baseline.py๋ knowledge distillation ์์ด student ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค.
teacher ๋ชจ๋ธ๋ก๋ถํฐ temparature ๊ณ์๋ฅผ ํตํด logit_T์ prob_T๋ฅผ ๊ตฌํฉ๋๋ค.
student ๋ชจ๋ธ๋ก๋ถํฐ ๋๊ฐ์ด logit_T์ prob_T๋ฅผ ๊ตฌํฉ๋๋ค.
student์ ํฌ๋ก์ค ์ํธ๋กํผ์ teacher์ prob_T์ student์ prob_T์ ํฌ๋ก์ค์ํธ๋กํผ๋ฅผ ๋ํด์ค๋๋ค.
Teacher Network (acc 90%)
Trainable params: 3,841,610
Non-trainable params: 3,328
Student Network(acc 85%)
Trainable params: 769,162
Non-trainable params: 1,216
Student Network๊ณผ Distllation์ acc ๊ฒฐ๊ณผ๊ฐ ๋น์ทํ๋ค๋ ์ ์ด ์ค๋ง์ค๋ฌ์ ์ต๋๋ค. ํ์ง๋ง Distillation์ ํ์ต ์๋๋ Student ๋ง ํ ๋๋ณด๋ค ๋๋ฆฌ์ง๋ง ํ์ต์ ์์ ์ฑ์ด ํจ์ฌ ๋์ ๋ชจ์ต์ ๋ณด์์ต๋๋ค. ๋ง์ ๋จน๊ณ teacher student distllation ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ๋ฉด ํ์คํ distillation์ ๊ฒฐ๊ณผ๊ฐ ๋ ๋์ ๊ฒ ๊ฐ๊ธด ํฉ๋๋ค.
My mini - contribution + ๋๋์ .
๊ธฐ์ฌ : ๊ธฐ์กด ์ฝ๋๋ฅผ TF 2.0์ผ๋ก ๋ฐ๊ฟจ์ต๋๋ค.
๋๋์ : ์ค์ ์ด distillation์ ํ๋ผ๋ฏธํฐ์๋ฅผ ์ผ๋ง๋ ์ค์ผ ์ง ๋ชจ๋ฅผ ๋ ๋งค์ฐ ์ ์ฉํ ๊ฒ ๊ฐ์ต๋๋ค.