Detecting Diabetic Retinopathy to Stop Blindness using Pretrained Deep Learning Models
Posted on Thu 26 September 2019 in posts • 20 min read
Introduction¶
Millions of people suffer from diabetic retinopathy, the leading cause of blindness among working aged adults. Aravind Eye Hospital in India hopes to detect and prevent this disease among people living in rural areas where medical screening is difficult to conduct.
Currently, Aravind technicians travel to these rural areas to capture images and then rely on highly trained doctors to review the images and provide diagnosis. Their goal is to scale their efforts through technology; to gain the ability to automatically screen images for disease and provide information on how severe the condition may be.
In this project, we will try to build a deep learning model to classify thousands of eye images from Aravind Eye Hospital based on the severity of the diabetic retinopathy.
This project is inspired by APTOS 2019 Blindness Detection competition on Kaggle.
For reference, you can find the jupyter notebook in my Github repo, any feedback is appreciated!
Setting up the Training Data¶
We will download the dataset from Kaggle, for further information about using Kaggle API, check the Kaggle API's repository
After we athenticated our request, we can now download it:
!kaggle competitions download -c aptos2019-blindness-detection
!ls
!unzip aptos2019-blindness-detection.zip
!unzip train_images.zip
Reading the Train Set¶
let's start by reading the train set:
import tensorflow as tf
tf.__version__
import pandas as pd
import numpy as np
SEED = 42
np.random.seed(SEED)
train=pd.read_csv("/content/train.csv")
train.head()
len(train)
Based on the data description provided on Kaggle, we are provided with a large set of retina images taken using fundus photography under a variety of imaging conditions.
A clinician has rated each image for the severity of diabetic retinopathy on a scale of 0 to 4:
0 - No DR
1 - Mild
2 - Moderate
3 - Severe
4 - Proliferative DR
Like any real-world data set, we will encounter noise in both the images and labels. Images may contain artifacts, be out of focus, underexposed, or overexposed. The images were gathered from multiple clinics using a variety of cameras over an extended period of time, which will introduce further variation. In the upcoming steps we will try to preprocess the images to better highlight some important features.
Let's read the test data into a Pandas dataframe:
test=pd.read_csv("/content/test.csv")
Next, let's create a dictionnary that divides the images ids into 4 classes based on the diagnosis:
dict={0:"No DR",
1:"Mild",
2:"Moderate",
3:"Severe",
4:"Profilerative DR"}
dict
diags={}
for k in dict.keys():
diags[k]=train[train["diagnosis"]==k]
Let's now check the distribution of each class in the train set:
import matplotlib.pyplot as plt
%matplotlib inline
cases=train["diagnosis"].value_counts(normalize=True,ascending=False)*100
ax=cases.plot(kind="bar")
ax.set_xticklabels(cases.index,rotation=0)
ax.set_title("Diagnosis Classes Distribution")
for i, v in enumerate(cases.values):
ax.text(i-0.2, v+0.5, str(round(v,3)))
for sp in ax.spines.keys():
ax.spines[sp].set_visible(False)
ax.set_yticks([])
plt.show()
We notice that the classes are not uniformly distributed; we notice that around 50% of the train data is classified as part of class 0, and around 27.28% classified as part of class 1. These 2 classes only accouts for more that 75% of the training data, which will inrease the risks that the trained model to be largely affected by this bias in data.
One solution is to calculate a weight for each class that will balance out the skew in the data and even out the distribution, which we will feed later on to the model as a seperate parameter.
weights=1./(train["diagnosis"].value_counts(normalize=True))
weights=weights/weights[0]
weights
Exploratory Data Analysis¶
Let's examine some random images of each class:
import cv2
from google.colab.patches import cv2_imshow
fig=plt.figure(figsize=(15,25))
for i,k in enumerate(diags.keys()):
for j,v in enumerate(np.random.choice(diags[k]["id_code"],size=5)):
ax=fig.add_subplot(5,5,5*(i) +(j+1),xticks=[], yticks=[])
image=cv2.imread(v+".png",cv2.COLOR_BGR2RGB)
image=cv2.resize(image,(150,150))
plt.imshow(image)
ax.set_title("label:"+v+"\n diagnosis:"+str(k))
plt.show()