Diff of /semseg/data_loader.py [000000] .. [cc8b8f]

Switch to side-by-side view

--- a
+++ b/semseg/data_loader.py
@@ -0,0 +1,81 @@
+import torch
+import os
+import torchio
+from torchio import Image, ImagesDataset, SubjectsDataset
+
+
+class SemSegConfig():
+    train_images = None
+    train_labels = None
+    val_images   = None
+    val_labels   = None
+    do_normalize = True
+    augmentation = None
+    zero_pad     = True
+    pad_ref      = (64,64,64)
+    batch_size   = 4
+    num_workers  = 0
+
+
+def TorchIODataLoader3DTraining(config: SemSegConfig) -> torch.utils.data.DataLoader:
+    print('Building TorchIO Training Set Loader...')
+    subject_list = list()
+    for idx, (image_path, label_path) in enumerate(zip(config.train_images, config.train_labels)):
+        s1 = torchio.Subject(
+            t1=Image(type=torchio.INTENSITY, path=image_path),
+            label=Image(type=torchio.LABEL, path=label_path),
+        )
+
+        subject_list.append(s1)
+
+    # Deprecated
+    # subjects_dataset = ImagesDataset(subject_list, transform=config.transform_train)
+    subjects_dataset = SubjectsDataset(subject_list, transform=config.transform_train)
+    train_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
+                                             shuffle=True, num_workers=config.num_workers)
+    print('TorchIO Training Loader built!')
+    return train_data
+
+
+def TorchIODataLoader3DValidation(config: SemSegConfig) -> torch.utils.data.DataLoader:
+    print('Building TorchIO Validation Set Loader...')
+    subject_list = list()
+    for idx, (image_path, label_path) in enumerate(zip(config.val_images, config.val_labels)):
+        s1 = torchio.Subject(
+            t1=Image(type=torchio.INTENSITY, path=image_path),
+            label=Image(type=torchio.LABEL, path=label_path),
+        )
+
+        subject_list.append(s1)
+
+    # Deprecated
+    # subjects_dataset = ImagesDataset(subject_list, transform=config.transform_val)
+    subjects_dataset = SubjectsDataset(subject_list, transform=config.transform_val)
+    val_data = torch.utils.data.DataLoader(subjects_dataset, batch_size=config.batch_size,
+                                           shuffle=False, num_workers=config.num_workers)
+    print('TorchIO Validation Loader built!')
+    return val_data
+
+
+def get_pad_3d_image(pad_ref: tuple = (64, 64, 64), zero_pad: bool = True):
+    def pad_3d_image(image):
+        if zero_pad:
+            value_to_pad = 0
+        else:
+            value_to_pad = image.min()
+        pad_ref_channels = (image.shape[0], *pad_ref)
+        # print("image.shape = {}".format(image.shape))
+        if value_to_pad == 0:
+            image_padded = torch.zeros(pad_ref_channels)
+        else:
+            image_padded = value_to_pad * torch.ones(pad_ref_channels)
+        image_padded[:,:image.shape[1],:image.shape[2],:image.shape[3]] = image
+        # print("image_padded.shape = {}".format(image_padded.shape))
+        return image_padded
+    return pad_3d_image
+
+
+def z_score_normalization(inputs):
+    input_mean = torch.mean(inputs)
+    input_std = torch.std(inputs)
+    return (inputs - input_mean)/input_std