a b/src/preprocessing_3w.py
1
import numpy as np
2
import pandas as pd
3
import os
4
import click
5
import glob
6
import cv2
7
import pydicom
8
from tqdm import tqdm
9
from joblib import delayed, Parallel
10
import random
11
import pydicom
12
from scipy import ndimage
13
import pydicom
14
from skimage import exposure
15
16
17
def window_image(img, window_center, window_width, intercept, slope):
18
    img = (img * slope + intercept)
19
    img_min = window_center - window_width // 2
20
    img_max = window_center + window_width // 2
21
    img[img < img_min] = img_min
22
    img[img > img_max] = img_max
23
    return img
24
25
26
def get_first_of_dicom_field_as_int(x):
27
    # get x[0] as in int is x is a 'pydicom.multival.MultiValue', otherwise get int(x)
28
    if type(x) == pydicom.multival.MultiValue:
29
        return int(x[0])
30
    else:
31
        return int(x)
32
33
34
def get_windowing(data):
35
    dicom_fields = [data[('0028', '1050')].value,  # window center
36
                    data[('0028', '1051')].value,  # window width
37
                    data[('0028', '1052')].value,  # intercept
38
                    data[('0028', '1053')].value]  # slope
39
    return [get_first_of_dicom_field_as_int(x) for x in dicom_fields]
40
41
42
@click.group()
43
def cli():
44
    print("CLI")
45
46
47
windows_range = {
48
    'brain': [40, 80],
49
    'bone': [600, 2800],
50
    'subdual': [75, 215]
51
}
52
53
54
def refine_label(label_mask):
55
    label_mask = label_mask.astype(np.bool)
56
    # Fill hole
57
    label_mask = ndimage.binary_fill_holes(label_mask)
58
    # Get largest connected component
59
    label_im, nb_labels = ndimage.label(label_mask)
60
    sizes = ndimage.sum(label_mask, label_im, range(nb_labels + 1))
61
    mask_size = sizes < max(sizes)
62
    remove_pixel = mask_size[label_im]
63
    label_im[remove_pixel] = 0
64
    labels = np.unique(label_im)
65
    label_mask = np.searchsorted(labels, label_im)
66
    return label_mask
67
68
69
def cut_edge(image, keep_margin):
70
    '''
71
    function that cuts zero edge
72
    '''
73
    H, W = image.shape
74
    H_s, H_e = 0, H - 1
75
    W_s, W_e = 0, W - 1
76
77
    while H_s < H:
78
        if image[H_s, :].sum() != 0:
79
            break
80
        H_s += 1
81
    while H_e > H_s:
82
        if image[H_e, :].sum() != 0:
83
            break
84
        H_e -= 1
85
    while W_s < W:
86
        if image[:, W_s].sum() != 0:
87
            break
88
        W_s += 1
89
    while W_e > W_s:
90
        if image[:, W_e].sum() != 0:
91
            break
92
        W_e -= 1
93
    if keep_margin != 0:
94
        H_s = max(0, H_s - keep_margin)
95
        H_e = min(H - 1, H_e + keep_margin)
96
        W_s = max(0, W_s - keep_margin)
97
        W_e = min(W - 1, W_e + keep_margin)
98
    return int(H_s), int(H_e) + 1, int(W_s), int(W_e) + 1
99
100
101
def pre_preocessing(image, pad_size=(512, 512)):
102
    # Convert to [0, 255]
103
    # image = (image-image.min()) / (image.max() - image.min())
104
    # image= image*255
105
    image[image < 0] = 0
106
    # Remove unwanted region
107
    mask = image > 0
108
    mask = refine_label(mask)
109
    image = image * mask
110
    # Center crop and pad to size
111
    # mask = image>0
112
    # min_H_s, max_H_e, min_W_s, max_W_e = cut_edge(mask, 32)
113
    # image = image[min_H_s: max_H_e, min_W_s:max_W_e]
114
    # Pad to size
115
    H, W = image.shape
116
    pad_H, pad_W = pad_size[0], pad_size[1]
117
    pad_H0 = max((pad_H - H) // 2, 0)
118
    pad_H1 = max(pad_H - H - pad_H0, 0)
119
    pad_W0 = max((pad_W - W) // 2, 0)
120
    pad_W1 = max(pad_W - W - pad_W0, 0)
121
    image = np.pad(image, [(pad_H0, pad_H1), (pad_W0, pad_W1)], mode='constant', constant_values=0)
122
    return image
123
124
125
def convert_dicom_to_jpg(dicomfile, outputdir):
126
    try:
127
        data = pydicom.read_file(dicomfile)
128
        image = data.pixel_array
129
        window_center, window_width, intercept, slope = get_windowing(data)
130
        id = dicomfile.split("/")[-1].split(".")[0]
131
132
        images = []
133
        # count =0
134
135
        for k, v in windows_range.items():
136
            image_windowed = window_image(image, v[0], v[1], intercept, slope)
137
            image_windowed = pre_preocessing(image_windowed, pad_size=(512, 512))
138
            images.append(image_windowed)
139
140
            # image_windowed = exposure.equalize_adapthist(image_windowed, clip_limit=0.01)
141
            # min_value= image_windowed.min()
142
            # max_value = image_windowed.max()
143
            # print (image_windowed.min(),image_windowed.max())
144
            # if count ==0:
145
            #     image_windowed=np.uint8(image_windowed)
146
            #     clahe = cv2.createCLAHE(clipLimit = 1.0, tileGridSize = (8,8))
147
            #     image_windowed = clahe.apply(image_windowed)
148
            #     images.append(image_windowed)
149
            # print (image_windowed.min(),image_windowed.max())
150
            # count +=1
151
        images = np.asarray(images).transpose((1, 2, 0))
152
        # print (images.shape)
153
154
        output_image = os.path.join(outputdir, id + ".jpg")
155
        cv2.imwrite(output_image, images)
156
    except:
157
        print(dicomfile)
158
159
160
@cli.command()
161
@click.option('--inputdir', type=str)
162
@click.option('--outputdir', type=str)
163
def extract_images(
164
        inputdir,
165
        outputdir,
166
):
167
    os.makedirs(outputdir, exist_ok=True)
168
    files = glob.glob(inputdir + "/*.dcm")
169
    Parallel(n_jobs=8)(delayed(convert_dicom_to_jpg)(file, outputdir) for file in tqdm(files, total=len(files)))
170
171
172
def split_by_patient(
173
        train_csv,
174
        train_meta_csv,
175
        n_folds,
176
        outdir
177
):
178
    os.makedirs(outdir, exist_ok=True)
179
    train_df = pd.read_csv(train_csv)
180
    train_meta_df = pd.read_csv(train_meta_csv)
181
    train_meta_df['ID'] = train_meta_df['ID'].apply(lambda x: "_".join(x.split("_")[:2]))
182
    train_meta_df = train_meta_df[['ID', 'PatientID']]
183
184
185
if __name__ == '__main__':
186
    cli()