--- a +++ b/main.py @@ -0,0 +1,215 @@ +import pandas as pd +import numpy as np +import argparse +from ast import literal_eval +import torch +import torch.nn as nn +from gensim.models import Word2Vec +import logging + +from src.bert.bert_model import BERTclassifier +from src.bert.bert_dataset import BERTdataset +from src.bert.bert_train import bert_fit +from src.bert.bert_utils import bert_test_results + +from src.rnn.rnn_utils import count_vocab_index, get_emb_matrix +from src.rnn.rnn_dataset import rnndataset +from src.rnn.lstm import LSTMw2vmodel +from src.rnn.gru import GRUw2vmodel + +from src.cnn.cnn_dataset import cnndataset +from src.cnn.cnn import character_cnn + +from src.hybrid.hybrid_dataset import hybriddataset +from src.hybrid.hybrid import hybrid +from src.hybrid.hybrid_fit import hybrid_fit +from src.hybrid.hybrid_test_results import hybrid_test_results + +from src.ovr.mlmodel_data import mlmodel_data +from src.ovr.mlmodel_result import mlmodel_result +from src.ovr.MLmodels import train_classifier +from sklearn.feature_extraction.text import TfidfVectorizer + +from src.fit import fit +from src.test_results import test_results + +from src.utils import dataloader + +def data(args): + train_diagnosis = pd.read_csv(args.train_path) + test_diagnosis = pd.read_csv(args.test_path) + + train_diagnosis['ICD9_CODE'] = train_diagnosis['ICD9_CODE'].apply(literal_eval) + train_diagnosis['ICD9_CATEGORY'] = train_diagnosis['ICD9_CATEGORY'].apply(literal_eval) + train_diagnosis['ICD10'] = train_diagnosis['ICD10'].apply(literal_eval) + train_diagnosis['ICD10_CATEGORY'] = train_diagnosis['ICD10_CATEGORY'].apply(literal_eval) + + test_diagnosis['ICD9_CODE'] = test_diagnosis['ICD9_CODE'].apply(literal_eval) + test_diagnosis['ICD9_CATEGORY'] = test_diagnosis['ICD9_CATEGORY'].apply(literal_eval) + test_diagnosis['ICD10'] = test_diagnosis['ICD10'].apply(literal_eval) + test_diagnosis['ICD10_CATEGORY'] = test_diagnosis['ICD10_CATEGORY'].apply(literal_eval) + + return train_diagnosis, test_diagnosis + +def run(args): + + train_diagnosis,test_diagnosis = data(args) + + SEED = 2021 + torch.manual_seed(SEED) + torch.cuda.manual_seed_all(SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + logging.basicConfig(filename='train.log', filemode = 'w', level=logging.DEBUG) + logging.info("Model Name: %s", args.model_name.upper()) + logging.info("Device: %s", device) + logging.info("Batch Size: %d", args.batch_size) + logging.info("Learning Rate: %f", args.learning_rate) + + if args.model_name == "bert": + + learning_rate = args.learning_rate + loss_fn = nn.BCELoss() + opt_fn = torch.optim.Adam + + bert_train_dataset = BERTdataset(train_diagnosis) + bert_test_dataset = BERTdataset(test_diagnosis) + + bert_train_loader, bert_val_loader, bert_test_loader = dataloader(bert_train_dataset, bert_test_dataset, args.batch_size, args.val_split) + + model = BERTclassifier().to(device) + + bert_fit(args.epochs, model, bert_train_loader, bert_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) + bert_test_results(model, bert_test_loader, args.icd_type, device) + + + elif args.model_name == 'gru': + learning_rate = args.learning_rate + loss_fn = nn.BCELoss() + opt_fn = torch.optim.Adam + + counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis) + rnn_train_dataset = rnndataset(train_diagnosis, vocab2index) + rnn_test_dataset = rnndataset(train_diagnosis, vocab2index) + + rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split) + + + w2vmodel = Word2Vec.load(args.w2vmodel) + weights = get_emb_matrix(w2vmodel, counts) + + gruw2vmodel = GRUw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device) + + fit(args.epochs, gruw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) + test_results(gruw2vmodel, rnn_test_loader, args.icd_type, device) + + + elif args.model_name == 'lstm': + learning_rate = args.learning_rate + loss_fn = nn.BCELoss() + opt_fn = torch.optim.Adam + + counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis) + rnn_train_dataset = rnndataset(train_diagnosis, vocab2index) + rnn_test_dataset = rnndataset(train_diagnosis, vocab2index) + + rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split) + + + w2vmodel = Word2Vec.load(args.w2vmodel) + weights = get_emb_matrix(w2vmodel, counts) + + lstmw2vmodel = LSTMw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device) + + fit(args.epochs, lstmw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) + test_results(lstmw2vmodel, rnn_test_loader, args.icd_type, device) + + + elif args.model_name == "cnn": + + learning_rate = args.learning_rate + loss_fn = nn.BCELoss() + opt_fn = torch.optim.Adam + + cnn_train_dataset = cnndataset(train_diagnosis) + cnn_test_dataset = cnndataset(test_diagnosis) + + cnn_train_loader, cnn_val_loader, cnn_test_loader = dataloader(cnn_train_dataset, cnn_test_dataset, args.batch_size, args.val_split) + + model = character_cnn(cnn_train_dataset.vocabulary, cnn_train_dataset.sequence_length).to(device) + + fit(args.epochs, model, cnn_train_loader, cnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) + test_results(model, cnn_test_loader, args.icd_type, device) + + + elif args.model_name == 'hybrid': + + learning_rate = args.learning_rate + loss_fn = nn.BCELoss() + opt_fn = torch.optim.Adam + + counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis) + + hybrid_train_dataset = hybriddataset(train_diagnosis, vocab2index) + hybrid_test_dataset = hybriddataset(train_diagnosis, vocab2index) + + hybrid_train_loader, hybrid_val_loader, hybrid_test_loader = dataloader(hybrid_train_dataset, hybrid_test_dataset, args.batch_size, args.val_split) + + + w2vmodel = Word2Vec.load(args.w2vmodel) + weights = get_emb_matrix(w2vmodel, counts) + + model = hybrid(hybrid_train_dataset.vocabulary, hybrid_train_dataset.sequence_length, weights_matrix = weights, hidden_size = 256, num_layers = 2).to(device) + + hybrid_fit(args.epochs, model, hybrid_train_loader, hybrid_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device) + hybrid_test_results(model, hybrid_test_loader, args.icd_type, device) + + elif args.model_name == 'ovr': + + X_train, y_train = mlmodel_data(train_diagnosis, args.icd_type) + X_test, y_test = mlmodel_data(test_diagnosis, args.icd_type) + + tfidf_vectorizer = TfidfVectorizer(max_df = 0.8) + X_train = tfidf_vectorizer.fit_transform(X_train) + X_test = tfidf_vectorizer.transform(X_test) + + ml_model = train_classifier(X_train, y_train) + y_predict = ml_model.predict(X_test) + + print('-'*20 + args.icd_type + '-'*20) + mlmodel_result(y_test, y_predict) + + + + + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Automatic Assignment of Medical Codes") + + parser.add_argument("--train_path", type = str, default = './data/train.csv') + parser.add_argument("--test_path", type = str, default = './data/test.csv') + + parser.add_argument("--model_name", type = str, choices = ['bert', 'hybrid', 'gru', 'lstm', 'cnn', 'ovr'], default = "bert") + parser.add_argument("--icd_type", type = str, choices = ['icd9cat', 'icd9code', 'icd10cat', 'icd10code'], default = 'icd9cat') + + parser.add_argument("--batch_size", type = int, default = 16) + parser.add_argument("--val_split", type = float, default = 2/7) + parser.add_argument("--learning_rate", type = float, default = 2e-5) + parser.add_argument("--epochs", type = int, default = 4) + + parser.add_argument("--w2vmodel", type = str, default = "w2vmodel.model") + + args = parser.parse_args() + run(args) + + + + + +