Diff of /tests/loader_testing.py [000000] .. [96354c]

Switch to side-by-side view

--- a
+++ b/tests/loader_testing.py
@@ -0,0 +1,60 @@
+import importlib
+
+import torch
+import torchvision
+from torch.utils.data import DataLoader
+from torchvision import transforms as T
+from src.dataset.loaders.brats_dataset_whole_volume import BratsDataset
+from src.dataset.utils import dataset
+from src.dataset.utils.visualization import plot_3_view, plot_batch_slice
+from matplotlib import pyplot as plt
+import numpy as np
+
+
+def matplotlib_imshow(images, normalized=False):
+
+    img = torchvision.utils.make_grid(images, padding=10)
+    npimg = img.numpy()
+    # c h w
+    trans_im = np.transpose(npimg, (1, 2, 0))
+    plt.imshow(trans_im, cmap="gray")
+    plt.savefig("batch")
+
+csv = "/Users/lauramora/Documents/MASTER/TFM/Data/2020/train/no_patch/brats20_data.csv"
+data, data_test = dataset.read_brats(csv)
+
+modalities_to_use = {BratsDataset.flair_idx: True, BratsDataset.t1_idx: True, BratsDataset.t2_idx: True,
+                     BratsDataset.t1ce_idx: True}
+sampling_method = importlib.import_module("src.dataset.patching.random_tumor_distribution")
+transforms = T.Compose([T.ToTensor()])
+
+data_train = data * 100
+batch_size = 4
+train_dataset = BratsDataset(data_train, modalities_to_use, sampling_method, (128, 128, 128), transforms)
+train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
+
+
+i = 0
+s = 20
+seg_maps_2d, volumes_2d = [], []
+for patients_ids, data_batch, labels_batch in train_loader:
+    for seg_map, volume in zip(labels_batch, data_batch):
+        slice = seg_map[:, s, :].unsqueeze(0)
+        seg_maps_2d.append(slice)
+
+       #  volume_slice = volume_mod[0, :, s, :].unsqueeze(0)
+        volume_slice = torch.cat((volume[0, :, s, :].unsqueeze(0),
+                                  volume[1, :, s, :].unsqueeze(0),
+                                  volume[2, :, s, :].unsqueeze(0)), 0)
+
+        volumes_2d.append(volume_slice)
+
+
+    seg_maps_2d_tensor = torch.stack(seg_maps_2d)
+    volumes_2d_tensor = torch.stack(volumes_2d)
+    matplotlib_imshow(seg_maps_2d)
+    print(f"Batch {i}")
+    i +=1
+    break
+
+