Left ventricular ejection fraction (LVEF), the proportion of blood pumped out of the left chamber of the heart, is a cornerstone metric for assessing cardiac function and the risk of heart failure. It is typically measured through echocardiography, which, albeit accurate, is resource-intensive and time-consuming. One of our current thesis projects explores the possibility of providing a viable alternative for estimating LVEF via 12-lead electrocardiogram (EKG) data. EKG machines are much more widely accessible screening tools of cardiac health. In primary care centers and emergency units, however, there are often no specialists to analyze the EKG tracings, spurring scientists to develop AI-aided models that automate the interpretation of EKG data.
Related works using EKG data in prediction models have been limited to binary outcome diagnoses, primarily employing convolutional neural networks (CNNs) trained on EKGs as image inputs. In our prior work, we have developed a Machine Learning (ML) model that predicts one of four classes — normal, mildly, moderately, and severely reduced LVEF. We extracted from real-world EKG time-series a tabular representation of the data and achieve competitive performance with an XGBoost classifier that can provide a much more granular, therefore clinically more meaningful, estimation of the ejection fraction.
In our final project for this class, we explore the role of representation learning in analyzing EKG time-series data, framing it as a unifying theme across key dimensions:
Feature transformations vs. time-series embeddings: Traditional ML models rely on handcrafted tabular features, while DL models such as CNNs and transformers learn hierarchical representations directly from raw signals.
Lead selection as representation choice: Each of the 12 EKG leads offers a unique "view" of the heart’s electrical activity, effectively acting as a different representation of the same underlying physiological process.
Regarding the first point, our prior ML methodology focused on extracting temporal, statistical, and frequency-domain characteristics using the Python package tsfresh [7], casting the EKG time-series into a rich feature space suitable for gradient boosted models like XGBoost. For this study, we hypothesize that DL models, in particular CNN- and transformer-based, will outperform ML models on our dataset for 4-class classification of the LVEF. Indeed, leading studies [1]-[5] have achieved notable results using CNN/ResNet-based architectures trained on extensive EKG datasets. They primarily perform binary prediction of cardiovascular (CV) conditions, some readily assessable from EKGs (e.g., arrhythmias), others less so (e.g., heart failure). The neural net in [3] operates on the EKG time-series from a frequency domain perspective via Short-Time Fourier Transform, which we chose not to build upon due to the similarity to our tsfresh features in the ML approach. A more recent study adapts the generative pre-trained transformer (GPT) architecture to predict the next time-series point in one-dimensional, periodic physiological signals, resulting in two general-purpose pre-trained models, dubbed "ECG-PT" and "PPG-PT" respectively [6]. The authors claim that if these GPT models can learn a general understanding of the heart signals, then they will be straightforward to fine-tune for diagnosis of heart conditions.
Given our background primarily in ML, we adopt a more pedagogical approach in this final project with the aim of deepening our understanding and contrasting how DL architectures work in practice, rather than trying to surpass existing benchmarks. Building on the growing body of research applying DL to cardiac data, we investigate the efficacy of convolutional and residual networks and pre-trained models like ECG-PT on multi-class classification for a more nuanced assessment of CV conditions, specifically those not easily inferred from EKGs alone, such as LVEF. Because no prior study has published multi-class results for LVEF, we do not benchmark against the publications cited above.
The choice of leads to include in model training is another aspect of divergence among existing publications: [6] uses Lead I only, while [3] identified a subset of leads to be particularly impactful for the prediction of heart failure. This also informs a key question: can a single lead provide sufficient information for accurate classification? If yes, then this means that wearables such as Apple Watch, which measure Lead I, have the potential to be employed as diagnostic tools.
By centering representation learning as the unifying thread, we
Compare ML (feature-based) against DL (embedding-based representation) models for 4-class LVEF classification, concretely our existing XGBoost model against a CNN and a fine-tuned ECG-PT model.
Investigate whether a single-lead, mainly Lead I, offers sufficient representation of the heart's activity or if multi-lead models offer critical advantages.
Utilize common interpretability tools for DL to connect model decisions to clinical insights, evaluating whether learned representations align with meaningful EKG features.
From our prior study, we collaborate with a team of cardiologists from Hartford Healthcare on gathering our own patient cohort from hospital system, targeting outpatients who had both an EKG and an echocardiogram within 2 weeks of each other. The echocardiogram provides the ground-truth LVEF target which we aim to predict using the EKG recorded closest in time, with a granularity of 4 severity levels: normal (LVEF ≥50%), mildly (LVEF 40-49%), moderately (LVEF 30-39%), or severely (LVEF <30%) reduced LVEF. The true class distribution is heavily imbalanced as the normal class (class 3) takes up 89.27%, the mildly reduced (class 2) about 5.56%, the moderately reduced (class 1) about 4.13%, and the severely reduced (class 0) only 1.03% of the total 39'065 EKG samples. In order to account for this imbalance, our main performance metric is the per-class AUC (one-v-rest) as other scores like accuracy and F1-scores depend heavily on the chosen probability thresholds.
The 39'065 EKG records are initially loaded as 10-second sequences of 5000 samples (i.e. sampling frequency of 500 Hz) and indexed by the 8 standard leads 'I', 'II', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', from which we can derive the remaining 4 leads 'aVR', 'aVL', 'aVF', and 'III' via linear combinations. For the subset of 5 leads, we chose 3 lateral leads (I, V5, V6) and 2 anterior leads (V3, V4) [3]. Lead I, given its applicability in wearables, is selected for the single lead experiments.
To tune the model parameters and set aside unseen test data, we randomly split the EKGs into the same 85-7.5-7.5% train-validation-test cohort for all models based on patient identifiers to avoid data leakage, while stratifying along the class distribution.
Our CNN/Resnet model, inspired from [1], leverages the multi-dimensionality of our EKG time-series, stacking the individual leads into n_batches x n_leads x 5000 tensors.
Below is a brief overview of our CNN architecture (see also block diagram below):
Initial Conv1D layer with ReLU activation and batch normalization.
3 Residual Units (ResBlk):
Each Residual Unit includes 2 convolutional layers, batch normalization (BN), ReLU activation, and dropout.
Shortcut connections adaptively downsample inputs for dimensionality alignment.
Global Average Pooling (GAP) following the ResBlks to reduce feature maps.
Fully connected layer with SoftMax activation outputs probabilities for the 4 classes.
Cross-entropy loss function with Adam optimizer and class-specific weights to address class imbalance in the dataset. (This ensures that minority classes contribute more significantly to the loss calculation, encouraging more balanced model predictions.)
Key hyper-parameters were tuned in a first round across subsets of 5000 EKG training samples for 1-lead, 5-lead, and 12-lead EKGs separately, to obtain the best learning rate (0.0001), batch size (32), dropout rate (0.2), and kernel size (5). Using these parameter settings, we refit the model with early stopping to prevent overfitting (triggered at the 9th epoch). The final model, based on validation accuracy, is trained on the full training and validation cohorts. Bootstrapping resamples the test data 100 times with replacement to construct the 95% confidence intervals for the AUC scores.
ECG-PT
For the transformer-based approach, we extended the ECG-PT architecture in [6] from single-lead input data to multi-lead EKG time-series. Following the pre-processing steps given by the authors, the signals are first downsampled (500 to 100 Hz) and truncated to a fixed sequence length (block_size = 500). After initial attempts on all 12 leads reveal to be too resource-intensive, we had to restrict to the subset of 5 leads on only 10% of the training data for faster iteration.
Below is a brief overview of our transformer-derived architecture:
Lead-wise embedding layers generate token embeddings for sequence representation and position embeddings for temporal context.
Multi-head attention layers enable the model to focus on relevant portions of the input signal across different heads.
We loaded the pre-trained ECG-PT model (originally for next-time-point prediction) for fine-tuning on our downstream task by freezing the first 7 layers and unfreezing the last transformer block, final normalization layer, and the classification head (see block diagram below). Many of the remaining configurations such as for cross entropy loss with class weights are similar the CNN design described earlier. Our final learning rate is 3e-4, batch size is 16, maximum iterations 20 epochs.
Dynamic Confidence-Based Ensembling (CNN + XGBoost)
To combine the strengths of our CNN and XGBoost models, we provide an ensembling strategy: for each sample, the confidence, here defined as the maximum predicted probability for the predicted class, is calculated for both models. The final dynamic ensemble prediction for each sample is then taken from the model with the higher confidence, ensuring that we select the most confident and reliable prediction.
Note that the ECG-PT model is not included in the ensembling process. This decision was made because the ECG-PT model was trained on a small fraction of the available data due to computational constraints, and thus it would have introduced inconsistencies in data representation and coverage when aggregated with the other models.
Block diagram of the CNN/ResNet architecture from [1], which we modified only slightly.
Block diagram of the ECG-PT proposed by [6] where we unfreeze the final transformer block and fine-tune to our downstream task.
Learning curves of ECG-PT with peaks in the validation loss and accuracy indicating likely overfitting of the model. Early stopping was triggered after 8 epochs as the validation loss did not improve by at least 1e-4 for 5 consecutive epochs. We increased the dropout to 0.3 and used the AdamW optimizer that incorporates weight decay for better regularization.
In addition to the loss function weights, we also tuned the class-specific thresholds for probability predictions. Example of one set of adjusted thresholds (0.1063, 0.1361, 0.1046, 0.5879 for classes 0 to 3), generating a confusion matrix that illustrates how the CNN model tends to misclassify the abnormal LVEF test samples as normal due to heavy class imbalance. Such thresholds are usually defined after deliberation with clinicians over the practical trade-off in real-world implementation.
Our per-class AUC results on the test set are shown in the comparison table with 95% confidence intervals in square brackets and we note:
The best performance across the board is achieved by the dynamic confidence-based ensembling model of CNN and XGBoost.
The CNN models outperform the XGBoost models on average, albeit the confidence intervals overlap.
The CNN models seem to offer an advantage in robustness as their results have smaller variance / narrower confidence intervals) compared to the XGBoost models.
Using all 12 leads clearly improves the model performance, whether they are embedding-based (CNN) or feature-based (XGBoost). Concretely, the XGBoost trained on 12 leads, for instance, has an average increase of 5.22% with respect to the one trained on only Lead I.
A remark on the difference in AUC scores between the classes: the models tend to misclassify class 2 and class 1 samples as class 3, the majority class, despite our mitigation efforts. The most probable reason lies in the similarity between the EKG signals of these classes which would confuse the model as it tries to distinguish. The hope is to understand such confounding patterns via saliency or heat maps down the road, but this is not the focus of the project at hand.
Regretfully due to computation limits, we could not train the ECG-PT on the full training and validation sets. This is likely the main reason why the model underperforms on the test cohort (see bar chart on the right). As such, we will focus our upcoming discussion on the CNN model for fair comparison of the learning methodology.
(Another potential reason for weak performance is that inputting signals of different sampling frequencies into the model pre-trained on ECG sampled at 100Hz "will have a negative impact on the quality of model inference" [6]. )
Both the ECG-PT and the XGBoost in the bar plot were trained on the same 10% subset, while the CNN was trained on the full set. The XGBoost performs slightly better than ECT-PT.
Using our empirical results, we can draw some conclusions to respond to our initial hypotheses, keeping in mind the limitations of our study. In the following, we aim to deepen our understanding of why the models showcase such results when applied on time-series cardiac data.
A few conjectures:
Long sequence length: With 5000 time steps per lead, patterns like R-R intervals (heart rate) or ST-segment changes span varying time scales. Convolutional layers can focus on local features, while residual connections allow for long-range temporal dependencies. Note that LSTM architectures are a well-known alternative for capturing such dependencies in time-series data. However, both [1] and [5] have experimented with including LSTM into their DNN without notable value-add in performance, hence we abstained from implementing LSTM.
Granularity of heart signal patterns: the GAP module in the CNN helps ensuring that both high-level patterns (e.g. the overall morphology of PQRST waves), and specific EKG segments contribute equally to the final prediction. The shortcut connections preserve the "identity mapping" of the input, such that the network can focus on refining patterns (e.g., R-peak amplitudes, ST-segment elevation) without losing raw signal information.
12-lead data complexity: From our prior discussions with clinical experts, we know that each lead captures a unique spatial projection of the heart’s electrical activity. Moreover, cardiac patterns like arrhythmias appear in different combinations across leads. We believe that residual units help extract multi-scale temporal features while preserving the hierarchical structure of the data.
The last point provides a natural segway to the next question:
Does one lead offer enough representation to predict LVEF?
Intuitively, just like the human cardiologist is trained to read from all 12 tracings of the heart's electrical waves, an ML or DL model may gain valuable insights from all the 12 leads. When stacked together, such an input structure may help the model identify hidden links and/or inter-lead correlations as they are effectively different "views" on the heart (like a camera taking a picture of the heart from 12 different angles). From our empirical observations, training on Lead I alone indeed results in worse discriminatory power than doing so on more or all leads (see table and bar chart for CNN scores). Nevertheless, the advantages consist in the much faster pre-processing and training time and the ability to detect low LVEF using wearables alone. Given these practical considerations, we may be encouraged to further optimize our models to one-dimensional EKG data in extension works.
Finally, it should be intuitive that fewer leads could also facilitate the model's interpretability and explainability. This again leads us to address the third question:
We left off the question on embedding- versus feature-based representation without giving an intuition on the latter approach. In fact, it is not straightforward to assess whether feature extractors for time-series characteristics are able to capture the same information as the neural net modules. Ideally, explicit feature engineering, e.g. via the tsfresh package in our prior work, should allow us to interpret the ML model's decision readily through feature importance coefficients [8]. However, because we end up extracting about 880 features for each lead, the majority of features such as Fourier and autocorrelation coefficients turn out to be inaccessible to clinicians (who do not usually speak this engineering language). Rather, our collaborators at the hospital appreciate visualization of the model’s focus areas directly overlaid on familiar EKG plots. As an initial attempt, we generated the heat maps below using Grad-CAM, where higher-intensity regions indicate where our CNN model concentrated most to make predictions. Interestingly, we noticed that the colouring patterns are identical across the 12 leads (at least to the naked eye), which seems to contradict our previous statement that one lead alone is insufficient for accurate classifications. Hence, we may need to investigate to what extent these heat maps truly explain the model's inner functionings.
Example of a class 3 test sample (LVEF > 50%) which the CNN model correctly classified. The leads are numbered in the following order: ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'].
Example of a class 0 test sample (LVEF <30%) which the CNN model misclassified as class 3 (normal).
Regarding the interpretability of our ECG-PT model, we first alert that this is still work in progress. We extracted and aggregated the attention weights from the last transformer's multi-head attention block. After several scaling steps to allow for enough contrast, we overlay these weights on the ECG signal to create attention maps that highlight regions with high model focus. A PyQt-based GUI was developed to load and visualize a ECG signal from the training set, predict the LVEF classification using ECG-PT, and simultaneously plot the attention map.
The preliminary result displayed below may help us infer some of the inner workings of the transformer model. The model attends heavily to the token's own position in the sequence, which is typical for many transformer-based models, resulting in the diagonal demarcation. We had to zoom the heat map into the first 100 positions to see the brighter cells along the left corner, suggesting that early tokens are influencing the model's understanding of the sequence stronger. The overall pattern confirms what we outlined previously about how convolutional and residual units learn different representations of the signal: the ECG-PT model attends to both the relationship between adjacent and nearby positions in the sequence (diagonal cells), and to the dependencies across longer temporal ranges (off diagonal cells), which could be distant QRS complexes.
Our PyQt-based GUI, partially adopted from [6], with interactive buttons to plot a truncated sample of Lead I and a zoomed-in attention map.
We extend the motivation of a fine-grained classification of the LVEF to Deep Learning frameworks, bypassing the need for explicit feature engineering, and assess the representation learning process and interpretability of these methods. Our CNN outperforms our prior XGBoost model but to gain an edge over published works, we may need to explore architectural modifications (e.g. Inception Modules which we have started incorporating between the residual blocks). Due to computational constraints, we could not draw fair conclusions on our fine-tuned ECG-PT model, but we believe that it offers some promising paths.
Furthermore, we open new avenues for ensemble models that harness the strengths of both ML and DL approaches, something that has not been proposed in the literature yet. That being said, from an implementation perspective, we believe that the excess computation time to develop and train ensemble models is likely not worth the minor increase in classification power.
In our ongoing research, we continue exploring both ML and DL modeling strategies as well as the integration of clinical metadata, which will ultimately lead us to train an end-to-end multimodal model for the diagnosis of a diverse range of heart conditions.
[1] Ribeiro, A.H., Ribeiro, M.H., Paixão, G.M.M. et al. Automatic diagnosis of the 12-lead ECG using a deep neural network. Nat Commun 11, 1760 (2020). https://doi.org/10.1038/s41467-020-15432-4
[2] Kalmady, S.V., Salimi, A., Sun, W. et al. Development and validation of machine learning algorithms based on electrocardiograms for cardiovascular diagnoses at the population level. npj Digit. Med. 7, 133 (2024). https://doi.org/10.1038/s41746-024-01130-8
[3] Cho, Jinwoo*; Lee, ByeongTak*; Kwon, Joon-Myoung†,‡; Lee, Yeha*; Park, Hyunho*; Oh, Byung-Hee§; Jeon, Ki-Hyun†,§; Park, Jinsik§; Kim, Kyung-Hee†,§. Artificial Intelligence Algorithm for Screening Heart Failure with Reduced Ejection Fraction Using Electrocardiography. ASAIO Journal 67(3):p 314-321, March 2021. | DOI: 10.1097/MAT.0000000000001218
[4] Attia, Z.I., Kapa, S., Lopez-Jimenez, F. et al. Screening for cardiac contractile dysfunction using an artificial intelligence–enabled electrocardiogram. Nat Med 25, 70–74 (2019). https://doi.org/10.1038/s41591-018-0240-2
[5] Raghunath, S., Ulloa Cerna, A.E., Jing, L. et al. Prediction of mortality from 12-lead electrocardiogram voltage data using a deep neural network. Nat Med 26, 886–891 (2020). https://doi.org/10.1038/s41591-020-0870-z
[6] Harry J. Davies, Interpretable Pre-Trained Transformers for Heart Time-Series Data https://www.arxiv.org/pdf/2407.20775 with pre-trained models available at https://github.com/harryjdavies/HeartGPT?tab=readme-ov-file
[7] https://tsfresh.readthedocs.io/en/latest/
[8] Bertsimas, Dimitris et al. “Machine Learning for Real-Time Heart Disease Prediction.” IEEE journal of biomedical and health informatics vol. 25,9 (2021): 3627-3637. doi:10.1109/JBHI.2021.3066347