How to Train an EHR Mortality Prediction Model with PhysioNet 2012
Introduction
In this cookbook, we will cover how to download the PhysioNet 2012 dataset, preprocess multivariate clinical time-series data, and train an LSTM to predict in-hospital mortality.
Prerequisites
- Python 3.10+
- Pandas, Scikit-learn, and PyTorch installed
- Access to the PhysioNet 2012 dataset download page
Step 1: Download and Parse the Data
First, download the PhysioNet 2012 dataset. The data comes as raw text files, one for each patient, containing timestamped vital signs and lab results over the first 48 hours of an ICU stay.
wget https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz
tar -xvzf set-a.tar.gz
Step 2: Time-Series Imputation
EHR data is notoriously sparse. We must impute missing values using forward-filling and mean imputation.
import pandas as pd
import numpy as np
def parse_patient_file(filepath):
df = pd.read_csv(filepath, header=0)
# Pivot time-series data
df = df.pivot(index='Time', columns='Parameter', values='Value')
# Forward fill then fill remaining NaNs with 0 (or feature means)
df = df.ffill().fillna(0)
return df.values
Step 3: Training an LSTM in PyTorch
Because the data consists of sequences, Recurrent Neural Networks like LSTMs perform exceptionally well.
import torch
import torch.nn as nn
class MortalityLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out, _ = self.lstm(x)
out = self.fc(out[:, -1, :]) # Take last hidden state
return self.sigmoid(out)
Conclusion
Training an EHR predictive model requires significant data wrangling, but building an LSTM on top of the PhysioNet dataset is an excellent way to benchmark clinical machine learning architectures.