Diff of /create_patches_fp.py [000000] .. [0fdc30]

Switch to side-by-side view

--- a
+++ b/create_patches_fp.py
@@ -0,0 +1,310 @@
+# internal imports
+from wsi_core.WholeSlideImage import WholeSlideImage
+from wsi_core.wsi_utils import StitchCoords
+from wsi_core.batch_process_utils import initialize_df
+# other imports
+import os
+import numpy as np
+import time
+import argparse
+import pdb
+import pandas as pd
+
+def stitching(file_path, wsi_object, downscale = 64):
+	start = time.time()
+	heatmap = StitchCoords(file_path, wsi_object, downscale=downscale, bg_color=(0,0,0), alpha=-1, draw_grid=False)
+	total_time = time.time() - start
+	
+	return heatmap, total_time
+
+def segment(WSI_object, seg_params = None, filter_params = None, mask_file = None):
+	### Start Seg Timer
+	start_time = time.time()
+	# Use segmentation file
+	if mask_file is not None:
+		WSI_object.initSegmentation(mask_file)
+	# Segment	
+	else:
+		WSI_object.segmentTissue(**seg_params, filter_params=filter_params)
+
+	### Stop Seg Timers
+	seg_time_elapsed = time.time() - start_time   
+	return WSI_object, seg_time_elapsed
+
+def patching(WSI_object, **kwargs):
+	### Start Patch Timer
+	start_time = time.time()
+
+	# Patch
+	file_path = WSI_object.process_contours(**kwargs)
+
+
+	### Stop Patch Timer
+	patch_time_elapsed = time.time() - start_time
+	return file_path, patch_time_elapsed
+
+
+def seg_and_patch(source, save_dir, patch_save_dir, mask_save_dir, stitch_save_dir, 
+				  patch_size = 256, step_size = 256, 
+				  seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
+				  'keep_ids': 'none', 'exclude_ids': 'none'},
+				  filter_params = {'a_t':100, 'a_h': 16, 'max_n_holes':8}, 
+				  vis_params = {'vis_level': -1, 'line_thickness': 500},
+				  patch_params = {'use_padding': True, 'contour_fn': 'four_pt'},
+				  patch_level = 0,
+				  use_default_params = False, 
+				  seg = False, save_mask = True, 
+				  stitch= False, 
+				  patch = False, auto_skip=True, process_list = None):
+	
+
+
+	slides = sorted(os.listdir(source))
+	slides = [slide for slide in slides if os.path.isfile(os.path.join(source, slide))]
+	if process_list is None:
+		df = initialize_df(slides, seg_params, filter_params, vis_params, patch_params)
+	
+	else:
+		df = pd.read_csv(process_list)
+		df = initialize_df(df, seg_params, filter_params, vis_params, patch_params)
+
+	mask = df['process'] == 1
+	process_stack = df[mask]
+
+	total = len(process_stack)
+
+	legacy_support = 'a' in df.keys()
+	if legacy_support:
+		print('detected legacy segmentation csv file, legacy support enabled')
+		df = df.assign(**{'a_t': np.full((len(df)), int(filter_params['a_t']), dtype=np.uint32),
+		'a_h': np.full((len(df)), int(filter_params['a_h']), dtype=np.uint32),
+		'max_n_holes': np.full((len(df)), int(filter_params['max_n_holes']), dtype=np.uint32),
+		'line_thickness': np.full((len(df)), int(vis_params['line_thickness']), dtype=np.uint32),
+		'contour_fn': np.full((len(df)), patch_params['contour_fn'])})
+
+	seg_times = 0.
+	patch_times = 0.
+	stitch_times = 0.
+
+	for i in range(total):
+		df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
+		idx = process_stack.index[i]
+		slide = process_stack.loc[idx, 'slide_id']
+		print("\n\nprogress: {:.2f}, {}/{}".format(i/total, i, total))
+		print('processing {}'.format(slide))
+		
+		df.loc[idx, 'process'] = 0
+		slide_id, _ = os.path.splitext(slide)
+
+		if auto_skip and os.path.isfile(os.path.join(patch_save_dir, slide_id + '.h5')):
+			print('{} already exist in destination location, skipped'.format(slide_id))
+			df.loc[idx, 'status'] = 'already_exist'
+			continue
+
+		# Inialize WSI
+		full_path = os.path.join(source, slide)
+		WSI_object = WholeSlideImage(full_path)
+
+		if use_default_params:
+			current_vis_params = vis_params.copy()
+			current_filter_params = filter_params.copy()
+			current_seg_params = seg_params.copy()
+			current_patch_params = patch_params.copy()
+			
+		else:
+			current_vis_params = {}
+			current_filter_params = {}
+			current_seg_params = {}
+			current_patch_params = {}
+
+
+			for key in vis_params.keys():
+				if legacy_support and key == 'vis_level':
+					df.loc[idx, key] = -1
+				current_vis_params.update({key: df.loc[idx, key]})
+
+			for key in filter_params.keys():
+				if legacy_support and key == 'a_t':
+					old_area = df.loc[idx, 'a']
+					seg_level = df.loc[idx, 'seg_level']
+					scale = WSI_object.level_downsamples[seg_level]
+					adjusted_area = int(old_area * (scale[0] * scale[1]) / (512 * 512))
+					current_filter_params.update({key: adjusted_area})
+					df.loc[idx, key] = adjusted_area
+				current_filter_params.update({key: df.loc[idx, key]})
+
+			for key in seg_params.keys():
+				if legacy_support and key == 'seg_level':
+					df.loc[idx, key] = -1
+				current_seg_params.update({key: df.loc[idx, key]})
+
+			for key in patch_params.keys():
+				current_patch_params.update({key: df.loc[idx, key]})
+
+		if current_vis_params['vis_level'] < 0:
+			if len(WSI_object.level_dim) == 1:
+				current_vis_params['vis_level'] = 0
+			
+			else:	
+				wsi = WSI_object.getOpenSlide()
+				best_level = wsi.get_best_level_for_downsample(64)
+				current_vis_params['vis_level'] = best_level
+
+		if current_seg_params['seg_level'] < 0:
+			if len(WSI_object.level_dim) == 1:
+				current_seg_params['seg_level'] = 0
+			
+			else:
+				wsi = WSI_object.getOpenSlide()
+				best_level = wsi.get_best_level_for_downsample(64)
+				current_seg_params['seg_level'] = best_level
+
+		keep_ids = str(current_seg_params['keep_ids'])
+		if keep_ids != 'none' and len(keep_ids) > 0:
+			str_ids = current_seg_params['keep_ids']
+			current_seg_params['keep_ids'] = np.array(str_ids.split(',')).astype(int)
+		else:
+			current_seg_params['keep_ids'] = []
+
+		exclude_ids = str(current_seg_params['exclude_ids'])
+		if exclude_ids != 'none' and len(exclude_ids) > 0:
+			str_ids = current_seg_params['exclude_ids']
+			current_seg_params['exclude_ids'] = np.array(str_ids.split(',')).astype(int)
+		else:
+			current_seg_params['exclude_ids'] = []
+
+		w, h = WSI_object.level_dim[current_seg_params['seg_level']] 
+		if w * h > 1e8:
+			print('level_dim {} x {} is likely too large for successful segmentation, aborting'.format(w, h))
+			df.loc[idx, 'status'] = 'failed_seg'
+			continue
+
+		df.loc[idx, 'vis_level'] = current_vis_params['vis_level']
+		df.loc[idx, 'seg_level'] = current_seg_params['seg_level']
+
+
+		seg_time_elapsed = -1
+		if seg:
+			WSI_object, seg_time_elapsed = segment(WSI_object, current_seg_params, current_filter_params) 
+
+		if save_mask:
+			mask = WSI_object.visWSI(**current_vis_params)
+			mask_path = os.path.join(mask_save_dir, slide_id+'.jpg')
+			mask.save(mask_path)
+
+		patch_time_elapsed = -1 # Default time
+		if patch:
+			current_patch_params.update({'patch_level': patch_level, 'patch_size': patch_size, 'step_size': step_size, 
+										 'save_path': patch_save_dir})
+			file_path, patch_time_elapsed = patching(WSI_object = WSI_object,  **current_patch_params,)
+		
+		stitch_time_elapsed = -1
+		if stitch:
+			file_path = os.path.join(patch_save_dir, slide_id+'.h5')
+			if os.path.isfile(file_path):
+				heatmap, stitch_time_elapsed = stitching(file_path, WSI_object, downscale=64)
+				stitch_path = os.path.join(stitch_save_dir, slide_id+'.jpg')
+				heatmap.save(stitch_path)
+
+		print("segmentation took {} seconds".format(seg_time_elapsed))
+		print("patching took {} seconds".format(patch_time_elapsed))
+		print("stitching took {} seconds".format(stitch_time_elapsed))
+		df.loc[idx, 'status'] = 'processed'
+
+		seg_times += seg_time_elapsed
+		patch_times += patch_time_elapsed
+		stitch_times += stitch_time_elapsed
+
+	seg_times /= total
+	patch_times /= total
+	stitch_times /= total
+
+	df.to_csv(os.path.join(save_dir, 'process_list_autogen.csv'), index=False)
+	print("average segmentation time in s per slide: {}".format(seg_times))
+	print("average patching time in s per slide: {}".format(patch_times))
+	print("average stiching time in s per slide: {}".format(stitch_times))
+		
+	return seg_times, patch_times
+
+parser = argparse.ArgumentParser(description='seg and patch')
+parser.add_argument('--source', type = str,
+					help='path to folder containing raw wsi image files')
+parser.add_argument('--step_size', type = int, default=256,
+					help='step_size')
+parser.add_argument('--patch_size', type = int, default=256,
+					help='patch_size')
+parser.add_argument('--patch', default=False, action='store_true')
+parser.add_argument('--seg', default=False, action='store_true')
+parser.add_argument('--stitch', default=False, action='store_true')
+parser.add_argument('--no_auto_skip', default=True, action='store_false')
+parser.add_argument('--save_dir', type = str,
+					help='directory to save processed data')
+parser.add_argument('--preset', default=None, type=str,
+					help='predefined profile of default segmentation and filter parameters (.csv)')
+parser.add_argument('--patch_level', type=int, default=0, 
+					help='downsample level at which to patch')
+parser.add_argument('--process_list',  type = str, default=None,
+					help='name of list of images to process with parameters (.csv)')
+
+if __name__ == '__main__':
+	args = parser.parse_args()
+
+	patch_save_dir = os.path.join(args.save_dir, 'patches')
+	mask_save_dir = os.path.join(args.save_dir, 'masks')
+	stitch_save_dir = os.path.join(args.save_dir, 'stitches')
+
+	if args.process_list:
+		process_list = os.path.join(args.save_dir, args.process_list)
+
+	else:
+		process_list = None
+
+	print('source: ', args.source)
+	print('patch_save_dir: ', patch_save_dir)
+	print('mask_save_dir: ', mask_save_dir)
+	print('stitch_save_dir: ', stitch_save_dir)
+	
+	directories = {'source': args.source, 
+				   'save_dir': args.save_dir,
+				   'patch_save_dir': patch_save_dir, 
+				   'mask_save_dir' : mask_save_dir, 
+				   'stitch_save_dir': stitch_save_dir} 
+
+	for key, val in directories.items():
+		print("{} : {}".format(key, val))
+		if key not in ['source']:
+			os.makedirs(val, exist_ok=True)
+
+	seg_params = {'seg_level': -1, 'sthresh': 8, 'mthresh': 7, 'close': 4, 'use_otsu': False,
+				  'keep_ids': 'none', 'exclude_ids': 'none'}
+	filter_params = {'a_t':100, 'a_h': 16, 'max_n_holes':8}
+	vis_params = {'vis_level': -1, 'line_thickness': 250}
+	patch_params = {'use_padding': True, 'contour_fn': 'four_pt'}
+
+	if args.preset:
+		preset_df = pd.read_csv(os.path.join('presets', args.preset))
+		for key in seg_params.keys():
+			seg_params[key] = preset_df.loc[0, key]
+
+		for key in filter_params.keys():
+			filter_params[key] = preset_df.loc[0, key]
+
+		for key in vis_params.keys():
+			vis_params[key] = preset_df.loc[0, key]
+
+		for key in patch_params.keys():
+			patch_params[key] = preset_df.loc[0, key]
+	
+	parameters = {'seg_params': seg_params,
+				  'filter_params': filter_params,
+	 			  'patch_params': patch_params,
+				  'vis_params': vis_params}
+
+	print(parameters)
+
+	seg_times, patch_times = seg_and_patch(**directories, **parameters,
+											patch_size = args.patch_size, step_size=args.step_size, 
+											seg = args.seg,  use_default_params=False, save_mask = True, 
+											stitch= args.stitch,
+											patch_level=args.patch_level, patch = args.patch,
+											process_list = process_list, auto_skip=args.no_auto_skip)