|
a |
|
b/scripts/evaluate.py |
|
|
1 |
import json |
|
|
2 |
|
|
|
3 |
def parse_entities(record, key): |
|
|
4 |
""" |
|
|
5 |
Parse entities from a record in the data. |
|
|
6 |
|
|
|
7 |
Args: |
|
|
8 |
- record: A dictionary representing a single record in the data. |
|
|
9 |
- key: The key to extract data from ('output' or 'prediction'). |
|
|
10 |
|
|
|
11 |
Returns: |
|
|
12 |
- A set containing the extracted entities. |
|
|
13 |
""" |
|
|
14 |
# Convert the string into a dictionary |
|
|
15 |
entities = json.loads(record[key]) |
|
|
16 |
|
|
|
17 |
# Initialize a set to store the entities |
|
|
18 |
flattened_entities = set() |
|
|
19 |
|
|
|
20 |
# Extract the entities |
|
|
21 |
for value in entities.values(): |
|
|
22 |
# Check if item is a list of adverse events |
|
|
23 |
if isinstance(value, list): |
|
|
24 |
flattened_entities.update(value) |
|
|
25 |
# Parse drug names |
|
|
26 |
else: |
|
|
27 |
flattened_entities.add(value) |
|
|
28 |
|
|
|
29 |
return flattened_entities |
|
|
30 |
|
|
|
31 |
def calculate_precision_recall(data): |
|
|
32 |
""" |
|
|
33 |
Calculate precision and recall from the data. |
|
|
34 |
|
|
|
35 |
Args: |
|
|
36 |
- data: A list of dictionaries, each containing 'output' and 'prediction'. |
|
|
37 |
|
|
|
38 |
Returns: |
|
|
39 |
- precision: The precision of the predictions. |
|
|
40 |
- recall: The recall of the predictions. |
|
|
41 |
""" |
|
|
42 |
# Initialize variables |
|
|
43 |
true_positives = 0 |
|
|
44 |
false_positives = 0 |
|
|
45 |
false_negatives = 0 |
|
|
46 |
|
|
|
47 |
# parse all the samples in the test dataset |
|
|
48 |
for record in data: |
|
|
49 |
# Extract ground truths |
|
|
50 |
gt_entities = parse_entities(record, 'output') |
|
|
51 |
# Extract predictions |
|
|
52 |
pred_entities = parse_entities(record, 'prediction') |
|
|
53 |
|
|
|
54 |
# Calculate TP, FP, FN for each sample in test data |
|
|
55 |
true_positives += len(gt_entities & pred_entities) |
|
|
56 |
false_positives += len(pred_entities - gt_entities) |
|
|
57 |
false_negatives += len(gt_entities - pred_entities) |
|
|
58 |
|
|
|
59 |
# Calculate Precision |
|
|
60 |
precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0 |
|
|
61 |
# Calculate Recall |
|
|
62 |
recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0 |
|
|
63 |
|
|
|
64 |
return precision, recall |
|
|
65 |
|
|
|
66 |
if __name__ == '__main__': |
|
|
67 |
|
|
|
68 |
prediction_files = [ |
|
|
69 |
'data/predictions-llama2-adapter.json', # Llama-2 Adapter |
|
|
70 |
'data/predictions-stablelm-adapter.json', # Stable-LM Adapter |
|
|
71 |
'data/predictions-llama2-lora.json', # Llama-2 Lora |
|
|
72 |
'data/predictions-stablelm-lora.json', # Stable-LM Lora |
|
|
73 |
] |
|
|
74 |
|
|
|
75 |
for filename in prediction_files: |
|
|
76 |
# Get model name and tune type |
|
|
77 |
file_components = filename.split('-') |
|
|
78 |
|
|
|
79 |
# Load the predcitions JSON data |
|
|
80 |
with open(filename, 'r') as file: |
|
|
81 |
data = json.load(file) |
|
|
82 |
|
|
|
83 |
precision, recall = calculate_precision_recall(data) |
|
|
84 |
print(f"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}") |