🚀 We're looking for ML Engineers and Medical Reviewers! Join the OpenPHR Mission →
Tutorial

How to Train an EHR Mortality Prediction Model with PhysioNet 2012

Difficulty: Advanced Time: 25 min read

Introduction

In this cookbook, we will cover how to download the PhysioNet 2012 dataset, preprocess multivariate clinical time-series data, and train an LSTM to predict in-hospital mortality.

Prerequisites

  • Python 3.10+
  • Pandas, Scikit-learn, and PyTorch installed
  • Access to the PhysioNet 2012 dataset download page

Step 1: Download and Parse the Data

First, download the PhysioNet 2012 dataset. The data comes as raw text files, one for each patient, containing timestamped vital signs and lab results over the first 48 hours of an ICU stay.

wget https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz
tar -xvzf set-a.tar.gz

Step 2: Time-Series Imputation

EHR data is notoriously sparse. We must impute missing values using forward-filling and mean imputation.

import pandas as pd
import numpy as np

def parse_patient_file(filepath):
    df = pd.read_csv(filepath, header=0)
    # Pivot time-series data
    df = df.pivot(index='Time', columns='Parameter', values='Value')
    # Forward fill then fill remaining NaNs with 0 (or feature means)
    df = df.ffill().fillna(0)
    return df.values

Step 3: Training an LSTM in PyTorch

Because the data consists of sequences, Recurrent Neural Networks like LSTMs perform exceptionally well.

import torch
import torch.nn as nn

class MortalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :]) # Take last hidden state
        return self.sigmoid(out)

Conclusion

Training an EHR predictive model requires significant data wrangling, but building an LSTM on top of the PhysioNet dataset is an excellent way to benchmark clinical machine learning architectures.