|
a |
|
b/chexbert/src/utils.py |
|
|
1 |
import copy |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
import pandas as pd |
|
|
5 |
import numpy as np |
|
|
6 |
import json |
|
|
7 |
from models.bert_labeler import bert_labeler |
|
|
8 |
from bert_tokenizer import tokenize |
|
|
9 |
from sklearn.metrics import f1_score, confusion_matrix |
|
|
10 |
from statsmodels.stats.inter_rater import cohens_kappa |
|
|
11 |
from transformers import BertTokenizer |
|
|
12 |
from constants import * |
|
|
13 |
|
|
|
14 |
def get_weighted_f1_weights(train_path_or_csv): |
|
|
15 |
"""Compute weights used to obtain the weighted average of |
|
|
16 |
mention, negation and uncertain f1 scores. |
|
|
17 |
@param train_path_or_csv: A path to the csv file or a dataframe |
|
|
18 |
|
|
|
19 |
@return weight_dict (dictionary): maps conditions to a list of weights, the order |
|
|
20 |
in the lists is negation, uncertain, positive |
|
|
21 |
""" |
|
|
22 |
if isinstance(train_path_or_csv, str): |
|
|
23 |
df = pd.read_csv(train_path_or_csv) |
|
|
24 |
else: |
|
|
25 |
df = train_path_or_csv |
|
|
26 |
df.replace(0, 2, inplace=True) |
|
|
27 |
df.replace(-1, 3, inplace=True) |
|
|
28 |
df.fillna(0, inplace=True) |
|
|
29 |
|
|
|
30 |
weight_dict = {} |
|
|
31 |
for cond in CONDITIONS: |
|
|
32 |
weights = [] |
|
|
33 |
col = df[cond] |
|
|
34 |
|
|
|
35 |
mask = col == 2 |
|
|
36 |
weights.append(mask.sum()) |
|
|
37 |
|
|
|
38 |
mask = col == 3 |
|
|
39 |
weights.append(mask.sum()) |
|
|
40 |
|
|
|
41 |
mask = col == 1 |
|
|
42 |
weights.append(mask.sum()) |
|
|
43 |
|
|
|
44 |
if np.sum(weights) > 0: |
|
|
45 |
weights = np.array(weights)/np.sum(weights) |
|
|
46 |
weight_dict[cond] = weights |
|
|
47 |
return weight_dict |
|
|
48 |
|
|
|
49 |
def weighted_avg(scores, weights): |
|
|
50 |
"""Compute weighted average of scores |
|
|
51 |
@param scores(List): the task scores |
|
|
52 |
@param weights (List): corresponding normalized weights |
|
|
53 |
|
|
|
54 |
@return (float): the weighted average of task scores |
|
|
55 |
""" |
|
|
56 |
return np.sum(np.array(scores) * np.array(weights)) |
|
|
57 |
|
|
|
58 |
def compute_train_weights(train_path): |
|
|
59 |
"""Compute class weights for rebalancing rare classes |
|
|
60 |
@param train_path (str): A path to the training csv file |
|
|
61 |
|
|
|
62 |
@returns weight_arr (torch.Tensor): Tensor of shape (train_set_size), containing |
|
|
63 |
the weight assigned to each training example |
|
|
64 |
""" |
|
|
65 |
df = pd.read_csv(train_path) |
|
|
66 |
cond_weights = {} |
|
|
67 |
for cond in CONDITIONS: |
|
|
68 |
col = df[cond] |
|
|
69 |
val_counts = col.value_counts() |
|
|
70 |
if cond != 'No Finding': |
|
|
71 |
weights = {} |
|
|
72 |
weights['0.0'] = len(df) / val_counts[0] |
|
|
73 |
weights['-1.0'] = len(df) / val_counts[-1] |
|
|
74 |
weights['1.0'] = len(df) / val_counts[1] |
|
|
75 |
weights['nan'] = len(df) / (len(df) - val_counts.sum()) |
|
|
76 |
else: |
|
|
77 |
weights = {} |
|
|
78 |
weights['1.0'] = len(df) / val_counts[1] |
|
|
79 |
weights['nan'] = len(df) / (len(df) - val_counts.sum()) |
|
|
80 |
|
|
|
81 |
cond_weights[cond] = weights |
|
|
82 |
|
|
|
83 |
weight_arr = torch.zeros(len(df)) |
|
|
84 |
for i in range(len(df)): #loop over training set |
|
|
85 |
for cond in CONDITIONS: #loop over all conditions |
|
|
86 |
label = str(df[cond].iloc[i]) |
|
|
87 |
weight_arr[i] += cond_weights[cond][label] #add weight for given class' label |
|
|
88 |
|
|
|
89 |
return weight_arr |
|
|
90 |
|
|
|
91 |
def generate_attention_masks(batch, source_lengths, device): |
|
|
92 |
"""Generate masks for padded batches to avoid self-attention over pad tokens |
|
|
93 |
@param batch (Tensor): tensor of token indices of shape (batch_size, max_len) |
|
|
94 |
where max_len is length of longest sequence in the batch |
|
|
95 |
@param source_lengths (List[Int]): List of actual lengths for each of the |
|
|
96 |
sequences in the batch |
|
|
97 |
@param device (torch.device): device on which data should be |
|
|
98 |
|
|
|
99 |
@returns masks (Tensor): Tensor of masks of shape (batch_size, max_len) |
|
|
100 |
""" |
|
|
101 |
masks = torch.ones(batch.size(0), batch.size(1), dtype=torch.float) |
|
|
102 |
for idx, src_len in enumerate(source_lengths): |
|
|
103 |
masks[idx, src_len:] = 0 |
|
|
104 |
return masks.to(device) |
|
|
105 |
|
|
|
106 |
def compute_mention_f1(y_true, y_pred): |
|
|
107 |
"""Compute the mention F1 score as in CheXpert paper |
|
|
108 |
@param y_true (list): List of 14 tensors each of shape (dev_set_size) |
|
|
109 |
@param y_pred (list): Same as y_true but for model predictions |
|
|
110 |
|
|
|
111 |
@returns res (list): List of 14 scalars |
|
|
112 |
""" |
|
|
113 |
for j in range(len(y_true)): |
|
|
114 |
y_true[j][y_true[j] == 2] = 1 |
|
|
115 |
y_true[j][y_true[j] == 3] = 1 |
|
|
116 |
y_pred[j][y_pred[j] == 2] = 1 |
|
|
117 |
y_pred[j][y_pred[j] == 3] = 1 |
|
|
118 |
|
|
|
119 |
res = [] |
|
|
120 |
for j in range(len(y_true)): |
|
|
121 |
res.append(f1_score(y_true[j], y_pred[j], pos_label=1)) |
|
|
122 |
|
|
|
123 |
return res |
|
|
124 |
|
|
|
125 |
def compute_blank_f1(y_true, y_pred): |
|
|
126 |
"""Compute the blank F1 score |
|
|
127 |
@param y_true (list): List of 14 tensors each of shape (dev_set_size) |
|
|
128 |
@param y_pred (list): Same as y_true but for model predictions |
|
|
129 |
|
|
|
130 |
@returns res (list): List of 14 scalars |
|
|
131 |
""" |
|
|
132 |
for j in range(len(y_true)): |
|
|
133 |
y_true[j][y_true[j] == 2] = 1 |
|
|
134 |
y_true[j][y_true[j] == 3] = 1 |
|
|
135 |
y_pred[j][y_pred[j] == 2] = 1 |
|
|
136 |
y_pred[j][y_pred[j] == 3] = 1 |
|
|
137 |
|
|
|
138 |
res = [] |
|
|
139 |
for j in range(len(y_true)): |
|
|
140 |
res.append(f1_score(y_true[j], y_pred[j], pos_label=0)) |
|
|
141 |
|
|
|
142 |
return res |
|
|
143 |
|
|
|
144 |
def compute_negation_f1(y_true, y_pred): |
|
|
145 |
"""Compute the negation F1 score as in CheXpert paper |
|
|
146 |
@param y_true (list): List of 14 tensors each of shape (dev_set_size) |
|
|
147 |
@param y_pred (list): Same as y_true but for model predictions |
|
|
148 |
|
|
|
149 |
@returns res (list): List of 14 scalars |
|
|
150 |
""" |
|
|
151 |
for j in range(len(y_true)): |
|
|
152 |
y_true[j][y_true[j] == 3] = 0 |
|
|
153 |
y_true[j][y_true[j] == 1] = 0 |
|
|
154 |
y_pred[j][y_pred[j] == 3] = 0 |
|
|
155 |
y_pred[j][y_pred[j] == 1] = 0 |
|
|
156 |
|
|
|
157 |
res = [] |
|
|
158 |
for j in range(len(y_true)-1): |
|
|
159 |
res.append(f1_score(y_true[j], y_pred[j], pos_label=2)) |
|
|
160 |
|
|
|
161 |
res.append(0) #No Finding gets score of zero |
|
|
162 |
return res |
|
|
163 |
|
|
|
164 |
def compute_positive_f1(y_true, y_pred): |
|
|
165 |
"""Compute the positive F1 score |
|
|
166 |
@param y_true (list): List of 14 tensors each of shape (dev_set_size) |
|
|
167 |
@param y_pred (list): Same as y_true but for model predictions |
|
|
168 |
|
|
|
169 |
@returns res (list): List of 14 scalars |
|
|
170 |
""" |
|
|
171 |
for j in range(len(y_true)): |
|
|
172 |
y_true[j][y_true[j] == 3] = 0 |
|
|
173 |
y_true[j][y_true[j] == 2] = 0 |
|
|
174 |
y_pred[j][y_pred[j] == 3] = 0 |
|
|
175 |
y_pred[j][y_pred[j] == 2] = 0 |
|
|
176 |
|
|
|
177 |
res = [] |
|
|
178 |
for j in range(len(y_true)): |
|
|
179 |
res.append(f1_score(y_true[j], y_pred[j], pos_label=1)) |
|
|
180 |
|
|
|
181 |
return res |
|
|
182 |
|
|
|
183 |
def compute_uncertain_f1(y_true, y_pred): |
|
|
184 |
"""Compute the negation F1 score as in CheXpert paper |
|
|
185 |
@param y_true (list): List of 14 tensors each of shape (dev_set_size) |
|
|
186 |
@param y_pred (list): Same as y_true but for model predictions |
|
|
187 |
|
|
|
188 |
@returns res (list): List of 14 scalars |
|
|
189 |
""" |
|
|
190 |
for j in range(len(y_true)): |
|
|
191 |
y_true[j][y_true[j] == 2] = 0 |
|
|
192 |
y_true[j][y_true[j] == 1] = 0 |
|
|
193 |
y_pred[j][y_pred[j] == 2] = 0 |
|
|
194 |
y_pred[j][y_pred[j] == 1] = 0 |
|
|
195 |
|
|
|
196 |
res = [] |
|
|
197 |
for j in range(len(y_true)-1): |
|
|
198 |
res.append(f1_score(y_true[j], y_pred[j], pos_label=3)) |
|
|
199 |
|
|
|
200 |
res.append(0) #No Finding gets a score of zero |
|
|
201 |
return res |
|
|
202 |
|
|
|
203 |
def evaluate(model, dev_loader, device, f1_weights, return_pred=False): |
|
|
204 |
""" Function to evaluate the current model weights |
|
|
205 |
@param model (nn.Module): the labeler module |
|
|
206 |
@param dev_loader (torch.utils.data.DataLoader): dataloader for dev set |
|
|
207 |
@param device (torch.device): device on which data should be |
|
|
208 |
@param f1_weights (dictionary): dictionary mapping conditions to f1 |
|
|
209 |
task weights |
|
|
210 |
@param return_pred (bool): whether to return predictions or not |
|
|
211 |
|
|
|
212 |
@returns res_dict (dictionary): dictionary with keys 'blank', 'mention', 'negation', |
|
|
213 |
'uncertain', 'positive' and 'weighted', with values |
|
|
214 |
being lists of length 14 with each element in the |
|
|
215 |
lists as a scalar. If return_pred is true then a |
|
|
216 |
tuple is returned with the aforementioned dictionary |
|
|
217 |
as the first item, a list of predictions as the |
|
|
218 |
second item, and a list of ground truth as the |
|
|
219 |
third item |
|
|
220 |
""" |
|
|
221 |
|
|
|
222 |
was_training = model.training |
|
|
223 |
model.eval() |
|
|
224 |
y_pred = [[] for _ in range(len(CONDITIONS))] |
|
|
225 |
y_true = [[] for _ in range(len(CONDITIONS))] |
|
|
226 |
|
|
|
227 |
with torch.no_grad(): |
|
|
228 |
for i, data in enumerate(dev_loader, 0): |
|
|
229 |
batch = data['imp'] #(batch_size, max_len) |
|
|
230 |
batch = batch.to(device) |
|
|
231 |
label = data['label'] #(batch_size, 14) |
|
|
232 |
label = label.permute(1, 0).to(device) |
|
|
233 |
src_len = data['len'] |
|
|
234 |
batch_size = batch.shape[0] |
|
|
235 |
attn_mask = generate_attention_masks(batch, src_len, device) |
|
|
236 |
|
|
|
237 |
out = model(batch, attn_mask) |
|
|
238 |
|
|
|
239 |
for j in range(len(out)): |
|
|
240 |
out[j] = out[j].to('cpu') #move to cpu for sklearn |
|
|
241 |
curr_y_pred = out[j].argmax(dim=1) #shape is (batch_size) |
|
|
242 |
y_pred[j].append(curr_y_pred) |
|
|
243 |
y_true[j].append(label[j].to('cpu')) |
|
|
244 |
|
|
|
245 |
if (i+1) % 200 == 0: |
|
|
246 |
print('Evaluation batch no: ', i+1) |
|
|
247 |
|
|
|
248 |
for j in range(len(y_true)): |
|
|
249 |
y_true[j] = torch.cat(y_true[j], dim=0) |
|
|
250 |
y_pred[j] = torch.cat(y_pred[j], dim=0) |
|
|
251 |
|
|
|
252 |
if was_training: |
|
|
253 |
model.train() |
|
|
254 |
|
|
|
255 |
mention_f1 = compute_mention_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) |
|
|
256 |
negation_f1 = compute_negation_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) |
|
|
257 |
uncertain_f1 = compute_uncertain_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) |
|
|
258 |
positive_f1 = compute_positive_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) |
|
|
259 |
blank_f1 = compute_blank_f1(copy.deepcopy(y_true), copy.deepcopy(y_pred)) |
|
|
260 |
|
|
|
261 |
weighted = [] |
|
|
262 |
kappas = [] |
|
|
263 |
for j in range(len(y_pred)): |
|
|
264 |
cond = CONDITIONS[j] |
|
|
265 |
avg = weighted_avg([negation_f1[j], uncertain_f1[j], positive_f1[j]], f1_weights[cond]) |
|
|
266 |
weighted.append(avg) |
|
|
267 |
|
|
|
268 |
mat = confusion_matrix(y_true[j], y_pred[j]) |
|
|
269 |
kappas.append(cohens_kappa(mat, return_results=False)) |
|
|
270 |
|
|
|
271 |
res_dict = {'mention': mention_f1, |
|
|
272 |
'blank': blank_f1, |
|
|
273 |
'negation': negation_f1, |
|
|
274 |
'uncertain': uncertain_f1, |
|
|
275 |
'positive': positive_f1, |
|
|
276 |
'weighted': weighted, |
|
|
277 |
'kappa': kappas} |
|
|
278 |
|
|
|
279 |
if return_pred: |
|
|
280 |
return res_dict, y_pred, y_true |
|
|
281 |
else: |
|
|
282 |
return res_dict |
|
|
283 |
|
|
|
284 |
def test(model, checkpoint_path, test_ld, f1_weights): |
|
|
285 |
"""Evaluate model on test set. |
|
|
286 |
@param model (nn.Module): labeler module |
|
|
287 |
@param checkpoint_path (string): location of saved model checkpoint |
|
|
288 |
@param test_ld (dataloader): dataloader for test set |
|
|
289 |
@param f1_weights (dictionary): maps conditions to f1 task weights |
|
|
290 |
""" |
|
|
291 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
292 |
if torch.cuda.device_count() > 1: |
|
|
293 |
print("Using", torch.cuda.device_count(), "GPUs!") |
|
|
294 |
model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's |
|
|
295 |
model = model.to(device) |
|
|
296 |
|
|
|
297 |
checkpoint = torch.load(checkpoint_path) |
|
|
298 |
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
299 |
|
|
|
300 |
print("Doing evaluation on test set\n") |
|
|
301 |
metrics = evaluate(model, test_ld, device, f1_weights) |
|
|
302 |
weighted = metrics['weighted'] |
|
|
303 |
kappas = metrics['kappa'] |
|
|
304 |
|
|
|
305 |
for j in range(len(CONDITIONS)): |
|
|
306 |
print('%s kappa: %.3f' % (CONDITIONS[j], kappas[j])) |
|
|
307 |
print('average: %.3f' % np.mean(kappas)) |
|
|
308 |
|
|
|
309 |
print() |
|
|
310 |
for j in range(len(CONDITIONS)): |
|
|
311 |
print('%s weighted_f1: %.3f' % (CONDITIONS[j], weighted[j])) |
|
|
312 |
print('average of weighted_f1: %.3f' % (np.mean(weighted))) |
|
|
313 |
|
|
|
314 |
print() |
|
|
315 |
for j in range(len(CONDITIONS)): |
|
|
316 |
print('%s blank_f1: %.3f, negation_f1: %.3f, uncertain_f1: %.3f, positive_f1: %.3f' % (CONDITIONS[j], |
|
|
317 |
metrics['blank'][j], |
|
|
318 |
metrics['negation'][j], |
|
|
319 |
metrics['uncertain'][j], |
|
|
320 |
metrics['positive'][j])) |
|
|
321 |
|
|
|
322 |
men_macro_avg = np.mean(metrics['mention']) |
|
|
323 |
neg_macro_avg = np.mean(metrics['negation'][:-1]) #No Finding has no negations |
|
|
324 |
unc_macro_avg = np.mean(metrics['uncertain'][:-2]) #No Finding, Support Devices have no uncertain labels in test set |
|
|
325 |
pos_macro_avg = np.mean(metrics['positive']) |
|
|
326 |
blank_macro_avg = np.mean(metrics['blank']) |
|
|
327 |
|
|
|
328 |
print("blank macro avg: %.3f, negation macro avg: %.3f, uncertain macro avg: %.3f, positive macro avg: %.3f" % (blank_macro_avg, |
|
|
329 |
neg_macro_avg, |
|
|
330 |
unc_macro_avg, |
|
|
331 |
pos_macro_avg)) |
|
|
332 |
print() |
|
|
333 |
for j in range(len(CONDITIONS)): |
|
|
334 |
print('%s mention_f1: %.3f' % (CONDITIONS[j], metrics['mention'][j])) |
|
|
335 |
print('mention macro avg: %.3f' % men_macro_avg) |
|
|
336 |
|
|
|
337 |
|
|
|
338 |
def label_report_list(checkpoint_path, report_list): |
|
|
339 |
""" Evaluate model on list of reports. |
|
|
340 |
@param checkpoint_path (string): location of saved model checkpoint |
|
|
341 |
@param report_list (list): list of report impressions (string) |
|
|
342 |
""" |
|
|
343 |
imp = pd.Series(report_list) |
|
|
344 |
imp = imp.str.strip() |
|
|
345 |
imp = imp.replace('\n',' ', regex=True) |
|
|
346 |
imp = imp.replace('[0-9]\.', '', regex=True) |
|
|
347 |
imp = imp.replace('\s+', ' ', regex=True) |
|
|
348 |
imp = imp.str.strip() |
|
|
349 |
|
|
|
350 |
model = bert_labeler() |
|
|
351 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
352 |
if torch.cuda.device_count() > 1: |
|
|
353 |
print("Using", torch.cuda.device_count(), "GPUs!") |
|
|
354 |
model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's |
|
|
355 |
model = model.to(device) |
|
|
356 |
checkpoint = torch.load(checkpoint_path) |
|
|
357 |
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
358 |
model.eval() |
|
|
359 |
|
|
|
360 |
y_pred = [] |
|
|
361 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
362 |
new_imps = tokenize(imp, tokenizer) |
|
|
363 |
with torch.no_grad(): |
|
|
364 |
for imp in new_imps: |
|
|
365 |
# run forward prop |
|
|
366 |
imp = torch.LongTensor(imp) |
|
|
367 |
source = imp.view(1, len(imp)) |
|
|
368 |
|
|
|
369 |
attention = torch.ones(len(imp)) |
|
|
370 |
attention = attention.view(1, len(imp)) |
|
|
371 |
out = model(source.to(device), attention.to(device)) |
|
|
372 |
|
|
|
373 |
# get predictions |
|
|
374 |
result = {} |
|
|
375 |
for j in range(len(out)): |
|
|
376 |
curr_y_pred = out[j].argmax(dim=1) #shape is (1) |
|
|
377 |
result[CONDITIONS[j]] = CLASS_MAPPING[curr_y_pred.item()] |
|
|
378 |
y_pred.append(result) |
|
|
379 |
return y_pred |
|
|
380 |
|