|
a |
|
b/experimental/datasets.py |
|
|
1 |
# |
|
|
2 |
# Lightnet dataset that works with brambox annotations |
|
|
3 |
# Copyright EAVISE |
|
|
4 |
# |
|
|
5 |
# https://eavise.gitlab.io/lightnet/_modules/lightnet/models/_dataset_brambox.html#BramboxDataset |
|
|
6 |
# https://eavise.gitlab.io/brambox/notes/02-getting_started.html#Loading-data |
|
|
7 |
|
|
|
8 |
import os |
|
|
9 |
import copy |
|
|
10 |
import logging |
|
|
11 |
from PIL import Image |
|
|
12 |
import numpy as np |
|
|
13 |
import lightnet.data as lnd |
|
|
14 |
from pathflowai.utils import load_sql_df |
|
|
15 |
import dask.array as da |
|
|
16 |
from os.path import join |
|
|
17 |
|
|
|
18 |
try: |
|
|
19 |
import brambox as bb |
|
|
20 |
except ImportError: |
|
|
21 |
bb = None |
|
|
22 |
|
|
|
23 |
__all__ = ['BramboxDataset'] |
|
|
24 |
log = logging.getLogger(__name__) |
|
|
25 |
|
|
|
26 |
# ADD IMAGE ANNOTATION TRANSFORM |
|
|
27 |
# ADD TRAIN VAL TEST INFO |
|
|
28 |
|
|
|
29 |
class BramboxPathFlowDataset(lnd.Dataset): |
|
|
30 |
""" Dataset for any brambox annotations. |
|
|
31 |
|
|
|
32 |
Args: |
|
|
33 |
annotations (dataframe): Dataframe containing brambox annotations |
|
|
34 |
input_dimension (tuple): (width,height) tuple with default dimensions of the network |
|
|
35 |
class_label_map (list): List of class_labels |
|
|
36 |
identify (function, optional): Lambda/function to get image based of annotation filename or image id; Default **replace/add .png extension to filename/id** |
|
|
37 |
img_transform (torchvision.transforms.Compose): Transforms to perform on the images |
|
|
38 |
anno_transform (torchvision.transforms.Compose): Transforms to perform on the annotations |
|
|
39 |
|
|
|
40 |
Note: |
|
|
41 |
This dataset opens images with the Pillow library |
|
|
42 |
""" |
|
|
43 |
def __init__(self, input_dir, patch_info_file, patch_size, annotations, input_dimension, class_label_map=None, identify=None, img_transform=None, anno_transform=None): |
|
|
44 |
if bb is None: |
|
|
45 |
raise ImportError('Brambox needs to be installed to use this dataset') |
|
|
46 |
super().__init__(input_dimension) |
|
|
47 |
|
|
|
48 |
self.annos = annotations |
|
|
49 |
self.annos['ignore']=0 |
|
|
50 |
self.annos['class_label']=self.annos['class_label'].astype(int)#-1 |
|
|
51 |
print(self.annos['class_label'].unique()) |
|
|
52 |
#print(self.annos.shape) |
|
|
53 |
self.keys = self.annos.image.cat.categories # stores unique patches |
|
|
54 |
#print(self.keys) |
|
|
55 |
self.img_tf = img_transform |
|
|
56 |
self.anno_tf = anno_transform |
|
|
57 |
self.patch_info=load_sql_df(patch_info_file, patch_size) |
|
|
58 |
IDs=self.patch_info['ID'].unique() |
|
|
59 |
self.slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs} |
|
|
60 |
self.id = lambda k: k.split('/') |
|
|
61 |
# experiment |
|
|
62 |
#self.annos['x_top_left'], self.annos['y_top_left']=self.annos['y_top_left'], self.annos['x_top_left'] |
|
|
63 |
self.annos['width'], self.annos['height']=self.annos['height'], self.annos['width'] |
|
|
64 |
# Add class_ids |
|
|
65 |
if class_label_map is None: |
|
|
66 |
log.warning(f'No class_label_map given, generating it by sorting unique class labels from data alphabetically, which is not always deterministic behaviour') |
|
|
67 |
class_label_map = list(np.sort(self.annos.class_label.unique())) |
|
|
68 |
self.annos['class_id'] = self.annos.class_label.map(dict((l, i) for i, l in enumerate(class_label_map))) |
|
|
69 |
|
|
|
70 |
def __len__(self): |
|
|
71 |
return len(self.keys) |
|
|
72 |
|
|
|
73 |
@lnd.Dataset.resize_getitem |
|
|
74 |
def __getitem__(self, index): |
|
|
75 |
""" Get transformed image and annotations based of the index of ``self.keys`` |
|
|
76 |
|
|
|
77 |
Args: |
|
|
78 |
index (int): index of the ``self.keys`` list containing all the image identifiers of the dataset. |
|
|
79 |
|
|
|
80 |
Returns: |
|
|
81 |
tuple: (transformed image, list of transformed brambox boxes) |
|
|
82 |
""" |
|
|
83 |
if index >= len(self): |
|
|
84 |
raise IndexError(f'list index out of range [{index}/{len(self)-1}]') |
|
|
85 |
|
|
|
86 |
# Load |
|
|
87 |
#print(self.keys[index]) |
|
|
88 |
ID,x,y,patch_size=self.id(self.keys[index]) |
|
|
89 |
x,y,patch_size=int(x),int(y),int(patch_size) |
|
|
90 |
img = self.slides[ID][x:x+patch_size,y:y+patch_size].compute()#Image.open(self.id(self.keys[index])) |
|
|
91 |
anno = bb.util.select_images(self.annos, [self.keys[index]]) |
|
|
92 |
|
|
|
93 |
# Transform |
|
|
94 |
if self.img_tf is not None: |
|
|
95 |
img = self.img_tf(img) |
|
|
96 |
if self.anno_tf is not None: |
|
|
97 |
anno = self.anno_tf(anno) |
|
|
98 |
|
|
|
99 |
return img, anno |