Diff of /main.py [000000] .. [71ad2f]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,215 @@
+import pandas as pd
+import numpy as np 
+import argparse 
+from ast import literal_eval 
+import torch 
+import torch.nn as nn
+from gensim.models import Word2Vec
+import logging
+
+from src.bert.bert_model import BERTclassifier
+from src.bert.bert_dataset import BERTdataset
+from src.bert.bert_train import bert_fit
+from src.bert.bert_utils import bert_test_results
+
+from src.rnn.rnn_utils import count_vocab_index, get_emb_matrix
+from src.rnn.rnn_dataset import rnndataset
+from src.rnn.lstm import LSTMw2vmodel
+from src.rnn.gru import GRUw2vmodel
+
+from src.cnn.cnn_dataset import cnndataset
+from src.cnn.cnn import character_cnn
+
+from src.hybrid.hybrid_dataset import hybriddataset
+from src.hybrid.hybrid import hybrid
+from src.hybrid.hybrid_fit import hybrid_fit
+from src.hybrid.hybrid_test_results import hybrid_test_results
+
+from src.ovr.mlmodel_data import mlmodel_data
+from src.ovr.mlmodel_result import mlmodel_result
+from src.ovr.MLmodels import train_classifier
+from sklearn.feature_extraction.text import TfidfVectorizer
+
+from src.fit import fit
+from src.test_results import test_results
+
+from src.utils import dataloader
+
+def data(args):
+    train_diagnosis = pd.read_csv(args.train_path)
+    test_diagnosis = pd.read_csv(args.test_path)
+
+    train_diagnosis['ICD9_CODE'] = train_diagnosis['ICD9_CODE'].apply(literal_eval)
+    train_diagnosis['ICD9_CATEGORY'] = train_diagnosis['ICD9_CATEGORY'].apply(literal_eval)
+    train_diagnosis['ICD10'] = train_diagnosis['ICD10'].apply(literal_eval)
+    train_diagnosis['ICD10_CATEGORY'] = train_diagnosis['ICD10_CATEGORY'].apply(literal_eval)
+
+    test_diagnosis['ICD9_CODE'] = test_diagnosis['ICD9_CODE'].apply(literal_eval)
+    test_diagnosis['ICD9_CATEGORY'] = test_diagnosis['ICD9_CATEGORY'].apply(literal_eval)
+    test_diagnosis['ICD10'] = test_diagnosis['ICD10'].apply(literal_eval)
+    test_diagnosis['ICD10_CATEGORY'] = test_diagnosis['ICD10_CATEGORY'].apply(literal_eval)
+
+    return train_diagnosis, test_diagnosis
+
+def run(args):
+    
+    train_diagnosis,test_diagnosis = data(args)
+
+    SEED = 2021
+    torch.manual_seed(SEED)
+    torch.cuda.manual_seed_all(SEED)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    
+    logging.basicConfig(filename='train.log', filemode = 'w', level=logging.DEBUG)
+    logging.info("Model Name: %s", args.model_name.upper())
+    logging.info("Device: %s", device)
+    logging.info("Batch Size: %d", args.batch_size)
+    logging.info("Learning Rate: %f", args.learning_rate)
+    
+    if args.model_name == "bert":        
+
+        learning_rate = args.learning_rate
+        loss_fn = nn.BCELoss()
+        opt_fn = torch.optim.Adam
+
+        bert_train_dataset = BERTdataset(train_diagnosis)
+        bert_test_dataset = BERTdataset(test_diagnosis)
+
+        bert_train_loader, bert_val_loader, bert_test_loader = dataloader(bert_train_dataset, bert_test_dataset, args.batch_size, args.val_split)
+        
+        model = BERTclassifier().to(device)
+
+        bert_fit(args.epochs, model, bert_train_loader, bert_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
+        bert_test_results(model, bert_test_loader, args.icd_type, device)
+    
+
+    elif args.model_name == 'gru':
+        learning_rate = args.learning_rate
+        loss_fn = nn.BCELoss()
+        opt_fn = torch.optim.Adam
+
+        counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis)
+        rnn_train_dataset = rnndataset(train_diagnosis, vocab2index)
+        rnn_test_dataset = rnndataset(train_diagnosis, vocab2index)
+
+        rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split)
+        
+
+        w2vmodel = Word2Vec.load(args.w2vmodel)
+        weights = get_emb_matrix(w2vmodel, counts)
+
+        gruw2vmodel = GRUw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device)
+        
+        fit(args.epochs, gruw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
+        test_results(gruw2vmodel, rnn_test_loader, args.icd_type, device)
+
+
+    elif args.model_name == 'lstm':
+        learning_rate = args.learning_rate
+        loss_fn = nn.BCELoss()
+        opt_fn = torch.optim.Adam
+
+        counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis)
+        rnn_train_dataset = rnndataset(train_diagnosis, vocab2index)
+        rnn_test_dataset = rnndataset(train_diagnosis, vocab2index)
+
+        rnn_train_loader, rnn_val_loader, rnn_test_loader = dataloader(rnn_train_dataset, rnn_test_dataset, args.batch_size, args.val_split)
+        
+
+        w2vmodel = Word2Vec.load(args.w2vmodel)
+        weights = get_emb_matrix(w2vmodel, counts)
+
+        lstmw2vmodel = LSTMw2vmodel(weights_matrix = weights, hidden_size = 256, num_layers = 2, device = device).to(device)
+        
+        fit(args.epochs, lstmw2vmodel, rnn_train_loader, rnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
+        test_results(lstmw2vmodel, rnn_test_loader, args.icd_type, device)
+
+
+    elif args.model_name == "cnn":
+        
+        learning_rate = args.learning_rate
+        loss_fn = nn.BCELoss()
+        opt_fn = torch.optim.Adam
+
+        cnn_train_dataset = cnndataset(train_diagnosis)
+        cnn_test_dataset = cnndataset(test_diagnosis)
+
+        cnn_train_loader, cnn_val_loader, cnn_test_loader = dataloader(cnn_train_dataset, cnn_test_dataset, args.batch_size, args.val_split)
+
+        model = character_cnn(cnn_train_dataset.vocabulary, cnn_train_dataset.sequence_length).to(device)
+
+        fit(args.epochs, model, cnn_train_loader, cnn_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
+        test_results(model, cnn_test_loader, args.icd_type, device)
+
+
+    elif args.model_name == 'hybrid':
+        
+        learning_rate = args.learning_rate
+        loss_fn = nn.BCELoss()
+        opt_fn = torch.optim.Adam
+
+        counts, vocab2index = count_vocab_index(train_diagnosis, test_diagnosis)
+
+        hybrid_train_dataset = hybriddataset(train_diagnosis, vocab2index)
+        hybrid_test_dataset = hybriddataset(train_diagnosis, vocab2index)
+
+        hybrid_train_loader, hybrid_val_loader, hybrid_test_loader = dataloader(hybrid_train_dataset, hybrid_test_dataset, args.batch_size, args.val_split)
+          
+
+        w2vmodel = Word2Vec.load(args.w2vmodel)
+        weights = get_emb_matrix(w2vmodel, counts)
+
+        model = hybrid(hybrid_train_dataset.vocabulary, hybrid_train_dataset.sequence_length, weights_matrix = weights, hidden_size = 256, num_layers = 2).to(device)
+
+        hybrid_fit(args.epochs, model, hybrid_train_loader, hybrid_val_loader, args.icd_type, opt_fn, loss_fn, learning_rate, device)
+        hybrid_test_results(model, hybrid_test_loader, args.icd_type, device)
+
+    elif args.model_name == 'ovr':
+       
+        X_train, y_train = mlmodel_data(train_diagnosis, args.icd_type)
+        X_test, y_test = mlmodel_data(test_diagnosis, args.icd_type)
+
+        tfidf_vectorizer = TfidfVectorizer(max_df = 0.8)
+        X_train = tfidf_vectorizer.fit_transform(X_train)
+        X_test = tfidf_vectorizer.transform(X_test)
+
+        ml_model = train_classifier(X_train, y_train)
+        y_predict = ml_model.predict(X_test)
+
+        print('-'*20 + args.icd_type + '-'*20)
+        mlmodel_result(y_test, y_predict)
+
+
+
+
+
+
+
+        
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser("Automatic Assignment of Medical Codes")
+
+    parser.add_argument("--train_path", type = str, default = './data/train.csv')
+    parser.add_argument("--test_path", type = str, default = './data/test.csv')
+
+    parser.add_argument("--model_name", type = str, choices = ['bert', 'hybrid', 'gru', 'lstm', 'cnn', 'ovr'], default = "bert")
+    parser.add_argument("--icd_type", type = str, choices = ['icd9cat', 'icd9code', 'icd10cat', 'icd10code'], default = 'icd9cat')
+
+    parser.add_argument("--batch_size", type = int, default = 16)
+    parser.add_argument("--val_split", type = float, default = 2/7)
+    parser.add_argument("--learning_rate", type = float, default = 2e-5)
+    parser.add_argument("--epochs", type = int, default = 4)
+
+    parser.add_argument("--w2vmodel", type = str, default = "w2vmodel.model")
+
+    args = parser.parse_args()
+    run(args)
+
+
+
+
+
+