a b/src/evaluation/explainability/bert.py
1
# Base Dependencies
2
# -----------------
3
from pathlib import Path
4
from os.path import join
5
6
# Local Dependencies
7
# ------------------
8
from ml_models.bert import ClinicalBERTTokenizer
9
10
# 3rd-Party Dependencies
11
# ----------------------
12
import torch
13
from datasets import load_from_disk
14
15
from transformers import BertForSequenceClassification, AutoConfig
16
from bertviz import head_view, model_view
17
18
# Constants
19
# ---------
20
from constants import CHECKPOINTS_CACHE_DIR, DDI_HF_TEST_PATH
21
22
N_CHANGES = 50
23
24
25
def run():
26
    """This scripts stores the attention maps of the Clinical BERT model for the first 50 changes in the DDI test set."""
27
    init_model_path = Path(
28
        join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_5.ck")
29
    )
30
    end_model_path = Path(
31
        join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_6.ck")
32
    )
33
34
    head_views_output_folder = Path(
35
        join("results", "ddi", "bert", "interpretability", "head_views")
36
    )
37
    model_views_output_folder = Path(
38
        join("results", "ddi", "bert", "interpretability", "model_views")
39
    )
40
41
    # load dataset and tokenize
42
    tokenizer = ClinicalBERTTokenizer()
43
    test_dataset = load_from_disk(Path(join(DDI_HF_TEST_PATH, "bert")))
44
45
    sentences = test_dataset["sentence"]
46
    labels = test_dataset["label"]
47
48
    # load BERT models
49
    init_config = AutoConfig.from_pretrained(
50
        pretrained_model_name_or_path=init_model_path
51
    )
52
    init_config.output_attentions = True
53
54
    end_config = AutoConfig.from_pretrained(
55
        pretrained_model_name_or_path=end_model_path
56
    )
57
    end_config.output_attentions = True
58
59
    init_model = BertForSequenceClassification.from_pretrained(
60
        pretrained_model_name_or_path=init_model_path, config=init_config
61
    )
62
    end_model = BertForSequenceClassification.from_pretrained(
63
        pretrained_model_name_or_path=end_model_path, config=end_config
64
    )
65
66
    changes = []
67
68
    for index, (sentence, label) in enumerate(zip(sentences, labels)):
69
        if label > 0:
70
            inputs = tokenizer.encode(sentence, return_tensors="pt")
71
            init_outputs = init_model(inputs)
72
            end_outputs = end_model(inputs)
73
74
            init_y_pred = torch.argmax(init_outputs["logits"])
75
            end_y_pred = torch.argmax(end_outputs["logits"])
76
77
            if end_y_pred == label and init_y_pred != label:
78
79
                tokens = tokenizer.convert_ids_to_tokens(inputs[0])
80
                init_head_view = head_view(
81
                    init_outputs["attentions"], tokens, html_action="return"
82
                )
83
                init_model_view = model_view(
84
                    init_outputs["attentions"], tokens, html_action="return"
85
                )
86
                end_head_view = head_view(
87
                    end_outputs["attentions"], tokens, html_action="return"
88
                )
89
                end_model_view = model_view(
90
                    end_outputs["attentions"], tokens, html_action="return"
91
                )
92
93
                # Save the HTMLs object to file
94
                file_path = Path(
95
                    join(head_views_output_folder, str(index) + "_init.html")
96
                )
97
                with open(file_path, "w") as f:
98
                    f.write(init_head_view.data)
99
100
                file_path = Path(
101
                    join(head_views_output_folder, str(index) + "_end.html")
102
                )
103
                with open(file_path, "w") as f:
104
                    f.write(end_head_view.data)
105
106
                file_path = Path(
107
                    join(model_views_output_folder, str(index) + "_init.html")
108
                )
109
                with open(file_path, "w") as f:
110
                    f.write(init_model_view.data)
111
112
                file_path = Path(
113
                    join(model_views_output_folder, str(index) + "_end.html")
114
                )
115
                with open(file_path, "w") as f:
116
                    f.write(end_model_view.data)
117
118
                changes.append(
119
                    f"Index: {str(index)} Initial prediction: {init_y_pred} Final Prediction: {end_y_pred}"
120
                )
121
                if len(changes) == N_CHANGES:
122
                    break
123
124
    # save list of changes to file
125
    with open(
126
        join("results", "ddi", "bert", "interpretability", "changes.txt"), "w"
127
    ) as f:
128
        for item in changes:
129
            f.write(f"{item}\n")