LoRA Fine Tuning
posted in 2024
posted in 2024
(I) Objective: The primary objective is to address the challenges of high computational costs and limited resource accessibility during training. LoRA is employed to update models and mitigate language drift under the following scenarios:
Temporal Drift: Language evolves over time.
Domain Drift: Language varies across different domains (e.g., political vs. technical).
Geographical Drift: Terms specific to one region may not be understood or applicable in another.
(II) Introduction to LoRA (Low-Rank Adaptation):
LoRA works by freezing the original model's weight matrix and updating only smaller matrices derived from a low-rank decomposition of the delta weight matrix, ΔW = AB. (Fig. 1)
LoRA weights are initialized as follows:
Matrix A is initialized using Kaiming-uniform distribution (line 157 in Fig. 1).
Matrix B is initialized with zeros (line 162 in Fig. 1), forming an identity transformation.
The input fed to W is also given to B*A, and the output of B*A is added to the output of the original matrix W. (line 582 in Fig. 1)
After training, the LoRA weights are merged into the base model for deployment.
reference code: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py
(Fig. 1) LoRA introduction and reference code.
(III) Introduction to QLoRA (Quantized Low-Rank Adaptation): (Fig. 2)
1. 4-bit Normal Float (NF4): To address the issue where traditional quantization methods often lead to uneven memory usage when storing outliers, QLoRA assigns an equal number of values from the input tensor to each quantization bin. This approach ensures that all bins are utilized effectively, maximizing memory efficiency and enhancing computational performance.
2. Double Quantization: This technique involves quantizing the quantization constants of the model, providing additional memory savings.
3. Paged Optimizers: Utilizes unified memory feature to handle automatic page-to-page data transfers between CPU and GPU. This feature ensures error-free GPU processing, even when the GPU runs out of memory. Optimizer states are allocated in paged memory and automatically evicted to CPU RAM when GPU memory is insufficient. The data is paged back into GPU memory during the optimizer update step as needed, ensuring smooth and efficient memory management.
reference: https://wandb.ai/sauravmaheshkar/QLoRA/reports/What-is-QLoRA---Vmlldzo2MTI2OTc5
(Fig. 2) QLoRA introduction.
(IV) Introduction to Quantization:
The mapping between the FP16 numbers and their corresponding INT8 number in absolute max quantization. (Fig. 3a)
Matrix multiplication computation in three-steps process: (Fig. 3b)
Extract outliers (i.e., values exceeding a predefined threshold) column-wise from the input hidden states.
Perform matrix multiplication using FP16 for outliers and INT8 for non-outliers.
Dequantize the non-outlier results and sum them with the outlier results to obtain the final output in FP16.
(Fig. 3a) Absolute max quantization.
(Fig. 3b) INT8 matrix multiplication.
(V) Result: (Table. 1)
Training datasets: Utilized a diverse and comprehensive set of publicly available datasets, including:
English Datasets: Common Voice, Fleurs, LibriSpeech, SPGISpeech, GigaSpeech.
Chinese Datasets: Common Voice, Aishell1, Aishell2, MAGICDATA Mandarin Chinese Read Speech Corpus, MagicData RAMC, Primewords Chinese Corpus Set 1, aidatatang_200zh, THCHS-30, TALCS, zhvoice, WeNetSpeech.
Evaluation metrics are presented in Table 1, showing:
Character Error Rate (CER) for Chinese testing dataset.
Word Error Rate (WER) for English testing datasets.
Official Whisper models (tiny, base, small, medium, largeV2, largeV3) are highlighted with an orange background, while the QLoRA fine-tuned large-v3-turbo model is highlighted with a green background.
The QLoRA fine-tuned large-v3-turbo model achieves Character Error Rate (CER) and Word Error Rate (WER) lower than those of the official Whisper large-v3-turbo model in all testing datasets.
(Table. 1) Character Error Rate (CER) and Word Error Rate (WER) are metrics where lower values indicate better performance.