[2d4573]: / mimic_icd9_coding / coding_pipeline.py

Download this file

119 lines (101 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#%%
import itertools
import os
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from .utils.preprocessing import clean_data
from torchnlp.utils import lengths_to_mask
from .utils.BERTRunner import run_BERT, convertBERT
#%%
class codingPipeline:
"""Base class for the mimic icd9 classification pipeline"""
def __init__(self, model='forest', data_path='../data/', text_col='TEXT', label_col='TARGET', verbose=True, bert_fast_dev_run=False, run=True):
self.label_col=label_col
self.text_col=text_col
self.data_path = data_path
self.prep_path = data_path + 'mimic_prep.csv'
if model == 'BERT':
self.model_type='BERT'
if run:
self.model = self.run_bert_model(bert_fast_dev_run)
else:
self.model_type='notBERT'
if not os.path.exists(self.data_path):
os.mkdir(self.data_path)
if not os.path.exists(self.prep_path):
df = self.make_preproc_data()
self.df = df
else:
df = pd.read_csv(self.prep_path, converters={self.label_col: eval}, index_col=0)
if run:
self.run_model(df, model, verbose=verbose)
def load_data(self):
"""Load in the processed data for testing"""
df = pd.read_csv(self.prep_path, converters={self.label_col: eval}, index_col=0)
self.data = df
def run_bert_model(self, bert_fast_dev_run):
"""Helper function to run the bert model"""
model = run_BERT(self.data_path, bert_fast_dev_run)
self.model = model
return model
def predict(self, item):
"""Run trained model on datum"""
if self.model_type =='BERT':
tokens, masked_item = convertBERT(item)
import torch
device = torch.device("cuda")
with torch.no_grad():
preds = self.model(tokens.to(device), masked_item.to(device))
preds = preds.detach().cpu().numpy()
return np.argmax(preds, axis=1)
if not isinstance(item, list):
item = [item]
item = self.vectorizer.transform(item)
return self.mlBinarizer.inverse_transform(self.model.predict(item))
def make_preproc_data(self):
"""Preprocess data"""
df = pd.read_csv(self.data_path + 'mimic_full.csv', converters={self.label_col: eval})
df[self.text_col] = clean_data(df[self.text_col])
df.to_csv(self.data_path + 'mimic_prep.csv')
return df
def run_model(self, df, model, report=False, verbose=True):
"""Run a sklearn-based model"""
X = self.vectorize(df[self.text_col])
y = self.mlb(df[self.label_col])
if model =='forest':
self.model = RandomForestClassifier(n_jobs=-1)
else:
self.model = model
X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.3, random_state=123)
self.model.fit(X_train, y_train)
# test and predict
y_pred = self.model.predict(X_test)
# Classification metrics
classification_report_test = classification_report(y_test, y_pred)
self.auroc = roc_auc_score(y_test, y_pred)
self.report = classification_report(y_test, y_pred, output_dict=True)
# classification_report_train = classification_report(y_train, self.model.predict(X_train), output_dict=True)
if verbose:
print('\nClassification Report')
print('======================================================')
print('\n', classification_report_test)
def vectorize(self, series):
"""Helper function to vectorize features"""
td = TfidfVectorizer(max_features=4500)
transformed = td.fit_transform(series)
self.vectorizer = td
return transformed
def mlb(self, labels):
"""Helper function to transform labels into an encoding schema
using sklearns multi label binarizer"""
all_labels = set(itertools.chain.from_iterable(labels))
self.labels = all_labels
mlBinarizer = MultiLabelBinarizer()
mlBinarizer.fit([list(all_labels)])
self.mlBinarizer = mlBinarizer
return mlBinarizer.transform(labels)