a b/scripts/evaluate.py
1
import json
2
3
def parse_entities(record, key):
4
    """
5
    Parse entities from a record in the data.
6
    
7
    Args:
8
    - record: A dictionary representing a single record in the data.
9
    - key: The key to extract data from ('output' or 'prediction').
10
11
    Returns:
12
    - A set containing the extracted entities.
13
    """
14
    # Convert the string into a dictionary
15
    entities = json.loads(record[key])
16
    
17
    # Initialize a set to store the entities
18
    flattened_entities = set()
19
    
20
    # Extract the entities
21
    for value in entities.values():
22
        # Check if item is a list of adverse events
23
        if isinstance(value, list):
24
            flattened_entities.update(value)
25
        # Parse drug names
26
        else:
27
            flattened_entities.add(value)
28
    
29
    return flattened_entities
30
31
def calculate_precision_recall(data):
32
    """
33
    Calculate precision and recall from the data.
34
35
    Args:
36
    - data: A list of dictionaries, each containing 'output' and 'prediction'.
37
38
    Returns:
39
    - precision: The precision of the predictions.
40
    - recall: The recall of the predictions.
41
    """
42
    # Initialize variables
43
    true_positives = 0
44
    false_positives = 0
45
    false_negatives = 0
46
    
47
    # parse all the samples in the test dataset
48
    for record in data:
49
        # Extract ground truths
50
        gt_entities = parse_entities(record, 'output')
51
        # Extract predictions
52
        pred_entities = parse_entities(record, 'prediction')
53
    
54
        # Calculate TP, FP, FN for each sample in test data
55
        true_positives += len(gt_entities & pred_entities)
56
        false_positives += len(pred_entities - gt_entities)
57
        false_negatives += len(gt_entities - pred_entities)
58
    
59
    # Calculate Precision
60
    precision = true_positives / (true_positives + false_positives) if true_positives + false_positives > 0 else 0
61
    # Calculate Recall
62
    recall = true_positives / (true_positives + false_negatives) if true_positives + false_negatives > 0 else 0
63
64
    return precision, recall
65
66
if __name__ == '__main__':
67
68
    prediction_files = [
69
                'data/predictions-llama2-adapter.json',   # Llama-2 Adapter
70
                'data/predictions-stablelm-adapter.json', # Stable-LM Adapter
71
                'data/predictions-llama2-lora.json',      # Llama-2 Lora
72
                'data/predictions-stablelm-lora.json',    # Stable-LM Lora
73
                ]
74
    
75
    for filename in prediction_files:
76
        # Get model name and tune type
77
        file_components = filename.split('-')
78
        
79
        # Load the predcitions JSON data
80
        with open(filename, 'r') as file:
81
            data = json.load(file)
82
83
        precision, recall = calculate_precision_recall(data)
84
        print(f"[INFO] {file_components[1]}-{file_components[2].split('.')[0]} ----> Precision: {round(precision,3)} Recall: {round(recall,3)}")