[96354c]: / src / dataset / loaders / batch_sampler.py

Download this file

105 lines (67 with data), 3.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
import random
from torch.utils.data import Sampler, DataLoader
from src.dataset.loaders.brats_dataset_whole_volume import BratsDataset
class BratsSampler(Sampler):
def __init__(self, dataset, n_patients, n_samples):
self.batch_size = n_patients*n_samples
self.n_samples = n_samples
self.n_patients = n_patients
self.dataset_indices = list(range(0, len(dataset)))
def __iter__(self):
batch_indices = []
random.shuffle(self.dataset_indices)
for patient_id in self.dataset_indices:
batch_indices.extend(patient_id for _ in range(self.n_samples))
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if len(batch_indices) > 0:
yield batch_indices
def __len__(self):
return (len(self.dataset_indices) + self.batch_size - 1) // self.batch_size
class BratsPatchSampler(Sampler):
def __init__(self, dataset, n_patients, n_samples):
self.batch_size = n_patients*n_samples
self.n_samples = n_samples
self.n_patients = n_patients
self.dataset_indices = list(range(0, len(dataset)))
self.dataset = dataset.data
self.patches_by_patient = self._generate_structure()
def _generate_structure(self):
patches_by_patient = {}
for index, patient_patch in enumerate(self.dataset):
patient = patient_patch.patient
if patient not in patches_by_patient.keys():
patches_by_patient[patient] = []
patches_by_patient[patient].append(index)
return patches_by_patient
def __iter__(self):
while len(self.patches_by_patient) > 0:
batch_indices = []
patients_for_batch = random.sample(self.patches_by_patient.items(), min(self.n_patients, len(self.patches_by_patient)))
for patient, patches in patients_for_batch:
selected_patches = random.sample(patches, min(self.n_samples, len(patches)))
batch_indices.extend(selected_patches)
for patch_id in selected_patches:
self.patches_by_patient[patient].remove(patch_id)
if not self.patches_by_patient[patient]:
del self.patches_by_patient[patient]
yield batch_indices
def __len__(self):
return (len(self.dataset_indices) + self.batch_size - 1) // self.batch_size
if __name__ == "__main__":
from src.dataset.utils import dataset
from torchvision import transforms as T
import importlib
data, _ = dataset.read_brats("/Users/lauramora/Documents/MASTER/TFM/Data/2020/train/random_tumor_distribution/brats20_data.csv")
data_train = data[:40]
data_val = data[:40]
modalities_to_use = {BratsDataset.flair_idx: True, BratsDataset.t1_idx: True, BratsDataset.t2_idx: True,
BratsDataset.t1ce_idx: True}
sampling_method = importlib.import_module("src.dataset.patching.random_tumor_distribution")
dataset = BratsDataset(data_train, modalities_to_use, sampling_method, (64,64,64), T.Compose([T.ToTensor()]))
sampler = BratsPatchSampler(dataset, n_patients=2, n_samples=3)
train_loader = DataLoader(dataset=dataset, batch_sampler=sampler, num_workers=1)
for i, (patient, volumes, segmentations) in enumerate(train_loader):
print(f"item {i} --> {volumes.shape}")
print(f"item {i} --> {segmentations.shape}")