a b/notebooks/evaluate.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 18,
6
   "id": "21d5a4b0-d594-48c3-ad3e-57f1f1b5c29d",
7
   "metadata": {},
8
   "outputs": [],
9
   "source": [
10
    "import json"
11
   ]
12
  },
13
  {
14
   "cell_type": "code",
15
   "execution_count": 19,
16
   "id": "75b28dce-720a-454a-a130-17c391b20c70",
17
   "metadata": {},
18
   "outputs": [],
19
   "source": [
20
    "def parse_entities(record, key):\n",
21
    "    \"\"\"\n",
22
    "    Parse entities from a record in the data.\n",
23
    "    \n",
24
    "    Args:\n",
25
    "    - record: A dictionary representing a single record in the data.\n",
26
    "    - key: The key to extract data from ('output' or 'prediction').\n",
27
    "\n",
28
    "    Returns:\n",
29
    "    - A set containing the extracted entities.\n",
30
    "    \"\"\"\n",
31
    "    # Convert the string into a dictionary\n",
32
    "    entities = json.loads(record[key])\n",
33
    "    \n",
34
    "    # Initialize a set to store the entities\n",
35
    "    flattened_entities = set()\n",
36
    "    \n",
37
    "    # Extract the entities\n",
38
    "    for value in entities.values():\n",
39
    "        # Check if item is a list of adverse events\n",
40
    "        if isinstance(value, list):\n",
41
    "            flattened_entities.update(value)\n",
42
    "        # Parse drug names\n",
43
    "        else:\n",
44
    "            flattened_entities.add(value)\n",
45
    "    \n",
46
    "    return flattened_entities"
47
   ]
48
  },
49
  {
50
   "cell_type": "code",
51
   "execution_count": 20,
52
   "id": "be48635c-7e26-488e-a2bc-0da52e08e752",
53
   "metadata": {},
54
   "outputs": [],
55
   "source": [
56
    "def calculate_precision_recall(data):\n",
57
    "    \"\"\"\n",
58
    "    Calculate precision and recall from the data.\n",
59
    "\n",
60
    "    Args:\n",
61
    "    - data: A list of dictionaries, each containing 'output' and 'prediction'.\n",
62
    "\n",
63
    "    Returns:\n",
64
    "    - precision: The precision of the predictions.\n",
65
    "    - recall: The recall of the predictions.\n",
66
    "    \"\"\"\n",
67
    "    # Initialize variables\n",
68
    "    true_positives = 0\n",
69
    "    false_positives = 0\n",
70
    "    false_negatives = 0\n",
71
    "    \n",
72
    "    # parse all the samples in the test dataset\n",
73
    "    for record in data:\n",
74
    "        # Extract ground truths\n",
75
    "        gt_entities = parse_entities(record, 'output')\n",
76
    "        # Extract predictions\n",
77
    "        pred_entities = parse_entities(record, 'prediction')\n",
78
    "    \n",
79
    "        # Calculate TP, FP, FN for each sample in test data\n",
80
    "        true_positives += len(gt_entities & pred_entities)\n",
81
    "        false_positives += len(pred_entities - gt_entities)\n",
82
    "        false_negatives += len(gt_entities - pred_entities)\n",
83
    "    \n",
84
    "    # Calculate Precision\n",
85
    "    precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0\n",
86
    "    # Calculate Recall\n",
87
    "    recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0\n",
88
    "\n",
89
    "    return precision, recall"
90
   ]
91
  },
92
  {
93
   "cell_type": "code",
94
   "execution_count": 23,
95
   "id": "a16f2c86-2cff-4c4f-893d-60a9f06023b8",
96
   "metadata": {},
97
   "outputs": [],
98
   "source": [
99
    "prediction_files = [\n",
100
    "                './data/predictions-llama2-adapter.json',   # Llama-2 Adapter\n",
101
    "                './data/predictions-stablelm-adapter.json', # Stable-LM Adapter\n",
102
    "                './data/predictions-llama2-lora.json',      # Llama-2 Lora\n",
103
    "                './data/predictions-stablelm-lora.json',    # Stable-LM Lora\n",
104
    "                ]"
105
   ]
106
  },
107
  {
108
   "cell_type": "code",
109
   "execution_count": 26,
110
   "id": "c95cb283-e1ac-45f5-aee3-b29fd932aba5",
111
   "metadata": {},
112
   "outputs": [
113
    {
114
     "name": "stdout",
115
     "output_type": "stream",
116
     "text": [
117
      "[INFO] llama2-adapter ----> Precision: 0.886 Recall: 0.891\n",
118
      "[INFO] stablelm-adapter ----> Precision: 0.854 Recall: 0.839\n",
119
      "[INFO] llama2-lora ----> Precision: 0.871 Recall: 0.851\n",
120
      "[INFO] stablelm-lora ----> Precision: 0.818 Recall: 0.828\n"
121
     ]
122
    }
123
   ],
124
   "source": [
125
    "for filename in prediction_files:\n",
126
    "    # Get model name and tune type\n",
127
    "    file_components = filename.split('-')\n",
128
    "    \n",
129
    "    # Load the predcitions JSON data\n",
130
    "    with open(filename, 'r') as file:\n",
131
    "        data = json.load(file)\n",
132
    "\n",
133
    "    precision, recall = calculate_precision_recall(data)\n",
134
    "    print(f\"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}\")"
135
   ]
136
  },
137
  {
138
   "cell_type": "code",
139
   "execution_count": null,
140
   "id": "7ff32790-3e03-464e-bd03-4cb29a244f37",
141
   "metadata": {},
142
   "outputs": [],
143
   "source": []
144
  }
145
 ],
146
 "metadata": {
147
  "kernelspec": {
148
   "display_name": "scrape",
149
   "language": "python",
150
   "name": "python3"
151
  },
152
  "language_info": {
153
   "codemirror_mode": {
154
    "name": "ipython",
155
    "version": 3
156
   },
157
   "file_extension": ".py",
158
   "mimetype": "text/x-python",
159
   "name": "python",
160
   "nbconvert_exporter": "python",
161
   "pygments_lexer": "ipython3",
162
   "version": "3.10.13"
163
  }
164
 },
165
 "nbformat": 4,
166
 "nbformat_minor": 5
167
}