a b/extra/torchio_ex.py
1
import torch
2
import torchio
3
from torchio import ImagesDataset, Image, ZNormalization, Compose
4
from config.config import *
5
from config.paths import train_images_folder, train_labels_folder, train_images, train_labels
6
from semseg.data_loader_torchio import get_pad_3d_image
7
from augm.lambda_channel import LambdaChannel
8
9
transforms_dict = {
10
    ZNormalization(): 1.0,
11
    # LambdaChannel(z_score_normalization, types_to_apply=torchio.INTENSITY): 1.0,
12
    # RandomAffine(): 0.25,
13
    # RandomElasticDeformation(max_displacement=3): 0.25,
14
    LambdaChannel(get_pad_3d_image(pad_ref=(48, 64, 48), zero_pad=False)): 1.0,
15
}
16
17
transform = Compose(transforms_dict)
18
19
subject_list = list()
20
21
idx = 0
22
for idx,(train_image, train_label) in enumerate(zip(train_images, train_labels)):
23
    image_path = os.path.join(train_images_folder, train_image)
24
    label_path = os.path.join(train_labels_folder, train_label)
25
26
    s1 = torchio.Subject(
27
        t1    = Image(type=torchio.INTENSITY, path=image_path),
28
        label = Image(type=torchio.LABEL, path=label_path),
29
    )
30
31
    subject_list.append(s1)
32
33
subjects_dataset = ImagesDataset(subject_list, transform=transform)
34
subject_sample = subjects_dataset[0]
35
36
for idx in range(0,len(train_images[:10])):
37
    subject_sample = subjects_dataset[idx]
38
    print("Iter {} on {}".format(idx+1,len(train_images)))
39
    print("t1.shape       = {}".format(subject_sample.t1.shape))
40
    print("label.shape    = {}".format(subject_sample.label.shape))
41
    print("t1 [min - max] = [{:.1f} : {:.1f}]".format(subject_sample.t1.data.min(),subject_sample.t1.data.max()))
42
    print("label.unique   = {}".format(subject_sample.label.data.unique()))
43
44
45
config = SemSegMRIConfig()
46
train_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
47
                                         shuffle=False, num_workers=config.num_workers)
48
49
50
iterable_data_loader = iter(train_data)
51
el = next(iterable_data_loader)
52
inputs = el['t1']['data']
53
labels = el['label']['data']
54
print("Shape of Batch: [input {}] [label {}]".format(inputs.shape, labels.shape))
55
56
for i, el in enumerate(train_data):
57
    print("Iteration {} on {}".format(i+1,len(train_data)))
58
    inputs = el['t1']['data']
59
    labels = el['label']['data']
60
    print("Shape of Batch: [input {}] [label {}]".format(inputs.shape, labels.shape))
61
    print("Range Inputs  : [{:.2f} %% {:.2f}]".format(inputs.min(),inputs.max()))