Diff of /main.py [000000] .. [71ad2f]

Switch to unified view

a b/main.py
1
import pandas as pd
2
import numpy as np 
3
import argparse 
4
from ast import literal_eval 
5
import torch 
6
import torch.nn as nn
7
from gensim.models import Word2Vec
8
import logging
9
10
from src.bert.bert_model import BERTclassifier
11
from src.bert.bert_dataset import BERTdataset
12
from src.bert.bert_train import bert_fit
13
from src.bert.bert_utils import bert_test_results
14
15
from src.rnn.rnn_utils import count_vocab_index, get_emb_matrix
16
from src.rnn.rnn_dataset import rnndataset
17
from src.rnn.lstm import LSTMw2vmodel
18
from src.rnn.gru import GRUw2vmodel
19
20
from src.cnn.cnn_dataset import cnndataset
21
from src.cnn.cnn import character_cnn
22
23
from src.hybrid.hybrid_dataset import hybriddataset
24
from src.hybrid.hybrid import hybrid
25
from src.hybrid.hybrid_fit import hybrid_fit
26
from src.hybrid.hybrid_test_results import hybrid_test_results
27
28
from src.ovr.mlmodel_data import mlmodel_data
29
from src.ovr.mlmodel_result import mlmodel_result
30
from src.ovr.MLmodels import train_classifier
31
from sklearn.feature_extraction.text import TfidfVectorizer
32
33
from src.fit import fit
34
from src.test_results import test_results
35
36
from src.utils import dataloader
37
38
def data(args):
39
    train_diagnosis = pd.read_csv(args.train_path)
40
    test_diagnosis = pd.read_csv(args.test_path)
41
42
    train_diagnosis['ICD9_CODE'] = train_diagnosis['ICD9_CODE'].apply(literal_eval)
43
    train_diagnosis['ICD9_CATEGORY'] = train_diagnosis['ICD9_CATEGORY'].apply(literal_eval)
44
    train_diagnosis['ICD10'] = train_diagnosis['ICD10'].apply(literal_eval)
45
    train_diagnosis['ICD10_CATEGORY'] = train_diagnosis['ICD10_CATEGORY'].apply(literal_eval)
46
47
    test_diagnosis['ICD9_CODE'] = test_diagnosis['ICD9_CODE'].apply(literal_eval)
48
    test_diagnosis['ICD9_CATEGORY'] = test_diagnosis['ICD9_CATEGORY'].apply(literal_eval)
49
    test_diagnosis['ICD10'] = test_diagnosis['ICD10'].apply(literal_eval)
50
    test_diagnosis['ICD10_CATEGORY'] = test_diagnosis['ICD10_CATEGORY'].apply(literal_eval)
51
52
    return train_diagnosis, test_diagnosis
53
54
def run(args):
55
    
56
    train_diagnosis,test_diagnosis = data(args)
57
58
    SEED = 2021
59
    torch.manual_seed(SEED)
60
    torch.cuda.manual_seed_all(SEED)
61
    torch.backends.cudnn.deterministic = True
62
    torch.backends.cudnn.benchmark = False
63
64
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
    
66
    logging.basicConfig(filename='train.log', filemode = 'w', level=logging.DEBUG)
67
    logging.info("Model Name: %s", args.model_name.upper())
68
    logging.info("Device: %s", device)
69
    logging.info("Batch Size: %d", args.batch_size)
70
    logging.info("Learning Rate: %f", args.learning_rate)
71
    
72
    if args.model_name == "bert":        
73
74
        learning_rate = args.learning_rate
75
        loss_fn = nn.BCELoss()
76
        opt_fn = torch.optim.Adam
77
78
        bert_train_dataset = BERTdataset(train_diagnosis)
79
        bert_test_dataset = BERTdataset(test_diagnosis)
80
81
        bert_train_loader, bert_val_loader, bert_test_loader = dataloader(bert_train_dataset, bert_test_dataset, args.batch_size, args.val_split)
82
        
83
        model = BERTclassifier().to(device)
84
85
        bert_fit(args.epochs, model, bert_train_loader, bert_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
86
        bert_test_results(model, bert_test_loader, args.icd_type, device)
87
    
88
89
    elif args.model_name == 'gru':
90
        learning_rate = args.learning_rate
91
        loss_fn = nn.BCELoss()
92
        opt_fn = torch.optim.Adam
93
94
        counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis)
95
        rnn_train_dataset = rnndataset(train_diagnosis, vocab2index)
96
        rnn_test_dataset = rnndataset(train_diagnosis, vocab2index)
97
98
        rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split)
99
        
100
101
        w2vmodel = Word2Vec.load(args.w2vmodel)
102
        weights = get_emb_matrix(w2vmodel, counts)
103
104
        gruw2vmodel = GRUw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device)
105
        
106
        fit(args.epochs, gruw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
107
        test_results(gruw2vmodel, rnn_test_loader, args.icd_type, device)
108
109
110
    elif args.model_name == 'lstm':
111
        learning_rate = args.learning_rate
112
        loss_fn = nn.BCELoss()
113
        opt_fn = torch.optim.Adam
114
115
        counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis)
116
        rnn_train_dataset = rnndataset(train_diagnosis, vocab2index)
117
        rnn_test_dataset = rnndataset(train_diagnosis, vocab2index)
118
119
        rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split)
120
        
121
122
        w2vmodel = Word2Vec.load(args.w2vmodel)
123
        weights = get_emb_matrix(w2vmodel, counts)
124
125
        lstmw2vmodel = LSTMw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device)
126
        
127
        fit(args.epochs, lstmw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
128
        test_results(lstmw2vmodel, rnn_test_loader, args.icd_type, device)
129
130
131
    elif args.model_name == "cnn":
132
        
133
        learning_rate = args.learning_rate
134
        loss_fn = nn.BCELoss()
135
        opt_fn = torch.optim.Adam
136
137
        cnn_train_dataset = cnndataset(train_diagnosis)
138
        cnn_test_dataset = cnndataset(test_diagnosis)
139
140
        cnn_train_loader, cnn_val_loader, cnn_test_loader = dataloader(cnn_train_dataset, cnn_test_dataset, args.batch_size, args.val_split)
141
142
        model = character_cnn(cnn_train_dataset.vocabulary, cnn_train_dataset.sequence_length).to(device)
143
144
        fit(args.epochs, model, cnn_train_loader, cnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
145
        test_results(model, cnn_test_loader, args.icd_type, device)
146
147
148
    elif args.model_name == 'hybrid':
149
        
150
        learning_rate = args.learning_rate
151
        loss_fn = nn.BCELoss()
152
        opt_fn = torch.optim.Adam
153
154
        counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis)
155
156
        hybrid_train_dataset = hybriddataset(train_diagnosis, vocab2index)
157
        hybrid_test_dataset = hybriddataset(train_diagnosis, vocab2index)
158
159
        hybrid_train_loader, hybrid_val_loader, hybrid_test_loader = dataloader(hybrid_train_dataset, hybrid_test_dataset, args.batch_size, args.val_split)
160
          
161
162
        w2vmodel = Word2Vec.load(args.w2vmodel)
163
        weights = get_emb_matrix(w2vmodel, counts)
164
165
        model = hybrid(hybrid_train_dataset.vocabulary, hybrid_train_dataset.sequence_length, weights_matrix = weights, hidden_size = 256, num_layers = 2).to(device)
166
167
        hybrid_fit(args.epochs, model, hybrid_train_loader, hybrid_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
168
        hybrid_test_results(model, hybrid_test_loader, args.icd_type, device)
169
170
    elif args.model_name == 'ovr':
171
       
172
        X_train, y_train = mlmodel_data(train_diagnosis, args.icd_type)
173
        X_test, y_test = mlmodel_data(test_diagnosis, args.icd_type)
174
175
        tfidf_vectorizer = TfidfVectorizer(max_df = 0.8)
176
        X_train = tfidf_vectorizer.fit_transform(X_train)
177
        X_test = tfidf_vectorizer.transform(X_test)
178
179
        ml_model = train_classifier(X_train, y_train)
180
        y_predict = ml_model.predict(X_test)
181
182
        print('-'*20 + args.icd_type + '-'*20)
183
        mlmodel_result(y_test, y_predict)
184
185
186
187
188
189
190
191
        
192
if __name__ == "__main__":
193
    parser = argparse.ArgumentParser("Automatic Assignment of Medical Codes")
194
195
    parser.add_argument("--train_path", type = str, default = './data/train.csv')
196
    parser.add_argument("--test_path", type = str, default = './data/test.csv')
197
198
    parser.add_argument("--model_name", type = str, choices = ['bert', 'hybrid', 'gru', 'lstm', 'cnn', 'ovr'], default = "bert")
199
    parser.add_argument("--icd_type", type = str, choices = ['icd9cat', 'icd9code', 'icd10cat', 'icd10code'], default = 'icd9cat')
200
201
    parser.add_argument("--batch_size", type = int, default = 16)
202
    parser.add_argument("--val_split", type = float, default = 2/7)
203
    parser.add_argument("--learning_rate", type = float, default = 2e-5)
204
    parser.add_argument("--epochs", type = int, default = 4)
205
206
    parser.add_argument("--w2vmodel", type = str, default = "w2vmodel.model")
207
208
    args = parser.parse_args()
209
    run(args)
210
211
212
213
214
215