--- a
+++ b/mylib/dataloader/dataset.py
@@ -0,0 +1,156 @@
+from collections.abc import Sequence
+import random
+import os
+
+import numpy as np
+from mylib.dataloader.path_manager import PATH
+from mylib.utils.misc import rotation, reflection, crop, random_center, _triple
+
+INFO = PATH.info
+
+LABEL = ['AAH', 'AIS', 'MIA', 'IAC']
+
+
+class ClfDataset(Sequence):
+    def __init__(self, crop_size=32, move=3, subset=[0, 1, 2, 3],
+                 define_label=lambda l: [l[0] + l[1], l[2], l[3]]):
+        '''The classification-only dataset.
+
+        :param crop_size: the input size
+        :param move: the random move
+        :param subset: choose which subset to use
+        :param define_label: how to define the label. default: for 3-output classification one hot encoding.
+        '''
+        index = []
+        for sset in subset:
+            index += list(INFO[INFO['subset'] == sset].index)
+        self.index = tuple(sorted(index))  # the index in the info
+        self.label = np.array([[label == s for label in LABEL] for s in INFO.loc[self.index, 'diagnosis']])
+        self.transform = Transform(crop_size, move)
+        self.define_label = define_label
+
+    def __getitem__(self, item):
+        name = INFO.loc[self.index[item], 'name']
+        with np.load(os.path.join(PATH.nodule_path, '%s.npz' % name)) as npz:
+            voxel = self.transform(npz['voxel'])
+        label = self.label[item]
+        return voxel, self.define_label(label)
+
+    def __len__(self):
+        return len(self.index)
+
+    @staticmethod
+    def _collate_fn(data):
+        xs = []
+        ys = []
+        for x, y in data:
+            xs.append(x)
+            ys.append(y)
+        return np.array(xs), np.array(ys)
+
+
+class ClfSegDataset(ClfDataset):
+    '''Classification and segmentation dataset.'''
+
+    def __getitem__(self, item):
+        name = INFO.loc[self.index[item], 'name']
+        with np.load(os.path.join(PATH.nodule_path, '%s.npz' % name)) as npz:
+            voxel, seg = self.transform(npz['voxel'], npz['seg'])
+            # voxel = self.transform(npz['voxel'] * (npz['seg'] * 0.8 + 0.2))
+        label = self.label[item]
+        return voxel, (self.define_label(label), seg)
+
+    @staticmethod
+    def _collate_fn(data):
+        xs = []
+        ys = []
+        segs = []
+        for x, y in data:
+            xs.append(x)
+            ys.append(y[0])
+            segs.append(y[1])
+        return np.array(xs), {"clf": np.array(ys), "seg": np.array(segs)}
+
+
+def get_loader(dataset, batch_size):
+    total_size = len(dataset)
+    print('Size', total_size)
+    index_generator = shuffle_iterator(range(total_size))
+    while True:
+        data = []
+        for _ in range(batch_size):
+            idx = next(index_generator)
+            data.append(dataset[idx])
+        yield dataset._collate_fn(data)
+
+
+def get_balanced_loader(dataset, batch_sizes):
+    assert len(batch_sizes) == len(LABEL)
+    total_size = len(dataset)
+    print('Size', total_size)
+    index_generators = []
+    for l_idx in range(len(batch_sizes)):
+        # this must be list, or `l_idx` will not be eval
+        iterator = [i for i in range(total_size) if dataset.label[i, l_idx]]
+        index_generators.append(shuffle_iterator(iterator))
+    while True:
+        data = []
+        for i, batch_size in enumerate(batch_sizes):
+            generator = index_generators[i]
+            for _ in range(batch_size):
+                idx = next(generator)
+                data.append(dataset[idx])
+        yield dataset._collate_fn(data)
+
+
+class Transform:
+    '''The online data augmentation, including:
+    1) random move the center by `move`
+    2) rotation 90 degrees increments
+    3) reflection in any axis
+    '''
+
+    def __init__(self, size, move):
+        self.size = _triple(size)
+        self.move = move
+
+    def __call__(self, arr, aux=None):
+        shape = arr.shape
+        if self.move is not None:
+            center = random_center(shape, self.move)
+            arr_ret = crop(arr, center, self.size)
+            angle = np.random.randint(4, size=3)
+            arr_ret = rotation(arr_ret, angle=angle)
+            axis = np.random.randint(4) - 1
+            arr_ret = reflection(arr_ret, axis=axis)
+            arr_ret = np.expand_dims(arr_ret, axis=-1)
+            if aux is not None:
+                aux_ret = crop(aux, center, self.size)
+                aux_ret = rotation(aux_ret, angle=angle)
+                aux_ret = reflection(aux_ret, axis=axis)
+                aux_ret = np.expand_dims(aux_ret, axis=-1)
+                return arr_ret, aux_ret
+            return arr_ret
+        else:
+            center = np.array(shape) // 2
+            arr_ret = crop(arr, center, self.size)
+            arr_ret = np.expand_dims(arr_ret, axis=-1)
+            if aux is not None:
+                aux_ret = crop(aux, center, self.size)
+                aux_ret = np.expand_dims(aux_ret, axis=-1)
+                return arr_ret, aux_ret
+            return arr_ret
+
+
+def shuffle_iterator(iterator):
+    # iterator should have limited size
+    index = list(iterator)
+    total_size = len(index)
+    i = 0
+    random.shuffle(index)
+    while True:
+        yield index[i]
+        i += 1
+        if i >= total_size:
+            i = 0
+            random.shuffle(index)