--- a +++ b/notebooks/evaluate.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 18, + "id": "21d5a4b0-d594-48c3-ad3e-57f1f1b5c29d", + "metadata": {}, + "outputs": [], + "source": [ + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "75b28dce-720a-454a-a130-17c391b20c70", + "metadata": {}, + "outputs": [], + "source": [ + "def parse_entities(record, key):\n", + " \"\"\"\n", + " Parse entities from a record in the data.\n", + " \n", + " Args:\n", + " - record: A dictionary representing a single record in the data.\n", + " - key: The key to extract data from ('output' or 'prediction').\n", + "\n", + " Returns:\n", + " - A set containing the extracted entities.\n", + " \"\"\"\n", + " # Convert the string into a dictionary\n", + " entities = json.loads(record[key])\n", + " \n", + " # Initialize a set to store the entities\n", + " flattened_entities = set()\n", + " \n", + " # Extract the entities\n", + " for value in entities.values():\n", + " # Check if item is a list of adverse events\n", + " if isinstance(value, list):\n", + " flattened_entities.update(value)\n", + " # Parse drug names\n", + " else:\n", + " flattened_entities.add(value)\n", + " \n", + " return flattened_entities" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "be48635c-7e26-488e-a2bc-0da52e08e752", + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_precision_recall(data):\n", + " \"\"\"\n", + " Calculate precision and recall from the data.\n", + "\n", + " Args:\n", + " - data: A list of dictionaries, each containing 'output' and 'prediction'.\n", + "\n", + " Returns:\n", + " - precision: The precision of the predictions.\n", + " - recall: The recall of the predictions.\n", + " \"\"\"\n", + " # Initialize variables\n", + " true_positives = 0\n", + " false_positives = 0\n", + " false_negatives = 0\n", + " \n", + " # parse all the samples in the test dataset\n", + " for record in data:\n", + " # Extract ground truths\n", + " gt_entities = parse_entities(record, 'output')\n", + " # Extract predictions\n", + " pred_entities = parse_entities(record, 'prediction')\n", + " \n", + " # Calculate TP, FP, FN for each sample in test data\n", + " true_positives += len(gt_entities & pred_entities)\n", + " false_positives += len(pred_entities - gt_entities)\n", + " false_negatives += len(gt_entities - pred_entities)\n", + " \n", + " # Calculate Precision\n", + " precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0\n", + " # Calculate Recall\n", + " recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0\n", + "\n", + " return precision, recall" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a16f2c86-2cff-4c4f-893d-60a9f06023b8", + "metadata": {}, + "outputs": [], + "source": [ + "prediction_files = [\n", + " './data/predictions-llama2-adapter.json', # Llama-2 Adapter\n", + " './data/predictions-stablelm-adapter.json', # Stable-LM Adapter\n", + " './data/predictions-llama2-lora.json', # Llama-2 Lora\n", + " './data/predictions-stablelm-lora.json', # Stable-LM Lora\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c95cb283-e1ac-45f5-aee3-b29fd932aba5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] llama2-adapter ----> Precision: 0.886 Recall: 0.891\n", + "[INFO] stablelm-adapter ----> Precision: 0.854 Recall: 0.839\n", + "[INFO] llama2-lora ----> Precision: 0.871 Recall: 0.851\n", + "[INFO] stablelm-lora ----> Precision: 0.818 Recall: 0.828\n" + ] + } + ], + "source": [ + "for filename in prediction_files:\n", + " # Get model name and tune type\n", + " file_components = filename.split('-')\n", + " \n", + " # Load the predcitions JSON data\n", + " with open(filename, 'r') as file:\n", + " data = json.load(file)\n", + "\n", + " precision, recall = calculate_precision_recall(data)\n", + " print(f\"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ff32790-3e03-464e-bd03-4cb29a244f37", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "scrape", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}