Diff of /chexbert/src/utils.py [000000] .. [4abb48]

Switch to unified view

a b/chexbert/src/utils.py
1
import copy
2
import torch
3
import torch.nn as nn
4
import pandas as pd
5
import numpy as np
6
import json
7
from models.bert_labeler import bert_labeler
8
from bert_tokenizer import tokenize
9
from sklearn.metrics import f1_score, confusion_matrix
10
from statsmodels.stats.inter_rater import cohens_kappa
11
from transformers import BertTokenizer
12
from constants import *
13
14
def get_weighted_f1_weights(train_path_or_csv):
15
    """Compute weights used to obtain the weighted average of
16
       mention, negation and uncertain f1 scores. 
17
    @param train_path_or_csv: A path to the csv file or a dataframe
18
19
    @return weight_dict (dictionary): maps conditions to a list of weights, the order
20
                                      in the lists is negation, uncertain, positive 
21
    """
22
    if isinstance(train_path_or_csv, str):
23
        df = pd.read_csv(train_path_or_csv)
24
    else:
25
        df = train_path_or_csv    
26
    df.replace(0, 2, inplace=True)
27
    df.replace(-1, 3, inplace=True)
28
    df.fillna(0, inplace=True)
29
    
30
    weight_dict = {}
31
    for cond in CONDITIONS:
32
        weights = []
33
        col = df[cond]
34
35
        mask = col == 2
36
        weights.append(mask.sum())
37
38
        mask = col == 3
39
        weights.append(mask.sum())
40
41
        mask = col == 1
42
        weights.append(mask.sum())
43
44
        if np.sum(weights) > 0:
45
            weights = np.array(weights)/np.sum(weights)
46
        weight_dict[cond] = weights
47
    return weight_dict
48
49
def weighted_avg(scores, weights):
50
    """Compute weighted average of scores
51
    @param scores(List): the task scores
52
    @param weights (List): corresponding normalized weights
53
54
    @return (float): the weighted average of task scores
55
    """
56
    return np.sum(np.array(scores) * np.array(weights))
57
58
def compute_train_weights(train_path):
59
    """Compute class weights for rebalancing rare classes
60
    @param train_path (str): A path to the training csv file
61
62
    @returns weight_arr (torch.Tensor): Tensor of shape (train_set_size), containing
63
                                        the weight assigned to each training example 
64
    """
65
    df = pd.read_csv(train_path)
66
    cond_weights = {}
67
    for cond in CONDITIONS:
68
        col = df[cond]
69
        val_counts = col.value_counts()
70
        if cond != 'No Finding':
71
            weights = {}
72
            weights['0.0'] = len(df) / val_counts[0]
73
            weights['-1.0'] = len(df) / val_counts[-1]
74
            weights['1.0'] = len(df) / val_counts[1]
75
            weights['nan'] = len(df) / (len(df) - val_counts.sum())
76
        else:
77
            weights = {}
78
            weights['1.0'] = len(df) / val_counts[1]
79
            weights['nan'] = len(df) / (len(df) - val_counts.sum())
80
            
81
        cond_weights[cond] = weights
82
        
83
    weight_arr = torch.zeros(len(df))
84
    for i in range(len(df)):     #loop over training set
85
        for cond in CONDITIONS:  #loop over all conditions
86
            label = str(df[cond].iloc[i])
87
            weight_arr[i] += cond_weights[cond][label] #add weight for given class' label
88
        
89
    return weight_arr
90
91
def generate_attention_masks(batch, source_lengths, device):
92
    """Generate masks for padded batches to avoid self-attention over pad tokens
93
    @param batch (Tensor): tensor of token indices of shape (batch_size, max_len)
94
                           where max_len is length of longest sequence in the batch
95
    @param source_lengths (List[Int]): List of actual lengths for each of the
96
                           sequences in the batch
97
    @param device (torch.device): device on which data should be
98
99
    @returns masks (Tensor): Tensor of masks of shape (batch_size, max_len)
100
    """
101
    masks = torch.ones(batch.size(0), batch.size(1), dtype=torch.float)
102
    for idx, src_len in enumerate(source_lengths):
103
        masks[idx, src_len:] = 0
104
    return masks.to(device)
105
106
def compute_mention_f1(y_true, y_pred):
107
    """Compute the mention F1 score as in CheXpert paper
108
    @param y_true (list): List of 14 tensors each of shape (dev_set_size)
109
    @param y_pred (list): Same as y_true but for model predictions
110
111
    @returns res (list): List of 14 scalars
112
    """
113
    for j in range(len(y_true)):
114
        y_true[j][y_true[j] == 2] = 1
115
        y_true[j][y_true[j] == 3] = 1
116
        y_pred[j][y_pred[j] == 2] = 1
117
        y_pred[j][y_pred[j] == 3] = 1
118
119
    res = []
120
    for j in range(len(y_true)): 
121
        res.append(f1_score(y_true[j], y_pred[j], pos_label=1))
122
        
123
    return res
124
125
def compute_blank_f1(y_true, y_pred):
126
    """Compute the blank F1 score 
127
    @param y_true (list): List of 14 tensors each of shape (dev_set_size)
128
    @param y_pred (list): Same as y_true but for model predictions
129
                                                                         
130
    @returns res (list): List of 14 scalars                           
131
    """
132
    for j in range(len(y_true)):
133
        y_true[j][y_true[j] == 2] = 1
134
        y_true[j][y_true[j] == 3] = 1
135
        y_pred[j][y_pred[j] == 2] = 1
136
        y_pred[j][y_pred[j] == 3] = 1
137
138
    res = []
139
    for j in range(len(y_true)):
140
        res.append(f1_score(y_true[j], y_pred[j], pos_label=0))
141
142
    return res
143
        
144
def compute_negation_f1(y_true, y_pred):
145
    """Compute the negation F1 score as in CheXpert paper
146
    @param y_true (list): List of 14 tensors each of shape (dev_set_size)
147
    @param y_pred (list): Same as y_true but for model predictions   
148
149
    @returns res (list): List of 14 scalars
150
    """
151
    for j in range(len(y_true)):
152
        y_true[j][y_true[j] == 3] = 0
153
        y_true[j][y_true[j] == 1] = 0
154
        y_pred[j][y_pred[j] == 3] = 0
155
        y_pred[j][y_pred[j] == 1] = 0
156
157
    res = []
158
    for j in range(len(y_true)-1):
159
        res.append(f1_score(y_true[j], y_pred[j], pos_label=2))
160
161
    res.append(0) #No Finding gets score of zero
162
    return res
163
164
def compute_positive_f1(y_true, y_pred):
165
    """Compute the positive F1 score
166
    @param y_true (list): List of 14 tensors each of shape (dev_set_size)
167
    @param y_pred (list): Same as y_true but for model predictions 
168
169
    @returns res (list): List of 14 scalars
170
    """
171
    for j in range(len(y_true)):
172
        y_true[j][y_true[j] == 3] = 0
173
        y_true[j][y_true[j] == 2] = 0
174
        y_pred[j][y_pred[j] == 3] = 0
175
        y_pred[j][y_pred[j] == 2] = 0
176
177
    res = []
178
    for j in range(len(y_true)):
179
        res.append(f1_score(y_true[j], y_pred[j], pos_label=1))
180
181
    return res
182
        
183
def compute_uncertain_f1(y_true, y_pred):
184
    """Compute the negation F1 score as in CheXpert paper
185
    @param y_true (list): List of 14 tensors each of shape (dev_set_size)
186
    @param y_pred (list): Same as y_true but for model predictions
187
188
    @returns res (list): List of 14 scalars
189
    """
190
    for j in range(len(y_true)):
191
        y_true[j][y_true[j] == 2] = 0
192
        y_true[j][y_true[j] == 1] = 0
193
        y_pred[j][y_pred[j] == 2] = 0
194
        y_pred[j][y_pred[j] == 1] = 0
195
196
    res = []
197
    for j in range(len(y_true)-1):
198
        res.append(f1_score(y_true[j], y_pred[j], pos_label=3))
199
200
    res.append(0) #No Finding gets a score of zero
201
    return res
202
203
def evaluate(model, dev_loader, device, f1_weights, return_pred=False):
204
    """ Function to evaluate the current model weights
205
    @param model (nn.Module): the labeler module 
206
    @param dev_loader (torch.utils.data.DataLoader): dataloader for dev set  
207
    @param device (torch.device): device on which data should be
208
    @param f1_weights (dictionary): dictionary mapping conditions to f1
209
                                    task weights
210
    @param return_pred (bool): whether to return predictions or not
211
212
    @returns res_dict (dictionary): dictionary with keys 'blank', 'mention', 'negation',
213
                           'uncertain', 'positive' and 'weighted', with values 
214
                            being lists of length 14 with each element in the 
215
                            lists as a scalar. If return_pred is true then a 
216
                            tuple is returned with the aforementioned dictionary 
217
                            as the first item, a list of predictions as the 
218
                            second item, and a list of ground truth as the 
219
                            third item
220
    """
221
    
222
    was_training = model.training
223
    model.eval()
224
    y_pred = [[] for _ in range(len(CONDITIONS))]
225
    y_true = [[] for _ in range(len(CONDITIONS))]
226
    
227
    with torch.no_grad():
228
        for i, data in enumerate(dev_loader, 0):
229
            batch = data['imp'] #(batch_size, max_len)
230
            batch = batch.to(device)
231
            label = data['label'] #(batch_size, 14)
232
            label = label.permute(1, 0).to(device)
233
            src_len = data['len']
234
            batch_size = batch.shape[0]
235
            attn_mask = generate_attention_masks(batch, src_len, device)
236
237
            out = model(batch, attn_mask)
238
            
239
            for j in range(len(out)):
240
                out[j] = out[j].to('cpu') #move to cpu for sklearn
241
                curr_y_pred = out[j].argmax(dim=1) #shape is (batch_size)
242
                y_pred[j].append(curr_y_pred)
243
                y_true[j].append(label[j].to('cpu'))
244
245
            if (i+1) % 200 == 0:
246
                print('Evaluation batch no: ', i+1)
247
                
248
    for j in range(len(y_true)):
249
        y_true[j] = torch.cat(y_true[j], dim=0)
250
        y_pred[j] = torch.cat(y_pred[j], dim=0)
251
252
    if was_training:
253
        model.train()
254
255
    mention_f1 = compute_mention_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred))
256
    negation_f1 = compute_negation_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred))
257
    uncertain_f1 = compute_uncertain_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred))
258
    positive_f1 = compute_positive_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred))
259
    blank_f1 = compute_blank_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred))
260
    
261
    weighted = []
262
    kappas = []
263
    for j in range(len(y_pred)):
264
        cond = CONDITIONS[j]
265
        avg = weighted_avg([negation_f1[j], uncertain_f1[j], positive_f1[j]], f1_weights[cond])
266
        weighted.append(avg)
267
268
        mat = confusion_matrix(y_true[j], y_pred[j])
269
        kappas.append(cohens_kappa(mat, return_results=False))
270
271
    res_dict = {'mention': mention_f1,
272
                'blank': blank_f1,
273
                'negation': negation_f1,
274
                'uncertain': uncertain_f1,
275
                'positive': positive_f1,
276
                'weighted': weighted,
277
                'kappa': kappas}
278
    
279
    if return_pred:
280
        return res_dict, y_pred, y_true
281
    else:
282
        return res_dict
283
284
def test(model, checkpoint_path, test_ld, f1_weights):
285
    """Evaluate model on test set. 
286
    @param model (nn.Module): labeler module
287
    @param checkpoint_path (string): location of saved model checkpoint
288
    @param test_ld (dataloader): dataloader for test set
289
    @param f1_weights (dictionary): maps conditions to f1 task weights
290
    """
291
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
292
    if torch.cuda.device_count() > 1:
293
        print("Using", torch.cuda.device_count(), "GPUs!")
294
        model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's
295
    model = model.to(device)
296
297
    checkpoint = torch.load(checkpoint_path)
298
    model.load_state_dict(checkpoint['model_state_dict'])
299
300
    print("Doing evaluation on test set\n")
301
    metrics = evaluate(model, test_ld, device, f1_weights)
302
    weighted = metrics['weighted']
303
    kappas = metrics['kappa']
304
305
    for j in range(len(CONDITIONS)):
306
        print('%s kappa: %.3f' % (CONDITIONS[j], kappas[j]))
307
    print('average: %.3f' % np.mean(kappas))
308
309
    print()
310
    for j in range(len(CONDITIONS)):
311
        print('%s weighted_f1: %.3f' % (CONDITIONS[j], weighted[j]))
312
    print('average of weighted_f1: %.3f' % (np.mean(weighted)))
313
    
314
    print()
315
    for j in range(len(CONDITIONS)):
316
        print('%s blank_f1:  %.3f, negation_f1: %.3f, uncertain_f1: %.3f, positive_f1: %.3f' % (CONDITIONS[j],
317
                                                                                                metrics['blank'][j],
318
                                                                                                metrics['negation'][j],
319
                                                                                                metrics['uncertain'][j],
320
                                                                                                metrics['positive'][j]))
321
322
    men_macro_avg = np.mean(metrics['mention'])
323
    neg_macro_avg = np.mean(metrics['negation'][:-1]) #No Finding has no negations
324
    unc_macro_avg = np.mean(metrics['uncertain'][:-2]) #No Finding, Support Devices have no uncertain labels in test set
325
    pos_macro_avg = np.mean(metrics['positive'])
326
    blank_macro_avg = np.mean(metrics['blank'])
327
        
328
    print("blank macro avg: %.3f, negation macro avg: %.3f, uncertain macro avg: %.3f, positive macro avg: %.3f" % (blank_macro_avg,
329
                                                                                                                    neg_macro_avg,
330
                                                                                                                    unc_macro_avg,
331
                                                                                                                    pos_macro_avg))
332
    print()
333
    for j in range(len(CONDITIONS)):
334
        print('%s mention_f1: %.3f' % (CONDITIONS[j], metrics['mention'][j]))
335
    print('mention macro avg: %.3f' % men_macro_avg)
336
    
337
338
def label_report_list(checkpoint_path, report_list):
339
    """ Evaluate model on list of reports.
340
    @param checkpoint_path (string): location of saved model checkpoint
341
    @param report_list (list): list of report impressions (string)
342
    """
343
    imp = pd.Series(report_list)
344
    imp = imp.str.strip()
345
    imp = imp.replace('\n',' ', regex=True)
346
    imp = imp.replace('[0-9]\.', '', regex=True)
347
    imp = imp.replace('\s+', ' ', regex=True)
348
    imp = imp.str.strip()
349
    
350
    model = bert_labeler()
351
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
352
    if torch.cuda.device_count() > 1:
353
        print("Using", torch.cuda.device_count(), "GPUs!")
354
        model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's
355
    model = model.to(device)
356
    checkpoint = torch.load(checkpoint_path)
357
    model.load_state_dict(checkpoint['model_state_dict'])
358
    model.eval()
359
360
    y_pred = []
361
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
362
    new_imps = tokenize(imp, tokenizer)
363
    with torch.no_grad():
364
        for imp in new_imps:
365
            # run forward prop
366
            imp = torch.LongTensor(imp)
367
            source = imp.view(1, len(imp))
368
            
369
            attention = torch.ones(len(imp))
370
            attention = attention.view(1, len(imp))
371
            out = model(source.to(device), attention.to(device))
372
373
            # get predictions
374
            result = {}
375
            for j in range(len(out)):
376
                curr_y_pred = out[j].argmax(dim=1) #shape is (1)
377
                result[CONDITIONS[j]] = CLASS_MAPPING[curr_y_pred.item()]
378
            y_pred.append(result)
379
    return y_pred
380