Switch to side-by-side view

--- a
+++ b/src/evaluation/explainability/bert.py
@@ -0,0 +1,129 @@
+# Base Dependencies
+# -----------------
+from pathlib import Path
+from os.path import join
+
+# Local Dependencies
+# ------------------
+from ml_models.bert import ClinicalBERTTokenizer
+
+# 3rd-Party Dependencies
+# ----------------------
+import torch
+from datasets import load_from_disk
+
+from transformers import BertForSequenceClassification, AutoConfig
+from bertviz import head_view, model_view
+
+# Constants
+# ---------
+from constants import CHECKPOINTS_CACHE_DIR, DDI_HF_TEST_PATH
+
+N_CHANGES = 50
+
+
+def run():
+    """This scripts stores the attention maps of the Clinical BERT model for the first 50 changes in the DDI test set."""
+    init_model_path = Path(
+        join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_5.ck")
+    )
+    end_model_path = Path(
+        join(CHECKPOINTS_CACHE_DIR, "al", "bert", "ddi", "model_6.ck")
+    )
+
+    head_views_output_folder = Path(
+        join("results", "ddi", "bert", "interpretability", "head_views")
+    )
+    model_views_output_folder = Path(
+        join("results", "ddi", "bert", "interpretability", "model_views")
+    )
+
+    # load dataset and tokenize
+    tokenizer = ClinicalBERTTokenizer()
+    test_dataset = load_from_disk(Path(join(DDI_HF_TEST_PATH, "bert")))
+
+    sentences = test_dataset["sentence"]
+    labels = test_dataset["label"]
+
+    # load BERT models
+    init_config = AutoConfig.from_pretrained(
+        pretrained_model_name_or_path=init_model_path
+    )
+    init_config.output_attentions = True
+
+    end_config = AutoConfig.from_pretrained(
+        pretrained_model_name_or_path=end_model_path
+    )
+    end_config.output_attentions = True
+
+    init_model = BertForSequenceClassification.from_pretrained(
+        pretrained_model_name_or_path=init_model_path, config=init_config
+    )
+    end_model = BertForSequenceClassification.from_pretrained(
+        pretrained_model_name_or_path=end_model_path, config=end_config
+    )
+
+    changes = []
+
+    for index, (sentence, label) in enumerate(zip(sentences, labels)):
+        if label > 0:
+            inputs = tokenizer.encode(sentence, return_tensors="pt")
+            init_outputs = init_model(inputs)
+            end_outputs = end_model(inputs)
+
+            init_y_pred = torch.argmax(init_outputs["logits"])
+            end_y_pred = torch.argmax(end_outputs["logits"])
+
+            if end_y_pred == label and init_y_pred != label:
+
+                tokens = tokenizer.convert_ids_to_tokens(inputs[0])
+                init_head_view = head_view(
+                    init_outputs["attentions"], tokens, html_action="return"
+                )
+                init_model_view = model_view(
+                    init_outputs["attentions"], tokens, html_action="return"
+                )
+                end_head_view = head_view(
+                    end_outputs["attentions"], tokens, html_action="return"
+                )
+                end_model_view = model_view(
+                    end_outputs["attentions"], tokens, html_action="return"
+                )
+
+                # Save the HTMLs object to file
+                file_path = Path(
+                    join(head_views_output_folder, str(index) + "_init.html")
+                )
+                with open(file_path, "w") as f:
+                    f.write(init_head_view.data)
+
+                file_path = Path(
+                    join(head_views_output_folder, str(index) + "_end.html")
+                )
+                with open(file_path, "w") as f:
+                    f.write(end_head_view.data)
+
+                file_path = Path(
+                    join(model_views_output_folder, str(index) + "_init.html")
+                )
+                with open(file_path, "w") as f:
+                    f.write(init_model_view.data)
+
+                file_path = Path(
+                    join(model_views_output_folder, str(index) + "_end.html")
+                )
+                with open(file_path, "w") as f:
+                    f.write(end_model_view.data)
+
+                changes.append(
+                    f"Index: {str(index)} Initial prediction: {init_y_pred} Final Prediction: {end_y_pred}"
+                )
+                if len(changes) == N_CHANGES:
+                    break
+
+    # save list of changes to file
+    with open(
+        join("results", "ddi", "bert", "interpretability", "changes.txt"), "w"
+    ) as f:
+        for item in changes:
+            f.write(f"{item}\n")