Diff of /standalone_aug.py [000000] .. [bb7f56]

Switch to unified view

a b/standalone_aug.py
1
2
import numpy as np
3
import os
4
import cProfile, pstats
5
6
from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
7
from batchgenerators.transforms.abstract_transforms import Compose
8
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
9
from batchgenerators.dataloading import SingleThreadedAugmenter
10
from batchgenerators.transforms.spatial_transforms import SpatialTransform
11
from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
12
13
14
def augmentation(patient_data):
15
    my_transforms = []
16
17
    mirror_transform = Mirror(axes=np.arange(3))
18
    my_transforms.append(mirror_transform)
19
20
    spatial_transform = SpatialTransform(patch_size=[256, 256, 256], patch_center_dist_from_border= (125.0, 125.0), 
21
                                do_elastic_deform=False, alpha=(0.0, 1500.0), sigma=(30.0, 50.0), do_rotation=True, 
22
                                angle_x= (0, 0.0),angle_y=(0, 0.0),angle_z=(0.0, 6.283185307179586), do_scale=True, 
23
                                scale=(0.8, 1.1),random_crop=False,order_data=2,order_seg=2)
24
25
    my_transforms.append(spatial_transform)
26
27
28
    my_transforms.append(ConvertSegToBoundingBoxCoordinates(3, get_rois_from_seg_flag=False, class_specific_seg_flag=False))
29
    all_transforms = Compose(my_transforms)
30
31
    multithreaded_generator = SingleThreadedAugmenter(patient_data, all_transforms)
32
    #multithreaded_generator = MultiThreadedAugmenter(patient_data, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
33
    
34
    return multithreaded_generator
35
36
##Dummy Data Creation
37
38
dumb_img = np.random.random_sample((3,256,256,256))-0.5
39
dumb_img.astype('float16')
40
data = list()
41
for i in range(0,8):
42
    data.append(dumb_img)
43
44
dumb_seg = np.zeros(shape=(1,256,256,256))
45
dumb_seg[0][120:135,120:135,120:135] = 1
46
dumb_seg.astype('uint8')
47
seg = list()
48
for i in range(0,8):
49
    seg.append(dumb_seg)
50
51
class_target = list()
52
for i in range(0,8):
53
    class_target.append([1])
54
batch_ids = [['1'],['2'],['3'],['4'],['5'],['6'],['7'],['8']]
55
56
# pp_dir = "/home/aisinai/data/preprocessed_data/pp_groin_256_f16"
57
# batch_ids = [['g11'],['g1'],['g4'],['g5'],['g14'],['g17'],['g19'],['g29']]
58
# data = list()
59
# seg = list()
60
# pids = list()
61
62
# img_batch = [os.path.join(pp_dir,"{}_img.npy".format(i)) for i in batch_ids]
63
64
# seg_batch = [os.path.join(pp_dir,"{}_rois.npy".format(i)) for i in batch_ids]
65
66
# for j in img_batch:
67
#     img = np.load(j)
68
#     data.append(img)
69
70
# for k in seg_batch:
71
#     roi = np.load(k)
72
#     seg.append(roi)
73
74
data = np.array(data)
75
seg = np.array(seg)
76
class_target = np.array(class_target)
77
print (data.shape,seg.shape,class_target.shape,class_target)
78
79
batches = list()
80
81
batch_one = {'data':data,'seg':seg,'pid':batch_ids,'class_target':class_target} #Data, Seg, PID dictionary
82
83
batches.append(batch_one)
84
85
batches_i = iter(batches)
86
87
### Run and Profile Standalone Script
88
89
profiler = cProfile.Profile()
90
profiler.enable()
91
92
augmented_data = augmentation(batches_i)
93
94
result = next(augmented_data)
95
96
profiler.disable()
97
stats = pstats.Stats(profiler).sort_stats('cumtime')
98
stats.print_stats()