Skip to content

Final project for Introduction to Reinforcement Learning for MSDS at University of San Francisco

Notifications You must be signed in to change notification settings

zs-barnes/RL-Sepsis-Prediction

Repository files navigation

Classifying Sepsis Patients with Reinforcement Learning¶

We designed a reinforcement learning environment and model to classify patients with sepsis at each hour.

A video presentation of our project can be found here, with the accompanying slide deck here.

The step-by-step results of our project can be found in our notebook here.

Project Contributors: Zachary Barnes & Mundy Reimer

Introduction

Sepsis is a life-threatening condition that arises when the body's response to infection causes injury to its tissues and organs. It is the most common cause of death for people who have been hospitalized, and results in a $15.4 billion annual cost in the US. Early detection and treatment are essential for prevention and a 1-hour delay in antibiotic treatment can lead to 4% increase in hospital mortality. Given the nature of our data as a multivariate timeseries of patient vital signs, this makes this an ideal classification problem to apply reinforcement learning methods to.

Data

physionet_logo

We used a public data set from the PhysioNet Computing Challenge which can be downloaded here.

An explanation by the PhysioNet Challenge is given below:

Data used in the competition is sourced from ICU patients in three separate hospital systems.

The data repository contains one file per subject (ex - training/p00101.psv). Each training data file provides a table with measurements over time. Each column of the table provides a sequence of measurements over time (ex - heart rate over several hours), where the header of the column describes the measurement. Each row of the table provides a collection of measurements at the same time (ex - heart rate and oxygen level). The table is formatted in the following way:

physionet_data_table

There are 40 time-dependent variables HR, O2Sat, Temp ..., HospAdmTime, which are described here. The final column, SepsisLabel, indicates the onset of sepsis according to the Sepsis-3 definition, where 1 indicates sepsis and 0 indicates no sepsis. Entries of NaN (not a number) indicate that there was no recorded measurement of a variable at the time interval.

timeseries

RL Framework

Our Reinforcement Learning environment is using OpenAI's gym.

For step-by-step instructions for how to set up your environment, see the section below on Setup.

To create this environment, we referenced:

  • How to create a custom gym environment with RL training code here.
  • Creating RL algorithms using the Stable Baselines package here.

We can briefly frame our reinforcement learning problem as such:

  • Environment - SepsisEnv modeled using OpenAI Gym, where we have a sequential multivariate timeseries of patients' vital signs
  • Agent - A binary classifier that predicts whether patients have sepsis or not
  • States - Each timestep that contains multiple patient vital signs taken at the same time
  • Actions - Binary prediction of whether a patient has sepsis (1) or does not (0)
  • Rewards - The calculated score between 1 and -2 based on the utility function calculated from true/false positive and true/false negative rates

The algorithm will be evaluated by its performance as a binary classifier using a utility function created by the PhysioNet Challenge. This utility function rewards classifiers for early predictions of sepsis and penalizes them for late predictions and for predictions of sepsis in non-sepsis patients.

The PhysioNet Challenge defines a score U(s,t) for each prediction. This will be done for each line in the data file that represents each patient s and each time interval t:

physionet_utility

The following figure shows the utility function for a sepsis patient with t_sepsis = 48 as an example (figure from PhysioNet Challenge):

physionet_utility_plot

Evaluation

We then compared performance across multiple algorithms. You can check out our notebook here for more.

In total, we compare:

  • Proximal Policy Optimization Algorithm + Multi-Layer Perceptron
  • Proximal Policy Optimization Algorithm + Multi-Layer Perceptron, Long-Short Term Memory
  • Proximal Policy Optimization Algorithm + Multi-Layer Perceptron, Long-Short Term Memory with Layer Normalization
  • Synchronous, deterministic variant of Asynchronous Advantage Actor Critic + Multi-Layer Perceptron
  • Synchronous, deterministic variant of Asynchronous Advantage Actor Critic + Multi-Layer Perceptron, Long-Short Term Memory
  • Deep Q Network + Multi-Layer Perceptron
  • Deep Q Network + Multi-Layer Perceptron, Long-Short Term Memory

The plot below nicely summarizes our results, with both versions of our Deep Q-Learning Network with Multi-Layer Perceptrons performing the best, all the combinations of A2C and Proximal Policy models performing worse than the Deep Q-Learning Networks, and our random baseline model performing the worst as expected:

visualization_anim

Having both our Deep Q-Learning Networks perform the best makes sense since it combines Q-Learning with the power of deep neural networks to let RL work for complex, high-dimensional environments like our multivariate space of all the patient's vital signs. Because of our large data size of over 20,000 patients, each with roughly 50 records of more than 40-time dependent variables, we can see our agent's learning stabilize over the long term. And because our space of actions is not complicated (it is a binary variable) we see that our agent can learn this classification problem. Furthermore, we can see divergence in rewards for different models with only running roughly ~1,300 patients, which is a great sign of our agent learning differing between the different combinations that we tested above. Finally, we tuned our hyperparameters for the best performing model (DQN with MLP policy) using Bayesian Optimization methods which can be found in this notebook here.

Next Steps

For future direction, we will attempt to tease out the differences between the different Deep Q-Learning Networks and explore the potential benefits or pitfalls of providing layer normalization versus not. Since this is medical data where each feature is interpretable, this also lends itself quite well to feature engineering depending upon domain expertise input provided by a medical professional. We can also run this on alternative data sets to validate our work outside these three hospitals to see if this is generalizable.


Setup

1) Install dependencies

If using conda, create an environment with python 3.7: conda create -n rl_sepsis python=3.7

Activate the environment: conda activate rl_sepsis

Then, install the necessary packages: pip install -r requirements.txt

2) Clean data

We have uploaded training set A from the physionet competition into the repo. To load and clean the data, run:

make load_data

This will take about 10 minutes, and the progress bar will be displayed using tqdm. It will create a cache\ directory, (created from the cache_em_all package) where the cleaned data will be stored.

Now, in a notebook or .py file, you can load the data with

from load_data import load_data
df = load_data()

where df is a pandas data frame.

Alternatively, once you clone this repo you can open up Load_Data.ipynb and run all the cells. If no error is thrown, then you have loaded the data successfully.

3) Add Rewards

Using the utility function provided by the competition, we have added two columns that correspond to the reward received at each hour depending on whether predicting a zero or a one.

To create the reward columns, run: make add_reward

This should only take 10-15 seconds, and will add the file "training_setA_rewards" under the cache\ directory.

4) Train Model

To see the RL train, simply run make train_model. This will train 9 models for roughly 1,300 patients. Progress bar with model, policy and rewards will be displayed.

5) Results

To see graphical results of performance for the different baseline models, see our visualization notebook.

About

Final project for Introduction to Reinforcement Learning for MSDS at University of San Francisco

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages