--- a
+++ b/dataloaders/la_heart.py
@@ -0,0 +1,187 @@
+import h5py
+import torch
+import numpy as np
+import itertools
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import Sampler
+
+
+class LAHeart(Dataset):
+    """ LA Dataset """
+
+    def __init__(self, base_dir=None, split='train', num=None, transform=None):
+        self._base_dir = base_dir
+        self.transform = transform
+        self.sample_list = []
+        if split == 'train':
+            with open(self._base_dir + '/../train.list', 'r') as f:
+                self.image_list = f.readlines()
+        elif split == 'test':
+            with open(self._base_dir + '/../test.list', 'r') as f:
+                self.image_list = f.readlines()
+        self.image_list = [item.strip() for item in self.image_list]
+        if num is not None:
+            self.image_list = self.image_list[:num]
+        print("total {} samples".format(len(self.image_list)))
+
+    def __len__(self):
+        return len(self.image_list)
+
+    def __getitem__(self, idx):
+        image_name = self.image_list[idx]
+        # print(image_name)
+        h5f = h5py.File(self._base_dir + "/" + image_name + "/mri_norm2.h5", 'r')
+        image = h5f['image'][:]
+        label = h5f['label'][:]
+        sample = {'image': image, 'label': label}
+        if self.transform:
+            sample = self.transform(sample)
+        return sample
+
+class RandomCrop(object):
+    """
+    Crop randomly the image in a sample
+    Args:
+    output_size (int): Desired output size
+    """
+
+    def __init__(self, output_size):
+        self.output_size = output_size
+
+    def __call__(self, sample):
+        image, label = sample['image'], sample['label']
+
+        # pad the sample if necessary
+        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
+                self.output_size[2]:
+            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
+            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
+            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
+            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
+            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
+
+        (w, h, d) = image.shape
+        w1 = np.random.randint(0, w - self.output_size[0])
+        h1 = np.random.randint(0, h - self.output_size[1])
+        d1 = np.random.randint(0, d - self.output_size[2])
+
+        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
+        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
+        return {'image': image, 'label': label}
+
+class CenterCrop(object):
+    def __init__(self, output_size):
+        self.output_size = output_size
+
+    def __call__(self, sample):
+        image, label = sample['image'], sample['label']
+
+        # pad the sample if necessary
+        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
+                self.output_size[2]:
+            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
+            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
+            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
+            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
+            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
+
+        (w, h, d) = image.shape
+
+        w1 = int(round((w - self.output_size[0]) / 2.))
+        h1 = int(round((h - self.output_size[1]) / 2.))
+        d1 = int(round((d - self.output_size[2]) / 2.))
+
+        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
+        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
+
+        return {'image': image, 'label': label}
+
+
+class RandomRotFlip(object):
+    """
+    Crop randomly flip the dataset in a sample
+    Args:
+    output_size (int): Desired output size
+    """
+
+    def __call__(self, sample):
+        image, label = sample['image'], sample['label']
+        k = np.random.randint(0, 4)
+        image = np.rot90(image, k)
+        label = np.rot90(label, k)
+        axis = np.random.randint(0, 2)
+        image = np.flip(image, axis=axis).copy()
+        label = np.flip(label, axis=axis).copy()
+
+        return {'image': image, 'label': label}
+
+
+class ToTensor(object):
+    """Convert ndarrays in sample to Tensors."""
+
+    def __call__(self, sample):
+        image = sample['image']
+        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
+        return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()}
+
+
+class TwoStreamBatchSampler(Sampler):
+    """Iterate two sets of indices
+
+    An 'epoch' is one iteration through the primary indices.
+    During the epoch, the secondary indices are iterated through
+    as many times as needed.
+    """
+    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
+        self.primary_indices = primary_indices
+        self.secondary_indices = secondary_indices
+        self.secondary_batch_size = secondary_batch_size
+        self.primary_batch_size = batch_size - secondary_batch_size
+
+        assert len(self.primary_indices) >= self.primary_batch_size > 0
+        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
+
+    def __iter__(self):
+        primary_iter = iterate_once(self.primary_indices)
+        secondary_iter = iterate_eternally(self.secondary_indices)
+        return (
+            primary_batch + secondary_batch
+            for (primary_batch, secondary_batch)
+            in zip(grouper(primary_iter, self.primary_batch_size),
+                    grouper(secondary_iter, self.secondary_batch_size))
+        )
+
+    def __len__(self):
+        return len(self.primary_indices) // self.primary_batch_size
+
+def iterate_once(iterable):
+    return np.random.permutation(iterable)
+
+
+def iterate_eternally(indices):
+    def infinite_shuffles():
+        while True:
+            yield np.random.permutation(indices)
+    return itertools.chain.from_iterable(infinite_shuffles())
+
+
+def grouper(iterable, n):
+    "Collect data into fixed-length chunks or blocks"
+    # grouper('ABCDEFG', 3) --> ABC DEF"
+    args = [iter(iterable)] * n
+    return zip(*args)
+
+
+if __name__ == '__main__':
+    train_set = LAHeart('E:/data/LASet/data')
+    print(len(train_set))
+    # data = train_set[0]
+    # image, label = data['image'], data['label']
+    # print(image.shape, label.shape)
+    labeled_idxs = list(range(25))
+    unlabeled_idxs = list(range(25,123))
+    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2)
+    i = 0
+    for x in batch_sampler:
+        i += 1
+        print('%02d'%i,'\t',x)
\ No newline at end of file