--- a +++ b/src/evaluation/explainability/bert.py @@ -0,0 +1,129 @@ +# Base Dependencies +# ----------------- +from pathlib import Path +from os.path import join + +# Local Dependencies +# ------------------ +from ml_models.bert import ClinicalBERTTokenizer + +# 3rd-Party Dependencies +# ---------------------- +import torch +from datasets import load_from_disk + +from transformers import BertForSequenceClassification, AutoConfig +from bertviz import head_view, model_view + +# Constants +# --------- +from constants import CHECKPOINTS_CACHE_DIR, DDI_HF_TEST_PATH + +N_CHANGES = 50 + + +def run(): + """This scripts stores the attention maps of the Clinical BERT model for the first 50 changes in the DDI test set.""" + init_model_path = Path( + join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_5.ck") + ) + end_model_path = Path( + join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_6.ck") + ) + + head_views_output_folder = Path( + join("results", "ddi", "bert", "interpretability", "head_views") + ) + model_views_output_folder = Path( + join("results", "ddi", "bert", "interpretability", "model_views") + ) + + # load dataset and tokenize + tokenizer = ClinicalBERTTokenizer() + test_dataset = load_from_disk(Path(join(DDI_HF_TEST_PATH, "bert"))) + + sentences = test_dataset["sentence"] + labels = test_dataset["label"] + + # load BERT models + init_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=init_model_path + ) + init_config.output_attentions = True + + end_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=end_model_path + ) + end_config.output_attentions = True + + init_model = BertForSequenceClassification.from_pretrained( + pretrained_model_name_or_path=init_model_path, config=init_config + ) + end_model = BertForSequenceClassification.from_pretrained( + pretrained_model_name_or_path=end_model_path, config=end_config + ) + + changes = [] + + for index, (sentence, label) in enumerate(zip(sentences, labels)): + if label > 0: + inputs = tokenizer.encode(sentence, return_tensors="pt") + init_outputs = init_model(inputs) + end_outputs = end_model(inputs) + + init_y_pred = torch.argmax(init_outputs["logits"]) + end_y_pred = torch.argmax(end_outputs["logits"]) + + if end_y_pred == label and init_y_pred != label: + + tokens = tokenizer.convert_ids_to_tokens(inputs[0]) + init_head_view = head_view( + init_outputs["attentions"], tokens, html_action="return" + ) + init_model_view = model_view( + init_outputs["attentions"], tokens, html_action="return" + ) + end_head_view = head_view( + end_outputs["attentions"], tokens, html_action="return" + ) + end_model_view = model_view( + end_outputs["attentions"], tokens, html_action="return" + ) + + # Save the HTMLs object to file + file_path = Path( + join(head_views_output_folder, str(index) + "_init.html") + ) + with open(file_path, "w") as f: + f.write(init_head_view.data) + + file_path = Path( + join(head_views_output_folder, str(index) + "_end.html") + ) + with open(file_path, "w") as f: + f.write(end_head_view.data) + + file_path = Path( + join(model_views_output_folder, str(index) + "_init.html") + ) + with open(file_path, "w") as f: + f.write(init_model_view.data) + + file_path = Path( + join(model_views_output_folder, str(index) + "_end.html") + ) + with open(file_path, "w") as f: + f.write(end_model_view.data) + + changes.append( + f"Index: {str(index)} Initial prediction: {init_y_pred} Final Prediction: {end_y_pred}" + ) + if len(changes) == N_CHANGES: + break + + # save list of changes to file + with open( + join("results", "ddi", "bert", "interpretability", "changes.txt"), "w" + ) as f: + for item in changes: + f.write(f"{item}\n")