168 lines (167 with data), 5.3 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 18,
"id": "21d5a4b0-d594-48c3-ad3e-57f1f1b5c29d",
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "75b28dce-720a-454a-a130-17c391b20c70",
"metadata": {},
"outputs": [],
"source": [
"def parse_entities(record, key):\n",
" \"\"\"\n",
" Parse entities from a record in the data.\n",
" \n",
" Args:\n",
" - record: A dictionary representing a single record in the data.\n",
" - key: The key to extract data from ('output' or 'prediction').\n",
"\n",
" Returns:\n",
" - A set containing the extracted entities.\n",
" \"\"\"\n",
" # Convert the string into a dictionary\n",
" entities = json.loads(record[key])\n",
" \n",
" # Initialize a set to store the entities\n",
" flattened_entities = set()\n",
" \n",
" # Extract the entities\n",
" for value in entities.values():\n",
" # Check if item is a list of adverse events\n",
" if isinstance(value, list):\n",
" flattened_entities.update(value)\n",
" # Parse drug names\n",
" else:\n",
" flattened_entities.add(value)\n",
" \n",
" return flattened_entities"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "be48635c-7e26-488e-a2bc-0da52e08e752",
"metadata": {},
"outputs": [],
"source": [
"def calculate_precision_recall(data):\n",
" \"\"\"\n",
" Calculate precision and recall from the data.\n",
"\n",
" Args:\n",
" - data: A list of dictionaries, each containing 'output' and 'prediction'.\n",
"\n",
" Returns:\n",
" - precision: The precision of the predictions.\n",
" - recall: The recall of the predictions.\n",
" \"\"\"\n",
" # Initialize variables\n",
" true_positives = 0\n",
" false_positives = 0\n",
" false_negatives = 0\n",
" \n",
" # parse all the samples in the test dataset\n",
" for record in data:\n",
" # Extract ground truths\n",
" gt_entities = parse_entities(record, 'output')\n",
" # Extract predictions\n",
" pred_entities = parse_entities(record, 'prediction')\n",
" \n",
" # Calculate TP, FP, FN for each sample in test data\n",
" true_positives += len(gt_entities & pred_entities)\n",
" false_positives += len(pred_entities - gt_entities)\n",
" false_negatives += len(gt_entities - pred_entities)\n",
" \n",
" # Calculate Precision\n",
" precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0\n",
" # Calculate Recall\n",
" recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0\n",
"\n",
" return precision, recall"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "a16f2c86-2cff-4c4f-893d-60a9f06023b8",
"metadata": {},
"outputs": [],
"source": [
"prediction_files = [\n",
" './data/predictions-llama2-adapter.json', # Llama-2 Adapter\n",
" './data/predictions-stablelm-adapter.json', # Stable-LM Adapter\n",
" './data/predictions-llama2-lora.json', # Llama-2 Lora\n",
" './data/predictions-stablelm-lora.json', # Stable-LM Lora\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c95cb283-e1ac-45f5-aee3-b29fd932aba5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] llama2-adapter ----> Precision: 0.886 Recall: 0.891\n",
"[INFO] stablelm-adapter ----> Precision: 0.854 Recall: 0.839\n",
"[INFO] llama2-lora ----> Precision: 0.871 Recall: 0.851\n",
"[INFO] stablelm-lora ----> Precision: 0.818 Recall: 0.828\n"
]
}
],
"source": [
"for filename in prediction_files:\n",
" # Get model name and tune type\n",
" file_components = filename.split('-')\n",
" \n",
" # Load the predcitions JSON data\n",
" with open(filename, 'r') as file:\n",
" data = json.load(file)\n",
"\n",
" precision, recall = calculate_precision_recall(data)\n",
" print(f\"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ff32790-3e03-464e-bd03-4cb29a244f37",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "scrape",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}