Token-Scaled Logit Distillation for Ternary Weight Generative Language Models
Token-Scaled Logit Distillation for Ternary Weight Generative Language Models
Minsoo Kim1, Sihwa Lee1, Janghwan Lee1, Sukjin Hong1,2, Du-Seong Chang2, Wonyong Sung3, Jungwook Choi †1
1 Hanyang University, 2KT, 3 Seoul National University, Republic of Korea
Abstract
Generative Language Models (GLMs) have shown impressive performance in tasks such as text generation, understanding, and reasoning. However, the large model size poses challenges for practical deployment. To solve this problem, Quantization-Aware Training (QAT) has become increasingly popular. However, current QAT methods for generative models have resulted in a noticeable loss of accuracy. To counteract this issue, we propose a novel knowledge distillation method specifically designed for GLMs. Our method, called token-scaled logit distillation, prevents overfitting and provides superior learning from the teacher model and ground truth. This research marks the first evaluation of ternary weight quantization-aware training of large-scale GLMs with less than 1.0 degradation in perplexity and achieves enhanced accuracy in tasks like common-sense QA and arithmetic reasoning as well as natural language understanding. Our code is available at https://github.com/aiha-lab/TSLD.
Motivation: Accuracy Loss in Ternary Quantization
Ternary Weight Quantization for Generative Language Models:
Saving up to 16x less GPU memory than FP32
Multiplication-less MATMUL implementation
Limitation of Ternary GLMs in QAT and PTQ:
Significant accuracy loss in Post-Training-Quantization (PTQ)
Lack of analysis of GLM characteristics in QAT with KD
Approach: Logit KD for Cumulative Quantization Error
QAT-KD Methods for Encoder and Decoder QAT
L2L KD: proposed for encoder QAT (MSE loss)
Logit KD: conventional KD (soft cross-entropy)
Cumulative Quantization Errors in Causal Attention:
Causal attention integrates masking in self-attention
Quantization errors accumulated towards later tokens
Logit Distillation for Cumulative Quantization Error
L2L KD follows middle stage output, fails reproducing final layer
Logit KD adjusts middle stage output, reproducing final logit distribution
Logit KD aligns with the characteristics of the decoder model, counteracting cumulative quantization error, outperforming L2L KD in QAT
Method: Token-Scaled Logit Distillation (TSLD) for Superior Learning in KD
Experiments & Results: Showing Best Performance in Language Modeling and Reasoning Accuracy