--- a +++ b/tests/loader_testing.py @@ -0,0 +1,60 @@ +import importlib + +import torch +import torchvision +from torch.utils.data import DataLoader +from torchvision import transforms as T +from src.dataset.loaders.brats_dataset_whole_volume import BratsDataset +from src.dataset.utils import dataset +from src.dataset.utils.visualization import plot_3_view, plot_batch_slice +from matplotlib import pyplot as plt +import numpy as np + + +def matplotlib_imshow(images, normalized=False): + + img = torchvision.utils.make_grid(images, padding=10) + npimg = img.numpy() + # c h w + trans_im = np.transpose(npimg, (1, 2, 0)) + plt.imshow(trans_im, cmap="gray") + plt.savefig("batch") + +csv = "/Users/lauramora/Documents/MASTER/TFM/Data/2020/train/no_patch/brats20_data.csv" +data, data_test = dataset.read_brats(csv) + +modalities_to_use = {BratsDataset.flair_idx: True, BratsDataset.t1_idx: True, BratsDataset.t2_idx: True, + BratsDataset.t1ce_idx: True} +sampling_method = importlib.import_module("src.dataset.patching.random_tumor_distribution") +transforms = T.Compose([T.ToTensor()]) + +data_train = data * 100 +batch_size = 4 +train_dataset = BratsDataset(data_train, modalities_to_use, sampling_method, (128, 128, 128), transforms) +train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=1) + + +i = 0 +s = 20 +seg_maps_2d, volumes_2d = [], [] +for patients_ids, data_batch, labels_batch in train_loader: + for seg_map, volume in zip(labels_batch, data_batch): + slice = seg_map[:, s, :].unsqueeze(0) + seg_maps_2d.append(slice) + + # volume_slice = volume_mod[0, :, s, :].unsqueeze(0) + volume_slice = torch.cat((volume[0, :, s, :].unsqueeze(0), + volume[1, :, s, :].unsqueeze(0), + volume[2, :, s, :].unsqueeze(0)), 0) + + volumes_2d.append(volume_slice) + + + seg_maps_2d_tensor = torch.stack(seg_maps_2d) + volumes_2d_tensor = torch.stack(volumes_2d) + matplotlib_imshow(seg_maps_2d) + print(f"Batch {i}") + i +=1 + break + +