Diff of /process_mimic.py [000000] .. [bab239]

Switch to unified view

a b/process_mimic.py
1
# This script processes MIMIC-III dataset and builds a binary matrix or a count matrix depending on your input.
2
# The output matrix is a Numpy matrix of type float32, and suitable for training medGAN.
3
# Written by Edward Choi (mp2893@gatech.edu)
4
# Usage: Put this script to the folder where MIMIC-III CSV files are located. Then execute the below command.
5
# python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv <output file> <"binary"|"count">
6
# Note that the last argument "binary/count" determines whether you want to create a binary matrix or a count matrix.
7
8
# Output files
9
# <output file>.pids: cPickled Python list of unique Patient IDs. Used for intermediate processing
10
# <output file>.matrix: Numpy float32 matrix. Each row corresponds to a patient. Each column corresponds to a ICD9 diagnosis code.
11
# <output file>.types: cPickled Python dictionary that maps string diagnosis codes to integer diagnosis codes.
12
13
import sys
14
import _pickle as pickle
15
import numpy as np
16
from datetime import datetime
17
18
def convert_to_icd9(dxStr):
19
    if dxStr.startswith('E'):
20
        if len(dxStr) > 4: return dxStr[:4] + '.' + dxStr[4:]
21
        else: return dxStr
22
    else:
23
        if len(dxStr) > 3: return dxStr[:3] + '.' + dxStr[3:]
24
        else: return dxStr
25
    
26
def convert_to_3digit_icd9(dxStr):
27
    if dxStr.startswith('E'):
28
        if len(dxStr) > 4: return dxStr[:4]
29
        else: return dxStr
30
    else:
31
        if len(dxStr) > 3: return dxStr[:3]
32
        else: return dxStr
33
34
if __name__ == '__main__':
35
    admissionFile = sys.argv[1]
36
    diagnosisFile = sys.argv[2]
37
    outFile = sys.argv[3]
38
    binary_count = sys.argv[4]
39
40
    if binary_count != 'binary' and binary_count != 'count':
41
        print('You must choose either binary or count.')
42
        sys.exit()
43
44
    print('Building pid-admission mapping, admission-date mapping')
45
    pidAdmMap = {}
46
    admDateMap = {}
47
    infd = open(admissionFile, 'r')
48
    infd.readline()
49
    for line in infd:
50
        tokens = line.strip().split(',')
51
        pid = int(tokens[1])
52
        admId = int(tokens[2])
53
        admTime = datetime.strptime(tokens[3], '%Y-%m-%d %H:%M:%S')
54
        admDateMap[admId] = admTime
55
        if pid in pidAdmMap: pidAdmMap[pid].append(admId)
56
        else: pidAdmMap[pid] = [admId]
57
    infd.close()
58
59
    print('Building admission-dxList mapping')
60
    admDxMap = {}
61
    infd = open(diagnosisFile, 'r')
62
    infd.readline()
63
    for line in infd:
64
        tokens = line.strip().split(',')
65
        admId = int(tokens[2])
66
        #dxStr = 'D_' + convert_to_icd9(tokens[4][1:-1]) ############## Uncomment this line and comment the line below, if you want to use the entire ICD9 digits.
67
        dxStr = 'D_' + convert_to_3digit_icd9(tokens[4][1:-1])
68
        if admId in admDxMap: admDxMap[admId].append(dxStr)
69
        else: admDxMap[admId] = [dxStr]
70
    infd.close()
71
72
    print('Building pid-sortedVisits mapping')
73
    pidSeqMap = {}
74
    for pid, admIdList in pidAdmMap.items():
75
        #if len(admIdList) < 2: continue
76
        sortedList = sorted([(admDateMap[admId], admDxMap[admId]) for admId in admIdList])
77
        pidSeqMap[pid] = sortedList
78
    
79
    print('Building pids, dates, strSeqs')
80
    pids = []
81
    dates = []
82
    seqs = []
83
    for pid, visits in pidSeqMap.items():
84
        pids.append(pid)
85
        seq = []
86
        date = []
87
        for visit in visits:
88
            date.append(visit[0])
89
            seq.append(visit[1])
90
        dates.append(date)
91
        seqs.append(seq)
92
    
93
    print('Converting strSeqs to intSeqs, and making types')
94
    types = {}
95
    newSeqs = []
96
    for patient in seqs:
97
        newPatient = []
98
        for visit in patient:
99
            newVisit = []
100
            for code in visit:
101
                if code in types:
102
                    newVisit.append(types[code])
103
                else:
104
                    types[code] = len(types)
105
                    newVisit.append(types[code])
106
            newPatient.append(newVisit)
107
        newSeqs.append(newPatient)
108
109
    print('Constructing the matrix')
110
    numPatients = len(newSeqs)
111
    numCodes = len(types)
112
    matrix = np.zeros((numPatients, numCodes)).astype('float32')
113
    for i, patient in enumerate(newSeqs):
114
        for visit in patient:
115
            for code in visit:
116
                if binary_count == 'binary':
117
                    matrix[i][code] = 1.
118
                else:
119
                    matrix[i][code] += 1.
120
121
    pickle.dump(pids, open(outFile+'.pids', 'wb'), -1)
122
    pickle.dump(matrix, open(outFile+'.matrix', 'wb'), -1)
123
    pickle.dump(types, open(outFile+'.types', 'wb'), -1)