a b/preprocess_data.py
1
import os
2
import json
3
import uuid
4
from PIL import Image
5
from typing import Any, Dict
6
import argparse
7
8
# Define split name
9
parser = argparse.ArgumentParser()
10
parser.add_argument('-s', "--split_data", type=str, required=True, help='Specify the split name in argument')
11
args = parser.parse_args()
12
13
# Assign the split name
14
split_name = args.split_data # train, test, val
15
16
# Define paths to the annotation file and the images folder
17
annotations_path = './data/annotation.json'
18
images_folder = './data/images'
19
output_folder = f'./dataset_{split_name}' # dataset for train
20
21
# Make sure the output folder exists and create it if not
22
if not os.path.exists(output_folder):
23
    os.makedirs(output_folder)
24
25
26
# Define the function to convert the JSON object into a token sequence string
27
def json2token( obj: Any, sort_json_key: bool = True):
28
    """
29
    Convert the JSON object into a token sequence string.
30
31
    Args:
32
        obj (Any): The JSON object to convert, which can be a dictionary, list, or other types.
33
        sort_json_key (bool): Whether to sort the keys of a dictionary. Default is True.
34
35
    Returns:
36
        str: A string representing the token sequence extracted from the JSON object.
37
    """
38
    if type(obj) == dict:
39
        if len(obj) == 1 and "text_sequence" in obj:
40
            return obj["text_sequence"]
41
        else:
42
            output = ""
43
            if sort_json_key:
44
                keys = sorted(obj.keys(), reverse=True)
45
            else:
46
                keys = obj.keys()
47
            for k in keys:
48
                output += (
49
                    fr"<s_{k}>"
50
                    + json2token(obj[k], sort_json_key)
51
                    + fr"</s_{k}>"
52
                )
53
            return output
54
    elif type(obj) == list:
55
        return r"<sep/>".join(
56
            [json2token(item, sort_json_key) for item in obj]
57
        )
58
    else:
59
        obj = str(obj)
60
        return obj
61
    
62
63
64
# Load the annotations file from data_path
65
with open(annotations_path) as f: # annotation.json
66
    annotations = json.load(f)
67
68
69
70
# Need to convert the token back to JSON later using "llava-hf/llava-v1.6-mistral-7b-hf" processor
71
# Need this to process outputs laters
72
#from transformers import AutoProcessor
73
#MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
74
#processor = AutoProcessor.from_pretrained(MODEL_ID)
75
76
77
# Convert token sequence string to JSON object
78
import re
79
def token2json(tokens, is_inner_value=False, added_vocab=None):
80
        """
81
        Convert a (generated) token sequence into an ordered JSON format.
82
        """
83
        if added_vocab is None:
84
            added_vocab = processor.tokenizer.get_added_vocab()
85
86
        output = {}
87
88
        while tokens:
89
            start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
90
            if start_token is None:
91
                break
92
            key = start_token.group(1)
93
            key_escaped = re.escape(key)
94
95
            end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
96
            start_token = start_token.group()
97
            if end_token is None:
98
                tokens = tokens.replace(start_token, "")
99
            else:
100
                end_token = end_token.group()
101
                start_token_escaped = re.escape(start_token)
102
                end_token_escaped = re.escape(end_token)
103
                content = re.search(
104
                    f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL
105
                )
106
                if content is not None:
107
                    content = content.group(1).strip()
108
                    if r"<s_" in content and r"</s_" in content:  # non-leaf node
109
                        value = token2json(content, is_inner_value=True, added_vocab=added_vocab)
110
                        if value:
111
                            if len(value) == 1:
112
                                value = value[0]
113
                            output[key] = value
114
                    else:  # leaf nodes
115
                        output[key] = []
116
                        for leaf in content.split(r"<sep/>"):
117
                            leaf = leaf.strip()
118
                            if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
119
                                leaf = leaf[1:-2]  # for categorical special tokens
120
                            output[key].append(leaf)
121
                        if len(output[key]) == 1:
122
                            output[key] = output[key][0]
123
124
                tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
125
                if tokens[:6] == r"<sep/>":  # non-leaf nodes
126
                    return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
127
128
        if len(output):
129
            return [output] if is_inner_value else output
130
        else:
131
            return [] if is_inner_value else {"text_sequence": tokens}
132
        
133
134
135
# Generate dataset.json file and images folder from the annotations.json
136
def process_and_save(data_annotations, images_folder, output_folder, split= split_name):
137
    # Define a new output subfolder for the processed images
138
    new_image_folder = os.path.join(output_folder, 'images')
139
    if not os.path.exists(new_image_folder):
140
        os.makedirs(new_image_folder)
141
142
    # Initialize list to hold all JSON data
143
    json_data_list = []
144
145
    # Iterate through the training set
146
    for item in data_annotations[split]: # train, test, test
147
        patient_id = item['id']
148
        # Define path for the first image (0.png)
149
        image_path = os.path.join(images_folder, patient_id, '0.png')
150
151
        # Check if the image exists
152
        if not os.path.exists(image_path):
153
            continue  # Skip if the expected image is not found
154
155
        # Load the image
156
        image = Image.open(image_path)
157
158
        # Create a unique ID for each image
159
        unique_id = str(uuid.uuid4())
160
161
        # Define the new image path for saving
162
        new_image_path = os.path.join(new_image_folder, f"{unique_id}.png")
163
164
        # Save the image
165
        image.save(new_image_path)
166
167
        report_dict= item['report']
168
        report_json= json2token(report_dict, sort_json_key=False)
169
170
        #print(f"[INST] <image>\nGenerate Report [\INST] {target_sequence}")
171
172
        # Structure the JSON data in the LLaVA format
173
        json_data = {
174
            "id": unique_id,
175
            "image": f"{unique_id}.png",
176
            "conversations": [
177
                {
178
                    "from": "human",
179
                    "value": "Please describe the findings in the X-ray."
180
                },
181
                {
182
                    "from": "gpt",
183
                    "value": report_json  # Using the report as the GPT's response
184
                }
185
            ]
186
        }
187
188
        # Append to the list
189
        json_data_list.append(json_data)
190
191
    # Save the JSON data list to a file
192
193
    # create dir if not exist
194
    if not os.path.exists(os.path.join(output_folder, split)):
195
        os.makedirs(os.path.join(output_folder, split))
196
        
197
    json_output_path = os.path.join(output_folder, f'{split}/{split}_dataset.json')
198
    with open(json_output_path, 'w') as json_file:
199
        json.dump(json_data_list, json_file, indent=4)
200
201
202
203
# Load the annotations
204
with open(annotations_path, 'r') as file:
205
    data_annotations = json.load(file)
206
207
# Process and save the dataset
208
process_and_save(data_annotations, images_folder, output_folder, split_name) # run once