a b/semseg/data_loader.py
1
import torch
2
import os
3
import torchio
4
from torchio import Image, ImagesDataset, SubjectsDataset
5
6
7
class SemSegConfig():
8
    train_images = None
9
    train_labels = None
10
    val_images   = None
11
    val_labels   = None
12
    do_normalize = True
13
    augmentation = None
14
    zero_pad     = True
15
    pad_ref      = (64,64,64)
16
    batch_size   = 4
17
    num_workers  = 0
18
19
20
def TorchIODataLoader3DTraining(config: SemSegConfig) -> torch.utils.data.DataLoader:
21
    print('Building TorchIO Training Set Loader...')
22
    subject_list = list()
23
    for idx, (image_path, label_path) in enumerate(zip(config.train_images, config.train_labels)):
24
        s1 = torchio.Subject(
25
            t1=Image(type=torchio.INTENSITY, path=image_path),
26
            label=Image(type=torchio.LABEL, path=label_path),
27
        )
28
29
        subject_list.append(s1)
30
31
    # Deprecated
32
    # subjects_dataset = ImagesDataset(subject_list, transform=config.transform_train)
33
    subjects_dataset = SubjectsDataset(subject_list, transform=config.transform_train)
34
    train_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
35
                                             shuffle=True, num_workers=config.num_workers)
36
    print('TorchIO Training Loader built!')
37
    return train_data
38
39
40
def TorchIODataLoader3DValidation(config: SemSegConfig) -> torch.utils.data.DataLoader:
41
    print('Building TorchIO Validation Set Loader...')
42
    subject_list = list()
43
    for idx, (image_path, label_path) in enumerate(zip(config.val_images, config.val_labels)):
44
        s1 = torchio.Subject(
45
            t1=Image(type=torchio.INTENSITY, path=image_path),
46
            label=Image(type=torchio.LABEL, path=label_path),
47
        )
48
49
        subject_list.append(s1)
50
51
    # Deprecated
52
    # subjects_dataset = ImagesDataset(subject_list, transform=config.transform_val)
53
    subjects_dataset = SubjectsDataset(subject_list, transform=config.transform_val)
54
    val_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
55
                                           shuffle=False, num_workers=config.num_workers)
56
    print('TorchIO Validation Loader built!')
57
    return val_data
58
59
60
def get_pad_3d_image(pad_ref: tuple = (64, 64, 64), zero_pad: bool = True):
61
    def pad_3d_image(image):
62
        if zero_pad:
63
            value_to_pad = 0
64
        else:
65
            value_to_pad = image.min()
66
        pad_ref_channels = (image.shape[0], *pad_ref)
67
        # print("image.shape = {}".format(image.shape))
68
        if value_to_pad == 0:
69
            image_padded = torch.zeros(pad_ref_channels)
70
        else:
71
            image_padded = value_to_pad * torch.ones(pad_ref_channels)
72
        image_padded[:,:image.shape[1],:image.shape[2],:image.shape[3]] = image
73
        # print("image_padded.shape = {}".format(image_padded.shape))
74
        return image_padded
75
    return pad_3d_image
76
77
78
def z_score_normalization(inputs):
79
    input_mean = torch.mean(inputs)
80
    input_std = torch.std(inputs)
81
    return (inputs - input_mean)/input_std