--- a
+++ b/experimental/datasets.py
@@ -0,0 +1,99 @@
+#
+#   Lightnet dataset that works with brambox annotations
+#   Copyright EAVISE
+#
+# https://eavise.gitlab.io/lightnet/_modules/lightnet/models/_dataset_brambox.html#BramboxDataset
+# https://eavise.gitlab.io/brambox/notes/02-getting_started.html#Loading-data
+
+import os
+import copy
+import logging
+from PIL import Image
+import numpy as np
+import lightnet.data as lnd
+from pathflowai.utils import load_sql_df
+import dask.array as da
+from os.path import join
+
+try:
+    import brambox as bb
+except ImportError:
+    bb = None
+
+__all__ = ['BramboxDataset']
+log = logging.getLogger(__name__)
+
+# ADD IMAGE ANNOTATION TRANSFORM
+# ADD TRAIN VAL TEST INFO
+
+class BramboxPathFlowDataset(lnd.Dataset):
+    """ Dataset for any brambox annotations.
+
+    Args:
+        annotations (dataframe): Dataframe containing brambox annotations
+        input_dimension (tuple): (width,height) tuple with default dimensions of the network
+        class_label_map (list): List of class_labels
+        identify (function, optional): Lambda/function to get image based of annotation filename or image id; Default **replace/add .png extension to filename/id**
+        img_transform (torchvision.transforms.Compose): Transforms to perform on the images
+        anno_transform (torchvision.transforms.Compose): Transforms to perform on the annotations
+
+    Note:
+        This dataset opens images with the Pillow library
+    """
+    def __init__(self, input_dir, patch_info_file, patch_size, annotations, input_dimension, class_label_map=None, identify=None, img_transform=None, anno_transform=None):
+        if bb is None:
+            raise ImportError('Brambox needs to be installed to use this dataset')
+        super().__init__(input_dimension)
+
+        self.annos = annotations
+        self.annos['ignore']=0
+        self.annos['class_label']=self.annos['class_label'].astype(int)#-1
+        print(self.annos['class_label'].unique())
+        #print(self.annos.shape)
+        self.keys = self.annos.image.cat.categories # stores unique patches
+        #print(self.keys)
+        self.img_tf = img_transform
+        self.anno_tf = anno_transform
+        self.patch_info=load_sql_df(patch_info_file, patch_size)
+        IDs=self.patch_info['ID'].unique()
+        self.slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
+        self.id = lambda k: k.split('/')
+        # experiment
+        #self.annos['x_top_left'], self.annos['y_top_left']=self.annos['y_top_left'], self.annos['x_top_left']
+        self.annos['width'], self.annos['height']=self.annos['height'], self.annos['width']
+        # Add class_ids
+        if class_label_map is None:
+            log.warning(f'No class_label_map given, generating it by sorting unique class labels from data alphabetically, which is not always deterministic behaviour')
+            class_label_map = list(np.sort(self.annos.class_label.unique()))
+        self.annos['class_id'] = self.annos.class_label.map(dict((l, i) for i, l in enumerate(class_label_map)))
+
+    def __len__(self):
+        return len(self.keys)
+
+    @lnd.Dataset.resize_getitem
+    def __getitem__(self, index):
+        """ Get transformed image and annotations based of the index of ``self.keys``
+
+        Args:
+            index (int): index of the ``self.keys`` list containing all the image identifiers of the dataset.
+
+        Returns:
+            tuple: (transformed image, list of transformed brambox boxes)
+        """
+        if index >= len(self):
+            raise IndexError(f'list index out of range [{index}/{len(self)-1}]')
+
+        # Load
+        #print(self.keys[index])
+        ID,x,y,patch_size=self.id(self.keys[index])
+        x,y,patch_size=int(x),int(y),int(patch_size)
+        img = self.slides[ID][x:x+patch_size,y:y+patch_size].compute()#Image.open(self.id(self.keys[index]))
+        anno = bb.util.select_images(self.annos, [self.keys[index]])
+
+        # Transform
+        if self.img_tf is not None:
+            img = self.img_tf(img)
+        if self.anno_tf is not None:
+            anno = self.anno_tf(anno)
+
+        return img, anno