Switch to unified view

a b/create_heatmaps_survival.py
1
from __future__ import print_function
2
3
import numpy as np
4
5
import argparse
6
7
import torch
8
import torch.nn as nn
9
import pdb
10
import os
11
import pandas as pd
12
from utils.utils import *
13
from math import floor
14
from utils.eval_utils_survival import initiate_model as initiate_model
15
from models.model_amil import AMIL
16
from models.resnet_custom import resnet50_baseline
17
from types import SimpleNamespace
18
from collections import namedtuple
19
import h5py
20
import yaml
21
from wsi_core.batch_process_utils import initialize_df
22
from vis_utils.heatmap_utils import initialize_wsi, drawHeatmap, compute_from_patches
23
from wsi_core.wsi_utils import sample_rois
24
from utils.file_utils import save_hdf5
25
26
parser = argparse.ArgumentParser(description='Heatmap inference script')
27
parser.add_argument('--save_exp_code', type=str, default=None,
28
                    help='experiment code')
29
parser.add_argument('--overlap', type=float, default=None)
30
parser.add_argument('--config_file', type=str, default="heatmap_config_template.yaml")
31
args = parser.parse_args()
32
33
def infer_single_slide(model, features, label, reverse_label_dict, k=1):
34
    features = features.to(device)
35
    with torch.no_grad():
36
        if isinstance(model, AMIL):
37
            model_results_dict = model(features)
38
            risk, A, _= model(features)
39
40
            risk = risk.item()
41
            A = A.view(-1, 1).cpu().numpy()
42
43
        else:
44
            raise NotImplementedError
45
46
    return risk, A
47
48
def load_params(df_entry, params):
49
    for key in params.keys():
50
        if key in df_entry.index:
51
            dtype = type(params[key])
52
            val = df_entry[key] 
53
            val = dtype(val)
54
            if isinstance(val, str):
55
                if len(val) > 0:
56
                    params[key] = val
57
            elif not np.isnan(val):
58
                params[key] = val
59
            else:
60
                pdb.set_trace()
61
62
    return params
63
64
def parse_config_dict(args, config_dict):
65
    if args.save_exp_code is not None:
66
        config_dict['exp_arguments']['save_exp_code'] = args.save_exp_code
67
    if args.overlap is not None:
68
        config_dict['patching_arguments']['overlap'] = args.overlap
69
    return config_dict
70
71
if __name__ == '__main__':
72
    config_path = os.path.join('heatmaps/configs', args.config_file)
73
    config_dict = yaml.safe_load(open(config_path, 'r'))
74
    config_dict = parse_config_dict(args, config_dict)
75
76
    for key, value in config_dict.items():
77
        if isinstance(value, dict):
78
            print('\n'+key)
79
            for value_key, value_value in value.items():
80
                print (value_key + " : " + str(value_value))
81
        else:
82
            print ('\n'+key + " : " + str(value))
83
            
84
    decision = input('Continue? Y/N ')
85
    if decision in ['Y', 'y', 'Yes', 'yes']:
86
        pass
87
    elif decision in ['N', 'n', 'No', 'NO']:
88
        exit()
89
    else:
90
        raise NotImplementedError
91
92
    args = config_dict
93
    patch_args = argparse.Namespace(**args['patching_arguments'])
94
    data_args = argparse.Namespace(**args['data_arguments'])
95
    model_args = args['model_arguments']
96
    model_args.update({'n_classes': args['exp_arguments']['n_classes']})
97
    model_args = argparse.Namespace(**model_args)
98
    exp_args = argparse.Namespace(**args['exp_arguments'])
99
    heatmap_args = argparse.Namespace(**args['heatmap_arguments'])
100
    sample_args = argparse.Namespace(**args['sample_arguments'])
101
102
    patch_size = tuple([patch_args.patch_size for i in range(2)])
103
    step_size = tuple((np.array(patch_size) * (1-patch_args.overlap)).astype(int))
104
    print('patch_size: {} x {}, with {:.2f} overlap, step size is {} x {}'.format(patch_size[0], patch_size[1], patch_args.overlap, step_size[0], step_size[1]))
105
106
    
107
    preset = data_args.preset
108
    def_seg_params = {'seg_level': -1, 'sthresh': 15, 'mthresh': 11, 'close': 2, 'use_otsu': False, 
109
                      'keep_ids': 'none', 'exclude_ids':'none'}
110
    def_filter_params = {'a_t':50.0, 'a_h': 8.0, 'max_n_holes':10}
111
    def_vis_params = {'vis_level': -1, 'line_thickness': 250}
112
    def_patch_params = {'use_padding': True, 'contour_fn': 'four_pt'}
113
114
    if preset is not None:
115
        preset_df = pd.read_csv(preset)
116
        for key in def_seg_params.keys():
117
            def_seg_params[key] = preset_df.loc[0, key]
118
119
        for key in def_filter_params.keys():
120
            def_filter_params[key] = preset_df.loc[0, key]
121
122
        for key in def_vis_params.keys():
123
            def_vis_params[key] = preset_df.loc[0, key]
124
125
        for key in def_patch_params.keys():
126
            def_patch_params[key] = preset_df.loc[0, key]
127
128
129
    if data_args.process_list is None:
130
        if isinstance(data_args.data_dir, list):
131
            slides = []
132
            for data_dir in data_args.data_dir:
133
                slides.extend(os.listdir(data_dir))
134
        else:
135
            slides = sorted(os.listdir(data_args.data_dir))
136
        slides = [slide for slide in slides if data_args.slide_ext in slide]
137
        df = initialize_df(slides, def_seg_params, def_filter_params, def_vis_params, def_patch_params, use_heatmap_args=False)
138
        
139
    else:
140
        df = pd.read_csv(os.path.join('heatmaps/process_lists', data_args.process_list))
141
        df = initialize_df(df, def_seg_params, def_filter_params, def_vis_params, def_patch_params, use_heatmap_args=False)
142
143
    mask = df['process'] == 1
144
    process_stack = df[mask].reset_index(drop=True)
145
    total = len(process_stack)
146
    print('\nlist of slides to process: ')
147
    print(process_stack.head(len(process_stack)))
148
149
    print('\ninitializing model from checkpoint')
150
    ckpt_path = model_args.ckpt_path
151
    print('\nckpt path: {}'.format(ckpt_path))
152
    
153
    if model_args.initiate_fn == 'initiate_model':
154
        settings = {
155
            'drop_out': model_args.drop_out,
156
            'model_type': model_args.model_type,
157
            'model_size': model_args.model_size,
158
            'print_model_info': False
159
            }
160
        model =  initiate_model(settings, model_args.ckpt_path)
161
    else:
162
        raise NotImplementedError
163
164
165
    feature_extractor = resnet50_baseline(pretrained=True)
166
    feature_extractor.eval()
167
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
    print('Done!')
169
170
    label_dict =  data_args.label_dict
171
    class_labels = list(label_dict.keys())
172
    class_encodings = list(label_dict.values())
173
    reverse_label_dict = {class_encodings[i]: class_labels[i] for i in range(len(class_labels))} 
174
175
    if torch.cuda.device_count() > 1:
176
        device_ids = list(range(torch.cuda.device_count()))
177
        feature_extractor = nn.DataParallel(feature_extractor, device_ids=device_ids).to('cuda:0')
178
    else:
179
        feature_extractor = feature_extractor.to(device)
180
181
    os.makedirs(exp_args.production_save_dir, exist_ok=True)
182
    os.makedirs(exp_args.raw_save_dir, exist_ok=True)
183
    blocky_wsi_kwargs = {'top_left': None, 'bot_right': None, 'patch_size': patch_size, 'step_size': patch_size, 
184
    'custom_downsample':patch_args.custom_downsample, 'level': patch_args.patch_level, 'use_center_shift': heatmap_args.use_center_shift}
185
186
    for i in range(len(process_stack)):
187
        slide_name = process_stack.loc[i, 'slide_id']
188
        if data_args.slide_ext not in slide_name:
189
            slide_name+=data_args.slide_ext
190
        print('\nprocessing: ', slide_name) 
191
192
        try:
193
            label = process_stack.loc[i, 'label']
194
        except KeyError:
195
            label = 'Unspecified'
196
197
        slide_id = slide_name.replace(data_args.slide_ext, '')
198
199
        if not isinstance(label, str):
200
            grouping = reverse_label_dict[label]
201
        else:
202
            grouping = label
203
204
        p_slide_save_dir = os.path.join(exp_args.production_save_dir, exp_args.save_exp_code, str(grouping))
205
        os.makedirs(p_slide_save_dir, exist_ok=True)
206
207
        r_slide_save_dir = os.path.join(exp_args.raw_save_dir, exp_args.save_exp_code, str(grouping),  slide_id)
208
        os.makedirs(r_slide_save_dir, exist_ok=True)
209
210
        if heatmap_args.use_roi:
211
            x1, x2 = process_stack.loc[i, 'x1'], process_stack.loc[i, 'x2']
212
            y1, y2 = process_stack.loc[i, 'y1'], process_stack.loc[i, 'y2']
213
            top_left = (int(x1), int(y1))
214
            bot_right = (int(x2), int(y2))
215
        else:
216
            top_left = None
217
            bot_right = None
218
        
219
        print('slide id: ', slide_id)
220
        print('top left: ', top_left, ' bot right: ', bot_right)
221
222
        if isinstance(data_args.data_dir, str):
223
            slide_path = os.path.join(data_args.data_dir, slide_name)
224
        elif isinstance(data_args.data_dir, dict):
225
            data_dir_key = process_stack.loc[i, data_args.data_dir_key]
226
            slide_path = os.path.join(data_args.data_dir[data_dir_key], slide_name)
227
        else:
228
            raise NotImplementedError
229
230
        mask_file = os.path.join(r_slide_save_dir, slide_id+'_mask.pkl')
231
        
232
        # Load segmentation and filter parameters
233
        seg_params = def_seg_params.copy()
234
        filter_params = def_filter_params.copy()
235
        vis_params = def_vis_params.copy()
236
237
        seg_params = load_params(process_stack.loc[i], seg_params)
238
        filter_params = load_params(process_stack.loc[i], filter_params)
239
        vis_params = load_params(process_stack.loc[i], vis_params)
240
241
        keep_ids = str(seg_params['keep_ids'])
242
        if len(keep_ids) > 0 and keep_ids != 'none':
243
            seg_params['keep_ids'] = np.array(keep_ids.split(',')).astype(int)
244
        else:
245
            seg_params['keep_ids'] = []
246
247
        exclude_ids = str(seg_params['exclude_ids'])
248
        if len(exclude_ids) > 0 and exclude_ids != 'none':
249
            seg_params['exclude_ids'] = np.array(exclude_ids.split(',')).astype(int)
250
        else:
251
            seg_params['exclude_ids'] = []
252
253
        for key, val in seg_params.items():
254
            print('{}: {}'.format(key, val))
255
256
        for key, val in filter_params.items():
257
            print('{}: {}'.format(key, val))
258
259
        for key, val in vis_params.items():
260
            print('{}: {}'.format(key, val))
261
        
262
        print('Initializing WSI object')
263
        wsi_object = initialize_wsi(slide_path, seg_mask_path=mask_file, seg_params=seg_params, filter_params=filter_params)
264
        print('Done!')
265
266
        wsi_ref_downsample = wsi_object.level_downsamples[patch_args.patch_level]
267
268
        # the actual patch size for heatmap visualization should be the patch size * downsample factor * custom downsample factor
269
        vis_patch_size = tuple((np.array(patch_size) * np.array(wsi_ref_downsample) * patch_args.custom_downsample).astype(int))
270
271
        block_map_save_path = os.path.join(r_slide_save_dir, '{}_blockmap.h5'.format(slide_id))
272
        mask_path = os.path.join(r_slide_save_dir, '{}_mask.jpg'.format(slide_id))
273
        if vis_params['vis_level'] < 0:
274
            best_level = wsi_object.wsi.get_best_level_for_downsample(32)
275
            vis_params['vis_level'] = best_level
276
        mask = wsi_object.visWSI(**vis_params, number_contours=True)
277
        mask.save(mask_path)
278
        
279
        features_path = os.path.join(r_slide_save_dir, slide_id+'.pt')
280
        h5_path = os.path.join(r_slide_save_dir, slide_id+'.h5')
281
    
282
283
        ##### check if h5_features_file exists ######
284
        if not os.path.isfile(h5_path) :
285
            _, _, wsi_object = compute_from_patches(wsi_object=wsi_object, 
286
                                            model=model, 
287
                                            feature_extractor=feature_extractor, 
288
                                            batch_size=exp_args.batch_size, **blocky_wsi_kwargs, 
289
                                            attn_save_path=None, feat_save_path=h5_path, 
290
                                            ref_scores=None)                
291
        
292
        ##### check if pt_features_file exists ######
293
        if not os.path.isfile(features_path):
294
            file = h5py.File(h5_path, "r")
295
            features = torch.tensor(file['features'][:])
296
            torch.save(features, features_path)
297
            file.close()
298
299
        # load features 
300
        features = torch.load(features_path)
301
        process_stack.loc[i, 'bag_size'] = len(features)
302
        
303
        wsi_object.saveSegmentation(mask_file)
304
        risk, A = infer_single_slide(model, features, label, reverse_label_dict, exp_args.n_classes)
305
        del features
306
        
307
        if not os.path.isfile(block_map_save_path): 
308
            file = h5py.File(h5_path, "r")
309
            coords = file['coords'][:]
310
            file.close()
311
            asset_dict = {'attention_scores': A, 'coords': coords}
312
            block_map_save_path = save_hdf5(block_map_save_path, asset_dict, mode='w')
313
        
314
        # save predictions
315
        process_stack.loc[i, 'risk'] = risk
316
317
        os.makedirs('heatmaps/results/', exist_ok=True)
318
        if data_args.process_list is not None:
319
            process_stack.to_csv('heatmaps/results/{}.csv'.format(data_args.process_list.replace('.csv', '')), index=False)
320
        else:
321
            process_stack.to_csv('heatmaps/results/{}.csv'.format(exp_args.save_exp_code), index=False)
322
        
323
        file = h5py.File(block_map_save_path, 'r')
324
        dset = file['attention_scores']
325
        coord_dset = file['coords']
326
        scores = dset[:]
327
        coords = coord_dset[:]
328
        file.close()
329
330
        samples = sample_args.samples
331
        for sample in samples:
332
            if sample['sample']:
333
                tag = "label_{}_risk_{}".format(label, risk)
334
                sample_save_dir =  os.path.join(exp_args.production_save_dir, exp_args.save_exp_code, 'sampled_patches', str(tag), sample['name'])
335
                os.makedirs(sample_save_dir, exist_ok=True)
336
                print('sampling {}'.format(sample['name']))
337
                sample_results = sample_rois(scores, coords, k=sample['k'], mode=sample['mode'], seed=sample['seed'], 
338
                    score_start=sample.get('score_start', 0), score_end=sample.get('score_end', 1))
339
                for idx, (s_coord, s_score) in enumerate(zip(sample_results['sampled_coords'], sample_results['sampled_scores'])):
340
                    print('coord: {} score: {:.3f}'.format(s_coord, s_score))
341
                    patch = wsi_object.wsi.read_region(tuple(s_coord), patch_args.patch_level, (patch_args.patch_size, patch_args.patch_size)).convert('RGB')
342
                    patch.save(os.path.join(sample_save_dir, '{}_{}_x_{}_y_{}_a_{:.3f}.png'.format(idx, slide_id, s_coord[0], s_coord[1], s_score)))
343
344
        wsi_kwargs = {'top_left': top_left, 'bot_right': bot_right, 'patch_size': patch_size, 'step_size': step_size, 
345
        'custom_downsample':patch_args.custom_downsample, 'level': patch_args.patch_level, 'use_center_shift': heatmap_args.use_center_shift}
346
347
        heatmap_save_name = '{}_blockmap.tiff'.format(slide_id)
348
        if os.path.isfile(os.path.join(r_slide_save_dir, heatmap_save_name)):
349
            pass
350
        else:
351
            heatmap = drawHeatmap(scores, coords, slide_path, wsi_object=wsi_object, cmap=heatmap_args.cmap, alpha=heatmap_args.alpha, use_holes=True, binarize=False, vis_level=-1, blank_canvas=False,
352
                            thresh=-1, patch_size = vis_patch_size, convert_to_percentiles=True)
353
        
354
            heatmap.save(os.path.join(r_slide_save_dir, '{}_blockmap.png'.format(slide_id)))
355
            del heatmap
356
357
        save_path = os.path.join(r_slide_save_dir, '{}_{}_roi_{}.h5'.format(slide_id, patch_args.overlap, heatmap_args.use_roi))
358
359
        if heatmap_args.use_ref_scores:
360
            ref_scores = scores
361
        else:
362
            ref_scores = None
363
        
364
        if heatmap_args.calc_heatmap:
365
            compute_from_patches(wsi_object=wsi_object, clam_pred=None, model=model, feature_extractor=feature_extractor, batch_size=exp_args.batch_size, **wsi_kwargs, 
366
                                attn_save_path=save_path,  ref_scores=ref_scores)
367
368
        if not os.path.isfile(save_path):
369
            print('heatmap {} not found'.format(save_path))
370
            if heatmap_args.use_roi:
371
                save_path_full = os.path.join(r_slide_save_dir, '{}_{}_roi_False.h5'.format(slide_id, patch_args.overlap))
372
                print('found heatmap for whole slide')
373
                save_path = save_path_full
374
            else:
375
                continue
376
377
        file = h5py.File(save_path, 'r')
378
        dset = file['attention_scores']
379
        coord_dset = file['coords']
380
        scores = dset[:]
381
        coords = coord_dset[:]
382
        file.close()
383
384
        heatmap_vis_args = {'convert_to_percentiles': True, 'vis_level': heatmap_args.vis_level, 'blur': heatmap_args.blur, 'custom_downsample': heatmap_args.custom_downsample}
385
        if heatmap_args.use_ref_scores:
386
            heatmap_vis_args['convert_to_percentiles'] = False
387
388
        # heatmap_save_name = '{}_{}_roi_{}_blur_{}_rs_{}_bc_{}_a_{}_l_{}_bi_{}_{}.{}'.format(slide_id, float(patch_args.overlap), int(heatmap_args.use_roi),
389
        #                                                                               int(heatmap_args.blur), 
390
        #                                                                               int(heatmap_args.use_ref_scores), int(heatmap_args.blank_canvas), 
391
        #                                                                               float(heatmap_args.alpha), int(heatmap_args.vis_level), 
392
        #                                                                               int(heatmap_args.binarize), float(heatmap_args.binary_thresh), heatmap_args.save_ext)
393
394
        heatmap_save_name = '{}_heatmap_risk{:.4f}_label{}.{}'.format(slide_id, risk, label, heatmap_args.save_ext)
395
396
397
        if os.path.isfile(os.path.join(p_slide_save_dir, heatmap_save_name)):
398
            pass
399
        
400
        else:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
401
            heatmap = drawHeatmap(scores, coords, slide_path, wsi_object=wsi_object,  
402
                                  cmap=heatmap_args.cmap, alpha=heatmap_args.alpha, **heatmap_vis_args, 
403
                                  binarize=heatmap_args.binarize, 
404
                                  blank_canvas=heatmap_args.blank_canvas,
405
                                  thresh=heatmap_args.binary_thresh,  patch_size = vis_patch_size,
406
                                  overlap=patch_args.overlap, 
407
                                  top_left=top_left, bot_right = bot_right)
408
            if heatmap_args.save_ext == 'jpg':
409
                heatmap.save(os.path.join(p_slide_save_dir, heatmap_save_name), quality=100)
410
            else:
411
                heatmap.save(os.path.join(p_slide_save_dir, heatmap_save_name))
412
        
413
        if heatmap_args.save_orig:
414
            if heatmap_args.vis_level >= 0:
415
                vis_level = heatmap_args.vis_level
416
            else:
417
                vis_level = vis_params['vis_level']
418
            heatmap_save_name = '{}_orig_{}.{}'.format(slide_id,int(vis_level), heatmap_args.save_ext)
419
            if os.path.isfile(os.path.join(p_slide_save_dir, heatmap_save_name)):
420
                pass
421
            else:
422
                heatmap = wsi_object.visWSI(vis_level=vis_level, view_slide_only=True, custom_downsample=heatmap_args.custom_downsample)
423
                if heatmap_args.save_ext == 'jpg':
424
                    heatmap.save(os.path.join(p_slide_save_dir, heatmap_save_name), quality=100)
425
                else:
426
                    heatmap.save(os.path.join(p_slide_save_dir, heatmap_save_name))
427
428
    with open(os.path.join(exp_args.raw_save_dir, exp_args.save_exp_code, 'config.yaml'), 'w') as outfile:
429
        yaml.dump(config_dict, outfile, default_flow_style=False)
430
431