Diff of /data_preparation.py [000000] .. [afa31e]

Switch to unified view

a b/data_preparation.py
1
"""
2
# ===============Part_1: Original Data Format Conversion===============
3
4
To do: <Strategy>
5
    1. init data [.nii(Nifit file) -> .png]: 
6
            CT data: 3D -> 2D
7
            save: [1] .png (dir: data_preprocessing/images)
8
    2. mask data [.nii -> .json / .txt]: 
9
            CT mask data: 3D -> 2D 
10
                          Filter the masks(contains 'pneumonia') -> JSON & YOLO_TXT 
11
            save: [1] .png (dir: data_preprocessing/masks)
12
                  [2] .json (dir: data_preprocessing/labelJSON) 
13
                            (recognizable by Labelme.exe)
14
                  [3] .txt(dir: data_preprocessing/labels)
15
                            (for YOLO input: 
16
                                YOLO format annotation data file -> describe label's classID & relative position) 
17
    3. data augmentation:
18
            [1] CT (images) - optimize:
19
                [1.1] reduce the CT range and to increase the contrast (pkg: windowing -> transform window width & center)
20
                [1.2] Filter all CT & masks which contains 'pneumonia'
21
            [2] Masks - optimize: (for YOLOv8)
22
                [2.1] binarize masks (0 & 1)
23
                [2.2] padded & filled mask_tmp's connected area
24
                [2.3] close all contours in mask
25
                [2.4] contours -> polygons
26
                [2.5] discard the polygons(contours) if it contains few points
27
            [3] CT & Masks - optimize: -> create a train_data_generator to perform various transformations (for UNET)
28
                [3.1] define an image_generator -> keras.preprocessing.Image.ImageDataGenerator()
29
                [3.2] image data augmentation -> flow_from_directory()
30
                [3.3] image normalization (??? false -> binarize CT in real)
31
                    recommended: 
32
                        in U-Net : CT -> normalization is better than binarization
33
"""
34
35
import os
36
import SimpleITK as sitk
37
import matplotlib.pyplot as plt
38
import math
39
import numpy as np
40
import io
41
import base64
42
from skimage import measure
43
from scipy.ndimage import (
44
    binary_dilation,
45
    binary_fill_holes,
46
)  # dilation & filling for mask
47
from PIL import Image  # Pillow Imaging Library
48
import json
49
import tensorflow as tf
50
import cv2
51
import numpy as np
52
import pandas as pd
53
import matplotlib.pyplot as plt
54
55
56
# ========================1. Read Data========================
57
# Read .nii file
58
# def read_nifit(nii_path):
59
#     path_parts = nii_path.split("\\")
60
#     nii_name = path_parts[-1]
61
#     nii_type = path_parts[-2]
62
63
#     # read nii
64
#     img_slices = sitk.ReadImage(nii_path)
65
#     wincenter = img_slices
66
#     data_img_slices = sitk.GetArrayFromImage(img_slices)  # convert sitk image to numpy array
67
68
#     # write in txt
69
#     with open(f"{nii_type}_{nii_name.split('.')[0]}.txt", "w", encoding="utf-8") as f:
70
#         f.write(f"{nii_path}: \n {img_slices}\n")
71
#         f.write("===================================\n")
72
#         f.write(f"Shape of the data: {data_img_slices.shape}\n")
73
#         f.write(f"data_img_slices: \n")
74
#         for i in data_img_slices:
75
#             for j in i:
76
#                 f.write(f"{j}\n")
77
#             f.write("===================================\n")
78
79
#     return data_img_slices, nii_type, nii_name
80
81
82
# Make plt_savefig Name
83
def make_fig_name(nii_type, nii_name):
84
    return nii_type + "_" + nii_name.split(".")[0]
85
86
87
# Draw Nifit CT
88
def draw_image(data_img_slices, idx1, idx2, savefig=None):
89
    row = math.floor(math.sqrt(idx2 + 1 - idx1))
90
    col = math.ceil(math.sqrt(idx2 + 1 - idx1))
91
    if row * col < (idx2 + 1 - idx1):
92
        row += 1
93
    j = 1
94
    for i in range(idx1 - 1, idx2):
95
        plt.subplot(row, col, j)
96
        plt.imshow(data_img_slices[i], cmap="gray")
97
        plt.axis(False)
98
        j += 1
99
    plt.show()
100
    if savefig and isinstance(savefig, str):
101
        plt.savefig(f"CT_{savefig}.png")
102
103
104
# Draw CT histogram
105
def draw_histogram(data_img_slices, idx, savefig=None):
106
    ct_values = data_img_slices[idx - 1].flatten()  # flatten 3D -> 1D
107
    plt.figure(figsize=(10, 6))
108
    plt.hist(ct_values, bins=100, color="blue", alpha=0.7)
109
    plt.title("Histogram of CT Values")
110
    plt.xlabel("CT Value")
111
    plt.ylabel("Frequency")
112
    plt.grid(True)
113
    plt.show()
114
    if savefig and isinstance(savefig, str):
115
        plt.savefig(f"CT_histogram_{savefig}.png")
116
117
118
# Output nii's metadata
119
def create_metadata_dir(data_path):
120
    # create metadata folder & sub_dir
121
    meta_path = os.path.join(data_path, "metadata_tmp")
122
    img_meta_path = os.path.join(meta_path, "images")
123
    mask_meta_path = os.path.join(meta_path, "masks")
124
    PATH = {meta_path, img_meta_path, mask_meta_path}
125
    for path in PATH:
126
        os.makedirs(path, exist_ok=True)
127
    return img_meta_path, mask_meta_path
128
129
130
# Output nii's metadata
131
def write_metadata_in_txt(nii_path, sitk_nii, metadata, save_dir_path):
132
    meta_path = os.path.join(save_dir_path, f"{nii_name.split('.')[0]}_metadata.txt")
133
    with open(meta_path, "w", encoding="utf-8") as f:
134
        f.write(f"{nii_path}: \n {sitk_nii}\n")
135
        f.write("===================================\n")
136
        f.write(f"Shape of the data: {metadata.shape}\n")
137
        f.write(f"data_img_slices: \n")
138
        for i in metadata:
139
            f.write("\n".join(map(str, i)))
140
            f.write("===================================\n")
141
142
143
# ========================2. Data Augmentation========================
144
# img
145
def window_transform(sitkImage, winwidth=250, wincenter=80):
146
    """
147
    To do: transform window width & center
148
           to reduce the CT range and to increase the contrast (pkg: windowing)
149
150
    :param sitkImage(input) -> SimpleITK object
151
    :param winwidth -> CT values range --> winwidth: smaller, contrast: stronger
152
    :param wincenter -> all displayed CT values' center position --> wincenter: smaller, brightness: brighter
153
    :return sitkImage(output) -> sitkImage after transform
154
    """
155
    # define borders
156
    min = int(wincenter - winwidth / 2.0)
157
    max = int(wincenter + winwidth / 2.0)
158
159
    # define SimpleITK object -> transform window width & center
160
    intensityWindow = sitk.IntensityWindowingImageFilter()
161
    intensityWindow.SetWindowMinimum(min)
162
    intensityWindow.SetWindowMaximum(max)
163
    # print(
164
    #     f"new window width: {intensityWindow.GetWindowMinimum()} ~ {intensityWindow.GetWindowMaximum()}"
165
    # )
166
    sitkImage = intensityWindow.Execute(sitkImage)
167
168
    return sitkImage
169
170
# img
171
def get_winwidth_wincenter(img_nii_path, mask_nii_path):
172
    """
173
    To do: According to CT's Mask flexibly winwidth && wincenter
174
    :param img_nii_path -> nii_path
175
    :param mask_nii_path -> nii_path
176
    :return winwidth, wincenter
177
178
    Note: for Lung CT: default winwidth: 1000 ~ 1600HU, default wincenter:-600 ~ -800HU
179
    """
180
    all_min_max_hu = []
181
    for nii_name in os.listdir(mask_nii_path):
182
        if nii_name.endswith(".nii.gz") or nii_name.endswith(".nii"):
183
184
            # read nii file
185
            imgs_nii = sitk.ReadImage(os.path.join(img_nii_path, nii_name))
186
            masks_nii = sitk.ReadImage(os.path.join(mask_nii_path, nii_name))
187
188
            # nii -> array[]
189
            imgs_init = sitk.GetArrayFromImage(imgs_nii)
190
            masks = sitk.GetArrayFromImage(masks_nii)
191
192
            masks_hu_init = np.where(masks == 1, imgs_init, np.nan)
193
194
            for mask in masks_hu_init:
195
                arr_mask = mask.flatten()
196
                if not np.all(np.isnan(arr_mask)):
197
                    all_min_max_hu.extend([np.nanmin(arr_mask), np.nanmax(arr_mask)])
198
199
    # print(f"all_min_max_hu: {all_min_max_hu}")
200
201
    lower_bound_min = (
202
        math.floor(np.nanmin(all_min_max_hu) / 100) * 100
203
    )  # find nearest bounds
204
    upper_bound_max = math.ceil(np.nanmax(all_min_max_hu) / 100) * 100
205
    winwidth = upper_bound_max - lower_bound_min
206
    wincenter = np.around(1 / 2 * (upper_bound_max + lower_bound_min))
207
208
    # print(f"min_bound: {min}, max_bound: {max}")
209
    # print(f"min: {lower_bound_min}, max: {upper_bound_max}")
210
    print(f"winwidth: {winwidth}, wincenter: {wincenter}")
211
212
    return winwidth, wincenter
213
214
215
# img
216
def img_to_base64str(img_pil):
217
    """
218
    To do:  Image convertion: PIL Image -> Base64 string
219
220
    :img_pil    : PIL Image Object(gray)
221
    :return     : (Base64)      str(utf-8)
222
223
    PIL: Python Imaging Library
224
225
    (Binary) byte str ==> use in Image/Video/...
226
    (Base64) byte str (ASCII character) ==> use in Print/Mail/URL/...
227
    (Base64)      str (utf-8: Unicode character (text str)) ==> use in JSON/XML/...
228
229
    This allows image data to be transmitted over the network as text strings or stored in text files without being affected by the compatibility or transmission issues that binary data can bring.
230
    """
231
    ENCODING = "utf-8"
232
    img_byte = io.BytesIO()  # Create a binary data stream in memory
233
    img_pil.save(img_byte, format="PNG")  # save img_pil in img_byte
234
235
    binary_byte_str = img_byte.getvalue()  # img -> Binary byte data
236
    base64_byte_str = base64.b64encode(
237
        binary_byte_str
238
    )  # (Binary) byte str -> (Base64) byte str
239
    base64_str = base64_byte_str.decode(ENCODING)  # (Base64) byte str -> (Base64) str
240
241
    return base64_str
242
243
244
# mask
245
def close_contour(contour):
246
    """
247
    To do: close contour [traite image shapes and borders]
248
           sure that a 2D contour is connected end-to-end to form a closed contour
249
    """
250
    if not np.array_equal(contour[0], contour[-1]):  # contours: not closed
251
        contour = np.vstack((contour, contour[0]))  # add contours[0] at the end
252
    return contour
253
254
255
# mask
256
def binary_mask_to_polygon(binary_mask, tolerance=0):
257
    """
258
    To do: binary_mask -> polygon(COCO dataset format)
259
    From project "pycococreator-master"
260
261
    :binary_mask : array[](uint8)
262
    :tolerance=0 : control the fineness of the polygon approximation process
263
    :return      : polygons (list)
264
265
    contours: {[c_1][c_2]...[c_N]} -> N contours in mask
266
    contour : [[y1,x1]...[yn,xn]]  -> n points in this contour
267
              -> shape: (n,2)
268
              -> y: row coordinate, x: col coordinate
269
    polygons: [[segmentation_1], ..., [segmentation_N]] -> N polygons in mask
270
    """
271
    # treat mask
272
    padded_binary_mask = np.pad(
273
        binary_mask, pad_width=1, mode="constant", constant_values=0
274
    )  # padding: to avoid losing the boundary info when check the contour later
275
276
    # check all contours in mask
277
    contours = measure.find_contours(
278
        padded_binary_mask, 0.5
279
    )  # threshold: 0.5 (for binary img) # default contour's precision = 1 -> decimals=1
280
281
    # treat contours
282
    polygons = []
283
    for contour in contours:
284
        contour = np.subtract(contour, 1)  # restore: original = contour - padding(1)
285
286
        contour = close_contour(contour)  # make sure contour closed
287
        contour = measure.approximate_polygon(contour, tolerance)
288
        if len(contour) < 3:  # check points num # filter the 'contour' whose points < 3
289
            continue
290
291
        contour = np.flip(contour, axis=1)  # flip the axis: img(y,x) -> COCOdata(x,y)
292
        segmentation = contour.ravel().tolist()  # contour[points] -> 1D list
293
        segmentation = [
294
            0 if i < 0 else i for i in segmentation
295
        ]  # make sure segmentation >= 0
296
        polygons.append(segmentation)
297
298
    return polygons
299
300
301
# img & mask
302
def img_mask_to_json_txt(img, mask, class_names, class_mapping):
303
    """
304
    To do: img + mask -> json
305
306
    json: for img annotation tool (eg : Labelme)
307
308
    :img         : PIL Image Object (gray)
309
    :mask        : PIL Image Object (gray)
310
    :class_names : ['_background_', 'pneumonia']
311
    """
312
    # init JSON
313
    JSON_OUTPUT = {
314
        "version": "5.4.1",  # Labelme version
315
        "flags": {},  # Additional flag info
316
        "shapes": [],  # [{label, points, group_id, shape_type, flags}...{}]
317
        "imagePath": {},
318
        "imageData": {},  # Image data (Base64 encoded)
319
        "imageHeight": {},
320
        "imageWidth": {},
321
    }
322
323
    # init TXT -> YOLO .txt format: one polygon(object) <-> one row
324
    TXT_OUTPUT = []
325
    """
326
    TXT_OUTPUT = [
327
        # if bboxes is rectangle:
328
        [classID, x_center_nor, y_center_nor, width_nor, height_nor]
329
        ...  # here: nor-> normalized
330
331
        # for polygons in this case:
332
        [classID, [x1_nor, y1_nor], ...]
333
        ... # one polygon(object) <-> one row
334
    ] # list
335
    """
336
337
    # JSON - "imagePath" -> set externally
338
339
    # JSON - "imageData"
340
    imageData = img_to_base64str(img)  # Image -> base64_str ('utf-8')
341
    JSON_OUTPUT["imageData"] = imageData
342
343
    # JSON - "imageHeight" & "imageWidth"
344
    binary_mask = np.asarray(mask).astype(
345
        np.uint8
346
    )  # Image -> array ('uint-8') # binary_mask just contains 0 & 1
347
    JSON_OUTPUT["imageHeight"] = binary_mask.shape[0]  # height
348
    JSON_OUTPUT["imageWidth"] = binary_mask.shape[1]  # width
349
350
    # JSON - "shapes": label, points, group_id,  shape_type, flags
351
    success = False
352
    for i in np.unique(binary_mask):  # loop through each img's mask(frame)
353
        # Filter the masks(contains 'pneumonia')
354
        if i != 0:  # 0 : typically background
355
            # original mask_patch
356
            mask_tmp = np.where(
357
                binary_mask == i, 1, 0
358
            )  # check the connected area (set: 1, rest: 0 -> binarization)
359
            """
360
            Description for "np.where(condition, x(True), y(False))":
361
                mask_tmp = binary_mask
362
                for j in mask_tmp:
363
                    if j == i:
364
                        True
365
                        j = 1
366
                    else:
367
                        False
368
                        j = 0
369
                return mask_tmp: a new binary mask (contain only 0 & 1)
370
            """
371
            # Draw mask -> plot its boundary points
372
            # plt.subplot(1, 2, 1) # position: left
373
            # plt.imshow(mask_tmp, cmap='gray')
374
375
            # padded & filled mask_tmp's connected area
376
            mask_tmp = binary_dilation(
377
                mask_tmp, structure=np.ones((3, 3)), iterations=1
378
            )  # padding
379
            mask_tmp = binary_fill_holes(mask_tmp, structure=np.ones((5, 5)))  # filling
380
            # plt.subplot(1, 2, 2) # position: right
381
            # plt.imshow(mask_tmp, cmap='gray')
382
383
            # plt.show()
384
385
            # polygon extraction
386
            polygons = binary_mask_to_polygon(mask_tmp, tolerance=2)
387
388
            for polygon in polygons:
389
                if len(polygon) > 10:  # discard the contour if it contains few points
390
                    # "shapes" -> labels
391
                    label = class_names[i]  # class_names[1] : 'pneumonia'
392
393
                    # "shapes" -> points
394
                    # YOLO TXT -> one row, one contour(polygon)
395
                    points = []
396
                    yolotxt_row = str(class_mapping[label])  # YOLO txt_row -> classID
397
                    for j in range(0, len(polygon), 2):  # step_range: 2
398
                        x = polygon[j]  # default decimals = 1
399
                        y = polygon[j + 1]
400
                        points.append([x, y])  # point[x, y]
401
402
                        # YOLO txt_row -> point[x, y] -> normalization
403
                        x = round(x / JSON_OUTPUT["imageWidth"], 6)  # normalized
404
                        y = round(
405
                            y / JSON_OUTPUT["imageHeight"], 6
406
                        )  # normalized, 6: decimals = 6
407
                        yolotxt_row += " " + str(x) + " " + str(y)
408
409
                    # JSON - "shapes"
410
                    shape = {
411
                        "label": label,
412
                        "points": points,
413
                        "group_id": None,
414
                        # 'description': "",
415
                        "shape_type": "polygon",
416
                        "flags": {},
417
                        # 'mask' : None
418
                    }
419
                    JSON_OUTPUT["shapes"].append(shape)
420
421
                    # TXT - "row <-> contour"
422
                    TXT_OUTPUT.append(yolotxt_row)
423
                    # width_normalized, height_normalized ???
424
425
                    success = True
426
427
    return success, JSON_OUTPUT, TXT_OUTPUT
428
429
430
# Test Code
431
432
data_path = "./data/mosmed"
433
img_nii_path = os.path.join(data_path, "dataNii")
434
mask_nii_path = os.path.join(data_path, "maskNii")
435
436
# create data_preprocessing folder & sub_dir
437
data_pre_path = os.path.join(data_path, "data_preprocessing")
438
img_png_path = os.path.join(data_pre_path, "images")
439
mask_png_path = os.path.join(data_pre_path, "masks")
440
json_path = os.path.join(data_pre_path, "labelsJSON")
441
txt_path = os.path.join(data_pre_path, "labels")
442
443
PATH = {
444
    data_pre_path,  # preprocessing_tmp1 --> data augmentation for mask
445
    img_png_path,
446
    mask_png_path,
447
    json_path,
448
    txt_path
449
}
450
for path in PATH:
451
    os.makedirs(path, exist_ok=True)
452
453
class_names = ["_background_", "pneumonia"]
454
class_mapping = {"pneumonia": 0}
455
456
print("Starting: Nii -> PNG/JSON/TXT ...")
457
458
winwidth, wincenter = get_winwidth_wincenter(img_nii_path, mask_nii_path)
459
460
for nii_name in os.listdir(mask_nii_path):
461
    if nii_name.endswith('.nii.gz') or nii_name.endswith('.nii'):
462
463
        # read nii file
464
        imgs_nii = sitk.ReadImage(os.path.join(img_nii_path, nii_name))
465
        masks_nii = sitk.ReadImage(os.path.join(mask_nii_path, nii_name))
466
467
        # imgs_init = sitk.GetArrayFromImage(imgs_nii)
468
469
        # windowing
470
        imgs_nii = window_transform(imgs_nii, winwidth=winwidth, wincenter=wincenter)
471
472
        # nii -> array[]
473
        imgs = sitk.GetArrayFromImage(imgs_nii)
474
        masks = sitk.GetArrayFromImage(masks_nii).astype(np.uint8) # uint8 ???
475
476
        """
477
        # write nii's metadata in txt && Draw
478
        img_meta_path, mask_meta_path = create_metadata_dir(data_path)
479
        write_metadata_in_txt(f"{os.path.join(img_nii_path, nii_name)}", imgs_nii, imgs_init, img_meta_path)
480
        write_metadata_in_txt(f"{os.path.join(mask_nii_path, nii_name)}", masks_nii, masks, mask_meta_path)
481
        draw_image(imgs, 1, len(imgs), make_fig_name("image", nii_name))
482
        draw_histogram(imgs_init, 17, make_fig_name("image_init", nii_name, 17))
483
        """
484
485
        for idx in range(masks.shape[0]): # masks.shape: (depth, height, width)
486
            # get img/mask whose mask value > 0 -> which has pneumonia in CT
487
            if np.sum(masks[idx, :, :]) > 0: 
488
                img_png = Image.fromarray(imgs[idx, :, :]).convert('L') # array -> PIL Image object # L: RGB -> gray
489
                mask_png = Image.fromarray(masks[idx, :, :]).convert('L') # L: RGB -> gray
490
491
                # mask -> json
492
                sub_name = str(nii_name.split('.')[0] + '_' + str(idx))
493
                img_mask_json_path = os.path.join(json_path, sub_name + '.json' )
494
                img_mask_txt_path = os.path.join(txt_path, sub_name + '.txt' )
495
                success, JSON_OUTPUT, TXT_OUTPUT = img_mask_to_json_txt(img_png, mask_png, class_names, class_mapping)
496
497
                if success:
498
                    # 1.1 & 2.1 save PNG
499
                    img_png.save(os.path.join(img_png_path, sub_name + '.png'))
500
                    mask_png.save(os.path.join(mask_png_path, sub_name + '.png'))  # Save the original mask, not the mask after padding & filling & close_contours
501
502
                    # 2.2 write in JSON
503
                    # JSON - "imagePath"
504
                    JSON_OUTPUT['imagePath'] = sub_name + '.png'
505
                    with open(img_mask_json_path, 'w') as json_output:
506
                        json.dump(JSON_OUTPUT, json_output, indent=4) # Serialize python_data as a JSON file
507
508
                    # 2.3 write in TXT
509
                    with open(img_mask_txt_path, 'w') as txt_output:
510
                        TXT_OUTPUT = np.array(TXT_OUTPUT)
511
                        for idx, row in enumerate(TXT_OUTPUT):
512
                            if idx != len(TXT_OUTPUT) - 1:
513
                                row = row + '\n'
514
                            txt_output.write(row)
515
516
print("The conversion was completed successfully.")