Abstract:
Brain dynamics are highly complex and yet hold the key to understanding brain function and dysfunction. The dynamics captured by resting-state functional magnetic resonance imaging data are noisy, high-dimensional, and not readily interpretable. The typical approach of reducing this data to low-dimensional features and focusing on the most predictive features comes with strong assumptions and can miss essential aspects of the underlying dynamics. In contrast, introspection of discriminatively trained deep learning models may uncover disorder-relevant elements of the signal at the level of individual time points and spatial locations. Yet, the difficulty of reliable training on high-dimensional low sample size datasets and the unclear relevance of the resulting predictive markers prevent the widespread use of deep learning in functional neuroimaging. In this work, we introduce a deep learning framework to learn from high-dimensional dynamical data while maintaining stable, ecologically valid interpretations. Results successfully demonstrate that the proposed framework enables learning the dynamics of resting-state fMRI directly from small data and capturing compact, stable interpretations of features predictive of function and dysfunction.
Figure 1: An overview of our approach to model interpretation. Steps from A to D describe the process. E shows some spatio-temporal distinctive features responsible for model's prediction.
Figure 2: End-to-end process of RAR evaluation. For each subject in the dataset, based on the whole MILC class prediction and model parameters, we estimated the feature importance vector e using some interpretability method gi. Later on, we validated these estimates against random feature attributions gR using the RAR method and an SVM model. Through the SVM model’s performance when separately trained with different feature sets, we show that whole MILC model-estimated features were highly predictive compared to a random selection of a similar amount of features. Empirically, we show that ξ (X M | gi ) > ξ (X M | gR ), where ξ is the performance evaluation function (e.g., area under the curve) and X M refers to the modified dataset constructed based on only retained feature values.
We show the main results from the whole MILC architecture and its comparison with standard machine learning models (SML) below. Apparently, the whole MILC model, in general, can learn from the raw data where traditional SML models fail to maintain their predictive capacity. Moreover, the whole MILC w/ pretraining substantially improves the latent representations as reflected in the improved accuracy compared to the whole MILC w/o pretraining. Specifically, in most small data cases, the whole MILC w/ pretraining outperformed the whole MILC w/o pretraining across the datasets. However, as expected, when we gradually increased the number of subjects during training, the effect of pretraining on the classification performance diminished, and both configurations of whole MILC did equally well. We verified this trend over three datasets that correspond to autism spectrum disorder, schizophrenia, and Alzheimer’s disease.
Figure 3: Performance comparison of main classification results: traditional ML models vs. DL models
We show the results from the proposed RAR framework below. RAR employs SVM to classify the FNCs of the top 5% of the salient input data as estimated by the whole MILC model’s predictions. We used integrated gradients (IG) and smoothgrad integrated gradients (SGIG) to compute feature attributions. It is evident that when an independent classifier (SVM) learned on every subject’s most salient 5% data, the predictive power was significantly higher compared to the same SVM model trained on the randomly chosen same amount of data. In other words, the poor performance with randomly selected data parts indicates that other parts of the data were not exclusively discriminative as the whole MILC estimated salient 5% data parts. We also notice that sample masks over a different percentage of data coverage gradually obscured the localization of the discriminative activity within the data. Though the SVM model gradually became predictive with increased randomly selected data coverage, which we show in Supplementary Information, this performance upgrade was due to the gradual improvement in functional connectivity estimation and not attributable to the disease-specific localized parts within the data. For every disorder (Autism spectrum disorder, Schizophrenia, and Alzheimer’s disease), the higher AUC at this 5% indicates stronger relevance of the salient data parts to the underlying disorders. Furthermore, the RAR results reflect that in most cases, when whole MILC was trained with limited data, the w/ pretraining models estimated feature attributions more accurately than the models w/o pretraining.
Figure 4: RAR framework validates that the post hoc explanations are highly discriminative, i.e., the model learned and use highly discriminative features for its predictions.
We show below the top 10% FNC for patients computed using most 5% of the salient data as thresholded using feature attribution maps (saliency maps) for different disorders. Apart from the high predictive capacity of the salient data, we observed some intriguing differences among these connectograms. The autism spectrum disorder exhibits the lowest between-domain FNC. However, salient data in autism disorder highlights domain changes in specific cerebellum, sensorimotor, and subcortical domains. The model-identified salient data reflects the most widespread pattern for schizophrenia and is consistent with the literature showing cerebellum interaction across multiple domains and sensorimotor changes. The predictive features for Alzheimer’s disease disease mainly concentrate on visual and cognitive interactions.
Figure 5: Functional connectivity obtained using 5% salient data
Figure 6: A: Full FNC for patients computed using most 5% of the salient data selected based on feature attribution values for different disorders. B: Static FNC (i.e., using 100% data) matrices for patients of different disorders. The FNC based on 5% salient data (A) does indeed convey the same focused dynamic information as currently assessed in FNC matrices based on 100% data (B). It is thus apparent that the proposed model can capture the focused information aligned with the current domain knowledge. C: Pairwise difference of FNC matrices based on 5% salient data. The difference FNC matrices based on focused data indicate that each disorder has a uniquely distinguishable association with brain dynamics.