# internal imports
from wsi_core.WholeSlideImage import WholeSlideImage
from wsi_core.wsi_utils import StitchCoords
from wsi_core.batch_process_utils import initialize_df
# other imports
import os
import numpy as np
import time
import argparse
import pdb
import pandas as pd
def stitching(file_path, wsi_object, downscale = 64):
start = time.time()
heatmap = StitchCoords(file_path, wsi_object, downscale=downscale, bg_color=(0,0,0), alpha=-1, draw_grid=False)
total_time = time.time() - start
return heatmap, total_time
def segment(WSI_object, seg_params = None, filter_params = None, mask_file = None):
### Start Seg Timer
start_time = time.time()
# Use segmentation file
if mask_file is not None:
WSI_object.initSegmentation(mask_file)
# Segment
else:
WSI_object.segmentTissue(**seg_params, filter_params=filter_params)
### Stop Seg Timers
seg_time_elapsed = time.time() - start_time
return WSI_object, seg_time_elapsed
def patching(WSI_object, **kwargs):
### Start Patch Timer
start_time = time.time()
# Patch
file_path = WSI_object.process_contours(**kwargs)
### Stop Patch Timer
patch_time_elapsed = time.time() - start_time
return file_path, patch_time_elapsed
def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_dir,
patch_size = 256, step_size = 256,
seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
'keep_ids': 'none', 'exclude_ids': 'none'},
filter_params = {'a_t':100, 'a_h': 16, 'max_n_holes':8},
vis_params = {'vis_level': -1, 'line_thickness': 500},
patch_params = {'use_padding': True, 'contour_fn': 'four_pt'},
patch_level = 0,
use_default_params = False,
seg = False, save_mask = True,
stitch= False,
patch = False, auto_skip=True, process_list = None):
slides = sorted(os.listdir(source))
slides = [slide for slide in slides if os.path.isfile(os.path.join(source, slide))]
if process_list is None:
df = initialize_df(slides, seg_params, filter_params, vis_params, patch_params)
else:
df = pd.read_csv(process_list)
df = initialize_df(df, seg_params, filter_params, vis_params, patch_params)
mask = df['process'] == 1
process_stack = df[mask]
total = len(process_stack)
legacy_support = 'a' in df.keys()
if legacy_support:
print('detected legacy segmentation csv file, legacy support enabled')
df = df.assign(**{'a_t': np.full((len(df)), int(filter_params['a_t']), dtype=np.uint32),
'a_h': np.full((len(df)), int(filter_params['a_h']), dtype=np.uint32),
'max_n_holes': np.full((len(df)), int(filter_params['max_n_holes']), dtype=np.uint32),
'line_thickness': np.full((len(df)), int(vis_params['line_thickness']), dtype=np.uint32),
'contour_fn': np.full((len(df)), patch_params['contour_fn'])})
seg_times = 0.
patch_times = 0.
stitch_times = 0.
for i in range(total):
df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
idx = process_stack.index[i]
slide = process_stack.loc[idx, 'slide_id']
print("\n\nprogress: {:.2f}, {}/{}".format(i/total, i, total))
print('processing {}'.format(slide))
df.loc[idx, 'process'] = 0
slide_id, _ = os.path.splitext(slide)
if auto_skip and os.path.isfile(os.path.join(patch_save_dir, slide_id + '.h5')):
print('{} already exist in destination location, skipped'.format(slide_id))
df.loc[idx, 'status'] = 'already_exist'
continue
# Inialize WSI
full_path = os.path.join(source, slide)
WSI_object = WholeSlideImage(full_path)
if use_default_params:
current_vis_params = vis_params.copy()
current_filter_params = filter_params.copy()
current_seg_params = seg_params.copy()
current_patch_params = patch_params.copy()
else:
current_vis_params = {}
current_filter_params = {}
current_seg_params = {}
current_patch_params = {}
for key in vis_params.keys():
if legacy_support and key == 'vis_level':
df.loc[idx, key] = -1
current_vis_params.update({key: df.loc[idx, key]})
for key in filter_params.keys():
if legacy_support and key == 'a_t':
old_area = df.loc[idx, 'a']
seg_level = df.loc[idx, 'seg_level']
scale = WSI_object.level_downsamples[seg_level]
adjusted_area = int(old_area * (scale[0] * scale[1]) / (512 * 512))
current_filter_params.update({key: adjusted_area})
df.loc[idx, key] = adjusted_area
current_filter_params.update({key: df.loc[idx, key]})
for key in seg_params.keys():
if legacy_support and key == 'seg_level':
df.loc[idx, key] = -1
current_seg_params.update({key: df.loc[idx, key]})
for key in patch_params.keys():
current_patch_params.update({key: df.loc[idx, key]})
if current_vis_params['vis_level'] < 0:
if len(WSI_object.level_dim) == 1:
current_vis_params['vis_level'] = 0
else:
wsi = WSI_object.getOpenSlide()
best_level = wsi.get_best_level_for_downsample(64)
current_vis_params['vis_level'] = best_level
if current_seg_params['seg_level'] < 0:
if len(WSI_object.level_dim) == 1:
current_seg_params['seg_level'] = 0
else:
wsi = WSI_object.getOpenSlide()
best_level = wsi.get_best_level_for_downsample(64)
current_seg_params['seg_level'] = best_level
keep_ids = str(current_seg_params['keep_ids'])
if keep_ids != 'none' and len(keep_ids) > 0:
str_ids = current_seg_params['keep_ids']
current_seg_params['keep_ids'] = np.array(str_ids.split(',')).astype(int)
else:
current_seg_params['keep_ids'] = []
exclude_ids = str(current_seg_params['exclude_ids'])
if exclude_ids != 'none' and len(exclude_ids) > 0:
str_ids = current_seg_params['exclude_ids']
current_seg_params['exclude_ids'] = np.array(str_ids.split(',')).astype(int)
else:
current_seg_params['exclude_ids'] = []
w, h = WSI_object.level_dim[current_seg_params['seg_level']]
if w * h > 1e8:
print('level_dim {} x {} is likely too large for successful segmentation, aborting'.format(w, h))
df.loc[idx, 'status'] = 'failed_seg'
continue
df.loc[idx, 'vis_level'] = current_vis_params['vis_level']
df.loc[idx, 'seg_level'] = current_seg_params['seg_level']
seg_time_elapsed = -1
if seg:
WSI_object, seg_time_elapsed = segment(WSI_object, current_seg_params, current_filter_params)
if save_mask:
mask = WSI_object.visWSI(**current_vis_params)
mask_path = os.path.join(mask_save_dir, slide_id+'.jpg')
mask.save(mask_path)
patch_time_elapsed = -1 # Default time
if patch:
current_patch_params.update({'patch_level': patch_level, 'patch_size': patch_size, 'step_size': step_size,
'save_path': patch_save_dir})
file_path, patch_time_elapsed = patching(WSI_object = WSI_object, **current_patch_params,)
stitch_time_elapsed = -1
if stitch:
file_path = os.path.join(patch_save_dir, slide_id+'.h5')
if os.path.isfile(file_path):
heatmap, stitch_time_elapsed = stitching(file_path, WSI_object, downscale=64)
stitch_path = os.path.join(stitch_save_dir, slide_id+'.jpg')
heatmap.save(stitch_path)
print("segmentation took {} seconds".format(seg_time_elapsed))
print("patching took {} seconds".format(patch_time_elapsed))
print("stitching took {} seconds".format(stitch_time_elapsed))
df.loc[idx, 'status'] = 'processed'
seg_times += seg_time_elapsed
patch_times += patch_time_elapsed
stitch_times += stitch_time_elapsed
seg_times /= total
patch_times /= total
stitch_times /= total
df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
print("average segmentation time in s per slide: {}".format(seg_times))
print("average patching time in s per slide: {}".format(patch_times))
print("average stiching time in s per slide: {}".format(stitch_times))
return seg_times, patch_times
parser = argparse.ArgumentParser(description='seg and patch')
parser.add_argument('--source', type = str,
help='path to folder containing raw wsi image files')
parser.add_argument('--step_size', type = int, default=256,
help='step_size')
parser.add_argument('--patch_size', type = int, default=256,
help='patch_size')
parser.add_argument('--patch', default=False, action='store_true')
parser.add_argument('--seg', default=False, action='store_true')
parser.add_argument('--stitch', default=False, action='store_true')
parser.add_argument('--no_auto_skip', default=True, action='store_false')
parser.add_argument('--save_dir', type = str,
help='directory to save processed data')
parser.add_argument('--preset', default=None, type=str,
help='predefined profile of default segmentation and filter parameters (.csv)')
parser.add_argument('--patch_level', type=int, default=0,
help='downsample level at which to patch')
parser.add_argument('--process_list', type = str, default=None,
help='name of list of images to process with parameters (.csv)')
if __name__ == '__main__':
args = parser.parse_args()
patch_save_dir = os.path.join(args.save_dir, 'patches')
mask_save_dir = os.path.join(args.save_dir, 'masks')
stitch_save_dir = os.path.join(args.save_dir, 'stitches')
if args.process_list:
process_list = os.path.join(args.save_dir, args.process_list)
else:
process_list = None
print('source: ', args.source)
print('patch_save_dir: ', patch_save_dir)
print('mask_save_dir: ', mask_save_dir)
print('stitch_save_dir: ', stitch_save_dir)
directories = {'source': args.source,
'save_dir': args.save_dir,
'patch_save_dir': patch_save_dir,
'mask_save_dir' : mask_save_dir,
'stitch_save_dir': stitch_save_dir}
for key, val in directories.items():
print("{} : {}".format(key, val))
if key not in ['source']:
os.makedirs(val, exist_ok=True)
seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
'keep_ids': 'none', 'exclude_ids': 'none'}
filter_params = {'a_t':100, 'a_h': 16, 'max_n_holes':8}
vis_params = {'vis_level': -1, 'line_thickness': 250}
patch_params = {'use_padding': True, 'contour_fn': 'four_pt'}
if args.preset:
preset_df = pd.read_csv(os.path.join('presets', args.preset))
for key in seg_params.keys():
seg_params[key] = preset_df.loc[0, key]
for key in filter_params.keys():
filter_params[key] = preset_df.loc[0, key]
for key in vis_params.keys():
vis_params[key] = preset_df.loc[0, key]
for key in patch_params.keys():
patch_params[key] = preset_df.loc[0, key]
parameters = {'seg_params': seg_params,
'filter_params': filter_params,
'patch_params': patch_params,
'vis_params': vis_params}
print(parameters)
seg_times, patch_times = seg_and_patch(**directories, **parameters,
patch_size = args.patch_size, step_size=args.step_size,
seg = args.seg, use_default_params=False, save_mask = True,
stitch= args.stitch,
patch_level=args.patch_level, patch = args.patch,
process_list = process_list, auto_skip=args.no_auto_skip)