Diff of /prompt.py [000000] .. [b160a6]

Switch to unified view

a b/prompt.py
1
# ----------- Task 1: Prompt engineering: reorganize the X-Ray report findings into predefined anatomical regions
2
3
# import packages
4
import openai
5
from openai import OpenAI
6
import json
7
import os
8
from tqdm import tqdm
9
import argparse
10
11
12
# Set up argument parsing
13
parser = argparse.ArgumentParser()
14
parser.add_argument('-i', "--input_path", type=str, help='Path to the input JSON file containing reports')
15
parser.add_argument('-o', "--output_path", type=str, help='Path to the output JSON file to save categorized reports')
16
parser.add_argument('-k', "--openAI_API_key", type=str, required=True, help='Your OpenAI API key')
17
args = parser.parse_args()
18
19
20
# Set the input and output file paths 
21
if args.input_path:
22
   input_file_path = args.input_path
23
else:
24
   input_file_path = './data/annotation_quiz_all.json'
25
26
if args.output_path:
27
   output_file_path = args.output_path
28
else:
29
   output_file_path = './data/annotation.json' # Update existing file with categorized reports in val set
30
31
32
# Set your API key for OpenAI
33
client = OpenAI(api_key= args.openAI_API_key) # Specify your OpenAI API key here
34
35
36
# Function to prompt gpt-4o-mini to categorize the findings 
37
def categorize_findings(report):
38
    # Create a chat completion request using a structured prompt
39
    chat_completion = client.chat.completions.create(
40
        model="gpt-4o-mini",
41
        messages=[
42
            {
43
                "role": "system",
44
                "content": """Categorize the findings of a chest X-ray report into predefined anatomical regions: bone, heart, lung, and mediastinal. 
45
                    If a finding does not clearly belong to these categories, classify it under 'others'. Read each sentence carefully. Determine the main anatomical focus of each sentence:
46
                    - If a sentence discusses any findings related to bones, categorize it under 'bone'.
47
                    - If it pertains to the heart, categorize it under 'heart'. 
48
                    - If a sentence discusses any findings related to the lungs or associated structures, categorize it under 'lung'.
49
                    - If it mentions any findings related to the mediastinal area, categorize it under 'mediastinal'.
50
                    - If a sentence does not fit any of the above categories or is ambiguous, place it under 'others'.
51
                    Provide the output as a JSON object with each category listing relevant sentences from the report in **plain text** without extra double quotes around the sentences. 
52
                    The format should be: {"bone": "", "heart": "", "lung": "", "mediastinal": "", "others": ""}.
53
                    """
54
            },
55
            {
56
                "role": "user",
57
                "content": report
58
            }
59
        ],
60
        response_format= {"type": "json_object"}
61
    )
62
    # Extract and return the model's response
63
    return chat_completion.choices[0].message.content
64
65
66
# Test code with 1 report
67
sample_report = "The cardiomediastinal silhouette and pulmonary vasculature are within normal limits in size. The lungs are mildly hypoinflated but grossly clear of focal airspace disease, pneumothorax, or pleural effusion. There are mild degenerative endplate changes in the thoracic spine. There are no acute bony findings."
68
categorized_report = categorize_findings(sample_report)
69
print(categorized_report)
70
71
result = json.loads(categorized_report)
72
#print(result)
73
#print(result['lung'])
74
75
76
# -----------  For all reports
77
78
# Get all reports
79
80
# Read the JSON file
81
with open(input_file_path, 'r') as file:
82
    data = json.load(file)
83
    val_reports = data.get('val', []) # retrieve the value associated with the key 'val' from the dictionary
84
85
print("Num of Reports in Val set: ",len(val_reports))
86
87
88
89
# Categorize findings for all reports with batching
90
91
# Batch size
92
batch_size = 10  # set batch size
93
categorized_results = []
94
95
# Process the reports in batches
96
for i in tqdm(range(0, len(val_reports), batch_size), desc="Processing reports"):
97
    batch = val_reports[i:i + batch_size]  # Get the current batch of reports
98
    
99
    for report in batch:
100
        try:
101
            result = categorize_findings(report['original_report'])  # Get output from the model
102
            result = json.loads(result)  # Load string from JSON object
103
        except Exception as e:
104
            print(f"Error processing report {report['id']}: {e}")
105
            continue
106
107
        dict_results = {'id': report['id'], 'report': result, 'split': report['split']}
108
        categorized_results.append(dict_results)
109
110
    # Also replace the original file with the updated results
111
    with open(input_file_path, 'r') as file: # Read the JSON file
112
        data = json.load(file)
113
    data['val'] = categorized_results # Update the 'val' key with the categorized results
114
115
    with open(output_file_path, 'w') as file: # Write the updated JSON back to a new file
116
        json.dump(data, file, indent=4)
117
    
118
    print("File updated successfully.")