Diff of /notebooks/evaluate.ipynb [000000] .. [248dc9]

Switch to side-by-side view

--- a
+++ b/notebooks/evaluate.ipynb
@@ -0,0 +1,167 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "id": "21d5a4b0-d594-48c3-ad3e-57f1f1b5c29d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "75b28dce-720a-454a-a130-17c391b20c70",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def parse_entities(record, key):\n",
+    "    \"\"\"\n",
+    "    Parse entities from a record in the data.\n",
+    "    \n",
+    "    Args:\n",
+    "    - record: A dictionary representing a single record in the data.\n",
+    "    - key: The key to extract data from ('output' or 'prediction').\n",
+    "\n",
+    "    Returns:\n",
+    "    - A set containing the extracted entities.\n",
+    "    \"\"\"\n",
+    "    # Convert the string into a dictionary\n",
+    "    entities = json.loads(record[key])\n",
+    "    \n",
+    "    # Initialize a set to store the entities\n",
+    "    flattened_entities = set()\n",
+    "    \n",
+    "    # Extract the entities\n",
+    "    for value in entities.values():\n",
+    "        # Check if item is a list of adverse events\n",
+    "        if isinstance(value, list):\n",
+    "            flattened_entities.update(value)\n",
+    "        # Parse drug names\n",
+    "        else:\n",
+    "            flattened_entities.add(value)\n",
+    "    \n",
+    "    return flattened_entities"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "id": "be48635c-7e26-488e-a2bc-0da52e08e752",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def calculate_precision_recall(data):\n",
+    "    \"\"\"\n",
+    "    Calculate precision and recall from the data.\n",
+    "\n",
+    "    Args:\n",
+    "    - data: A list of dictionaries, each containing 'output' and 'prediction'.\n",
+    "\n",
+    "    Returns:\n",
+    "    - precision: The precision of the predictions.\n",
+    "    - recall: The recall of the predictions.\n",
+    "    \"\"\"\n",
+    "    # Initialize variables\n",
+    "    true_positives = 0\n",
+    "    false_positives = 0\n",
+    "    false_negatives = 0\n",
+    "    \n",
+    "    # parse all the samples in the test dataset\n",
+    "    for record in data:\n",
+    "        # Extract ground truths\n",
+    "        gt_entities = parse_entities(record, 'output')\n",
+    "        # Extract predictions\n",
+    "        pred_entities = parse_entities(record, 'prediction')\n",
+    "    \n",
+    "        # Calculate TP, FP, FN for each sample in test data\n",
+    "        true_positives += len(gt_entities & pred_entities)\n",
+    "        false_positives += len(pred_entities - gt_entities)\n",
+    "        false_negatives += len(gt_entities - pred_entities)\n",
+    "    \n",
+    "    # Calculate Precision\n",
+    "    precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0\n",
+    "    # Calculate Recall\n",
+    "    recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0\n",
+    "\n",
+    "    return precision, recall"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "a16f2c86-2cff-4c4f-893d-60a9f06023b8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "prediction_files = [\n",
+    "                './data/predictions-llama2-adapter.json',   # Llama-2 Adapter\n",
+    "                './data/predictions-stablelm-adapter.json', # Stable-LM Adapter\n",
+    "                './data/predictions-llama2-lora.json',      # Llama-2 Lora\n",
+    "                './data/predictions-stablelm-lora.json',    # Stable-LM Lora\n",
+    "                ]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "c95cb283-e1ac-45f5-aee3-b29fd932aba5",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[INFO] llama2-adapter ----> Precision: 0.886 Recall: 0.891\n",
+      "[INFO] stablelm-adapter ----> Precision: 0.854 Recall: 0.839\n",
+      "[INFO] llama2-lora ----> Precision: 0.871 Recall: 0.851\n",
+      "[INFO] stablelm-lora ----> Precision: 0.818 Recall: 0.828\n"
+     ]
+    }
+   ],
+   "source": [
+    "for filename in prediction_files:\n",
+    "    # Get model name and tune type\n",
+    "    file_components = filename.split('-')\n",
+    "    \n",
+    "    # Load the predcitions JSON data\n",
+    "    with open(filename, 'r') as file:\n",
+    "        data = json.load(file)\n",
+    "\n",
+    "    precision, recall = calculate_precision_recall(data)\n",
+    "    print(f\"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7ff32790-3e03-464e-bd03-4cb29a244f37",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "scrape",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}