Stroke Prediction

Constructing prediction model for the risk of stroke

Samantha Lin
Geek Culture

--

Photo by Austrian National Library on Unsplash

1. Introduction

According to World Health Organisation (WHO), stroke are the second leading cause of death and the third leading cause of disability globally. Stroke is the sudden death of some brain cells due to lack of oxygen when the blood flow to the brain is lost by blockage or rupture of an artery to the brain, it is also a leading cause of dementia and depression.

Nearly 800,000 people in the United States suffer from a stroke each year, with about three in four being first-time strokes. 80% of the time these strokes can be prevented, so putting in place proper education on the signs of stroke is very important.

The objective of this study is to construct a prediction model for predicting stroke and to assess the accuracy of the model. We will explore seven different models to see which produces reliable and repeatable results. The models are: Decision Tree, Logistic Regression, Random Forest, Support Vector Machine, K Nearest Neighbour, Naive Bayes and KMeans Clustering. From the prediction outcome of the models, the best performance model will undergo the cross validation process to evaluate its repeatability.

2. Data Source

A population of 5110 people are involved in this study with 2995 females and 2115 males. The dataset for this study is extracted from Kaggle data respositories (https://www.kaggle.com/datasets) to predict whether a patient is likely to get stroke based on the following attribute information:

1.  id                : unique identifier
2. gender : "Male", "Female" or "Other"
3. age : age of the patient
4. hypertension : 0 if the patient doesn't have hypertension, 1 if the patient has hypertension
5. heart_disease : 0 if the patient doesn't have any heart diseases, 1 if the patient has a heart disease
6. ever_married : "No" or "Yes"
7. work_type : "children", "Govt_jov", "Never_worked", "Private" or "Self-employed"
8. Residence_type : "Rural" or "Urban"
9. avg_glucose_level : average glucose level in blood
10. bmi : body mass index
11. smoking_status : "formerly smoked", "never smoked", "smokes" or "Unknown"
12. stroke : 1 if the patient had a stroke, 0 the patient do not have a stroke

3. Importing libraries and data

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import warnings
warnings.filterwarnings(action='ignore')
data = pd.read_csv('healthcare-dataset-stroke-data.csv')
data.head(3)

4. Data Cleaning

data_row_count, data_column_count = data.shape
print('Row Count:', data_row_count)
print('Column Count:', data_column_count)
Row Count: 5110
Column Count: 12
data.info()<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 id 5110 non-null int64
1 gender 5110 non-null object
2 age 5110 non-null float64
3 hypertension 5110 non-null int64
4 heart_disease 5110 non-null int64
5 ever_married 5110 non-null object
6 work_type 5110 non-null object
7 Residence_type 5110 non-null object
8 avg_glucose_level 5110 non-null float64
9 bmi 4909 non-null float64
10 smoking_status 5110 non-null object
11 stroke 5110 non-null int64
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB
data.isna().sum()id 0
gender 0
age 0
hypertension 0
heart_disease 0
ever_married 0
work_type 0
Residence_type 0
avg_glucose_level 0
bmi 201
smoking_status 0
stroke 0
dtype: int64

There are 201 missing values in BMI feature. A simple way to dealing with the missing values is to remove the rows with null values however this may potentially remove data that aren’t null. Thus, we will substitute missing values with mean of bmi and check if imputations are done.

data['bmi'] = data['bmi'].fillna(data['bmi'].mean())
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 id 5110 non-null int64
1 gender 5110 non-null object
2 age 5110 non-null float64
3 hypertension 5110 non-null int64
4 heart_disease 5110 non-null int64
5 ever_married 5110 non-null object
6 work_type 5110 non-null object
7 Residence_type 5110 non-null object
8 avg_glucose_level 5110 non-null float64
9 bmi 5110 non-null float64
10 smoking_status 5110 non-null object
11 stroke 5110 non-null int64
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB
data.describe()

4.1 ID

data.id.nunique()5110

The total number of unique id is same as row count. We do not need another identifier. Thus, we will drop this column.

data = data.drop(columns ='id')data.shape(5110, 11)

4.2 Gender

data.gender.value_counts()Female    2994
Male 2115
Other 1
Name: gender, dtype: int64

Gender needs to be categorized as binary variable. And from the analysis perspective, it will be tedious to create another variable for one row value (‘Others’). Hence, we will impute this single value with mode in this column.

data['gender'] = data['gender'].replace('Other', list(data.gender.mode().values)[0])
data.gender.value_counts()
Female 2995
Male 2115
Name: gender, dtype: int64

5. Exploratory Data Analysis

5.1 Categorical Feature Analysis

df_cat = ['gender','hypertension','heart_disease','ever_married','work_type','Residence_type','smoking_status', 'stroke']

fig, axs = plt.subplots(4, 2, figsize=(14,20))
axs = axs.flatten()

# iterate through each column of df_catd and plot
for i, col_name in enumerate(df_cat):
sns.countplot(x=col_name, data=data, ax=axs[i], hue =data['stroke'], palette = 'flare')
plt.title("Bar chart of")
axs[i].set_xlabel(f"{col_name}", weight = 'bold')
axs[i].set_ylabel('Count', weight='bold')

From the above count plot, some observations can be drawn:

  • hypertension: Subjects that previously diagnosed with hypertension have highly risk of having stroke.
  • heart disease: Subjects that previously diagnosed with heart disease have highly risk of having stroke.
  • ever married: Subjects that ever married have highly risk of having stroke.
  • work type: Subjects that have any work experience and in government related work have highly risk of having stroke while those with no work experience barely experienced stroke.
  • Residence type: No obvious relationship with likelihood of experiencing stroke.
  • smoking status: Being a smoker or former smoker increases risk of having a stroke.

5.2 Numerical Feature Analysis

df_num = ['age', 'avg_glucose_level', 'bmi']

fig, axs = plt.subplots(1, 3, figsize=(16,5))
axs = axs.flatten()

# iterate through each column in df_num and plot
for i, col_name in enumerate(df_num):
sns.boxplot(x="stroke", y=col_name, data=data, ax=axs[i], palette = 'Set1')
axs[i].set_xlabel("Stroke", weight = 'bold')
axs[i].set_ylabel(f"{col_name}", weight='bold')

From the above boxplot, some observations can be drawn:

  • age: Subjects with stroke tends to have higher mean age.
  • ave glucose level: Subjects with stroke tends to have higher average glucose level.
  • bmi: bmi index does not give much indication on the likelihood of experiencing stroke. bmi index for super obesity is 50. Outliers in this feature should be replaced to its highest limit (50).

There are total 79 counts of outliers detected.

bmi_outliers=data.loc[data['bmi']>50]
bmi_outliers['bmi'].shape
(79,)

Replace values more than 50 in bmi column to 50.

data["bmi"] = pd.to_numeric(data["bmi"])
data["bmi"] = data["bmi"].apply(lambda x: 50 if x>50 else x)

By plotting boxplot of bmi column to confirm alterations are made.

sns.boxplot(data=data,x=data["bmi"],color='green')
plt.title("Boxplot of BMI Distribution");
plt.figure(figsize=(4,4))
data['stroke'].value_counts().plot.pie(autopct='%1.1f%%', colors = ['#66b3ff','#99ff99'])
plt.title("Pie Chart of Stroke Status", fontdict={'fontsize': 14})
plt.tight_layout()

4.9% of the population in this dataset is diagnosed with stroke

5.3 Multicollinearity Analysis

Since correlation check only accept numerical variables, preprocessing the categorical variables becomes a necessary step, we need to convert these categorical variables to numbers encoded to 0 or 1. We use labelEncoder from sklearn.preprocessing as it will be easy to decode a particular label back later after predicting if required.

from sklearn.preprocessing import LabelEncoderle = LabelEncoder()

data['gender'] = le.fit_transform(data['gender'])
data['ever_married'] = le.fit_transform(data['ever_married'])
data['work_type'] = le.fit_transform(data['work_type'])
data['Residence_type'] = le.fit_transform(data['Residence_type'])
data['smoking_status'] = le.fit_transform(data['smoking_status'])

df_en = data
df_en.head()
png
corr = df_en.corr().round(2)
plt.figure(figsize=(10,7))
sns.heatmap(corr, annot = True, cmap = 'RdYlGn');

From the above correlation matrix, we can verify the presence of multicollinearity between some of the variables. For instance, the ever_married and age column has a correlation of 0.68. Between this two attributes, age contains more information on whether one is susceptible to stroke. Thus, we will drop the ever_married column.

df_en = df_en.drop(['ever_married'], axis = 1)df_en.head(3)

5.4 Final Preprocessing

Variables that are measured at different scales do not contribute equally to model fitting and might end up creating a bias. Thus, to deal with this potential problem feature standardization (μ=0, σ=1) is usually used prior to model fitting. We have created an object of StandardScaler() and then applied fit_transform() function to apply standardization on ‘avg_glucose_level’,’bmi’ and ‘age’ columns.

from sklearn.preprocessing import StandardScaler
s = StandardScaler()
columns = ['avg_glucose_level','bmi','age']
stand_scaled = s.fit_transform(df_en[['avg_glucose_level','bmi','age']])
stand_scaled = pd.DataFrame(stand_scaled,columns=columns)

df_en=df_en.drop(columns=columns,axis=1)
stand_scaled.head()
df = pd.concat([df_en, stand_scaled], axis=1)
df.head(3)

6. Modelling

All the predictor variables will be mapped to an array x and the target variable to an array y. The target variable is ‘stroke’ column.

x=df.drop(['stroke'], axis=1)
y=df['stroke']
# Models
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.cluster import KMeans

# Evaluation
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state= 124)

Building models with their parameters and storing them into a dictionary. We will explore 7 algorithms to see which produces reliable and repeatable results. The 7 algorithms are:

  • Decision Tree
  • Logistic Regression
  • Random Forest
  • Support Vector Machine
  • K Nearest Neighbour
  • Naive Bayes
  • KMeans Clustering
models = dict()
models['Decision Tree'] = DecisionTreeClassifier()
models['Logreg'] = LogisticRegression()
models['Random Forest'] = RandomForestClassifier()
models['Support Vector Machine'] = SVC(kernel = 'sigmoid', gamma='scale')
models['kNN'] = KNeighborsClassifier()
models['Naive Bayes'] = GaussianNB()
models['KMeans'] = KMeans(n_clusters=2, n_init=10, random_state=42)
for model in models:

models[model].fit(x_train, y_train)
print(model + " model fitting completed.")
Decision Tree model fitting completed.
Logreg model fitting completed.
Random Forest model fitting completed.
Support Vector Machine model fitting completed.
kNN model fitting completed.
Naive Bayes model fitting completed.
KMeans model fitting completed.
print("Test Set Prediction:\n")

for x in models:

print('-'*20+x+'-'*20)
model = models[x]
y_pred = model.predict(x_test)
arg_test = {'y_true':y_test, 'y_pred':y_pred}
print(confusion_matrix(**arg_test))
print(classification_report(**arg_test))
Test Set Prediction:

--------------------Decision Tree--------------------
[[1398 66]
[ 63 6]]
precision recall f1-score support

0 0.96 0.95 0.96 1464
1 0.08 0.09 0.09 69

accuracy 0.92 1533
macro avg 0.52 0.52 0.52 1533
weighted avg 0.92 0.92 0.92 1533

--------------------Logreg--------------------
[[1464 0]
[ 69 0]]
precision recall f1-score support

0 0.95 1.00 0.98 1464
1 0.00 0.00 0.00 69

accuracy 0.95 1533
macro avg 0.48 0.50 0.49 1533
weighted avg 0.91 0.95 0.93 1533

--------------------Random Forest--------------------
[[1463 1]
[ 69 0]]
precision recall f1-score support

0 0.95 1.00 0.98 1464
1 0.00 0.00 0.00 69

accuracy 0.95 1533
macro avg 0.48 0.50 0.49 1533
weighted avg 0.91 0.95 0.93 1533

--------------------Support Vector Machine--------------------
[[1412 52]
[ 64 5]]
precision recall f1-score support

0 0.96 0.96 0.96 1464
1 0.09 0.07 0.08 69

accuracy 0.92 1533
macro avg 0.52 0.52 0.52 1533
weighted avg 0.92 0.92 0.92 1533

--------------------kNN--------------------
[[1457 7]
[ 66 3]]
precision recall f1-score support

0 0.96 1.00 0.98 1464
1 0.30 0.04 0.08 69

accuracy 0.95 1533
macro avg 0.63 0.52 0.53 1533
weighted avg 0.93 0.95 0.94 1533

--------------------Naive Bayes--------------------
[[1310 154]
[ 41 28]]
precision recall f1-score support

0 0.97 0.89 0.93 1464
1 0.15 0.41 0.22 69

accuracy 0.87 1533
macro avg 0.56 0.65 0.58 1533
weighted avg 0.93 0.87 0.90 1533

--------------------KMeans--------------------
[[ 266 1198]
[ 2 67]]
precision recall f1-score support

0 0.99 0.18 0.31 1464
1 0.05 0.97 0.10 69

accuracy 0.22 1533
macro avg 0.52 0.58 0.20 1533
weighted avg 0.95 0.22 0.30 1533

Take note that recall can be thought of as a measure of a classifiers completeness. A low recall for stroke (1) indicates many False Negatives.

print('Summary of Accuracy Score\n\n')
for i in models:
model = models[i]
print(i + ' Model: ',accuracy_score(y_test, model.predict(x_test)).round(4))
Summary of Accuracy Score


Decision Tree Model: 0.9159
Logreg Model: 0.955
Random Forest Model: 0.9543
Support Vector Machine Model: 0.9243
kNN Model: 0.9524
Naive Bayes Model: 0.8728
KMeans Model: 0.2172

From the above accuracy summary, Logistic Regression, Random Forest and KNN models all gives high accuracy score of 0.95. However, it is also important to consider the error type and recall value of each model. Models with 0.95 accuracy score generally have high false negative as shown in the confusion matrix. High false negative indicates type 2 error. For our study on stroke prediction, we want to avoid type 2 error as it means that we fail to identify subjects that has stroke and deem them stroke free instead. Inspecting from the classification report above, Naive Bayes Model has fit our objective although the accuracy is 0.87.

7. Cross Validation

from sklearn.model_selection import cross_val_score

gnb = GaussianNB()

scores = cross_val_score(gnb, x_train, y_train, cv = 10, scoring='accuracy')

print('Cross-validation scores:{}'.format(scores))
Cross-validation scores:[0.87430168 0.84916201 0.88826816 0.87709497 0.89944134 0.88547486
0.86592179 0.86554622 0.86834734 0.85714286]
print('Average cross-validation score: {:.4f}'.format(scores.mean()))Average cross-validation score: 0.8731
  • Using the mean cross-validation, we can conclude that we expect the model to be around 87.31% accurate on average.
  • If we look at all the 10 scores produced by the 10-fold cross-validation, we can also conclude that there is a relatively small variance in the accuracy between folds, ranging from 84.91% accuracy to 89.94% accuracy. So, we can conclude that the model is independent of the particular folds used for training.
  • Our original model accuracy is 0.8728, but the mean cross-validation accuracy is 0.8731. So, the 10-fold cross-validation accuracy does result in performance improvement for this model.

8. Conclusion

  1. Various model was used to predict whether a person is subjected to stroke. Naive Bayes model yields a very good performance as indicated by the model accuracy which was found to be 87.28%.
  2. Using the mean cross-validation, we can conclude that we expect the model to be around 87.31% accurate on average.
  3. If we look at all the 10 scores produced by the 10-fold cross-validation, we can also conclude that there is a relatively small variance in the accuracy between folds, hence the model is independent of the particular folds used for training.
  4. Our original model accuracy is 87.28% and the mean cross-validation accuracy is 87.31%. Thus, the 10-fold cross-validation accuracy does result in performance improvement for this model.
  5. Naive Bayes model can be further improve by tuning hyperparameters to get the better result or adjusting the probablity threshold to improve its performance.

--

--

Samantha Lin
Geek Culture

A development engineer exploring in the field of data analytics and additive manufacturing