|
a |
|
b/main.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import numpy as np |
|
|
3 |
import argparse |
|
|
4 |
from ast import literal_eval |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
from gensim.models import Word2Vec |
|
|
8 |
import logging |
|
|
9 |
|
|
|
10 |
from src.bert.bert_model import BERTclassifier |
|
|
11 |
from src.bert.bert_dataset import BERTdataset |
|
|
12 |
from src.bert.bert_train import bert_fit |
|
|
13 |
from src.bert.bert_utils import bert_test_results |
|
|
14 |
|
|
|
15 |
from src.rnn.rnn_utils import count_vocab_index, get_emb_matrix |
|
|
16 |
from src.rnn.rnn_dataset import rnndataset |
|
|
17 |
from src.rnn.lstm import LSTMw2vmodel |
|
|
18 |
from src.rnn.gru import GRUw2vmodel |
|
|
19 |
|
|
|
20 |
from src.cnn.cnn_dataset import cnndataset |
|
|
21 |
from src.cnn.cnn import character_cnn |
|
|
22 |
|
|
|
23 |
from src.hybrid.hybrid_dataset import hybriddataset |
|
|
24 |
from src.hybrid.hybrid import hybrid |
|
|
25 |
from src.hybrid.hybrid_fit import hybrid_fit |
|
|
26 |
from src.hybrid.hybrid_test_results import hybrid_test_results |
|
|
27 |
|
|
|
28 |
from src.ovr.mlmodel_data import mlmodel_data |
|
|
29 |
from src.ovr.mlmodel_result import mlmodel_result |
|
|
30 |
from src.ovr.MLmodels import train_classifier |
|
|
31 |
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
32 |
|
|
|
33 |
from src.fit import fit |
|
|
34 |
from src.test_results import test_results |
|
|
35 |
|
|
|
36 |
from src.utils import dataloader |
|
|
37 |
|
|
|
38 |
def data(args): |
|
|
39 |
train_diagnosis = pd.read_csv(args.train_path) |
|
|
40 |
test_diagnosis = pd.read_csv(args.test_path) |
|
|
41 |
|
|
|
42 |
train_diagnosis['ICD9_CODE'] = train_diagnosis['ICD9_CODE'].apply(literal_eval) |
|
|
43 |
train_diagnosis['ICD9_CATEGORY'] = train_diagnosis['ICD9_CATEGORY'].apply(literal_eval) |
|
|
44 |
train_diagnosis['ICD10'] = train_diagnosis['ICD10'].apply(literal_eval) |
|
|
45 |
train_diagnosis['ICD10_CATEGORY'] = train_diagnosis['ICD10_CATEGORY'].apply(literal_eval) |
|
|
46 |
|
|
|
47 |
test_diagnosis['ICD9_CODE'] = test_diagnosis['ICD9_CODE'].apply(literal_eval) |
|
|
48 |
test_diagnosis['ICD9_CATEGORY'] = test_diagnosis['ICD9_CATEGORY'].apply(literal_eval) |
|
|
49 |
test_diagnosis['ICD10'] = test_diagnosis['ICD10'].apply(literal_eval) |
|
|
50 |
test_diagnosis['ICD10_CATEGORY'] = test_diagnosis['ICD10_CATEGORY'].apply(literal_eval) |
|
|
51 |
|
|
|
52 |
return train_diagnosis, test_diagnosis |
|
|
53 |
|
|
|
54 |
def run(args): |
|
|
55 |
|
|
|
56 |
train_diagnosis,test_diagnosis = data(args) |
|
|
57 |
|
|
|
58 |
SEED = 2021 |
|
|
59 |
torch.manual_seed(SEED) |
|
|
60 |
torch.cuda.manual_seed_all(SEED) |
|
|
61 |
torch.backends.cudnn.deterministic = True |
|
|
62 |
torch.backends.cudnn.benchmark = False |
|
|
63 |
|
|
|
64 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
65 |
|
|
|
66 |
logging.basicConfig(filename='train.log', filemode = 'w', level=logging.DEBUG) |
|
|
67 |
logging.info("Model Name: %s", args.model_name.upper()) |
|
|
68 |
logging.info("Device: %s", device) |
|
|
69 |
logging.info("Batch Size: %d", args.batch_size) |
|
|
70 |
logging.info("Learning Rate: %f", args.learning_rate) |
|
|
71 |
|
|
|
72 |
if args.model_name == "bert": |
|
|
73 |
|
|
|
74 |
learning_rate = args.learning_rate |
|
|
75 |
loss_fn = nn.BCELoss() |
|
|
76 |
opt_fn = torch.optim.Adam |
|
|
77 |
|
|
|
78 |
bert_train_dataset = BERTdataset(train_diagnosis) |
|
|
79 |
bert_test_dataset = BERTdataset(test_diagnosis) |
|
|
80 |
|
|
|
81 |
bert_train_loader, bert_val_loader, bert_test_loader = dataloader(bert_train_dataset, bert_test_dataset, args.batch_size, args.val_split) |
|
|
82 |
|
|
|
83 |
model = BERTclassifier().to(device) |
|
|
84 |
|
|
|
85 |
bert_fit(args.epochs, model, bert_train_loader, bert_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) |
|
|
86 |
bert_test_results(model, bert_test_loader, args.icd_type, device) |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
elif args.model_name == 'gru': |
|
|
90 |
learning_rate = args.learning_rate |
|
|
91 |
loss_fn = nn.BCELoss() |
|
|
92 |
opt_fn = torch.optim.Adam |
|
|
93 |
|
|
|
94 |
counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis) |
|
|
95 |
rnn_train_dataset = rnndataset(train_diagnosis, vocab2index) |
|
|
96 |
rnn_test_dataset = rnndataset(train_diagnosis, vocab2index) |
|
|
97 |
|
|
|
98 |
rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
w2vmodel = Word2Vec.load(args.w2vmodel) |
|
|
102 |
weights = get_emb_matrix(w2vmodel, counts) |
|
|
103 |
|
|
|
104 |
gruw2vmodel = GRUw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device) |
|
|
105 |
|
|
|
106 |
fit(args.epochs, gruw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) |
|
|
107 |
test_results(gruw2vmodel, rnn_test_loader, args.icd_type, device) |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
elif args.model_name == 'lstm': |
|
|
111 |
learning_rate = args.learning_rate |
|
|
112 |
loss_fn = nn.BCELoss() |
|
|
113 |
opt_fn = torch.optim.Adam |
|
|
114 |
|
|
|
115 |
counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis) |
|
|
116 |
rnn_train_dataset = rnndataset(train_diagnosis, vocab2index) |
|
|
117 |
rnn_test_dataset = rnndataset(train_diagnosis, vocab2index) |
|
|
118 |
|
|
|
119 |
rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split) |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
w2vmodel = Word2Vec.load(args.w2vmodel) |
|
|
123 |
weights = get_emb_matrix(w2vmodel, counts) |
|
|
124 |
|
|
|
125 |
lstmw2vmodel = LSTMw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device) |
|
|
126 |
|
|
|
127 |
fit(args.epochs, lstmw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) |
|
|
128 |
test_results(lstmw2vmodel, rnn_test_loader, args.icd_type, device) |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
elif args.model_name == "cnn": |
|
|
132 |
|
|
|
133 |
learning_rate = args.learning_rate |
|
|
134 |
loss_fn = nn.BCELoss() |
|
|
135 |
opt_fn = torch.optim.Adam |
|
|
136 |
|
|
|
137 |
cnn_train_dataset = cnndataset(train_diagnosis) |
|
|
138 |
cnn_test_dataset = cnndataset(test_diagnosis) |
|
|
139 |
|
|
|
140 |
cnn_train_loader, cnn_val_loader, cnn_test_loader = dataloader(cnn_train_dataset, cnn_test_dataset, args.batch_size, args.val_split) |
|
|
141 |
|
|
|
142 |
model = character_cnn(cnn_train_dataset.vocabulary, cnn_train_dataset.sequence_length).to(device) |
|
|
143 |
|
|
|
144 |
fit(args.epochs, model, cnn_train_loader, cnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) |
|
|
145 |
test_results(model, cnn_test_loader, args.icd_type, device) |
|
|
146 |
|
|
|
147 |
|
|
|
148 |
elif args.model_name == 'hybrid': |
|
|
149 |
|
|
|
150 |
learning_rate = args.learning_rate |
|
|
151 |
loss_fn = nn.BCELoss() |
|
|
152 |
opt_fn = torch.optim.Adam |
|
|
153 |
|
|
|
154 |
counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis) |
|
|
155 |
|
|
|
156 |
hybrid_train_dataset = hybriddataset(train_diagnosis, vocab2index) |
|
|
157 |
hybrid_test_dataset = hybriddataset(train_diagnosis, vocab2index) |
|
|
158 |
|
|
|
159 |
hybrid_train_loader, hybrid_val_loader, hybrid_test_loader = dataloader(hybrid_train_dataset, hybrid_test_dataset, args.batch_size, args.val_split) |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
w2vmodel = Word2Vec.load(args.w2vmodel) |
|
|
163 |
weights = get_emb_matrix(w2vmodel, counts) |
|
|
164 |
|
|
|
165 |
model = hybrid(hybrid_train_dataset.vocabulary, hybrid_train_dataset.sequence_length, weights_matrix = weights, hidden_size = 256, num_layers = 2).to(device) |
|
|
166 |
|
|
|
167 |
hybrid_fit(args.epochs, model, hybrid_train_loader, hybrid_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) |
|
|
168 |
hybrid_test_results(model, hybrid_test_loader, args.icd_type, device) |
|
|
169 |
|
|
|
170 |
elif args.model_name == 'ovr': |
|
|
171 |
|
|
|
172 |
X_train, y_train = mlmodel_data(train_diagnosis, args.icd_type) |
|
|
173 |
X_test, y_test = mlmodel_data(test_diagnosis, args.icd_type) |
|
|
174 |
|
|
|
175 |
tfidf_vectorizer = TfidfVectorizer(max_df = 0.8) |
|
|
176 |
X_train = tfidf_vectorizer.fit_transform(X_train) |
|
|
177 |
X_test = tfidf_vectorizer.transform(X_test) |
|
|
178 |
|
|
|
179 |
ml_model = train_classifier(X_train, y_train) |
|
|
180 |
y_predict = ml_model.predict(X_test) |
|
|
181 |
|
|
|
182 |
print('-'*20 + args.icd_type + '-'*20) |
|
|
183 |
mlmodel_result(y_test, y_predict) |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
|
|
|
187 |
|
|
|
188 |
|
|
|
189 |
|
|
|
190 |
|
|
|
191 |
|
|
|
192 |
if __name__ == "__main__": |
|
|
193 |
parser = argparse.ArgumentParser("Automatic Assignment of Medical Codes") |
|
|
194 |
|
|
|
195 |
parser.add_argument("--train_path", type = str, default = './data/train.csv') |
|
|
196 |
parser.add_argument("--test_path", type = str, default = './data/test.csv') |
|
|
197 |
|
|
|
198 |
parser.add_argument("--model_name", type = str, choices = ['bert', 'hybrid', 'gru', 'lstm', 'cnn', 'ovr'], default = "bert") |
|
|
199 |
parser.add_argument("--icd_type", type = str, choices = ['icd9cat', 'icd9code', 'icd10cat', 'icd10code'], default = 'icd9cat') |
|
|
200 |
|
|
|
201 |
parser.add_argument("--batch_size", type = int, default = 16) |
|
|
202 |
parser.add_argument("--val_split", type = float, default = 2/7) |
|
|
203 |
parser.add_argument("--learning_rate", type = float, default = 2e-5) |
|
|
204 |
parser.add_argument("--epochs", type = int, default = 4) |
|
|
205 |
|
|
|
206 |
parser.add_argument("--w2vmodel", type = str, default = "w2vmodel.model") |
|
|
207 |
|
|
|
208 |
args = parser.parse_args() |
|
|
209 |
run(args) |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
|
|
|
213 |
|
|
|
214 |
|
|
|
215 |
|