Diff of /experimental/datasets.py [000000] .. [e9500f]

Switch to unified view

a b/experimental/datasets.py
1
#
2
#   Lightnet dataset that works with brambox annotations
3
#   Copyright EAVISE
4
#
5
# https://eavise.gitlab.io/lightnet/_modules/lightnet/models/_dataset_brambox.html#BramboxDataset
6
# https://eavise.gitlab.io/brambox/notes/02-getting_started.html#Loading-data
7
8
import os
9
import copy
10
import logging
11
from PIL import Image
12
import numpy as np
13
import lightnet.data as lnd
14
from pathflowai.utils import load_sql_df
15
import dask.array as da
16
from os.path import join
17
18
try:
19
    import brambox as bb
20
except ImportError:
21
    bb = None
22
23
__all__ = ['BramboxDataset']
24
log = logging.getLogger(__name__)
25
26
# ADD IMAGE ANNOTATION TRANSFORM
27
# ADD TRAIN VAL TEST INFO
28
29
class BramboxPathFlowDataset(lnd.Dataset):
30
    """ Dataset for any brambox annotations.
31
32
    Args:
33
        annotations (dataframe): Dataframe containing brambox annotations
34
        input_dimension (tuple): (width,height) tuple with default dimensions of the network
35
        class_label_map (list): List of class_labels
36
        identify (function, optional): Lambda/function to get image based of annotation filename or image id; Default **replace/add .png extension to filename/id**
37
        img_transform (torchvision.transforms.Compose): Transforms to perform on the images
38
        anno_transform (torchvision.transforms.Compose): Transforms to perform on the annotations
39
40
    Note:
41
        This dataset opens images with the Pillow library
42
    """
43
    def __init__(self, input_dir, patch_info_file, patch_size, annotations, input_dimension, class_label_map=None, identify=None, img_transform=None, anno_transform=None):
44
        if bb is None:
45
            raise ImportError('Brambox needs to be installed to use this dataset')
46
        super().__init__(input_dimension)
47
48
        self.annos = annotations
49
        self.annos['ignore']=0
50
        self.annos['class_label']=self.annos['class_label'].astype(int)#-1
51
        print(self.annos['class_label'].unique())
52
        #print(self.annos.shape)
53
        self.keys = self.annos.image.cat.categories # stores unique patches
54
        #print(self.keys)
55
        self.img_tf = img_transform
56
        self.anno_tf = anno_transform
57
        self.patch_info=load_sql_df(patch_info_file, patch_size)
58
        IDs=self.patch_info['ID'].unique()
59
        self.slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
60
        self.id = lambda k: k.split('/')
61
        # experiment
62
        #self.annos['x_top_left'], self.annos['y_top_left']=self.annos['y_top_left'], self.annos['x_top_left']
63
        self.annos['width'], self.annos['height']=self.annos['height'], self.annos['width']
64
        # Add class_ids
65
        if class_label_map is None:
66
            log.warning(f'No class_label_map given, generating it by sorting unique class labels from data alphabetically, which is not always deterministic behaviour')
67
            class_label_map = list(np.sort(self.annos.class_label.unique()))
68
        self.annos['class_id'] = self.annos.class_label.map(dict((l, i) for i, l in enumerate(class_label_map)))
69
70
    def __len__(self):
71
        return len(self.keys)
72
73
    @lnd.Dataset.resize_getitem
74
    def __getitem__(self, index):
75
        """ Get transformed image and annotations based of the index of ``self.keys``
76
77
        Args:
78
            index (int): index of the ``self.keys`` list containing all the image identifiers of the dataset.
79
80
        Returns:
81
            tuple: (transformed image, list of transformed brambox boxes)
82
        """
83
        if index >= len(self):
84
            raise IndexError(f'list index out of range [{index}/{len(self)-1}]')
85
86
        # Load
87
        #print(self.keys[index])
88
        ID,x,y,patch_size=self.id(self.keys[index])
89
        x,y,patch_size=int(x),int(y),int(patch_size)
90
        img = self.slides[ID][x:x+patch_size,y:y+patch_size].compute()#Image.open(self.id(self.keys[index]))
91
        anno = bb.util.select_images(self.annos, [self.keys[index]])
92
93
        # Transform
94
        if self.img_tf is not None:
95
            img = self.img_tf(img)
96
        if self.anno_tf is not None:
97
            anno = self.anno_tf(anno)
98
99
        return img, anno