Diff of /standalone_aug.py [000000] .. [bb7f56]

Switch to side-by-side view

--- a
+++ b/standalone_aug.py
@@ -0,0 +1,98 @@
+
+import numpy as np
+import os
+import cProfile, pstats
+
+from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.dataloading import SingleThreadedAugmenter
+from batchgenerators.transforms.spatial_transforms import SpatialTransform
+from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
+
+
+def augmentation(patient_data):
+    my_transforms = []
+
+    mirror_transform = Mirror(axes=np.arange(3))
+    my_transforms.append(mirror_transform)
+
+    spatial_transform = SpatialTransform(patch_size=[256, 256, 256], patch_center_dist_from_border= (125.0, 125.0), 
+                                do_elastic_deform=False, alpha=(0.0, 1500.0), sigma=(30.0, 50.0), do_rotation=True, 
+                                angle_x= (0, 0.0),angle_y=(0, 0.0),angle_z=(0.0, 6.283185307179586), do_scale=True, 
+                                scale=(0.8, 1.1),random_crop=False,order_data=2,order_seg=2)
+
+    my_transforms.append(spatial_transform)
+
+
+    my_transforms.append(ConvertSegToBoundingBoxCoordinates(3, get_rois_from_seg_flag=False, class_specific_seg_flag=False))
+    all_transforms = Compose(my_transforms)
+
+    multithreaded_generator = SingleThreadedAugmenter(patient_data, all_transforms)
+    #multithreaded_generator = MultiThreadedAugmenter(patient_data, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
+    
+    return multithreaded_generator
+
+##Dummy Data Creation
+
+dumb_img = np.random.random_sample((3,256,256,256))-0.5
+dumb_img.astype('float16')
+data = list()
+for i in range(0,8):
+    data.append(dumb_img)
+
+dumb_seg = np.zeros(shape=(1,256,256,256))
+dumb_seg[0][120:135,120:135,120:135] = 1
+dumb_seg.astype('uint8')
+seg = list()
+for i in range(0,8):
+    seg.append(dumb_seg)
+
+class_target = list()
+for i in range(0,8):
+    class_target.append([1])
+batch_ids = [['1'],['2'],['3'],['4'],['5'],['6'],['7'],['8']]
+
+# pp_dir = "/home/aisinai/data/preprocessed_data/pp_groin_256_f16"
+# batch_ids = [['g11'],['g1'],['g4'],['g5'],['g14'],['g17'],['g19'],['g29']]
+# data = list()
+# seg = list()
+# pids = list()
+
+# img_batch = [os.path.join(pp_dir,"{}_img.npy".format(i)) for i in batch_ids]
+
+# seg_batch = [os.path.join(pp_dir,"{}_rois.npy".format(i)) for i in batch_ids]
+
+# for j in img_batch:
+#     img = np.load(j)
+#     data.append(img)
+
+# for k in seg_batch:
+#     roi = np.load(k)
+#     seg.append(roi)
+
+data = np.array(data)
+seg = np.array(seg)
+class_target = np.array(class_target)
+print (data.shape,seg.shape,class_target.shape,class_target)
+
+batches = list()
+
+batch_one = {'data':data,'seg':seg,'pid':batch_ids,'class_target':class_target} #Data, Seg, PID dictionary
+
+batches.append(batch_one)
+
+batches_i = iter(batches)
+
+### Run and Profile Standalone Script
+
+profiler = cProfile.Profile()
+profiler.enable()
+
+augmented_data = augmentation(batches_i)
+
+result = next(augmented_data)
+
+profiler.disable()
+stats = pstats.Stats(profiler).sort_stats('cumtime')
+stats.print_stats()
\ No newline at end of file