--- a +++ b/CheXbert/src/utils.py @@ -0,0 +1,380 @@ +import copy +import torch +import torch.nn as nn +import pandas as pd +import numpy as np +import json +from models.bert_labeler import bert_labeler +from bert_tokenizer import tokenize +from sklearn.metrics import f1_score, confusion_matrix +from statsmodels.stats.inter_rater import cohens_kappa +from transformers import BertTokenizer +from constants import * + +def get_weighted_f1_weights(train_path_or_csv): + """Compute weights used to obtain the weighted average of + mention, negation and uncertain f1 scores. + @param train_path_or_csv: A path to the csv file or a dataframe + + @return weight_dict (dictionary): maps conditions to a list of weights, the order + in the lists is negation, uncertain, positive + """ + if isinstance(train_path_or_csv, str): + df = pd.read_csv(train_path_or_csv) + else: + df = train_path_or_csv + df.replace(0, 2, inplace=True) + df.replace(-1, 3, inplace=True) + df.fillna(0, inplace=True) + + weight_dict = {} + for cond in CONDITIONS: + weights = [] + col = df[cond] + + mask = col == 2 + weights.append(mask.sum()) + + mask = col == 3 + weights.append(mask.sum()) + + mask = col == 1 + weights.append(mask.sum()) + + if np.sum(weights) > 0: + weights = np.array(weights)/np.sum(weights) + weight_dict[cond] = weights + return weight_dict + +def weighted_avg(scores, weights): + """Compute weighted average of scores + @param scores(List): the task scores + @param weights (List): corresponding normalized weights + + @return (float): the weighted average of task scores + """ + return np.sum(np.array(scores) * np.array(weights)) + +def compute_train_weights(train_path): + """Compute class weights for rebalancing rare classes + @param train_path (str): A path to the training csv file + + @returns weight_arr (torch.Tensor): Tensor of shape (train_set_size), containing + the weight assigned to each training example + """ + df = pd.read_csv(train_path) + cond_weights = {} + for cond in CONDITIONS: + col = df[cond] + val_counts = col.value_counts() + if cond != 'No Finding': + weights = {} + weights['0.0'] = len(df) / val_counts[0] + weights['-1.0'] = len(df) / val_counts[-1] + weights['1.0'] = len(df) / val_counts[1] + weights['nan'] = len(df) / (len(df) - val_counts.sum()) + else: + weights = {} + weights['1.0'] = len(df) / val_counts[1] + weights['nan'] = len(df) / (len(df) - val_counts.sum()) + + cond_weights[cond] = weights + + weight_arr = torch.zeros(len(df)) + for i in range(len(df)): #loop over training set + for cond in CONDITIONS: #loop over all conditions + label = str(df[cond].iloc[i]) + weight_arr[i] += cond_weights[cond][label] #add weight for given class' label + + return weight_arr + +def generate_attention_masks(batch, source_lengths, device): + """Generate masks for padded batches to avoid self-attention over pad tokens + @param batch (Tensor): tensor of token indices of shape (batch_size, max_len) + where max_len is length of longest sequence in the batch + @param source_lengths (List[Int]): List of actual lengths for each of the + sequences in the batch + @param device (torch.device): device on which data should be + + @returns masks (Tensor): Tensor of masks of shape (batch_size, max_len) + """ + masks = torch.ones(batch.size(0), batch.size(1), dtype=torch.float) + for idx, src_len in enumerate(source_lengths): + masks[idx, src_len:] = 0 + return masks.to(device) + +def compute_mention_f1(y_true, y_pred): + """Compute the mention F1 score as in CheXpert paper + @param y_true (list): List of 14 tensors each of shape (dev_set_size) + @param y_pred (list): Same as y_true but for model predictions + + @returns res (list): List of 14 scalars + """ + for j in range(len(y_true)): + y_true[j][y_true[j] == 2] = 1 + y_true[j][y_true[j] == 3] = 1 + y_pred[j][y_pred[j] == 2] = 1 + y_pred[j][y_pred[j] == 3] = 1 + + res = [] + for j in range(len(y_true)): + res.append(f1_score(y_true[j], y_pred[j], pos_label=1)) + + return res + +def compute_blank_f1(y_true, y_pred): + """Compute the blank F1 score + @param y_true (list): List of 14 tensors each of shape (dev_set_size) + @param y_pred (list): Same as y_true but for model predictions + + @returns res (list): List of 14 scalars + """ + for j in range(len(y_true)): + y_true[j][y_true[j] == 2] = 1 + y_true[j][y_true[j] == 3] = 1 + y_pred[j][y_pred[j] == 2] = 1 + y_pred[j][y_pred[j] == 3] = 1 + + res = [] + for j in range(len(y_true)): + res.append(f1_score(y_true[j], y_pred[j], pos_label=0)) + + return res + +def compute_negation_f1(y_true, y_pred): + """Compute the negation F1 score as in CheXpert paper + @param y_true (list): List of 14 tensors each of shape (dev_set_size) + @param y_pred (list): Same as y_true but for model predictions + + @returns res (list): List of 14 scalars + """ + for j in range(len(y_true)): + y_true[j][y_true[j] == 3] = 0 + y_true[j][y_true[j] == 1] = 0 + y_pred[j][y_pred[j] == 3] = 0 + y_pred[j][y_pred[j] == 1] = 0 + + res = [] + for j in range(len(y_true)-1): + res.append(f1_score(y_true[j], y_pred[j], pos_label=2)) + + res.append(0) #No Finding gets score of zero + return res + +def compute_positive_f1(y_true, y_pred): + """Compute the positive F1 score + @param y_true (list): List of 14 tensors each of shape (dev_set_size) + @param y_pred (list): Same as y_true but for model predictions + + @returns res (list): List of 14 scalars + """ + for j in range(len(y_true)): + y_true[j][y_true[j] == 3] = 0 + y_true[j][y_true[j] == 2] = 0 + y_pred[j][y_pred[j] == 3] = 0 + y_pred[j][y_pred[j] == 2] = 0 + + res = [] + for j in range(len(y_true)): + res.append(f1_score(y_true[j], y_pred[j], pos_label=1)) + + return res + +def compute_uncertain_f1(y_true, y_pred): + """Compute the negation F1 score as in CheXpert paper + @param y_true (list): List of 14 tensors each of shape (dev_set_size) + @param y_pred (list): Same as y_true but for model predictions + + @returns res (list): List of 14 scalars + """ + for j in range(len(y_true)): + y_true[j][y_true[j] == 2] = 0 + y_true[j][y_true[j] == 1] = 0 + y_pred[j][y_pred[j] == 2] = 0 + y_pred[j][y_pred[j] == 1] = 0 + + res = [] + for j in range(len(y_true)-1): + res.append(f1_score(y_true[j], y_pred[j], pos_label=3)) + + res.append(0) #No Finding gets a score of zero + return res + +def evaluate(model, dev_loader, device, f1_weights, return_pred=False): + """ Function to evaluate the current model weights + @param model (nn.Module): the labeler module + @param dev_loader (torch.utils.data.DataLoader): dataloader for dev set + @param device (torch.device): device on which data should be + @param f1_weights (dictionary): dictionary mapping conditions to f1 + task weights + @param return_pred (bool): whether to return predictions or not + + @returns res_dict (dictionary): dictionary with keys 'blank', 'mention', 'negation', + 'uncertain', 'positive' and 'weighted', with values + being lists of length 14 with each element in the + lists as a scalar. If return_pred is true then a + tuple is returned with the aforementioned dictionary + as the first item, a list of predictions as the + second item, and a list of ground truth as the + third item + """ + + was_training = model.training + model.eval() + y_pred = [[] for _ in range(len(CONDITIONS))] + y_true = [[] for _ in range(len(CONDITIONS))] + + with torch.no_grad(): + for i, data in enumerate(dev_loader, 0): + batch = data['imp'] #(batch_size, max_len) + batch = batch.to(device) + label = data['label'] #(batch_size, 14) + label = label.permute(1, 0).to(device) + src_len = data['len'] + batch_size = batch.shape[0] + attn_mask = generate_attention_masks(batch, src_len, device) + + out = model(batch, attn_mask) + + for j in range(len(out)): + out[j] = out[j].to('cpu') #move to cpu for sklearn + curr_y_pred = out[j].argmax(dim=1) #shape is (batch_size) + y_pred[j].append(curr_y_pred) + y_true[j].append(label[j].to('cpu')) + + if (i+1) % 200 == 0: + print('Evaluation batch no: ', i+1) + + for j in range(len(y_true)): + y_true[j] = torch.cat(y_true[j], dim=0) + y_pred[j] = torch.cat(y_pred[j], dim=0) + + if was_training: + model.train() + + mention_f1 = compute_mention_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) + negation_f1 = compute_negation_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) + uncertain_f1 = compute_uncertain_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) + positive_f1 = compute_positive_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) + blank_f1 = compute_blank_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) + + weighted = [] + kappas = [] + for j in range(len(y_pred)): + cond = CONDITIONS[j] + avg = weighted_avg([negation_f1[j], uncertain_f1[j], positive_f1[j]], f1_weights[cond]) + weighted.append(avg) + + mat = confusion_matrix(y_true[j], y_pred[j]) + kappas.append(cohens_kappa(mat, return_results=False)) + + res_dict = {'mention': mention_f1, + 'blank': blank_f1, + 'negation': negation_f1, + 'uncertain': uncertain_f1, + 'positive': positive_f1, + 'weighted': weighted, + 'kappa': kappas} + + if return_pred: + return res_dict, y_pred, y_true + else: + return res_dict + +def test(model, checkpoint_path, test_ld, f1_weights): + """Evaluate model on test set. + @param model (nn.Module): labeler module + @param checkpoint_path (string): location of saved model checkpoint + @param test_ld (dataloader): dataloader for test set + @param f1_weights (dictionary): maps conditions to f1 task weights + """ + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() > 1: + print("Using", torch.cuda.device_count(), "GPUs!") + model = nn.DataParallel(model) #to utilize multiple GPU's + model = model.to(device) + + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + + print("Doing evaluation on test set\n") + metrics = evaluate(model, test_ld, device, f1_weights) + weighted = metrics['weighted'] + kappas = metrics['kappa'] + + for j in range(len(CONDITIONS)): + print('%s kappa: %.3f' % (CONDITIONS[j], kappas[j])) + print('average: %.3f' % np.mean(kappas)) + + print() + for j in range(len(CONDITIONS)): + print('%s weighted_f1: %.3f' % (CONDITIONS[j], weighted[j])) + print('average of weighted_f1: %.3f' % (np.mean(weighted))) + + print() + for j in range(len(CONDITIONS)): + print('%s blank_f1: %.3f, negation_f1: %.3f, uncertain_f1: %.3f, positive_f1: %.3f' % (CONDITIONS[j], + metrics['blank'][j], + metrics['negation'][j], + metrics['uncertain'][j], + metrics['positive'][j])) + + men_macro_avg = np.mean(metrics['mention']) + neg_macro_avg = np.mean(metrics['negation'][:-1]) #No Finding has no negations + unc_macro_avg = np.mean(metrics['uncertain'][:-2]) #No Finding, Support Devices have no uncertain labels in test set + pos_macro_avg = np.mean(metrics['positive']) + blank_macro_avg = np.mean(metrics['blank']) + + print("blank macro avg: %.3f, negation macro avg: %.3f, uncertain macro avg: %.3f, positive macro avg: %.3f" % (blank_macro_avg, + neg_macro_avg, + unc_macro_avg, + pos_macro_avg)) + print() + for j in range(len(CONDITIONS)): + print('%s mention_f1: %.3f' % (CONDITIONS[j], metrics['mention'][j])) + print('mention macro avg: %.3f' % men_macro_avg) + + +def label_report_list(checkpoint_path, report_list): + """ Evaluate model on list of reports. + @param checkpoint_path (string): location of saved model checkpoint + @param report_list (list): list of report impressions (string) + """ + imp = pd.Series(report_list) + imp = imp.str.strip() + imp = imp.replace('\n',' ', regex=True) + imp = imp.replace('[0-9]\.', '', regex=True) + imp = imp.replace('\s+', ' ', regex=True) + imp = imp.str.strip() + + model = bert_labeler() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() > 1: + print("Using", torch.cuda.device_count(), "GPUs!") + model = nn.DataParallel(model) #to utilize multiple GPU's + model = model.to(device) + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + + y_pred = [] + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + new_imps = tokenize(imp, tokenizer) + with torch.no_grad(): + for imp in new_imps: + # run forward prop + imp = torch.LongTensor(imp) + source = imp.view(1, len(imp)) + + attention = torch.ones(len(imp)) + attention = attention.view(1, len(imp)) + out = model(source.to(device), attention.to(device)) + + # get predictions + result = {} + for j in range(len(out)): + curr_y_pred = out[j].argmax(dim=1) #shape is (1) + result[CONDITIONS[j]] = CLASS_MAPPING[curr_y_pred.item()] + y_pred.append(result) + return y_pred +