|
a |
|
b/src/evaluation/explainability/bert.py |
|
|
1 |
# Base Dependencies |
|
|
2 |
# ----------------- |
|
|
3 |
from pathlib import Path |
|
|
4 |
from os.path import join |
|
|
5 |
|
|
|
6 |
# Local Dependencies |
|
|
7 |
# ------------------ |
|
|
8 |
from ml_models.bert import ClinicalBERTTokenizer |
|
|
9 |
|
|
|
10 |
# 3rd-Party Dependencies |
|
|
11 |
# ---------------------- |
|
|
12 |
import torch |
|
|
13 |
from datasets import load_from_disk |
|
|
14 |
|
|
|
15 |
from transformers import BertForSequenceClassification, AutoConfig |
|
|
16 |
from bertviz import head_view, model_view |
|
|
17 |
|
|
|
18 |
# Constants |
|
|
19 |
# --------- |
|
|
20 |
from constants import CHECKPOINTS_CACHE_DIR, DDI_HF_TEST_PATH |
|
|
21 |
|
|
|
22 |
N_CHANGES = 50 |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def run(): |
|
|
26 |
"""This scripts stores the attention maps of the Clinical BERT model for the first 50 changes in the DDI test set.""" |
|
|
27 |
init_model_path = Path( |
|
|
28 |
join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_5.ck") |
|
|
29 |
) |
|
|
30 |
end_model_path = Path( |
|
|
31 |
join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_6.ck") |
|
|
32 |
) |
|
|
33 |
|
|
|
34 |
head_views_output_folder = Path( |
|
|
35 |
join("results", "ddi", "bert", "interpretability", "head_views") |
|
|
36 |
) |
|
|
37 |
model_views_output_folder = Path( |
|
|
38 |
join("results", "ddi", "bert", "interpretability", "model_views") |
|
|
39 |
) |
|
|
40 |
|
|
|
41 |
# load dataset and tokenize |
|
|
42 |
tokenizer = ClinicalBERTTokenizer() |
|
|
43 |
test_dataset = load_from_disk(Path(join(DDI_HF_TEST_PATH, "bert"))) |
|
|
44 |
|
|
|
45 |
sentences = test_dataset["sentence"] |
|
|
46 |
labels = test_dataset["label"] |
|
|
47 |
|
|
|
48 |
# load BERT models |
|
|
49 |
init_config = AutoConfig.from_pretrained( |
|
|
50 |
pretrained_model_name_or_path=init_model_path |
|
|
51 |
) |
|
|
52 |
init_config.output_attentions = True |
|
|
53 |
|
|
|
54 |
end_config = AutoConfig.from_pretrained( |
|
|
55 |
pretrained_model_name_or_path=end_model_path |
|
|
56 |
) |
|
|
57 |
end_config.output_attentions = True |
|
|
58 |
|
|
|
59 |
init_model = BertForSequenceClassification.from_pretrained( |
|
|
60 |
pretrained_model_name_or_path=init_model_path, config=init_config |
|
|
61 |
) |
|
|
62 |
end_model = BertForSequenceClassification.from_pretrained( |
|
|
63 |
pretrained_model_name_or_path=end_model_path, config=end_config |
|
|
64 |
) |
|
|
65 |
|
|
|
66 |
changes = [] |
|
|
67 |
|
|
|
68 |
for index, (sentence, label) in enumerate(zip(sentences, labels)): |
|
|
69 |
if label > 0: |
|
|
70 |
inputs = tokenizer.encode(sentence, return_tensors="pt") |
|
|
71 |
init_outputs = init_model(inputs) |
|
|
72 |
end_outputs = end_model(inputs) |
|
|
73 |
|
|
|
74 |
init_y_pred = torch.argmax(init_outputs["logits"]) |
|
|
75 |
end_y_pred = torch.argmax(end_outputs["logits"]) |
|
|
76 |
|
|
|
77 |
if end_y_pred == label and init_y_pred != label: |
|
|
78 |
|
|
|
79 |
tokens = tokenizer.convert_ids_to_tokens(inputs[0]) |
|
|
80 |
init_head_view = head_view( |
|
|
81 |
init_outputs["attentions"], tokens, html_action="return" |
|
|
82 |
) |
|
|
83 |
init_model_view = model_view( |
|
|
84 |
init_outputs["attentions"], tokens, html_action="return" |
|
|
85 |
) |
|
|
86 |
end_head_view = head_view( |
|
|
87 |
end_outputs["attentions"], tokens, html_action="return" |
|
|
88 |
) |
|
|
89 |
end_model_view = model_view( |
|
|
90 |
end_outputs["attentions"], tokens, html_action="return" |
|
|
91 |
) |
|
|
92 |
|
|
|
93 |
# Save the HTMLs object to file |
|
|
94 |
file_path = Path( |
|
|
95 |
join(head_views_output_folder, str(index) + "_init.html") |
|
|
96 |
) |
|
|
97 |
with open(file_path, "w") as f: |
|
|
98 |
f.write(init_head_view.data) |
|
|
99 |
|
|
|
100 |
file_path = Path( |
|
|
101 |
join(head_views_output_folder, str(index) + "_end.html") |
|
|
102 |
) |
|
|
103 |
with open(file_path, "w") as f: |
|
|
104 |
f.write(end_head_view.data) |
|
|
105 |
|
|
|
106 |
file_path = Path( |
|
|
107 |
join(model_views_output_folder, str(index) + "_init.html") |
|
|
108 |
) |
|
|
109 |
with open(file_path, "w") as f: |
|
|
110 |
f.write(init_model_view.data) |
|
|
111 |
|
|
|
112 |
file_path = Path( |
|
|
113 |
join(model_views_output_folder, str(index) + "_end.html") |
|
|
114 |
) |
|
|
115 |
with open(file_path, "w") as f: |
|
|
116 |
f.write(end_model_view.data) |
|
|
117 |
|
|
|
118 |
changes.append( |
|
|
119 |
f"Index: {str(index)} Initial prediction: {init_y_pred} Final Prediction: {end_y_pred}" |
|
|
120 |
) |
|
|
121 |
if len(changes) == N_CHANGES: |
|
|
122 |
break |
|
|
123 |
|
|
|
124 |
# save list of changes to file |
|
|
125 |
with open( |
|
|
126 |
join("results", "ddi", "bert", "interpretability", "changes.txt"), "w" |
|
|
127 |
) as f: |
|
|
128 |
for item in changes: |
|
|
129 |
f.write(f"{item}\n") |