--- a +++ b/dataloaders/la_heart.py @@ -0,0 +1,187 @@ +import h5py +import torch +import numpy as np +import itertools +from torch.utils.data import Dataset +from torch.utils.data.sampler import Sampler + + +class LAHeart(Dataset): + """ LA Dataset """ + + def __init__(self, base_dir=None, split='train', num=None, transform=None): + self._base_dir = base_dir + self.transform = transform + self.sample_list = [] + if split == 'train': + with open(self._base_dir + '/../train.list', 'r') as f: + self.image_list = f.readlines() + elif split == 'test': + with open(self._base_dir + '/../test.list', 'r') as f: + self.image_list = f.readlines() + self.image_list = [item.strip() for item in self.image_list] + if num is not None: + self.image_list = self.image_list[:num] + print("total {} samples".format(len(self.image_list))) + + def __len__(self): + return len(self.image_list) + + def __getitem__(self, idx): + image_name = self.image_list[idx] + # print(image_name) + h5f = h5py.File(self._base_dir + "/" + image_name + "/mri_norm2.h5", 'r') + image = h5f['image'][:] + label = h5f['label'][:] + sample = {'image': image, 'label': label} + if self.transform: + sample = self.transform(sample) + return sample + +class RandomCrop(object): + """ + Crop randomly the image in a sample + Args: + output_size (int): Desired output size + """ + + def __init__(self, output_size): + self.output_size = output_size + + def __call__(self, sample): + image, label = sample['image'], sample['label'] + + # pad the sample if necessary + if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ + self.output_size[2]: + pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) + ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) + pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) + image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) + label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) + + (w, h, d) = image.shape + w1 = np.random.randint(0, w - self.output_size[0]) + h1 = np.random.randint(0, h - self.output_size[1]) + d1 = np.random.randint(0, d - self.output_size[2]) + + label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] + image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] + return {'image': image, 'label': label} + +class CenterCrop(object): + def __init__(self, output_size): + self.output_size = output_size + + def __call__(self, sample): + image, label = sample['image'], sample['label'] + + # pad the sample if necessary + if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ + self.output_size[2]: + pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) + ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) + pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) + image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) + label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) + + (w, h, d) = image.shape + + w1 = int(round((w - self.output_size[0]) / 2.)) + h1 = int(round((h - self.output_size[1]) / 2.)) + d1 = int(round((d - self.output_size[2]) / 2.)) + + label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] + image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] + + return {'image': image, 'label': label} + + +class RandomRotFlip(object): + """ + Crop randomly flip the dataset in a sample + Args: + output_size (int): Desired output size + """ + + def __call__(self, sample): + image, label = sample['image'], sample['label'] + k = np.random.randint(0, 4) + image = np.rot90(image, k) + label = np.rot90(label, k) + axis = np.random.randint(0, 2) + image = np.flip(image, axis=axis).copy() + label = np.flip(label, axis=axis).copy() + + return {'image': image, 'label': label} + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + image = sample['image'] + image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) + return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} + + +class TwoStreamBatchSampler(Sampler): + """Iterate two sets of indices + + An 'epoch' is one iteration through the primary indices. + During the epoch, the secondary indices are iterated through + as many times as needed. + """ + def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): + self.primary_indices = primary_indices + self.secondary_indices = secondary_indices + self.secondary_batch_size = secondary_batch_size + self.primary_batch_size = batch_size - secondary_batch_size + + assert len(self.primary_indices) >= self.primary_batch_size > 0 + assert len(self.secondary_indices) >= self.secondary_batch_size > 0 + + def __iter__(self): + primary_iter = iterate_once(self.primary_indices) + secondary_iter = iterate_eternally(self.secondary_indices) + return ( + primary_batch + secondary_batch + for (primary_batch, secondary_batch) + in zip(grouper(primary_iter, self.primary_batch_size), + grouper(secondary_iter, self.secondary_batch_size)) + ) + + def __len__(self): + return len(self.primary_indices) // self.primary_batch_size + +def iterate_once(iterable): + return np.random.permutation(iterable) + + +def iterate_eternally(indices): + def infinite_shuffles(): + while True: + yield np.random.permutation(indices) + return itertools.chain.from_iterable(infinite_shuffles()) + + +def grouper(iterable, n): + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3) --> ABC DEF" + args = [iter(iterable)] * n + return zip(*args) + + +if __name__ == '__main__': + train_set = LAHeart('E:/data/LASet/data') + print(len(train_set)) + # data = train_set[0] + # image, label = data['image'], data['label'] + # print(image.shape, label.shape) + labeled_idxs = list(range(25)) + unlabeled_idxs = list(range(25,123)) + batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2) + i = 0 + for x in batch_sampler: + i += 1 + print('%02d'%i,'\t',x) \ No newline at end of file