4D ASR: CTC/attention/RNN-T/Mask-CTCの統合(2022年〜2023年)

4D ASR: Joint modeling of CTC, Attention, Transducer, and Mask-Predict decoders

End-to-end (E2E) automatic speech recognition (ASR) can be classified into several models, including connectionist temporal classification (CTC), recurrent neural network transducer (RNN-T), attention mechanism, and mask-predict models.

There are pros and cons to each of these architectures, and thus practitioners may switch between these different models depending on application requirements.

Instead of building separate models, we propose a joint modeling scheme where four different decoders (CTC, attention, RNN-T, mask-predict) share an encoder -- we refer to this as 4D modeling.

Additionally, we propose to 1) train 4D models using a two-stage strategy which stabilizes multitask learning and 2) decode 4D models using a novel time-synchronous one-pass beam search.

We demonstrate that jointly trained 4D models improve the performances of each individual decoder. Further, we show that our joint CTC/RNN-T/attention decoding surpasses the previously proposed CTC/attention decoding.

1) Training weights are usually determined experimentally or based on meta-learning.  In this work, with four weights, experimenting with all possible combinations would be overly time-consuming.  To address this issue, we used a two-stage optimization strategy to determine the multitask weights. 

In the first stage, all four training weights were set to be equal, i.e., (0.25, 0.25, 0.25, 0.25). Then, in the second stage, the training weights were determined to be roughly proportional to the number of epochs in the first stage at which each validation loss reached its minimum value. This strategy is based on the proposition that losses requiring more epochs to convergence should be given higher weights.

2) Another key contribution of this paper is to propose two one-pass joint CTC/RNN-T/attention decoding algorithm using time-synchrony: RNN-T-driven algorithm.  The RNN-T-driven method uses RNN-T as a primary to generate hypotheses. Then, the generated hypotheses are scored combining CTC, attention, and RNN-T decoders.

The left figure shows the normalized validation losses in the first and second stages. 

The MLM loss took more epochs to converge than the other three decoders in the first training, indicating that the trained model did not converge sufficiently with the MLM decoder, or the other decoders were overfitted.  The difference in the convergence speed of the four losses, on the other hand, was smaller in the second training, indicating that the four losses converged relatively adequately.


The table shows the performance of each decoder on the test-other set of Librispeech 100 h without and with joint training in the first/second stage.

Even the first stage outperformed the model without multitask learning,  the performance of all four decoders improved in the second stage. 

Using the proposed two-stage approach, the four weights were efficiently determined with only two experimental trials.


The right figure shows the relationship between real-time factor (RTF) using a GPU (NVIDIA RTX3090) and WER on the Librispeech 100 h test-other set. 

The red dots denote the baselines, the blue dots denote 4D joint training but without joint decoding, and the green dots denote 4D joint training with joint decoding. Comparing the red and blue dots, the proposed 4D model reduced WER for all decoders without increasing RTF.  The proposed 4D model does not increase the computational cost as long as a single decoder is used. The proposed two joint decoding methods had larger RTFs than the other decoders due to increased complexity, but the WERs were the smallest. 

 国際学会 / Peer reviewed conference paper