[248dc9]: / notebooks / evaluate.ipynb

Download this file

168 lines (167 with data), 5.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "21d5a4b0-d594-48c3-ad3e-57f1f1b5c29d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "75b28dce-720a-454a-a130-17c391b20c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_entities(record, key):\n",
    "    \"\"\"\n",
    "    Parse entities from a record in the data.\n",
    "    \n",
    "    Args:\n",
    "    - record: A dictionary representing a single record in the data.\n",
    "    - key: The key to extract data from ('output' or 'prediction').\n",
    "\n",
    "    Returns:\n",
    "    - A set containing the extracted entities.\n",
    "    \"\"\"\n",
    "    # Convert the string into a dictionary\n",
    "    entities = json.loads(record[key])\n",
    "    \n",
    "    # Initialize a set to store the entities\n",
    "    flattened_entities = set()\n",
    "    \n",
    "    # Extract the entities\n",
    "    for value in entities.values():\n",
    "        # Check if item is a list of adverse events\n",
    "        if isinstance(value, list):\n",
    "            flattened_entities.update(value)\n",
    "        # Parse drug names\n",
    "        else:\n",
    "            flattened_entities.add(value)\n",
    "    \n",
    "    return flattened_entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "be48635c-7e26-488e-a2bc-0da52e08e752",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_precision_recall(data):\n",
    "    \"\"\"\n",
    "    Calculate precision and recall from the data.\n",
    "\n",
    "    Args:\n",
    "    - data: A list of dictionaries, each containing 'output' and 'prediction'.\n",
    "\n",
    "    Returns:\n",
    "    - precision: The precision of the predictions.\n",
    "    - recall: The recall of the predictions.\n",
    "    \"\"\"\n",
    "    # Initialize variables\n",
    "    true_positives = 0\n",
    "    false_positives = 0\n",
    "    false_negatives = 0\n",
    "    \n",
    "    # parse all the samples in the test dataset\n",
    "    for record in data:\n",
    "        # Extract ground truths\n",
    "        gt_entities = parse_entities(record, 'output')\n",
    "        # Extract predictions\n",
    "        pred_entities = parse_entities(record, 'prediction')\n",
    "    \n",
    "        # Calculate TP, FP, FN for each sample in test data\n",
    "        true_positives += len(gt_entities & pred_entities)\n",
    "        false_positives += len(pred_entities - gt_entities)\n",
    "        false_negatives += len(gt_entities - pred_entities)\n",
    "    \n",
    "    # Calculate Precision\n",
    "    precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0\n",
    "    # Calculate Recall\n",
    "    recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0\n",
    "\n",
    "    return precision, recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "a16f2c86-2cff-4c4f-893d-60a9f06023b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_files = [\n",
    "                './data/predictions-llama2-adapter.json',   # Llama-2 Adapter\n",
    "                './data/predictions-stablelm-adapter.json', # Stable-LM Adapter\n",
    "                './data/predictions-llama2-lora.json',      # Llama-2 Lora\n",
    "                './data/predictions-stablelm-lora.json',    # Stable-LM Lora\n",
    "                ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c95cb283-e1ac-45f5-aee3-b29fd932aba5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] llama2-adapter ----> Precision: 0.886 Recall: 0.891\n",
      "[INFO] stablelm-adapter ----> Precision: 0.854 Recall: 0.839\n",
      "[INFO] llama2-lora ----> Precision: 0.871 Recall: 0.851\n",
      "[INFO] stablelm-lora ----> Precision: 0.818 Recall: 0.828\n"
     ]
    }
   ],
   "source": [
    "for filename in prediction_files:\n",
    "    # Get model name and tune type\n",
    "    file_components = filename.split('-')\n",
    "    \n",
    "    # Load the predcitions JSON data\n",
    "    with open(filename, 'r') as file:\n",
    "        data = json.load(file)\n",
    "\n",
    "    precision, recall = calculate_precision_recall(data)\n",
    "    print(f\"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ff32790-3e03-464e-bd03-4cb29a244f37",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "scrape",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}