A Multimodal Ensemble Deep Learning Model for Functional Outcome Prognosis of Stroke Patients
Article information
Abstract
Background and Purpose
The accurate prediction of functional outcomes in patients with acute ischemic stroke (AIS) is crucial for informed clinical decision-making and optimal resource utilization. As such, this study aimed to construct an ensemble deep learning model that integrates multimodal imaging and clinical data to predict the 90-day functional outcomes after AIS.
Methods
We used data from the Korean Stroke Neuroimaging Initiative database, a prospective multicenter stroke registry to construct an ensemble model integrated individual 3D convolutional neural networks for diffusion-weighted imaging and fluid-attenuated inversion recovery (FLAIR), along with a deep neural network for clinical data, to predict 90-day functional independence after AIS using a modified Rankin Scale (mRS) of 3–6. To evaluate the performance of the ensemble model, we compared the area under the curve (AUC) of the proposed method with that of individual models trained on each modality to identify patients with AIS with an mRS score of 3–6.
Results
Of the 2,606 patients with AIS, 993 (38.1%) achieved an mRS score of 3–6 at 90 days post-stroke. Our model achieved AUC values of 0.830 (standard cross-validation [CV]) and 0.779 (time-based CV), which significantly outperformed the other models relying on single modalities: b-value of 1,000 s/mm2 (P<0.001), apparent diffusion coefficient map (P<0.001), FLAIR (P<0.001), and clinical data (P=0.004).
Conclusion
The integration of multimodal imaging and clinical data resulted in superior prediction of the 90-day functional outcomes in AIS patients compared to the use of a single data modality.
Introduction
Acute ischemic stroke (AIS) results in long-term functional disability, inflicting significant social and economic burdens [1]. Accurate prediction of stroke functional outcomes is important to achieve informed clinical decision-making and improve patients’ quality of life [2]. However, this prediction is difficult because of the heterogeneous nature of post-stroke disability [3]. Stroke functional outcome is influenced by various factors, encompassing clinical aspects such as age [3], patient characteristics [4], cognition [5], treatment [6], comorbidities [7], stroke severity [8], and even imaging biomarkers [9,10]. Widely used prognostic systems, including Ischemic Stroke Predictive Risk Score [11] or Acute Stroke Registry and Analysis of Lausanne [12], incorporate some of these clinical factors to predict stroke functional outcomes. However, the selective inclusion of clinical factors may lead to missing important information and failing to consider patient-specific details.
Machine learning (ML) and deep learning (DL) have achieved significant success in the field of medicine, offering a broad range of variables and algorithms. ML methods, such as support vector machines, decision trees, random forests, and deep neural networks, have been used to predict stroke functional outcomes, demonstrating improved performance compared to traditional risk scores [13,14]. However, these models mainly rely on clinical data, and do not incorporate imaging data, which contain valuable information for predicting stroke outcomes, such as the extent of tissue damage, penumbra, and collateral circulation. Recent studies have demonstrated the feasibility of incorporating imaging data into ML/DL models to predict stroke outcomes; however, most focused on selective populations undergoing reperfusion treatment, thus limiting their generalizability [15-19]. Furthermore, studies focusing on the general AIS population have predominantly focused on single-modality data from a single center [20,21].
This study aimed to predict long-term functional outcomes in AIS patients using a comprehensive model. The proposed model combines multiple magnetic resonance (MR) scans and clinical data from multiple centers. This ensemble model enhanced the performance, minimized biases, and reduced variations in the prediction results. Interpretability methods were employed to visualize the decision-making process based on unique patterns observed in the input data.
Methods
Data collection
Data were obtained from the Korean Stroke Neuroimaging Initiative (KOSNI) Registry, a prospective observational study conducted at 18 tertiary stroke centers in South Korea over an 8-year period (2011–2018). The study protocol was approved by the Institutional Review Board of Asan Medical Center (IRB number: 2013-0162), and informed consent was obtained from all participants. The inclusion criteria for enrollment in the KOSNI registry were as follows: (1) individuals aged >20 years and (2) those presenting with neurological symptoms indicative of stroke, including transient ischemic attacks (TIAs).
Of the 5,018 patients, 2,606 were eligible after applying the following exclusion criteria: (1) missing the 90-day modified Rankin Scale (mRS) [22], (2) presentation >24 hours after stroke onset, (3) absence of stroke lesions in baseline images, and (4) poor image quality or image preprocessing failure (Figure 1). Descriptive statistics comparing baseline characteristics between the study population and excluded participants are provided in Supplementary Table 1. The study defined binarized 90-day mRS outcomes >2, identifying stroke patients who required assistance for daily activities due to functional limitations [23]. Among the patients, 1,613 (61.90%) exhibited an mRS of 0–2, whereas 993 (38.10%) achieved an mRS of 3–6 at 90 days post-stroke.
Image data preprocessing
Baseline MR encompassed two subtypes of diffusion-weighted imaging (DWI): DWI with a b-value of 1,000 s/mm2 (b1000) and apparent diffusion coefficient (ADC) map, and fluid-attenuated inversion recovery (FLAIR).
Data pre-processing for raw MR scans involved the following steps: N4 bias field correction [24] was applied to each modality, followed by skull stripping using a brain mask derived by K-means clustering. The DWI images (b1000 and ADC map) were aligned to the Montreal Neurological Institute 152 (MNI 152) space, with 2-mm isotropic voxels using the ANTs SyN registration algorithm [25]. The FLAIR images were subjected to linear coregistration to align each volume with the DWI space within the subjects. The images were then aligned to a standard space using DWI deformation. The voxel intensity values in each image were normalized to ensure they fell within the range of 0–1.
Clinical data preprocessing
The clinical variables comprised 22 demographic and clinical features: age; sex; previous history of hypertension, diabetes, hyperlipidemia and stroke (including TIAs); current smoking status; body mass index (BMI); systolic blood pressure; diastolic blood pressure (DBP); hematocrit level; hemoglobin level; blood glucose level; creatinine level; total cholesterol level; high-density lipoprotein cholesterol (HDL-C); low-density lipoprotein cholesterol; total National Institutes of Health Stroke Scale (NIHSS) [8] at admission; duration between stroke onset and admission; reperfusion therapy status; risk status of cardiac embolic sources; and Trial of ORG 10172 in Acute Stroke Treatment (TOAST) [26] subtypes, including large-artery atherosclerosis, cardioembolism, small-vessel occlusion, other determined etiology, and undetermined etiology. Notably, each variable exhibited less than 5% missing data.
In the preprocessing of clinical variables, categorical features such as sex, past medical history, reperfusion therapy status, and TOAST subtypes were subjected to label encoding. Simple mode imputation was further applied to categorical variables with missing values. Conversely, all continuous variables were scaled according to the interquartile range (IQR), without any additional feature engineering, and imputed with the median value in cases where missing values were present.
Proposed approach
Model architecture
The prognostic model framework presented in Figure 2A involved the training of four different models using distinct modalities: clinical data, b1000, ADC map, and FLAIR. Supplementary Table 2 presents the hyperparameter details for each model.
For the clinical data, we employed a simple, fully connected neural network (FCN) consisting of three layers with eight hidden units. This FCN was trained using the Adam optimizer [27] and dropout regularization was applied to prevent overfitting.
By contrast, we used the 3D implementation version of ResNeXt [28] to extract features from the entire MR image. To enhance the performance, we incorporated the Convolutional Block Attention Module (CBAM) [29] after each ResNeXt block (Supplementary Figure 1A). CBAM is a lightweight and versatile attention mechanism that enables the model to focus on both spatial and channel features in the output feature map. It comprises two sequential submodules: a channel attention module and spatial attention module (Supplementary Figure 1B).
The channel attention module filters important information by passing the input feature maps through max-pooling and average-pooling layers, followed by a fully connected layer. The sigmoid function was applied to obtain the channel attention map MC as follows:
where F denotes the input feature map; σ, the sigmoid function; and MLP, the multilayer perceptron in the channel attention module.
The spatial attention module uses the output attention map of the channel attention module to identify locations of meaningful information. The input features sequentially undergo max pooling, average pooling, and convolutional layers to generate the spatial attention map MS:
where f7×7 denotes convolutional operation in spatial attention module.
Following ResNeXt-CBAM, the extracted features comprised 2,048 nodes for the final classification. Class probability was obtained using the sigmoid function for the dichotomized mRS.
To ensure fast convergence and robust training, we used the Rectified Adam [30] optimizer with cosine-annealing learning rate scheduling [31]. To prevent overfitting, additional strategies were employed, including early stopping and RandAugment [32], a stochastic automated data augmentation method that applies several transformation methods (Supplementary Figure 2).
The dataset used in this study exhibited a class imbalance, which could introduce bias toward the majority class during training. To mitigate this issue, we used focal loss [33] in each single-modality model, which is a modified version of the crossentropy loss that downweighs the loss assigned to well-classified examples.
Data fusion between baseline models
To improve model performance, we employed a data-fusion technique using a weighted average method. This approach combines probability vectors obtained from each model. The weights for the fusion were determined using the differential evolution method by optimizing the maximum F1 value of the ensemble model. Equation (3) illustrates the computation of the output probability distributions pi of the single-modality model using the fusion weight wi, where n denotes the number of models.
Output classes can be effectively determined by generating a hybrid probability distribution with optimized weights (Supplementary Table 3). The final mRS prediction was derived by applying a threshold of 0.45 to calibrate the predicted probabilities on an imbalanced dataset, which was determined by maximizing the F1 score of the results from 5-fold cross validation.
Evaluation
We randomly selected 20% of the entire dataset as the test set to ensure that the distribution of output classes was identical to the remaining data. We employed two distinct approaches for model training and evaluation: standard k-fold cross-validation (CV) and time-based k-fold CV (Figure 2B). In the standard approach, a stratified 5-fold CV was used to maintain consistent proportions of output classes in each fold. In contrast, time-based k-fold CV adopts sliding window CV, a resampling technique used to manage time-series data. After sorting all data by admission date, the training data were split into multiple training and validation subsets. The window size was set to 1,000 instances in each round, and each window was divided into training and validation sets with an 80:20 split. In particular, the validation consistently preceded the training set.
We assessed the performance of the models by measuring sensitivity, specificity, positive predictive value (PPV), negative predictive value, F1 score, and area under the curve (AUC). The AUC was the primary metric. To compare our model with baseline models, we conducted a DeLong test to identify any statistically significant differences based on the models trained on the entire training dataset. The significance level for statistical tests was set at P<0.05. We further calculated 95% confidence intervals (CI) using 200 bootstrap samples.
To gain insight into the decision-making processes of each model, we used two explainable AI methods. For the clinical data model, we used the kernel Shapley Additive Explanation (SHAP) [34] to estimate the contribution of each input feature. SHAP calculates global feature importance by averaging the contributions through sample permutation. For the imaging data, we used Grad-CAM [35] to visualize significant brain regions and classify mRS outcomes. Grad-CAM determines the weights of the feature maps based on model information using the global average of the gradient. We obtained voxel-wise average heat maps using Grad-CAM for all AIS patient samples in the test data, and defined the region of interest (ROI) by applying a 50% threshold of voxel intensities. To identify the concentrated areas within the ROI mask, we used automated anatomical labeling to identify specific brain regions.
Experimental setup
All experiments were conducted on a Linux Ubuntu 20.04 LTS workstation with an Intel CPU i9-9940X 3.30 GHz, two NVIDIA GeForce GTX 2080Ti graphics cards, and 64 GB of RAM. The DL models were implemented and trained in Python 3.8.10 using TensorFlow [36] 2.9.0. For image processing, OpenCV [37] 4.7.0 and scikit-image [38] 0.19.3 were used. The scikit-learn [39] 1.1.3 package was used for model evaluation and training. The interpretability of the clinical data model was visualized using the SHAP [34] 0.41.0 package. Grad-CAM-derived ROI-to-brain anatomy mapping was analyzed using the AtlasQuery tool of the FMRIB Software Library [40] based on the MNI structural atlas, Harvard–Oxford cortical structural atlas, and Harvard–Oxford subcortical structural atlas.
Results
Subjects
The study population comprised 2,606 patients selected from the total registry. Comparison of the baseline characteristics between the study population and excluded participants showed statistically significant differences in 9 features: age, history of diabetes and hyperlipidemia, DBP, Hematocrit, HDL-C, admission NIHSS, reperfusion therapy status, and TOAST subtypes (Supplementary Table 1).
Supplementary Table 4 presents the clinical and demographic characteristics of patients. The median age and baseline NIHSS score were 70 years (IQR, 61–76) and 5 (IQR, 2–10), respectively. After 90 days, 993 (38.1%) patients had poor functional outcomes (mRS score, 3–6), whereas 1,613 (61.9%) did not. Of those with poor functional outcomes, 795 belonged to the training group and 198 to the test group. No clinical inputs exhibited significant differences between the training and test groups (all P>0.05).
Prediction performance
Table 1 and Supplementary Table 5 present the average results of the standard 5 and time-based 5-fold CV for the evaluation of the performance of the proposed multimodal model in the prediction of functional outcomes. Compared to models trained with single modalities, our model consistently achieved the highest performance, with an AUC of 0.830 in standard CV, and 0.779 in time-based CV (95% confidence interval [CI]: 0.740, 0.844). All baseline models based on a single MR scan exhibited lower AUC values than our ensemble model.
The receiver operating characteristic curve (ROC) plot illustrated that the ensemble model outperformed those trained using a single modality. The proposed model showed a statistically significant improvement over clinical data (P=0.004), b1000 (P<0.001), ADC map (P<0.001), and FLAIR (P<0.001) on comparison of the ROC curves in the DeLong test (Figure 3).
Interpretable model analysis
In the clinical data model, SHAP values quantified the contribution of each feature (Figure 4) to the model results. Analysis using SHAP values revealed that age and baseline NIHSS score were the most influential features, whereas other clinical features had relatively minor impacts.
The image models generated average ROI heat maps using Grad-CAM, focusing on the infarcted area in the left hemisphere (Figure 5). Analysis of the different classification groups, true positive (TP), true negative (TN), false positive (FP), and false negative (FN), revealed a consistent ROI, but the intensity in the ROI for TN, FP, and FN was weaker than that for TP (Supplementary Figure 3). The precise anatomical locations of the ROI were consistent with the findings in the left cerebellum and temporo– occipital regions.
Discussion
In this study, we developed an ensemble DL model combining routinely collected multimodal imaging and clinical data to predict the functional outcomes of patients with AIS. Our approach involves the use of 3D CNN models to extract low-level features directly from high-dimensional input images combined with clinical model outputs. This integration led to an improved performance compared to the models trained by each single modality. Techniques such as image augmentation and focal loss were employed to minimize bias derived from data imbalance. Consequently, the final ensemble model achieved an AUC of 0.830 (standard CV) and 0.779 (time-based CV), outperforming the single-modality models.
The key strength of our study was our use of data collected from a multicenter registry, which provided a diverse and representative sample of patients with AIS. This broad coverage enabled training on various stroke types and locations, thus contributing to an improved model performance. Furthermore, the inclusion of diverse imaging protocols from multiple centers adds robustness to the prediction output of the model.
While the ensemble model achieved the highest average performance, the clinical- only model also showed good performance. However, it remains important to acknowledge the vulnerability of clinical data to data drift, as indicated by the wide range of 95% CIs for the AUC (from 0.662 to 0.830), despite the utilization of a multicenter data source. In contrast, our ensemble model exhibited robust performance, providing valuable mitigation against performance fluctuations caused by out-of-data distributions between the training and test sets. This stability makes the ensemble model particularly advantageous in real-world clinical settings with variations across different hospitals or imaging facilities.
To explore the relationship between the input data and poor functional outcomes, we conducted visual analyses using SHAP and Grad-CAM for each clinical and imaging input. The SHAP plot demonstrated a significant influence of age and NIHSS score on the prediction of functional outcomes based on clinical data, consistent with previous studies [15]. Grad-CAM was used to show which brain regions the model focused on. Interestingly, the average Grad-CAM graphs included distinct brain areas, rather than entire lesions. These findings suggest that poor functional outcomes are correlated with the left cerebellum and temporo–occipital region. Stroke functional outcomes were evaluated using the mRS, which evaluates the level of disability or dependency in daily activities influenced by motor function, balance, and visual function. However, the NIHSS assigns fewer points to ataxia and visual function (two and three points, respectively, out of 42 NIHSS points) than to motor function (19 points, including four points for each limb and three points for facial function). Consequently, given the significant clinical factors of age and the NIHSS score, in the context of imaging factors, the cerebellum, which is responsible for balance control, and the temporo–occipital cortex, which houses the optic pathway, may have been associated with unfavorable outcomes.
Despite these promising results, this study had several limitations. First, our ensemble model was dependent on MR imaging, making it unsuitable for use in many institutions that primarily use CT-based imaging for stroke diagnosis. Thus, the application of our model may be limited to facilities with MRI capabilities. Furthermore, our data may have been subject to several biases. Our study population comprised patients who visited the stroke center within 24 hours of stroke onset, exhibiting higher baseline severity, with an average admission NIHSS score of 5, compared to 3 in the excluded group. This severity mismatch can affect the distribution of each clinical feature between the study population and excluded patients, which could potentially limit the model generalizability [41]. The lack of full lesion growth may also have negatively affected the performance of the model, as most images in the dataset were early baseline images. Similar research has reported improved results using day 1 follow-up images, in which the lesion sizes were clearly seen [21]. Moreover, our model’s performance should be further validated in real-world settings as real-world data often show data drift due to variations in distribution over time or across different data sources [42]. Future studies should investigate the effects of these biases to provide more comprehensive insights.
Model training also encountered certain challenges. Training DL models from scratch is inherently difficult as they require extensive data, while small datasets can lead to overfitting. Although we attempted to mitigate overfitting using techniques such as RandAugment and early stopping, the size of the dataset remained limited. To address this, future studies should explore a transfer learning approach with a model pretrained on a larger external dataset, potentially improving the performance and mitigating overfitting. Another challenge during model training is the use of the entire image as an input, which may introduce noise during model training. To overcome these issues, we employed an attention mechanism to enable the model to focus more on the ROI and generate accurate feature extractions. Despite these efforts, some cases still exhibit noise owing to individual differences in lesion size and location. Future research should include strategies to filter out these noises to improve the model performance.
Conclusions
In this study, we constructed a comprehensive model for predicting the 90-day functional outcomes using multiple MR modalities and clinical metadata from a multicenter registry. This model was superior to other prediction models that rely on a single modality.
Supplementary materials
Supplementary materials related to this article can be found online at https://doi.org/10.5853/jos.2023.03426.
Notes
Funding statement
This research was supported by a grant from the Korea Health Technology R&D Project through the Korea Health Industry Development Institute (KHIDI), funded by the Ministry of Health and Welfare, Republic of Korea (HR18C0016), and a National IT Industry Promotion Agency (NIPA) grant funded by the Korean government (MSIT) (No. S0252-21-1001; Seoul, Republic of Korea).
Conflicts of interest
The authors have no financial conflicts of interest.
Author contribution
Conceptualization: DWK, HSJ. Study design: HSJ, DWK. Methodology: HSJ. Data collection: all authors. Investigation: HSJ. Statistical analysis: HSJ. Writing—original draft: HSJ. Writing—review & editing: all authors. Funding acquisition: DWK. Approval of final manuscript: all authors.