a b/create_patches_fp.py
1
# internal imports
2
from wsi_core.WholeSlideImage import WholeSlideImage
3
from wsi_core.wsi_utils import StitchCoords
4
from wsi_core.batch_process_utils import initialize_df
5
# other imports
6
import os
7
import numpy as np
8
import time
9
import argparse
10
import pdb
11
import pandas as pd
12
13
def stitching(file_path, wsi_object, downscale = 64):
14
    start = time.time()
15
    heatmap = StitchCoords(file_path, wsi_object, downscale=downscale, bg_color=(0,0,0), alpha=-1, draw_grid=False)
16
    total_time = time.time() - start
17
    
18
    return heatmap, total_time
19
20
def segment(WSI_object, seg_params = None, filter_params = None, mask_file = None):
21
    ### Start Seg Timer
22
    start_time = time.time()
23
    # Use segmentation file
24
    if mask_file is not None:
25
        WSI_object.initSegmentation(mask_file)
26
    # Segment   
27
    else:
28
        WSI_object.segmentTissue(**seg_params, filter_params=filter_params)
29
30
    ### Stop Seg Timers
31
    seg_time_elapsed = time.time() - start_time   
32
    return WSI_object, seg_time_elapsed
33
34
def patching(WSI_object, **kwargs):
35
    ### Start Patch Timer
36
    start_time = time.time()
37
38
    # Patch
39
    file_path = WSI_object.process_contours(**kwargs)
40
41
42
    ### Stop Patch Timer
43
    patch_time_elapsed = time.time() - start_time
44
    return file_path, patch_time_elapsed
45
46
47
def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_dir, 
48
                  patch_size = 256, step_size = 256, 
49
                  seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
50
                  'keep_ids': 'none', 'exclude_ids': 'none'},
51
                  filter_params = {'a_t':100, 'a_h': 16, 'max_n_holes':8}, 
52
                  vis_params = {'vis_level': -1, 'line_thickness': 500},
53
                  patch_params = {'use_padding': True, 'contour_fn': 'four_pt'},
54
                  patch_level = 0,
55
                  use_default_params = False, 
56
                  seg = False, save_mask = True, 
57
                  stitch= False, 
58
                  patch = False, auto_skip=True, process_list = None):
59
    
60
61
62
    slides = sorted(os.listdir(source))
63
    slides = [slide for slide in slides if os.path.isfile(os.path.join(source, slide))]
64
    if process_list is None:
65
        df = initialize_df(slides, seg_params, filter_params, vis_params, patch_params)
66
    
67
    else:
68
        df = pd.read_csv(process_list)
69
        df = initialize_df(df, seg_params, filter_params, vis_params, patch_params)
70
71
    mask = df['process'] == 1
72
    process_stack = df[mask]
73
74
    total = len(process_stack)
75
76
    legacy_support = 'a' in df.keys()
77
    if legacy_support:
78
        print('detected legacy segmentation csv file, legacy support enabled')
79
        df = df.assign(**{'a_t': np.full((len(df)), int(filter_params['a_t']), dtype=np.uint32),
80
        'a_h': np.full((len(df)), int(filter_params['a_h']), dtype=np.uint32),
81
        'max_n_holes': np.full((len(df)), int(filter_params['max_n_holes']), dtype=np.uint32),
82
        'line_thickness': np.full((len(df)), int(vis_params['line_thickness']), dtype=np.uint32),
83
        'contour_fn': np.full((len(df)), patch_params['contour_fn'])})
84
85
    seg_times = 0.
86
    patch_times = 0.
87
    stitch_times = 0.
88
89
    for i in range(total):
90
        df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
91
        idx = process_stack.index[i]
92
        slide = process_stack.loc[idx, 'slide_id']
93
        print("\n\nprogress: {:.2f}, {}/{}".format(i/total, i, total))
94
        print('processing {}'.format(slide))
95
        
96
        df.loc[idx, 'process'] = 0
97
        slide_id, _ = os.path.splitext(slide)
98
99
        if auto_skip and os.path.isfile(os.path.join(patch_save_dir, slide_id + '.h5')):
100
            print('{} already exist in destination location, skipped'.format(slide_id))
101
            df.loc[idx, 'status'] = 'already_exist'
102
            continue
103
104
        # Inialize WSI
105
        full_path = os.path.join(source, slide)
106
        WSI_object = WholeSlideImage(full_path)
107
108
        if use_default_params:
109
            current_vis_params = vis_params.copy()
110
            current_filter_params = filter_params.copy()
111
            current_seg_params = seg_params.copy()
112
            current_patch_params = patch_params.copy()
113
            
114
        else:
115
            current_vis_params = {}
116
            current_filter_params = {}
117
            current_seg_params = {}
118
            current_patch_params = {}
119
120
121
            for key in vis_params.keys():
122
                if legacy_support and key == 'vis_level':
123
                    df.loc[idx, key] = -1
124
                current_vis_params.update({key: df.loc[idx, key]})
125
126
            for key in filter_params.keys():
127
                if legacy_support and key == 'a_t':
128
                    old_area = df.loc[idx, 'a']
129
                    seg_level = df.loc[idx, 'seg_level']
130
                    scale = WSI_object.level_downsamples[seg_level]
131
                    adjusted_area = int(old_area * (scale[0] * scale[1]) / (512 * 512))
132
                    current_filter_params.update({key: adjusted_area})
133
                    df.loc[idx, key] = adjusted_area
134
                current_filter_params.update({key: df.loc[idx, key]})
135
136
            for key in seg_params.keys():
137
                if legacy_support and key == 'seg_level':
138
                    df.loc[idx, key] = -1
139
                current_seg_params.update({key: df.loc[idx, key]})
140
141
            for key in patch_params.keys():
142
                current_patch_params.update({key: df.loc[idx, key]})
143
144
        if current_vis_params['vis_level'] < 0:
145
            if len(WSI_object.level_dim) == 1:
146
                current_vis_params['vis_level'] = 0
147
            
148
            else:   
149
                wsi = WSI_object.getOpenSlide()
150
                best_level = wsi.get_best_level_for_downsample(64)
151
                current_vis_params['vis_level'] = best_level
152
153
        if current_seg_params['seg_level'] < 0:
154
            if len(WSI_object.level_dim) == 1:
155
                current_seg_params['seg_level'] = 0
156
            
157
            else:
158
                wsi = WSI_object.getOpenSlide()
159
                best_level = wsi.get_best_level_for_downsample(64)
160
                current_seg_params['seg_level'] = best_level
161
162
        keep_ids = str(current_seg_params['keep_ids'])
163
        if keep_ids != 'none' and len(keep_ids) > 0:
164
            str_ids = current_seg_params['keep_ids']
165
            current_seg_params['keep_ids'] = np.array(str_ids.split(',')).astype(int)
166
        else:
167
            current_seg_params['keep_ids'] = []
168
169
        exclude_ids = str(current_seg_params['exclude_ids'])
170
        if exclude_ids != 'none' and len(exclude_ids) > 0:
171
            str_ids = current_seg_params['exclude_ids']
172
            current_seg_params['exclude_ids'] = np.array(str_ids.split(',')).astype(int)
173
        else:
174
            current_seg_params['exclude_ids'] = []
175
176
        w, h = WSI_object.level_dim[current_seg_params['seg_level']] 
177
        if w * h > 1e8:
178
            print('level_dim {} x {} is likely too large for successful segmentation, aborting'.format(w, h))
179
            df.loc[idx, 'status'] = 'failed_seg'
180
            continue
181
182
        df.loc[idx, 'vis_level'] = current_vis_params['vis_level']
183
        df.loc[idx, 'seg_level'] = current_seg_params['seg_level']
184
185
186
        seg_time_elapsed = -1
187
        if seg:
188
            WSI_object, seg_time_elapsed = segment(WSI_object, current_seg_params, current_filter_params) 
189
190
        if save_mask:
191
            mask = WSI_object.visWSI(**current_vis_params)
192
            mask_path = os.path.join(mask_save_dir, slide_id+'.jpg')
193
            mask.save(mask_path)
194
195
        patch_time_elapsed = -1 # Default time
196
        if patch:
197
            current_patch_params.update({'patch_level': patch_level, 'patch_size': patch_size, 'step_size': step_size, 
198
                                         'save_path': patch_save_dir})
199
            file_path, patch_time_elapsed = patching(WSI_object = WSI_object,  **current_patch_params,)
200
        
201
        stitch_time_elapsed = -1
202
        if stitch:
203
            file_path = os.path.join(patch_save_dir, slide_id+'.h5')
204
            if os.path.isfile(file_path):
205
                heatmap, stitch_time_elapsed = stitching(file_path, WSI_object, downscale=64)
206
                stitch_path = os.path.join(stitch_save_dir, slide_id+'.jpg')
207
                heatmap.save(stitch_path)
208
209
        print("segmentation took {} seconds".format(seg_time_elapsed))
210
        print("patching took {} seconds".format(patch_time_elapsed))
211
        print("stitching took {} seconds".format(stitch_time_elapsed))
212
        df.loc[idx, 'status'] = 'processed'
213
214
        seg_times += seg_time_elapsed
215
        patch_times += patch_time_elapsed
216
        stitch_times += stitch_time_elapsed
217
218
    seg_times /= total
219
    patch_times /= total
220
    stitch_times /= total
221
222
    df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
223
    print("average segmentation time in s per slide: {}".format(seg_times))
224
    print("average patching time in s per slide: {}".format(patch_times))
225
    print("average stiching time in s per slide: {}".format(stitch_times))
226
        
227
    return seg_times, patch_times
228
229
parser = argparse.ArgumentParser(description='seg and patch')
230
parser.add_argument('--source', type = str,
231
                    help='path to folder containing raw wsi image files')
232
parser.add_argument('--step_size', type = int, default=256,
233
                    help='step_size')
234
parser.add_argument('--patch_size', type = int, default=256,
235
                    help='patch_size')
236
parser.add_argument('--patch', default=False, action='store_true')
237
parser.add_argument('--seg', default=False, action='store_true')
238
parser.add_argument('--stitch', default=False, action='store_true')
239
parser.add_argument('--no_auto_skip', default=True, action='store_false')
240
parser.add_argument('--save_dir', type = str,
241
                    help='directory to save processed data')
242
parser.add_argument('--preset', default=None, type=str,
243
                    help='predefined profile of default segmentation and filter parameters (.csv)')
244
parser.add_argument('--patch_level', type=int, default=0, 
245
                    help='downsample level at which to patch')
246
parser.add_argument('--process_list',  type = str, default=None,
247
                    help='name of list of images to process with parameters (.csv)')
248
249
if __name__ == '__main__':
250
    args = parser.parse_args()
251
252
    patch_save_dir = os.path.join(args.save_dir, 'patches')
253
    mask_save_dir = os.path.join(args.save_dir, 'masks')
254
    stitch_save_dir = os.path.join(args.save_dir, 'stitches')
255
256
    if args.process_list:
257
        process_list = os.path.join(args.save_dir, args.process_list)
258
259
    else:
260
        process_list = None
261
262
    print('source: ', args.source)
263
    print('patch_save_dir: ', patch_save_dir)
264
    print('mask_save_dir: ', mask_save_dir)
265
    print('stitch_save_dir: ', stitch_save_dir)
266
    
267
    directories = {'source': args.source, 
268
                   'save_dir': args.save_dir,
269
                   'patch_save_dir': patch_save_dir, 
270
                   'mask_save_dir' : mask_save_dir, 
271
                   'stitch_save_dir': stitch_save_dir} 
272
273
    for key, val in directories.items():
274
        print("{} : {}".format(key, val))
275
        if key not in ['source']:
276
            os.makedirs(val, exist_ok=True)
277
278
    seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
279
                  'keep_ids': 'none', 'exclude_ids': 'none'}
280
    filter_params = {'a_t':100, 'a_h': 16, 'max_n_holes':8}
281
    vis_params = {'vis_level': -1, 'line_thickness': 250}
282
    patch_params = {'use_padding': True, 'contour_fn': 'four_pt'}
283
284
    if args.preset:
285
        preset_df = pd.read_csv(os.path.join('presets', args.preset))
286
        for key in seg_params.keys():
287
            seg_params[key] = preset_df.loc[0, key]
288
289
        for key in filter_params.keys():
290
            filter_params[key] = preset_df.loc[0, key]
291
292
        for key in vis_params.keys():
293
            vis_params[key] = preset_df.loc[0, key]
294
295
        for key in patch_params.keys():
296
            patch_params[key] = preset_df.loc[0, key]
297
    
298
    parameters = {'seg_params': seg_params,
299
                  'filter_params': filter_params,
300
                  'patch_params': patch_params,
301
                  'vis_params': vis_params}
302
303
    print(parameters)
304
305
    seg_times, patch_times = seg_and_patch(**directories, **parameters,
306
                                            patch_size = args.patch_size, step_size=args.step_size, 
307
                                            seg = args.seg,  use_default_params=False, save_mask = True, 
308
                                            stitch= args.stitch,
309
                                            patch_level=args.patch_level, patch = args.patch,
310
                                            process_list = process_list, auto_skip=args.no_auto_skip)