--- a
+++ b/data_brain_parcellation.py
@@ -0,0 +1,349 @@
+__author__ = 'adeb'
+
+import os
+import glob
+import time
+import numpy as np
+import nibabel as nib
+import theano
+import matplotlib.cm as cm
+from matplotlib import pyplot as plt
+
+from spynet.utils.utilities import distrib_balls_in_bins
+from spynet.utils.multiprocess import parmap
+from spynet.data.dataset import Dataset
+from spynet.data.utils_3d.pick_patch import *
+from spynet.data.utils_3d.pick_voxel import *
+from spynet.data.utils_3d.pick_target import *
+
+from spynet.utils.utilities import tile_raster_images
+import PIL
+
+class DataGeneratorBrain():
+    """
+    Attributes:
+        pick_vx(function): Function to pick voxels
+        pick_patch(function): Function to pick patches
+        pick_tg(function): Function to pick patches
+
+        files: List of pairs (mri_file, label_file)
+        atlases: List of pairs (mri array, label array)
+        n_files (int): Number of files
+
+        n_out_features (int): Number of output classes in the datasets
+    """
+
+    # See Miccai rules
+    ignored_labels = range(1,4)+range(5,11)+range(12,23)+range(24,30)+[33,34]+[42,43]+[53,54]+range(63,69)+[70,74]+\
+                     range(80,100)+[110,111]+[126,127]+[130,131]+[158,159]+[188,189]
+
+    true_labels = [4, 11, 23, 30, 31, 32, 35, 36, 37, 38, 39, 40, 41, 44, 45, 46, 47, 48, 49, 50, 51, 52, 55, 56, 57,
+                   58, 59, 60, 61, 62, 69, 71, 72, 73, 75, 76, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 112,
+                   113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 128, 129, 132, 133, 134, 135, 136,
+                   137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,
+                   157, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
+                   179, 180, 181, 182, 183, 184, 185, 186, 187, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200,
+                   201, 202, 203, 204, 205, 206, 207]
+
+    def __init__(self):
+
+        self.pick_vx = None
+        self.pick_features = None
+        self.pick_tg = None
+
+        self.files = None
+        self.n_files = None
+        self.atlases = []
+        self.ls_region_centroids = []
+
+        self.n_out_features = None
+
+    def init_from_config(self, config):
+        self.pick_vx = create_pick_voxel(config)
+        self.pick_features = create_pick_features(config)
+        self.pick_tg = create_pick_target(config)
+        self.files = list_miccai_files(**config.general["source_kwargs"])
+        self.__init_common()
+
+    def init_from(self, files, pick_vx, pick_patch, pick_tg):
+        self.pick_vx = pick_vx
+        self.pick_features = pick_patch
+        self.pick_tg = pick_tg
+        self.files = files
+        self.__init_common()
+
+    def __init_common(self):
+        self.n_files = len(self.files)
+        self.ls_region_centroids = [None]*self.n_files
+
+        print "    preprocess the atlases ..."
+        for i, file_names in enumerate(self.files):
+            mri_file, lab_file = file_names
+            print "        " + mri_file
+            # nib.nifti1.FLOAT32_EPS_3 = -1e-6
+            mri = nib.load(mri_file).get_data().squeeze()
+            mri = mri.astype(np.float32, copy=False)
+            lab = nib.load(lab_file).get_data().squeeze()
+            lab = lab.astype(np.int16, copy=False)
+
+            mri, lab = crop_brain_and_pad(mri, lab, self.pick_features.required_pad)
+            self.scale_atlas(mri, lab)
+            # plt.imshow(mri[100,:,:], cmap = cm.Greys_r)
+            # plt.savefig("salut1.png")
+            affine = nib.load(mri_file).get_affine()
+            self.atlases.append((mri, lab, affine))
+
+            # Compute the centroids
+            if self.pick_features.has_instance_of(PickCentroidDistances):
+                region_centroids = RegionCentroids(134)
+                temp = lab.nonzero()
+                vxs = np.asarray(temp).T
+                region_centroids.update_barycentres(vxs, lab[temp])
+                self.ls_region_centroids[i] = region_centroids
+
+        # Number of classes
+        self.n_out_features = 135
+
+    def scale_atlas(self, mri, label):
+        only_brain = mri[label.nonzero()]
+        scalar_mean = np.mean(only_brain)
+        scalar_std = np.std(only_brain)
+        mri -= scalar_mean
+        mri /= scalar_std
+
+    def generate_single_atlas(self, atlas_id, n_points, region_centroid, batch_size, verbose=False):
+
+        print("    file {} \n".format(self.files[atlas_id]))
+
+        mri, lab, _ = self.atlases[atlas_id]
+        vx_batches_generator = self.pick_vx.pick(n_points, lab, verbose=verbose, batch_size=batch_size)
+        for vx_batch in vx_batches_generator:
+            patch = self.pick_features.pick(vx_batch, mri, lab, region_centroid)[0]
+
+            # patch_lab = self.pick_features.pick(vx_batch, lab, lab, self.ls_region_centroids[atlas_id])[0]
+            # image1_1 = PIL.Image.fromarray(tile_raster_images(X=patch[0:10],
+            #                                                img_shape=(29, 29), tile_shape=(1, 10),
+            #                                                tile_spacing=(1, 1)))
+            # image1_1.save("patches_2D_mri_1.png")
+            # patch_lab = self.pick_features.pick(vx_batch, lab, lab, self.ls_region_centroids[atlas_id])[0]
+            # image1_1 = PIL.Image.fromarray(tile_raster_images(X=patch[10:20],
+            #                                                img_shape=(29, 29), tile_shape=(1, 10),
+            #                                                tile_spacing=(1, 1)))
+            # image1_1.save("patches_2D_mri_2.png")
+            # temp_arr = tile_raster_images(X=patch_lab[0:10],
+            #                               img_shape=(29, 29), tile_shape=(1, 10),
+            #                               tile_spacing=(1, 1))
+            # image2_1 = PIL.Image.fromarray(np.uint8(cm.spectral(temp_arr)*255))
+            # image2_1.save("patches_2D_seg_1.png")
+            # temp_arr = tile_raster_images(X=patch_lab[10:20],
+            #                               img_shape=(29, 29), tile_shape=(1, 10),
+            #                               tile_spacing=(1, 1))
+            # image2_1 = PIL.Image.fromarray(np.uint8(cm.spectral(temp_arr)*255))
+            # image2_1.save("patches_2D_seg_2.png")
+
+            tg = self.pick_tg.pick(vx_batch, self.n_out_features, mri, lab)
+            yield vx_batch, patch, tg
+
+    def generate_parallel(self, batch_size):
+        print "Generate data ..."
+
+        ### Initialization of the containers
+
+        # Compute the number of voxels to extract from each atlas
+        voxels_per_atlas = distrib_balls_in_bins(batch_size, self.n_files)
+
+        ### Fill in the containers
+
+        # Function that will be run in parallel
+        def generate_from_one_brain(atlas_id):
+            n_points = voxels_per_atlas[atlas_id]
+            # Large batch_size so it can not be reached. We want to store everything, so we don't split
+            vx, patch, tg = \
+                next(self.generate_single_atlas(atlas_id, n_points,
+                                                self.ls_region_centroids[atlas_id], batch_size=1000000))
+            return vx, patch, tg, atlas_id
+
+        # Generate the patches in parallel
+        if self.n_files == 1:  # This special case is necessary to avoid a bug on the server
+            res_all = map(generate_from_one_brain, range(self.n_files))
+        else:
+            res_all = parmap(generate_from_one_brain, range(self.n_files))
+
+
+        # Initialize the containers
+        vx = np.zeros((batch_size, 3), dtype=int)
+        patch = np.zeros((batch_size, self.pick_features.n_features), dtype=theano.config.floatX)
+        tg = np.zeros((batch_size, self.n_out_features), dtype=theano.config.floatX)
+        file_id = np.zeros((batch_size, 1), dtype=int)
+
+        # Aggregate the data
+        idx1 = 0
+        for res in res_all:
+            idx2 = idx1 + res[0].shape[0]
+            vx[idx1:idx2], patch[idx1:idx2], tg[idx1:idx2], file_id[idx1:idx2] = res
+            idx1 = idx2
+
+        return vx, patch, tg, file_id
+
+
+def list_miccai_files(**kwargs):
+    """
+    List the the pairs (mri_file_name, label_file_name) of the miccai data.
+    """
+    mode = kwargs["mode"]
+    path = kwargs["path"]
+    label_path = path + "label/"
+    mri_files = glob.glob(path + "mri/*.nii")
+
+    if mode == "folder":
+        idx_files = xrange(len(mri_files))
+    elif mode == "idx_files":
+        idx_files = kwargs["idx_files"]
+    else:
+        raise Exception("Error to list the MICCAI files, the mode does not exist.")
+
+    return [(mri_files[i], label_path + os.path.splitext(os.path.basename(mri_files[i]))[0] + "_glm.nii")
+            for i in idx_files]
+
+
+def check_img_limits(img):
+    """
+    Find the boundaries of the non-zero region of the image
+    """
+    def check_limit_one_side(fun, iterations):
+        for i in iterations:
+            if np.any(fun(i)):
+                return i
+        return iterations[-1]
+
+    lim = np.zeros((3, 2), dtype=int)
+    dims = img.shape
+
+    f0 = lambda i: img[i, :, :]
+    f1 = lambda i: img[:, i, :]
+    f2 = lambda i: img[:, :, i]
+    f = (f0, f1, f2)
+
+    for j in xrange(3):
+        lim[j, 0] = check_limit_one_side(f[j], xrange(dims[j]))
+        lim[j, 1] = check_limit_one_side(f[j], reversed(xrange(dims[j])))
+
+    return lim
+
+
+def crop_brain_and_pad(mri, lab, pad):
+    """
+    Extract the brain from an mri image
+    """
+
+    lim = check_img_limits(lab)
+    lim[:, 0] -= pad
+    lim[:, 1] += pad
+
+    dim_orig = np.array(mri.shape)
+    pad_inf = np.zeros((3,), dtype=int)
+    too_low = lim[:, 0] < 0
+    pad_inf[too_low] = -lim[too_low, 0]
+    lim[too_low, 0] = 0
+
+    pad_sup = np.zeros((3,), dtype=int)
+    too_high = lim[:, 1] > dim_orig
+    pad_sup[too_high] = lim[too_high, 1] - dim_orig[too_high]
+    lim[too_high, 1] = dim_orig[too_high]
+
+    lim0 = slice(lim[0, 0], lim[0, 1])
+    lim1 = slice(lim[1, 0], lim[1, 1])
+    lim2 = slice(lim[2, 0], lim[2, 1])
+
+    mri = mri[lim0, lim1, lim2]
+    lab = lab[lim0, lim1, lim2]
+
+    mri = np.lib.pad(mri, zip(pad_inf, pad_sup), 'constant', constant_values=0)
+    lab = np.lib.pad(lab, zip(pad_inf, pad_sup), 'constant', constant_values=0)
+
+    return mri, lab
+
+
+class DatasetBrainParcellation(Dataset):
+    """
+    Specialized dataset class for the brain parcellation data.
+    Attributes:
+        vx(array n_data x 3): Array containing the coordinates x, y, z of the voxels
+        file_ids(array n_data x 3): Array containing the file id of the datapoint
+    """
+    def __init__(self):
+        Dataset.__init__(self)
+
+        # Initialize the additional containers
+        self.vx = None
+        self.file_ids = None
+
+    def populate_from_config(self, config):
+        data_generator = DataGeneratorBrain()
+        data_generator.init_from_config(config)
+        vx, inputs, outputs, file_ids = data_generator.generate_parallel(config.general["n_data"])
+        self.populate(inputs, outputs, vx, file_ids)
+        self.shuffle_data()
+
+    def populate(self, inputs, outputs, vx, file_ids):
+        self.inputs = inputs
+        self.outputs = outputs
+        self.vx = vx
+        self.file_ids = file_ids
+
+    def shuffle_data_virtual(self, perm):
+        self.vx = self.vx[perm]
+        self.file_ids = self.file_ids[perm]
+
+    def write_virtual(self, h5file):
+        h5file.create_dataset("voxels", data=self.vx, dtype='f')
+        h5file.create_dataset("file_id", data=self.file_ids, dtype='f')
+
+    def read_virtual(self, h5file):
+        self.vx = h5file["voxels"].value
+        self.file_ids = h5file["file_id"].value
+
+    def duplicate_datapoints_slice_virtual(self, ds, slice_idx):
+        ds.vx = self.vx[slice_idx]
+        ds.file_ids = self.file_ids
+        pass
+
+
+class RegionCentroids():
+    def __init__(self, n_regions):
+        self.n_regions = n_regions
+        self.barycentres = np.zeros((n_regions, 3))
+
+    def update_barycentres(self, vxs, regions):
+        self.barycentres = np.zeros((self.n_regions, 3))
+        for i in xrange(self.n_regions):
+            idxs = regions == i+1
+            if vxs[idxs].size == 0:
+                continue
+            self.barycentres[i] = np.mean(vxs[idxs], axis=0)
+
+        # For zero values (with no regions present), set them to the mean
+        self.barycentres[self.barycentres == 0] = self.barycentres[self.barycentres != 0].mean()
+
+    def compute_scaled_distances(self, vx):
+        distances = np.linalg.norm(self.barycentres - vx, axis=1)
+        return distances
+
+
+def generate_and_save(config):
+    file_path = config.general["file_path"]
+    ds = DatasetBrainParcellation()
+    ds.populate_from_config(config)
+    ds.write(file_path)
+
+
+            # for ignored_label in self.ignored_labels:
+            #     lab[lab == ignored_label] = 0
+            # for idx, label in enumerate(self.true_labels):
+            #     lab[lab==label] = idx+1
+            #
+            # aa = nib.Nifti1Image(mri, nib.load(mri_file).get_affine())
+            # nib.save(aa, mri_file)
+            #
+            # bb = nib.Nifti1Image(lab, nib.load(lab_file).get_affine())
+            # nib.save(bb, lab_file)
\ No newline at end of file