Switch to side-by-side view

--- a
+++ b/mimic_icd9_coding/utils/BERTRunner.py
@@ -0,0 +1,423 @@
+import itertools
+
+import pandas as pd
+import torch
+import torch.nn as nn
+from sklearn.model_selection import train_test_split
+from transformers import AutoModel, BertTokenizerFast
+import numpy as np
+from sklearn.metrics import classification_report
+
+
+
+MAX_SEQ_LEN=500
+
+def run_BERT(mimic_path, bert_fast_dev_run=False):
+    "credit to https://colab.research.google.com/github/prateekjoshi565/Fine-Tuning-BERT/blob/master/Fine_Tuning_BERT_for_Spam_Classification.ipynb"
+    # -*- coding: utf-8 -*-
+
+    #%%
+    # specify GPU
+    device = torch.device("cuda")
+
+    """# Load Dataset"""
+
+    #%%
+
+    train = pd.read_csv(mimic_path + 'train.csv', converters={'TARGET': eval})
+    test = pd.read_csv(mimic_path + 'test.csv', converters={'TARGET': eval})
+
+
+    if bert_fast_dev_run:
+        train = train.head(10)
+        test = test.head(10)
+
+    #%%
+    train.columns = ['text', 'label', 'HADM_ID']
+    test.columns = ['text', 'label', 'HADM_ID']
+
+    train_labels = set(itertools.chain.from_iterable(train.label))
+    test_labels = set(itertools.chain.from_iterable(test.label))
+    all_labels = train_labels.union(test_labels)
+
+    #%%
+    from sklearn.preprocessing import MultiLabelBinarizer
+
+    mlb = MultiLabelBinarizer()
+    mlb.fit([list(all_labels)])
+
+    train_text = train['text']
+    temp_text = test['text']
+    train_labels = train['label']
+    temp_labels = test['label']
+    # train_labels = train['text']
+    # temp_labels = test['label']
+
+
+    """# Split train dataset into train, validation and test sets"""
+
+    # we will use temp_text and temp_labels to create validation and test set
+    val_text, test_text, val_labels, test_labels = train_test_split(temp_text, temp_labels, 
+                                                                    random_state=123, 
+                                                                    test_size=0.5)
+
+    #%%
+    temp_labels = mlb.transform(temp_labels)
+    val_labels = mlb.transform(val_labels)
+    test_labels = mlb.transform(test_labels)
+    train_labels = mlb.transform(train_labels)
+    #%%
+    """# Import BERT Model and BERT Tokenizer"""
+    # import BERT-base pretrained model
+    bert = AutoModel.from_pretrained('bert-base-uncased')
+
+    # Load the BERT tokenizer
+    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
+    #%%
+    """# Tokenization"""
+
+    # get length of all the messages in the train set
+    seq_len = [len(i.split()) for i in train_text]
+
+    pd.Series(seq_len).hist(bins = 30)
+
+    max_seq_len = MAX_SEQ_LEN
+
+    # tokenize and encode sequences in the training set
+    tokens_train = tokenizer.batch_encode_plus(
+        train_text.tolist(),
+        max_length = max_seq_len,
+        pad_to_max_length=True,
+        truncation=True,
+        return_token_type_ids=False
+    )
+
+    # tokenize and encode sequences in the validation set
+    tokens_val = tokenizer.batch_encode_plus(
+        val_text.tolist(),
+        max_length = max_seq_len,
+        pad_to_max_length=True,
+        truncation=True,
+        return_token_type_ids=False
+    )
+
+    # tokenize and encode sequences in the test set
+    tokens_test = tokenizer.batch_encode_plus(
+        test_text.tolist(),
+        max_length = max_seq_len,
+        pad_to_max_length=True,
+        truncation=True,
+        return_token_type_ids=False
+    )
+    #%%
+    """# Convert Integer Sequences to Tensors"""
+
+    # for train set
+    train_seq = torch.tensor(tokens_train['input_ids'])
+    train_mask = torch.tensor(tokens_train['attention_mask'])
+    train_y = torch.tensor(train_labels.tolist())
+
+    # for validation set
+    val_seq = torch.tensor(tokens_val['input_ids'])
+    val_mask = torch.tensor(tokens_val['attention_mask'])
+    val_y = torch.tensor(val_labels.tolist())
+
+    # for test set
+    test_seq = torch.tensor(tokens_test['input_ids'])
+    test_mask = torch.tensor(tokens_test['attention_mask'])
+    test_y = torch.tensor(test_labels.tolist())
+    #%%
+    """# Create DataLoaders"""
+
+    from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
+
+    #define a batch size
+    batch_size = 32
+
+    # wrap tensors
+    train_data = TensorDataset(train_seq, train_mask, train_y)
+
+    # sampler for sampling the data during training
+    train_sampler = RandomSampler(train_data)
+
+    # dataLoader for train set
+    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
+
+    # wrap tensors
+    val_data = TensorDataset(val_seq, val_mask, val_y)
+
+    # sampler for sampling the data during training
+    val_sampler = SequentialSampler(val_data)
+
+    # dataLoader for validation set
+    val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)
+    #%%
+    """# Freeze BERT Parameters"""
+
+    # freeze all the parameters
+    for param in bert.parameters():
+        param.requires_grad = False
+    #%%
+    """# Define Model Architecture"""
+
+    class BERT_Arch(nn.Module):
+
+        def __init__(self, bert):
+        
+            super(BERT_Arch, self).__init__()
+
+            self.bert = bert 
+            # self.bert = nn.DataParallel(self.bert)
+            
+            # dropout layer
+            self.dropout = nn.Dropout(0.1)
+            
+            # relu activation function
+            self.relu =  nn.ReLU()
+
+            # dense layer 1
+            self.fc1 = nn.Linear(768,512)
+            
+            # dense layer 2 (Output layer)
+            self.fc2 = nn.Linear(512,len(all_labels))
+
+            #softmax activation function
+            self.softmax = nn.LogSoftmax(dim=1)
+
+        #define the forward pass
+        def forward(self, sent_id, mask):
+
+            #pass the inputs to the model  
+            _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)
+            
+            x = self.fc1(cls_hs)
+
+            x = self.relu(x)
+
+            x = self.dropout(x)
+
+            # output layer
+            x = self.fc2(x)
+            
+            # apply softmax activation
+            x = self.softmax(x)
+
+            return x
+
+    # pass the pre-trained BERT to our define architecture
+    model = BERT_Arch(bert)
+
+    # push the model to GPU
+    model = model.to(device)
+
+    # optimizer from hugging face transformers
+    from transformers import AdamW
+
+    # define the optimizer
+    optimizer = AdamW(model.parameters(), lr = 2e-3)
+    #%%
+    """# Find Class Weights"""
+
+    # from sklearn.utils.class_weight import compute_class_weight
+
+    # #compute the class weights
+    # class_wts = compute_class_weight('balanced', np.unique(train_labels), train_labels)
+
+    # print(class_wts)
+
+    # # convert class weights to tensor
+    # weights= torch.tensor(class_wts,dtype=torch.float)
+    # weights = weights.to(device)
+
+    # loss function
+    # cross_entropy  = nn.NLLLoss(weight=weights) 
+    # cross_entropy  = nn.NLLLoss() 
+    cross_entropy=nn.BCEWithLogitsLoss()
+
+    # number of training epochs
+    epochs = 3
+    #%%
+    """# Fine-Tune BERT"""
+
+    # function to train the model
+    def train():
+        
+        model.train()
+
+        total_loss, total_accuracy = 0, 0
+        
+        # empty list to save model predictions
+        total_preds=[]
+        
+        # iterate over batches
+        for step,batch in enumerate(train_dataloader):
+            
+            # progress update after every 50 batches.
+            if step % 50 == 0 and not step == 0:
+                print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))
+
+            # push the batch to gpu
+            batch = [r.to(device) for r in batch]
+        
+            sent_id, mask, labels = batch
+
+            # clear previously calculated gradients 
+            model.zero_grad()        
+
+            # get model predictions for the current batch
+            preds = model(sent_id, mask)
+
+            # compute the loss between actual and predicted values
+            # I"M WORRIED THIS COPIES MAKE SURE #TODO 
+            loss = cross_entropy(preds, labels.type_as(preds))
+
+            # add on to the total loss
+            total_loss = total_loss + loss.item()
+
+            # backward pass to calculate the gradients
+            loss.backward()
+
+            # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
+            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+
+            # update parameters
+            optimizer.step()
+
+            # model predictions are stored on GPU. So, push it to CPU
+            preds=preds.detach().cpu().numpy()
+
+            # append the model predictions
+            total_preds.append(preds)
+
+        # compute the training loss of the epoch
+        avg_loss = total_loss / len(train_dataloader)
+        
+        # predictions are in the form of (no. of batches, size of batch, no. of classes).
+        # reshape the predictions in form of (number of samples, no. of classes)
+        total_preds  = np.concatenate(total_preds, axis=0)
+
+        #returns the loss and predictions
+        return avg_loss, total_preds
+
+    # function for evaluating the model
+    def evaluate():
+    
+        print("\nEvaluating...")
+        
+        # deactivate dropout layers
+        model.eval()
+
+        total_loss, total_accuracy = 0, 0
+        
+        # empty list to save the model predictions
+        total_preds = []
+
+        # iterate over batches
+        for step,batch in enumerate(val_dataloader):
+            
+            # Progress update every 50 batches.
+            if step % 50 == 0 and not step == 0:
+                
+                # Calculate elapsed time in minutes.
+                # elapsed = format_time(time.time() - t0)
+                        
+                # Report progress.
+                print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))
+
+            # push the batch to gpu
+            batch = [t.to(device) for t in batch]
+
+            sent_id, mask, labels = batch
+
+            # deactivate autograd
+            with torch.no_grad():
+                
+                # model predictions
+                preds = model(sent_id, mask)
+
+            # compute the validation loss between actual and predicted values
+            loss = cross_entropy(preds,labels.type_as(preds))
+
+            total_loss = total_loss + loss.item()
+
+            preds = preds.detach().cpu().numpy()
+
+            total_preds.append(preds)
+
+        # compute the validation loss of the epoch
+        avg_loss = total_loss / len(val_dataloader) 
+
+        # reshape the predictions in form of (number of samples, no. of classes)
+        total_preds  = np.concatenate(total_preds, axis=0)
+
+        return avg_loss, total_preds
+    #%%
+    """# Start Model Training"""
+
+    # set initial loss to infinite
+    best_valid_loss = float('inf')
+
+    # empty lists to store training and validation loss of each epoch
+    train_losses=[]
+    valid_losses=[]
+
+    #for each epoch
+    for epoch in range(epochs):
+        
+        print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
+        
+        #train model
+        train_loss, _ = train()
+        
+        #evaluate model
+        valid_loss, _ = evaluate()
+        
+        #save the best model
+        if valid_loss < best_valid_loss:
+            best_valid_loss = valid_loss
+            torch.save(model.state_dict(), 'saved_weights.pt')
+        
+        # append training and validation loss
+        train_losses.append(train_loss)
+        valid_losses.append(valid_loss)
+        
+        print(f'\nTraining Loss: {train_loss:.3f}')
+        print(f'Validation Loss: {valid_loss:.3f}')
+    #%%
+    """# Load Saved Model"""
+
+    #load weights of best model
+    path = 'saved_weights.pt'
+    model.load_state_dict(torch.load(path))
+
+    with torch.no_grad():
+        preds = model(test_seq.to(device), test_mask.to(device))
+        preds = preds.detach().cpu().numpy()
+
+    preds = np.argmax(preds, axis = 1) 
+    test_y = np.argmax(test_y, axis=1)
+    print(classification_report(test_y, preds))
+
+
+
+    return model
+    #%%
+
+
+def convertBERT(item):
+
+    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
+
+    max_seq_len = MAX_SEQ_LEN
+
+    encoded_toks = tokenizer.batch_encode_plus(
+            [item],
+            max_length = max_seq_len,
+            pad_to_max_length=True,
+            truncation=True,
+            return_token_type_ids=False
+        )
+    tokens = torch.tensor(encoded_toks['input_ids'])
+    masked_item = torch.tensor(encoded_toks['attention_mask'])
+    return tokens, masked_item
+