Investigating Different Remedy Methods for Unbalanced Cohort With Thoracic Surgery Data
challenges: unbalanced data, medical data
Tool: R
challenges: unbalanced data, medical data
Tool: R
Many of the health medical data is considered as unbalanced. This creates a challenging problem when we use the data to predict the outcome of the disease, as many of the machine learning algorithms oftentimes fail to detect the true positive case, resulting very low sensitivity. In this project we examine different strategies in dealing with unbalanced data using Thoracic Surgery Data. Our results show that even stronger classifiers such as ensemble methods fail to correctly detect minority classes in test data. Experiments with different re-sampling methods show that SMOTE is the best strategy in re-balancing the cohort. Our conclusion is that using SMOTE with ensemble methods can significantly boost the sensitivity with minimal accuracy loss.
One data-driven approach in medial health care field is to use patients’ data to predict the outcome of an individual. However, many of the health medical data is considered unbalanced as the outcome is not usually distributed equally. This creates a challenging problem when we use the data to predict the outcome of the disease, as many of the machine learning algorithms oftentimes fail to detect the true positive case, resulting in very low sensitivity. This project is to examine different strategies in dealing with unbalanced data using a cleaned data set from UIC machine learning repository – Thoracic Surgery Data. We first outline baseline model results with no remedy methods applied, and then compare them with model performance with different strategies applied. Finally, we show how the improvement in model performance is widely applicable to other algorithms and discuss the results as well as the limitation of the project.
Our data contains in total 470 records with 17 attributes including 3 numeric variables and 14 categorical variables with no missing values. Among 14 categorical variables, 3 are multicategorical variables and 11 are binary categorical variables. Summary of numeric and multicategorical variable description is shown in Table 1. Summary of binary categorical variables is shown in Table 2 with visualization. Our task is to predict the post-operative life expectancy with the given 16 variables in the dataset, specifically we are doing a binary classification problem with True meaning the patient died within 1 year after the thoracic surgery and False meaning survived longer than 1 year after the thoracic surgery. There are 70 out of 470 people who died within 1 year after surgery in this dataset, creating unbalanced outcome cohort in our sample data.
We first conduct exploratory data analysis on the data to further inform ourselves with features. This is primarily done using density plot for numeric variables, heat map for binary variables and bar graph for multi-categorical variables. We also detect important features using Gini Index. Although this project is not focused on feature selection, this step is crucial in checking model assumptions and building informative models. Our data is divided into 85% training and 15% testing data with numeric variables standardized using training data. To compare the performance of different remedy methods, we first create baseline models using various machine learning algorithms including LDA, QDA, logistic regression and Random Forest with existing unbalanced data. Then we use various strategies to see how the model performance changes. Our primary focus of performance measurement is sensitivity and model accuracy, as our interest is to detect what kind of patients are more likely to die. Our strategies mainly include two parts: building stronger classifiers and re-constructing a balance data set. We first test to see if stronger classifiers such as Random Forest and XGBoost can deal with unbalanced data. Then we compare baseline models with the Random Forest model using re-constructed balanced data generated by various re-sampling techniques, namely under-sampling, over-sampling, and synthetic minority oversampling technique(SMOTE). Finally, we extend our results to other machine learning algorithms to see how the re-sampling strategies help improve different models’ performance. Specifically, we conduct XGBoost with balanced data generated by SMOTE to see the how the performance change. This reveals that the improvement in model performance is not specific to Random Forest, but to other algorithms as well.
Our preliminary exploratory data analysis shows that AGE is approximately normally distributed with majority of patients above 40 years old (Figure 1). We also see that some attributes are highly unbalanced, for example PRE19(MI up to 6 months), PRE25(peripheral arterial diseases) and PRE32(Asthma) have fewer than 2 people in True category for each variable while PRE10(cough before surgery) and PRE30(Smoking) have close to 80% of people in True category. This phenomenon is quite common given symptoms such as coughing is considered prevalent among lung cancer patients while Asthma and peripheral arterial diseases are rare disease among population. We also compare how the binary attributes differ between two outcomes (died within 1year vs. survived longer than 1year after surgery) using heat map (Figure 2). The number shows the proportion of True category for the outcome
specified in column with the corresponding attribute. For example, for PRE11, 78.2% of patients who had weakness before surgery survived longer than 1 year, and 21.8% of patients who had weakness before surgery died within 1 year. The differences in numbers can visualize how attributes’ distribution differ between the two groups. Although the number does not vary significantly across different attributes, we see that slightly more people who had Dyspnea and Type2 Diabetes Mellitus died within 1 year after surgery. Important features are ranked by mean decrease in accuracy and mean decrease in Gini Index (Figure 3). We see that PRE4(forced vital capacity), PRE5(Volume that has been exhaled at the end of the first second of forced expiration.), PRE14(size of original tumor), AGE (age at the time of surgery) and DGN(diagnosis) are among the most important features. Notice that PRE19, PRE25 and PRE32 are among the lowest importance, as there are fewer than 2 people in each True category, thus indicates very low predictive power. This observation corresponds with our visualization presented in Table 2. We have run different models with these important features only, but the model performance turns out to be worse so we use all 16 variables in our succeeding modeling.
For our baseline models, we used Linear Discriminant Analysis(LDA), Quadratic Discriminant Analysis(QDA), logistic regression and Random Forest. LDA assumes that the conditional distribution for each outcome(k) is multi-normal distribution with the same mean and covariance.
This assumption is hard to check given high-dimension nature of our data set. QDA assumes the same conditional distribution with less strict assumption on covariance. Notice from the assumption, both algorithms are derived based on continuous variables. R allows categorical variables, both binary and non-binary to be fed into LDA while only allows binary categorical and numeric variables in QDA. We thus build our LDA model using all 16 variables, but QDA only numeric variables. For logistic regression and Random Forest, we use all 16 variables. The model performance is summarized in Table 3.
From the result, we see that LDA, QDA, and logistic regression fails to correctly identifies True Positive case although LDA and logistic regression gives high accuracy. We originally thought logistic regression might perform better as the model does not assume any distributional assumptions, but the result shows that the model performs similarly to LDA.
The result also shows that non-linear classifiers seem to work better in detecting True Positives, although sacrificing model accuracy. Random Forest performed better than QDA and LDA and logistic regression performed somewhat the same. Next we experiment with stronger classifiers such as ensemble methods that combine weak learners with unbalanced data. To our disappointment, they either give a high accuracy with low sensitivity(XGBoost) or high sensitivity with low accuracy (Random Forest).
In the following experiment, we use stronger non-linear classifier: Random Forest with the same parameter (ntree = 1000) to test how model performs using balanced data generated by different re-sampling strategies. To validate our comparison, we use 60%majority class: 40% minority class. The first strategy is under-sampling, which means to down sample majority class while using all minority class data. We use 95 majority class data and 63 minority class data in our training data. The second strategy is over-sampling, where we over sample the minority with replacement while using all the majority class data. In order to lose minimal information, in implementation, we use all minority class data and synthesize the rest of monitory class data. We use 336 majority class data and 224 minority class data. Our third strategy is SMOTE, where we sample a subset of minority class and synthesize new data points using K-nearest neighbor. We use 189 majority class data with 126 minority class data. The summary results are shown in Table 4. We see that with balanced data, the model in general performed a lot better in terms sensitivity with some loss in accuracy. Under-sampling performed worse in detecting Survival as it down sample the majority class and thus loses a lot of information in detecting them. Over-sampling performed worse in detecting Died as the over sampling brings lots of replica of minority class data and thus cause the training model to overfit. SMOTE, on the other hand, reconciles the above two strategies by neither losing information nor causing overfit. As a result, the model performance balances in terms of accuracy and sensitivity.
Lastly, we complete the comparison by running XGBoost with balanced data. The model performance increased significantly as compared to baseline model using unbalanced data.
The experiments show that even powerful ensemble methods can fail to perform well in detecting the True Positives with unbalanced data set. The better remedy method is to generate balanced data with re-sampling techniques. Our model performs a lot better in terms of sensitivity with minor accuracy decrease with balanced data. This performance pattern is not only limited to Random Forest. Other methods such as XGBoost exhibits similar increasing ability in detecting True Positive case. From our results, Random Forest with SMOTE strategy works the best as it reconciles under-sampling and over-sampling strategies, avoiding overfitting and loss of information. There are quite a few limitations in our project. First, some of the results may only be applicable to this specific data set we are using. This is only one piece of empirical result. Second, limited amount of data records makes the drastic change in sensitivity misleading. For example, in Table 4, sensitivity change from 0.71 to 0.43 is only reflected by change in 1 data prediction. To what degree these model performances vary is not clearly shown with only 7 positive cases in our test data. One remedy is to use cross-validation to have a better estimate of model performance. In our case, we use ensemble methods so we can use out-of-bag error to estimate test error. However, because of the limitation of data, especially the limited amount of positive case, the out-of-bag error is optimistically biased. Nevertheless, the experiment results show a strong indication of how the model performance changes from using original unbalanced data to balanced data. For future work, one can experiment with feature engineering, and use Neural Network to see if it can perform well even with unbalanced data. Our hope is that when we encounter such unbalanced data in future, one can try these re-sampling methods with ensemble models to improve the model.
I would like to thank Upasana Mukherjee for critical insights on this project.
Download R code here