Code Monkey home page Code Monkey logo

stroke-prediction's Introduction

๐Ÿฅ Stroke Prediction

Using Artifical Neural Networks, Ensamble Models, and Explainable AI

Team Members in Alphabetical Order Email
Eirik Berge [email protected]
Camilla Idina Jensen Elvebakken [email protected]
Martin Ludvigsen [email protected]

To install the needed package dependencies, simply run pip install -r requirements.txt

Table of Contents

  1. TL;DR: Predicting Stroke with Advanced Statistical Methods
  2. About the Dataset
  3. Data Exploration
  4. Models Developed
  5. Explainable Artificial Intellegence
  6. Conclusion

TL;DR: Predicting Stroke with Advanced Statistical Methods

We analyze a stroke dataset and formulate various statistical models for predicting whether a patient has had a stroke based on measurable predictors. The goal is to, with the help of several easily measuable predictors such as smoking, hyptertension, age, to predict whether a person will suffer from a stroke. Since the data is heavily skewed (about 96% of the patients have never suffered a stroke), then we are forced to consider other measures than simply the accuracy of the model. As such, we develop various methods where we report both the accuracy, the recall, and the precision of the methods. For the full technical report, see

๐Ÿš€ Full Technical Report ๐Ÿš€

About Stroke and the Dataset

A stroke is a condition where the blood flow to the brain is decreased, causing cell death in the brain. One can roughly classify strokes into two main types: Ischemic stroke, which is due to lack of blood flow, and hemorrhagic stroke, due to bleeding. Both variants cause the brain to stop functioning properly. As strokes are one of the leading causes of death, it is of vital importance to understand the condition, as well as being able to predict the condition in advance so that preventive measures can be taken to decrease the chance. If you suspect that someone is experiencing a stroke (due to e.g. struggling to say simple complete sentences, or struggling to smile), then call your respective emergency number (in Norway: 113) immediately. For more information about the illness (in Norwegian), see

Helsenorge - Stroke (Hjerneslag)

The dataset stems from Kaggle - Stroke Prediction and records several details about over 5000 patients along with whether they have experienced a stroke. The complete list of recorded variables of the patients are:

  • id - A unique identifier for the patient.
  • gender - The gender of the patient (Male, Female, Other).
  • age - The age of the patient.
  • hypertension - Records if the patient has hypertension or not (0, 1).
  • heart_disease - Records if the patient has a heart disease or not (0, 1).
  • ever_married - Records if the patient has ever been married (No, Yes).
  • work_type - What kind of work the patient has (Children, Govt_job, Never_worked, Private, Self-employed).
  • residence_type - What area the patient lives in (Rural, Urban).
  • avg_glucose_level - Records the average glucose level in the patients blood.
  • bmi - Records the Body Mass Index (BMI) of the patient.
  • smoking_status - Records the smoking status of patient (formerly smoked, never smoked, smokes, Unknown).
  • stroke - Records if the patient has had a stroke or not (0, 1). This is the response variable we try to predict.

Unfortunately, the origin of the data is confidential, so we do not have any context regarding the data other than the variables listed above. In particular, we do not know the country of origin for the patients, nor do we know why the patients filled out the information we have been presented with. If the patients already had a severe medical history so that e.g. a physician asked them to fill out the details presented, then this can heavily influence the data we have been given. With such little information about the data collected, the models we develop can only be used for illustrative/educational purposes. For further development of the project, the focus should be on better data quality rather than more advanced models.

Data Exploration

In the data, there are 201 patients where their BMI has been reported in the variable bmi. Due to this being a possible relevant variable, we have chosen to remove these patents since they only constitute 4% of the total amount of patients. For the variable gender, there are three options: Male, Female, and Other. Since there is only 1 patient whom is registered with the gender Other, we must, unfortunately, discard this patient as we will not be able to use this information in a statistically significant way. The variable smoking_status has the options never smoked, formerly smoked, smokes, and unknown. Since there are a significant amount of patents registered with unknown as their smoking status, we have chosen to include these patients in the study.

The following histogram shows the age distribution of the patients that have experienced a stroke:

We see that more old people than young people have strokes, while we seem to have a good representation of all ages in the dataset. Hence age will be an important predictor for predicting stroke. We end this section by showing a heatmap of the correlation between the different variables:

We see from the heatmap above that the response stroke does not seem particularly correlated with any of the predictors. Thus the choice of non-linear models such as ensembles and neural networks is well motivated. When using methods with the assumption that the features are independent we have to be careful though - many of the features are highly correlated with each other, for example age, ever_married and children.

Models Developed

We develop several models to predict the binary variable stroke based on the other variables. The models we develop are:

  • Logistic Regression (with Ridge Penalty)
  • A Simple Deep Neural Network
  • Random Forests (standard, weighted, and balanced)
  • Boosting (specifically XgBoost)

For most of the models, we plot ROC curves as well as Precision-Recall Curves to graphically illustrate their performance. For the balanced random forest, the curves are shown below:

For some of the models, it is important to find the best hyperparameters. This can be done in a multitude of ways; we have opted to illustrate both the use of Grid Search and Bayesian Search to find optimal hyperparameters for our models.

The various methods with their properties are listed below. As can be seen from the diagram, some of the methods perform better than others with respect to different metrics. If a model should be considered, it should carefully be chosen based on whether high accuracy, precision, or recall is the most attractive property to have.

Model Tuning Precision Recall Accuracy ROC-AUC
Logistic Regression Built-in CV 0.113 0.837 0.730 0.854
Deep Neural Network - 0.113 0.816 0.738 -
Balanced Decision Tree GridSearchCV 0.074 0.959 0.522 0.732
Random Forest GridSearchCV 0.077 0.041 0.942 0.598
Weighted Random Forest GridSearchCV 0.119 0.735 0.773 0.818
Balanced Random Forest GridSearchCV 0.085 0.878 0.619 0.791
Balanced Random Forest BayesianSearchCV 0.086 0.939 0.597 0.826
XgBoost GridSearchCV 0.100 0.837 0.694 0.848
XgBoost BayesianSearchCV 0.099 0.837 0.688 0.850

Explainable Artificial Intelligence

Since the predictions from the classification problem we are working with are very important, possibly life-changing, predictions, the need for interpretability of the model arises. Understanding in which cases the model predicts that someone will have a stroke can help us identify which features or combination of features are important for early detection and prevention. We will now try to use some of the methods from explainable AI to interpret the results from our random forest methods. We choose to interpret the XgBoost model with parameters found with Bayesian search, as this produced some of the best results.

The following plot shows which features are most important in the building of the trees in the XgBoots model:

We see that the most important features seem to be age, avg_glucose_level, and hypertension. The model does not indicate that smoking or bmi is important, which is interesting.

For the age variable, we plot below a PD plot to see the marginal effects the features have on the predicted outcome of the XgBoots. The PD plots show what the marginal effect on the stroke prediction is for a specific value of a given feature.

Conclusion

We have developed various models for predicting future strokes in patients based on a small collection of easily testable variables. The models developed varies in performance for different metrics. Moreover, the computational speed for the different methods is also varied, ranging from a few seconds to several minutes. For more information about the models developed, we recommend reading the technical report:

๐Ÿš€ Full Technical Report ๐Ÿš€

stroke-prediction's People

Contributors

cielveba avatar ebbeberge avatar martilud avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.