從底層實作長短期記憶模型
從底層實作長短期記憶模型
本研究旨在脫離現成模型的依賴,自製長短期記憶模型(Long Short-Term Memory,LSTM)。在過程中,我們面臨過擬合與模型無法穩並收斂的問題,透過L2正則化、學習率衰減、早停機制、批次化等方法解決。最後與現有模型進行比較,「兩者在預測準確度相差0.07倍,但訓練時間差距1.5倍」,顯示自製模型在學習過程的價值,但仍需優化性能。
專題老師 : 林哲正 教授 專題學生 : 余信陞、黃柏誠、蕭傳原
架構與方法
長短期記憶模型LSTM 是一種改良的RNN,設計來解決時間序列數據中的梯度消失問題,和RNN相比多了一條長期記憶通道,因此在時間序列數據的預測上可以有較好的表現。
LSTM cell的架構如右圖所示,Ct、ht、Xt分別代表長期記憶、短期記憶、輸入,通過遺忘門和進入門將ht、Xt去更新長期記憶。最後進入輸出門將長期記憶加入短期記憶並繼續遞送給下個的cell。
LSTM 的訓練流程 為利用長短期記憶及當下時間點的輸入不斷向前計算產出一個預測數字,接著「比較預測和真實的數字得出梯度偏差,並利用公式回推weights和bias的偏差值再做調整」。右圖為LSTM訓練的流程。
問題與解決
a. 損失函數波動大(Loss Oscillation)
損失函數波動大是指在模型訓練過程中, 損失函數的數值頻繁且劇烈地上下變動,而不是穩定下降的情況。這種現象會影響模型的收斂效率, 甚至導致訓練失敗。
1.批次化(Batching): 是將資料分成一個個批次來進行訓練, 然後累積梯度後再一次全部更新權重, 從而提高訓練效率和計算穩定性。
2.學習率衰減 : 隨著訓練過程的進行, 逐漸減小學習率的策略。目的是在訓練後期減少模型參數的更新幅度, 從而避免在最小值附近振盪並幫助模型更好地收斂。
b. 過擬合(Overfitting)
模型在訓練資料上表現得非常好, 但在新的、未見過的資料上表現不佳的情況。這通常是因為模型過度學習了訓練資料中的噪音或偶然性特徵, 而不是學習到資料的真實模式或規律。
1. L2 正規化 ( Regularization ) : 是在更新權重時額外增加一個與權重值大小成正比的懲罰項,以提升模型的泛化能力並減少過擬合的風險。
2. 早停機制(Early Stopping ): 如果驗證集損失開始上升且已達耐心值門檻就會提前停止訓練並用最佳值時的權重作為模型權重。
研究結果
1.程式碼架構 :
2.損失函數圖表分析 :
與現有的框架效能比較:
準確度
約1 . 0 7 倍
訓練時間
約1 . 5 倍
開發工具