import h5py
import torch
import numpy as np
import itertools
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
class LAHeart(Dataset):
""" LA Dataset """
def __init__(self, base_dir=None, split='train', num=None, transform=None):
self._base_dir = base_dir
self.transform = transform
self.sample_list = []
if split == 'train':
with open(self._base_dir + '/../train.list', 'r') as f:
self.image_list = f.readlines()
elif split == 'test':
with open(self._base_dir + '/../test.list', 'r') as f:
self.image_list = f.readlines()
self.image_list = [item.strip() for item in self.image_list]
if num is not None:
self.image_list = self.image_list[:num]
print("total {} samples".format(len(self.image_list)))
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_name = self.image_list[idx]
# print(image_name)
h5f = h5py.File(self._base_dir + "/" + image_name + "/mri_norm2.h5", 'r')
image = h5f['image'][:]
label = h5f['label'][:]
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
class RandomCrop(object):
"""
Crop randomly the image in a sample
Args:
output_size (int): Desired output size
"""
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, sample):
image, label = sample['image'], sample['label']
# pad the sample if necessary
if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
self.output_size[2]:
pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
(w, h, d) = image.shape
w1 = np.random.randint(0, w - self.output_size[0])
h1 = np.random.randint(0, h - self.output_size[1])
d1 = np.random.randint(0, d - self.output_size[2])
label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
return {'image': image, 'label': label}
class CenterCrop(object):
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, sample):
image, label = sample['image'], sample['label']
# pad the sample if necessary
if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
self.output_size[2]:
pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
(w, h, d) = image.shape
w1 = int(round((w - self.output_size[0]) / 2.))
h1 = int(round((h - self.output_size[1]) / 2.))
d1 = int(round((d - self.output_size[2]) / 2.))
label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
return {'image': image, 'label': label}
class RandomRotFlip(object):
"""
Crop randomly flip the dataset in a sample
Args:
output_size (int): Desired output size
"""
def __call__(self, sample):
image, label = sample['image'], sample['label']
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return {'image': image, 'label': label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image = sample['image']
image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()}
class TwoStreamBatchSampler(Sampler):
"""Iterate two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size
assert len(self.primary_indices) >= self.primary_batch_size > 0
assert len(self.secondary_indices) >= self.secondary_batch_size > 0
def __iter__(self):
primary_iter = iterate_once(self.primary_indices)
secondary_iter = iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(grouper(primary_iter, self.primary_batch_size),
grouper(secondary_iter, self.secondary_batch_size))
)
def __len__(self):
return len(self.primary_indices) // self.primary_batch_size
def iterate_once(iterable):
return np.random.permutation(iterable)
def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)
if __name__ == '__main__':
train_set = LAHeart('E:/data/LASet/data')
print(len(train_set))
# data = train_set[0]
# image, label = data['image'], data['label']
# print(image.shape, label.shape)
labeled_idxs = list(range(25))
unlabeled_idxs = list(range(25,123))
batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2)
i = 0
for x in batch_sampler:
i += 1
print('%02d'%i,'\t',x)