|
a |
|
b/pathflowai/datasets.py |
|
|
1 |
""" |
|
|
2 |
datasets.py |
|
|
3 |
======================= |
|
|
4 |
Houses the DynamicImageDataset class, also functions to help with image color channel normalization, transformers, etc.. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
from torchvision import transforms |
|
|
9 |
import os |
|
|
10 |
import dask |
|
|
11 |
#from dask.distributed import Client; Client() |
|
|
12 |
import dask.array as da, pandas as pd, numpy as np |
|
|
13 |
from pathflowai.utils import * |
|
|
14 |
import pysnooper |
|
|
15 |
import nonechucks as nc |
|
|
16 |
from torch.utils.data import Dataset, DataLoader |
|
|
17 |
import random |
|
|
18 |
import albumentations as alb |
|
|
19 |
import copy |
|
|
20 |
from albumentations import pytorch as albtorch |
|
|
21 |
from sklearn.preprocessing import LabelBinarizer |
|
|
22 |
from sklearn.utils.class_weight import compute_class_weight |
|
|
23 |
from pathflowai.losses import class2one_hot |
|
|
24 |
import cv2 |
|
|
25 |
from scipy.ndimage.morphology import generate_binary_structure |
|
|
26 |
from dask_image.ndmorph import binary_dilation |
|
|
27 |
cv2.setNumThreads(0) |
|
|
28 |
cv2.ocl.setUseOpenCL(False) |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def RandomRotate90(): |
|
|
32 |
"""Transformer for random 90 degree rotation image. |
|
|
33 |
|
|
|
34 |
Returns |
|
|
35 |
------- |
|
|
36 |
function |
|
|
37 |
Transformer function for operation. |
|
|
38 |
|
|
|
39 |
""" |
|
|
40 |
return (lambda img: img.rotate(random.sample([0, 90, 180, 270], k=1)[0])) |
|
|
41 |
|
|
|
42 |
def get_data_transforms(patch_size = None, mean=[], std=[], resize=False, transform_platform='torch', elastic=True, user_transforms=dict()): |
|
|
43 |
"""Get data transformers for training test and validation sets. |
|
|
44 |
|
|
|
45 |
Parameters |
|
|
46 |
---------- |
|
|
47 |
patch_size:int |
|
|
48 |
Original patch size being transformed. |
|
|
49 |
mean:list of float |
|
|
50 |
Mean RGB |
|
|
51 |
std:list of float |
|
|
52 |
Std RGB |
|
|
53 |
resize:int |
|
|
54 |
Which patch size to resize to. |
|
|
55 |
transform_platform:str |
|
|
56 |
Use pytorch or albumentation transforms. |
|
|
57 |
elastic:bool |
|
|
58 |
Whether to add elastic deformations from albumentations. |
|
|
59 |
|
|
|
60 |
Returns |
|
|
61 |
------- |
|
|
62 |
dict |
|
|
63 |
Transformers. |
|
|
64 |
|
|
|
65 |
""" |
|
|
66 |
transform_dict=dict(torch=dict( |
|
|
67 |
colorjitter=lambda kargs: transforms.ColorJitter(**kargs), |
|
|
68 |
hflip=lambda kargs: transforms.RandomHorizontalFlip(), |
|
|
69 |
vflip=lambda kargs: transforms.RandomVerticalFlip(), |
|
|
70 |
r90= lambda kargs: RandomRotate90() |
|
|
71 |
), |
|
|
72 |
albumentations=dict( |
|
|
73 |
huesaturation=lambda kargs: alb.augmentations.transforms.HueSaturationValue(**kargs), |
|
|
74 |
flip=lambda kargs: alb.augmentations.transforms.Flip(**kargs), |
|
|
75 |
transpose=lambda kargs: alb.augmentations.transforms.Transpose(**kargs), |
|
|
76 |
affine=lambda kargs: alb.augmentations.transforms.ShiftScaleRotate(**kargs), |
|
|
77 |
r90=lambda kargs: alb.augmentations.transforms.RandomRotate90(**kargs), |
|
|
78 |
elastic=lambda kargs: alb.augmentations.transforms.ElasticTransform(**kargs) |
|
|
79 |
)) |
|
|
80 |
if 'normalization' in user_transforms: |
|
|
81 |
mean=user_transforms['normalization'].pop('mean') |
|
|
82 |
std=user_transforms['normalization'].pop('std') |
|
|
83 |
del user_transforms['normalization'] |
|
|
84 |
default_transforms=dict() # add normalization custom |
|
|
85 |
default_transforms['torch']=dict( |
|
|
86 |
colorjitter=dict(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.5), |
|
|
87 |
hflip=dict(), |
|
|
88 |
vflip=dict(), |
|
|
89 |
r90=dict()) |
|
|
90 |
default_transforms['albumentations']=dict( |
|
|
91 |
huesaturation=dict(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), |
|
|
92 |
r90=dict(p=0.5), |
|
|
93 |
elastic=dict(p=0.5)) |
|
|
94 |
main_transforms = default_transforms[transform_platform] if not user_transforms else user_transforms |
|
|
95 |
print(main_transforms) |
|
|
96 |
train_transforms=[transform_dict[transform_platform][k](v) for k,v in main_transforms.items()] |
|
|
97 |
torch_init=[transforms.ToPILImage(),transforms.Resize((patch_size,patch_size)),transforms.CenterCrop(patch_size)] |
|
|
98 |
albu_init=[alb.augmentations.transforms.Resize(patch_size, patch_size), |
|
|
99 |
alb.augmentations.transforms.CenterCrop(patch_size, patch_size)] |
|
|
100 |
tensor_norm=[transforms.ToTensor(),transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])] #mean and standard deviations for lung adenocarcinoma resection slides |
|
|
101 |
data_transforms = { 'torch': { |
|
|
102 |
'train': transforms.Compose(torch_init+train_transforms+tensor_norm), |
|
|
103 |
'val': transforms.Compose([ |
|
|
104 |
transforms.ToPILImage(), |
|
|
105 |
transforms.Resize((patch_size,patch_size)), |
|
|
106 |
transforms.CenterCrop(patch_size), |
|
|
107 |
transforms.ToTensor(), |
|
|
108 |
transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15]) |
|
|
109 |
]), |
|
|
110 |
'test': transforms.Compose([ |
|
|
111 |
transforms.ToPILImage(), |
|
|
112 |
transforms.Resize((patch_size,patch_size)), |
|
|
113 |
transforms.CenterCrop(patch_size), |
|
|
114 |
transforms.ToTensor(), |
|
|
115 |
transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15]) |
|
|
116 |
]), |
|
|
117 |
'pass': transforms.Compose([ |
|
|
118 |
transforms.ToPILImage(), |
|
|
119 |
transforms.CenterCrop(patch_size), |
|
|
120 |
transforms.ToTensor(), |
|
|
121 |
]) |
|
|
122 |
}, |
|
|
123 |
'albumentations':{ |
|
|
124 |
'train':alb.core.composition.Compose(albu_init+train_transforms), |
|
|
125 |
'val':alb.core.composition.Compose([ |
|
|
126 |
alb.augmentations.transforms.Resize(patch_size, patch_size), |
|
|
127 |
alb.augmentations.transforms.CenterCrop(patch_size, patch_size) |
|
|
128 |
]), |
|
|
129 |
'test':alb.core.composition.Compose([ |
|
|
130 |
alb.augmentations.transforms.Resize(patch_size, patch_size), |
|
|
131 |
alb.augmentations.transforms.CenterCrop(patch_size, patch_size) |
|
|
132 |
]), |
|
|
133 |
'normalize':transforms.Compose([transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])]) |
|
|
134 |
}} |
|
|
135 |
|
|
|
136 |
return data_transforms[transform_platform] |
|
|
137 |
|
|
|
138 |
def create_transforms(mean, std): |
|
|
139 |
"""Create transformers. |
|
|
140 |
|
|
|
141 |
Parameters |
|
|
142 |
---------- |
|
|
143 |
mean:list |
|
|
144 |
See get_data_transforms. |
|
|
145 |
std:list |
|
|
146 |
See get_data_transforms. |
|
|
147 |
|
|
|
148 |
Returns |
|
|
149 |
------- |
|
|
150 |
dict |
|
|
151 |
Transformers. |
|
|
152 |
|
|
|
153 |
""" |
|
|
154 |
return get_data_transforms(patch_size = 224, mean=mean, std=std, resize=True) |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
|
|
|
158 |
def get_normalizer(normalization_file, dataset_opts): |
|
|
159 |
"""Find mean and standard deviation of images in batches. |
|
|
160 |
|
|
|
161 |
Parameters |
|
|
162 |
---------- |
|
|
163 |
normalization_file:str |
|
|
164 |
File to store normalization information. |
|
|
165 |
dataset_opts:type |
|
|
166 |
Dictionary storing information to create DynamicDataset class. |
|
|
167 |
|
|
|
168 |
Returns |
|
|
169 |
------- |
|
|
170 |
dict |
|
|
171 |
Stores RGB mean, stdev. |
|
|
172 |
|
|
|
173 |
""" |
|
|
174 |
if os.path.exists(normalization_file): |
|
|
175 |
norm_dict = torch.load(normalization_file) |
|
|
176 |
else: |
|
|
177 |
norm_dict = {'normalization_file':normalization_file} |
|
|
178 |
|
|
|
179 |
if 'normalization_file' in norm_dict: |
|
|
180 |
|
|
|
181 |
transformers = get_data_transforms(patch_size = 224, mean=[], std=[], resize=True, transform_platform='torch') |
|
|
182 |
|
|
|
183 |
dataset_opts['transformers']=transformers |
|
|
184 |
#print(dict(pos_annotation_class=pos_annotation_class, segmentation=segmentation, patch_size=patch_size, fix_names=fix_names, other_annotations=other_annotations)) |
|
|
185 |
|
|
|
186 |
dataset = DynamicImageDataset(**dataset_opts)#nc.SafeDataset(DynamicImageDataset(**dataset_opts)) |
|
|
187 |
|
|
|
188 |
if dataset_opts['classify_annotations']: |
|
|
189 |
dataset.binarize_annotations() |
|
|
190 |
|
|
|
191 |
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) |
|
|
192 |
|
|
|
193 |
all_mean = torch.tensor([0.,0.,0.],dtype=torch.float)#[] |
|
|
194 |
|
|
|
195 |
all_std = torch.tensor([0.,0.,0.],dtype=torch.float) |
|
|
196 |
|
|
|
197 |
if torch.cuda.is_available(): |
|
|
198 |
all_mean=all_mean.cuda() |
|
|
199 |
all_std=all_std.cuda() |
|
|
200 |
|
|
|
201 |
with torch.no_grad(): |
|
|
202 |
for i,(X,_) in enumerate(dataloader): # x,3,224,224 |
|
|
203 |
if torch.cuda.is_available(): |
|
|
204 |
X=X.cuda() |
|
|
205 |
all_mean += torch.mean(X, (0,2,3)) |
|
|
206 |
all_std += torch.std(X, (0,2,3)) |
|
|
207 |
|
|
|
208 |
N=i+1 |
|
|
209 |
|
|
|
210 |
all_mean /= float(N) #(np.array(all_mean).mean(axis=0)).tolist() |
|
|
211 |
all_std /= float(N) #(np.array(all_std).mean(axis=0)).tolist() |
|
|
212 |
|
|
|
213 |
all_mean = all_mean.detach().cpu().numpy().tolist() |
|
|
214 |
all_std = all_std.detach().cpu().numpy().tolist() |
|
|
215 |
|
|
|
216 |
torch.save(dict(mean=all_mean,std=all_std),norm_dict['normalization_file']) |
|
|
217 |
|
|
|
218 |
norm_dict = torch.load(norm_dict['normalization_file']) |
|
|
219 |
return norm_dict |
|
|
220 |
|
|
|
221 |
def segmentation_transform(img,mask, transformer, normalizer, alb_reduction): |
|
|
222 |
"""Run albumentations and return an image and its segmentation mask. |
|
|
223 |
|
|
|
224 |
Parameters |
|
|
225 |
---------- |
|
|
226 |
img:array |
|
|
227 |
Image as array |
|
|
228 |
mask:array |
|
|
229 |
Categorical pixel by pixel. |
|
|
230 |
transformer : |
|
|
231 |
Transformation object. |
|
|
232 |
|
|
|
233 |
Returns |
|
|
234 |
------- |
|
|
235 |
tuple arrays |
|
|
236 |
Image and mask array. |
|
|
237 |
|
|
|
238 |
""" |
|
|
239 |
res=transformer(True, image=img, mask=mask) |
|
|
240 |
#res_mask_shape = res['mask'].size() |
|
|
241 |
return normalizer(torch.tensor(np.transpose(res['image']/alb_reduction,axes=(2,0,1)),dtype=torch.float)).float(), torch.tensor(res['mask']).long()#.view(res_mask_shape[0],res_mask_shape[1],res_mask_shape[2]) |
|
|
242 |
|
|
|
243 |
class DilationJitter: |
|
|
244 |
def __init__(self, dilation_jitter=dict(), segmentation=True, train_set=False): |
|
|
245 |
if dilation_jitter and segmentation and train_set: |
|
|
246 |
self.run_jitter=True |
|
|
247 |
self.dilation_jitter=dilation_jitter |
|
|
248 |
self.struct=generate_binary_structure(2,1) #structure=self.struct, |
|
|
249 |
else: |
|
|
250 |
self.run_jitter=False |
|
|
251 |
|
|
|
252 |
|
|
|
253 |
def __call__(self, mask): |
|
|
254 |
if self.run_jitter: |
|
|
255 |
for k in self.dilation_jitter: |
|
|
256 |
amount_jitter=int(round(max(np.random.normal(self.dilation_jitter[k]['mean'], |
|
|
257 |
self.dilation_jitter[k]['std']),1))) |
|
|
258 |
#print((mask==k).compute()) |
|
|
259 |
mask[binary_dilation(mask==k,structure=self.struct,iterations=amount_jitter)]=k |
|
|
260 |
|
|
|
261 |
return mask |
|
|
262 |
|
|
|
263 |
|
|
|
264 |
class DynamicImageDataset(Dataset): |
|
|
265 |
"""Generate image dataset that accesses images and annotations via dask. |
|
|
266 |
|
|
|
267 |
Parameters |
|
|
268 |
---------- |
|
|
269 |
dataset_df:dataframe |
|
|
270 |
Dataframe with WSI, which set it is in (train/test/val) and corresponding WSI labels if applicable. |
|
|
271 |
set:str |
|
|
272 |
Whether train, test, val or pass (normalization) set. |
|
|
273 |
patch_info_file:str |
|
|
274 |
SQL db with positional and annotation information on each slide. |
|
|
275 |
transformers:dict |
|
|
276 |
Contains transformers to apply on images. |
|
|
277 |
input_dir:str |
|
|
278 |
Directory where images comes from. |
|
|
279 |
target_names:list/str |
|
|
280 |
Names of initial targets, which may be modified. |
|
|
281 |
pos_annotation_class:str |
|
|
282 |
If selected and predicting on WSI, this class is labeled as a positive from the WSI, while the other classes are not. |
|
|
283 |
other_annotations:list |
|
|
284 |
Other annotations to consider from patch info db. |
|
|
285 |
segmentation:bool |
|
|
286 |
Conducting segmentation task? |
|
|
287 |
patch_size:int |
|
|
288 |
Patch size. |
|
|
289 |
fix_names:bool |
|
|
290 |
Whether to change the names of dataset_df. |
|
|
291 |
target_segmentation_class:list |
|
|
292 |
Now can be used for classification as well, matched with two below options, samples images only from this class. Can specify this and below two options multiple times. |
|
|
293 |
target_threshold:list |
|
|
294 |
Sampled only if above this threshold of occurence in the patches. |
|
|
295 |
oversampling_factor:list |
|
|
296 |
Over sample them at this amount. |
|
|
297 |
n_segmentation_classes:int |
|
|
298 |
Number classes to segment. |
|
|
299 |
gdl:bool |
|
|
300 |
Using generalized dice loss? |
|
|
301 |
mt_bce:bool |
|
|
302 |
For multi-target prediction tasks. |
|
|
303 |
classify_annotations:bool |
|
|
304 |
For classifying annotations. |
|
|
305 |
|
|
|
306 |
""" |
|
|
307 |
# when building transformers, need a resize patch size to make patches 224 by 224 |
|
|
308 |
#@pysnooper.snoop('init_data.log') |
|
|
309 |
def __init__(self,dataset_df, set, patch_info_file, transformers, input_dir, target_names, pos_annotation_class, other_annotations=[], segmentation=False, patch_size=224, fix_names=True, target_segmentation_class=-1, target_threshold=0., oversampling_factor=1., n_segmentation_classes=4, gdl=False, mt_bce=False, classify_annotations=False, dilation_jitter=dict(), modify_patches=True): |
|
|
310 |
|
|
|
311 |
#print('check',classify_annotations) |
|
|
312 |
reduce_alb=True |
|
|
313 |
self.patch_size=patch_size |
|
|
314 |
self.input_dir = input_dir |
|
|
315 |
self.alb_reduction=255. if reduce_alb else 1. |
|
|
316 |
self.transformer=transformers[set] |
|
|
317 |
original_set = copy.deepcopy(set) |
|
|
318 |
if set=='pass': |
|
|
319 |
set='train' |
|
|
320 |
self.targets = target_names |
|
|
321 |
self.mt_bce=mt_bce |
|
|
322 |
self.set = set |
|
|
323 |
self.segmentation = segmentation |
|
|
324 |
self.alb_normalizer=None |
|
|
325 |
if 'normalize' in transformers: |
|
|
326 |
self.alb_normalizer = transformers['normalize'] |
|
|
327 |
if len(self.targets)==1: |
|
|
328 |
self.targets = self.targets[0] |
|
|
329 |
if original_set == 'pass': |
|
|
330 |
self.transform_fn = lambda x,y: (self.transformer(x), torch.tensor(1.,dtype=torch.float)) |
|
|
331 |
else: |
|
|
332 |
if self.segmentation: |
|
|
333 |
self.transform_fn = lambda x,y: segmentation_transform(x,y, self.transformer, self.alb_normalizer, self.alb_reduction) |
|
|
334 |
else: |
|
|
335 |
if 'p' in dir(self.transformer): |
|
|
336 |
self.transform_fn = lambda x,y: (self.alb_normalizer(torch.tensor(np.transpose(self.transformer(True, image=x)['image']/self.alb_reduction,axes=(2,0,1)),dtype=torch.float)), torch.from_numpy(y).float()) |
|
|
337 |
else: |
|
|
338 |
self.transform_fn = lambda x,y: (self.transformer(x), torch.from_numpy(y).float()) |
|
|
339 |
self.image_set = dataset_df[dataset_df['set']==set] |
|
|
340 |
if self.segmentation: |
|
|
341 |
self.targets='target' |
|
|
342 |
self.image_set[self.targets] = 1. |
|
|
343 |
if not self.segmentation and fix_names: |
|
|
344 |
self.image_set.loc[:,'ID'] = self.image_set['ID'].map(fix_name) |
|
|
345 |
self.slide_info = pd.DataFrame(self.image_set.set_index('ID').loc[:,self.targets]) |
|
|
346 |
if self.mt_bce and not self.segmentation: |
|
|
347 |
if pos_annotation_class: |
|
|
348 |
self.targets = [pos_annotation_class]+list(other_annotations) |
|
|
349 |
else: |
|
|
350 |
self.targets = None |
|
|
351 |
print(self.targets) |
|
|
352 |
IDs = self.slide_info.index.tolist() |
|
|
353 |
pi_dict=dict(input_info_db=patch_info_file, |
|
|
354 |
slide_labels=self.slide_info, |
|
|
355 |
pos_annotation_class=pos_annotation_class, |
|
|
356 |
patch_size=patch_size, |
|
|
357 |
segmentation=self.segmentation, |
|
|
358 |
other_annotations=other_annotations, |
|
|
359 |
target_segmentation_class=target_segmentation_class, |
|
|
360 |
target_threshold=target_threshold, |
|
|
361 |
classify_annotations=classify_annotations, |
|
|
362 |
modify_patches=modify_patches) |
|
|
363 |
self.patch_info = modify_patch_info(**pi_dict) |
|
|
364 |
|
|
|
365 |
if self.segmentation and original_set!='pass': |
|
|
366 |
#IDs = self.patch_info['ID'].unique() |
|
|
367 |
self.segmentation_maps = {slide:npy2da(join(input_dir,'{}_mask.npy'.format(slide))) for slide in IDs} |
|
|
368 |
self.slides = {slide:load_preprocessed_img(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs} |
|
|
369 |
#print(self.slide_info) |
|
|
370 |
if original_set =='pass': |
|
|
371 |
self.segmentation=False |
|
|
372 |
#print(self.patch_info[self.targets].unique()) |
|
|
373 |
if oversampling_factor > 1: |
|
|
374 |
self.patch_info = pd.concat([self.patch_info]*int(oversampling_factor),axis=0).reset_index(drop=True) |
|
|
375 |
elif oversampling_factor < 1: |
|
|
376 |
self.patch_info = self.patch_info.sample(frac=oversampling_factor).reset_index(drop=True) |
|
|
377 |
self.length = self.patch_info.shape[0] |
|
|
378 |
self.n_segmentation_classes = n_segmentation_classes |
|
|
379 |
self.gdl=gdl if self.segmentation else False |
|
|
380 |
self.binarized=False |
|
|
381 |
self.classify_annotations=classify_annotations |
|
|
382 |
print(self.targets) |
|
|
383 |
self.dilation_jitter=DilationJitter(dilation_jitter,self.segmentation,(original_set=='train')) |
|
|
384 |
if not self.targets: |
|
|
385 |
self.targets = [pos_annotation_class]+list(other_annotations) |
|
|
386 |
|
|
|
387 |
def concat(self, other_dataset): |
|
|
388 |
"""Concatenate this dataset with others. Updates its own internal attributes. |
|
|
389 |
|
|
|
390 |
Parameters |
|
|
391 |
---------- |
|
|
392 |
other_dataset:DynamicImageDataset |
|
|
393 |
Other image dataset. |
|
|
394 |
|
|
|
395 |
""" |
|
|
396 |
self.patch_info = pd.concat([self.patch_info, other_dataset.patch_info],axis=0).reset_index(drop=True) |
|
|
397 |
self.length = self.patch_info.shape[0] |
|
|
398 |
if self.segmentation: |
|
|
399 |
self.segmentation_maps.update(other_dataset.segmentation_maps) |
|
|
400 |
#print(self.segmentation_maps.keys()) |
|
|
401 |
|
|
|
402 |
def retain_ID(self, ID): |
|
|
403 |
"""Reduce the sample set to just images from one ID. |
|
|
404 |
|
|
|
405 |
Parameters |
|
|
406 |
---------- |
|
|
407 |
ID:str |
|
|
408 |
Basename/ID to predict on. |
|
|
409 |
|
|
|
410 |
Returns |
|
|
411 |
------- |
|
|
412 |
self |
|
|
413 |
|
|
|
414 |
""" |
|
|
415 |
self.patch_info=self.patch_info.loc[self.patch_info['ID']==ID] |
|
|
416 |
self.length = self.patch_info.shape[0] |
|
|
417 |
self.segmentation_maps={ID:self.segmentation_maps[ID]} |
|
|
418 |
return self |
|
|
419 |
|
|
|
420 |
def split_by_ID(self): |
|
|
421 |
"""Generator similar to groupby, but splits up by ID, generates (ID,data) using retain_ID. |
|
|
422 |
|
|
|
423 |
Returns |
|
|
424 |
------- |
|
|
425 |
generator |
|
|
426 |
ID, DynamicDataset |
|
|
427 |
|
|
|
428 |
""" |
|
|
429 |
for ID in self.patch_info['ID'].unique(): |
|
|
430 |
new_dataset = copy.deepcopy(self) |
|
|
431 |
yield ID, new_dataset.retain_ID(ID) |
|
|
432 |
|
|
|
433 |
def select_IDs(self, IDs): |
|
|
434 |
for ID in IDs: |
|
|
435 |
if ID in self.patch_info['ID'].unique(): |
|
|
436 |
new_dataset = copy.deepcopy(self) |
|
|
437 |
yield ID, new_dataset.retain_ID(ID) |
|
|
438 |
|
|
|
439 |
|
|
|
440 |
def get_class_weights(self, i=0):#[0,1] |
|
|
441 |
"""Weight loss function with weights inversely proportional to the class appearence. |
|
|
442 |
|
|
|
443 |
Parameters |
|
|
444 |
---------- |
|
|
445 |
i:int |
|
|
446 |
If multi-target, class used for weighting. |
|
|
447 |
|
|
|
448 |
Returns |
|
|
449 |
------- |
|
|
450 |
self |
|
|
451 |
Dataset. |
|
|
452 |
|
|
|
453 |
""" |
|
|
454 |
if self.segmentation: |
|
|
455 |
label_counts=self.patch_info[list(map(str,list(range(self.n_segmentation_classes))))].sum(axis=0).values |
|
|
456 |
freq = label_counts/sum(label_counts) |
|
|
457 |
weights=1./(freq) |
|
|
458 |
elif self.mt_bce: |
|
|
459 |
weights=1./(self.patch_info.loc[:,self.targets].sum(axis=0).values) |
|
|
460 |
weights=weights/sum(weights) |
|
|
461 |
else: |
|
|
462 |
if self.binarized and len(self.targets)>1: |
|
|
463 |
y=np.argmax(self.patch_info.loc[:,self.targets].values,axis=1) |
|
|
464 |
elif (type(self.targets)==type('')): |
|
|
465 |
y=self.patch_info.loc[:,self.targets] |
|
|
466 |
else: |
|
|
467 |
y=self.patch_info.loc[:,self.targets[i]] |
|
|
468 |
y=y.values.astype(int).flatten() |
|
|
469 |
weights=compute_class_weight(class_weight='balanced',classes=np.unique(y),y=y) |
|
|
470 |
return weights |
|
|
471 |
|
|
|
472 |
def binarize_annotations(self, binarizer=None, num_targets=1, binary_threshold=0.): |
|
|
473 |
"""Label binarize some annotations or threshold them if classifying slide annotations. |
|
|
474 |
|
|
|
475 |
Parameters |
|
|
476 |
---------- |
|
|
477 |
binarizer:LabelBinarizer |
|
|
478 |
Binarizes the labels of a column(s) |
|
|
479 |
num_targets:int |
|
|
480 |
Number of desired targets to preidict on. |
|
|
481 |
binary_threshold:float |
|
|
482 |
Amount of annotation in patch before positive annotation. |
|
|
483 |
|
|
|
484 |
Returns |
|
|
485 |
------- |
|
|
486 |
binarizer |
|
|
487 |
|
|
|
488 |
""" |
|
|
489 |
|
|
|
490 |
annotations = self.patch_info['annotation'] |
|
|
491 |
annots=[annot for annot in list(self.patch_info.iloc[:,6:]) if annot !='area'] |
|
|
492 |
if not self.mt_bce and num_targets > 1: |
|
|
493 |
if binarizer == None: |
|
|
494 |
self.binarizer = LabelBinarizer().fit(annotations) |
|
|
495 |
else: |
|
|
496 |
self.binarizer = copy.deepcopy(binarizer) |
|
|
497 |
self.targets = self.binarizer.classes_ |
|
|
498 |
annotation_labels = pd.DataFrame(self.binarizer.transform(annotations),index=self.patch_info.index,columns=self.targets).astype(float) |
|
|
499 |
for col in list(annotation_labels): |
|
|
500 |
if col in list(self.patch_info): |
|
|
501 |
self.patch_info.loc[:,col]=annotation_labels[col].values |
|
|
502 |
else: |
|
|
503 |
self.patch_info[col]=annotation_labels[col].values |
|
|
504 |
else: |
|
|
505 |
self.binarizer=None |
|
|
506 |
self.targets=annots |
|
|
507 |
if num_targets == 1: |
|
|
508 |
self.targets = [self.targets[-1]] |
|
|
509 |
if binary_threshold>0.: |
|
|
510 |
self.patch_info.loc[:,self.targets]=(self.patch_info[self.targets]>=binary_threshold).values.astype(np.float32) |
|
|
511 |
print(self.targets) |
|
|
512 |
#self.patch_info = pd.concat([self.patch_info,annotation_labels],axis=1) |
|
|
513 |
self.binarized=True |
|
|
514 |
return self.binarizer |
|
|
515 |
|
|
|
516 |
def subsample(self, p): |
|
|
517 |
"""Sample subset of dataset. |
|
|
518 |
|
|
|
519 |
Parameters |
|
|
520 |
---------- |
|
|
521 |
p:float |
|
|
522 |
Fraction to subsample. |
|
|
523 |
|
|
|
524 |
""" |
|
|
525 |
np.random.seed(42) |
|
|
526 |
self.patch_info = self.patch_info.sample(frac=p) |
|
|
527 |
self.length = self.patch_info.shape[0] |
|
|
528 |
|
|
|
529 |
def update_dataset(self, input_dir, new_db, prediction_basename=[]): |
|
|
530 |
"""Experimental. Only use for segmentation for now.""" |
|
|
531 |
self.input_dir=input_dir |
|
|
532 |
self.patch_info=load_sql_df(new_db, self.patch_size) |
|
|
533 |
IDs = self.patch_info['ID'].unique() |
|
|
534 |
self.slides = {slide:load_preprocessed_img(join(self.input_dir,'{}.zarr'.format(slide))) for slide in IDs} |
|
|
535 |
if self.segmentation: |
|
|
536 |
if prediction_basename: |
|
|
537 |
self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs if slide in prediction_basename} |
|
|
538 |
else: |
|
|
539 |
self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs} |
|
|
540 |
self.length = self.patch_info.shape[0] |
|
|
541 |
|
|
|
542 |
#@pysnooper.snoop("getitem.log") |
|
|
543 |
def __getitem__(self, i): |
|
|
544 |
patch_info = self.patch_info.iloc[i] |
|
|
545 |
ID = patch_info['ID'] |
|
|
546 |
xs = patch_info['x'] |
|
|
547 |
ys = patch_info['y'] |
|
|
548 |
patch_size = patch_info['patch_size'] |
|
|
549 |
if xs==np.nan: |
|
|
550 |
entire_image=True |
|
|
551 |
else: |
|
|
552 |
entire_image=False |
|
|
553 |
targets=self.targets |
|
|
554 |
use_long=False |
|
|
555 |
if not self.segmentation: |
|
|
556 |
y = patch_info.loc[list(self.targets) if not isinstance(self.targets,str) else self.targets] |
|
|
557 |
if isinstance(y,pd.Series): |
|
|
558 |
y=y.values.astype(float) |
|
|
559 |
if self.binarized and not self.mt_bce and len(y)>1: |
|
|
560 |
y=np.array(y.argmax()) |
|
|
561 |
use_long=True |
|
|
562 |
y=np.array(y) |
|
|
563 |
if not y.shape: |
|
|
564 |
y=y.reshape(1) |
|
|
565 |
if self.segmentation: |
|
|
566 |
arr=self.segmentation_maps[ID] |
|
|
567 |
if not entire_image: |
|
|
568 |
arr=arr[xs:xs+patch_size,ys:ys+patch_size] |
|
|
569 |
arr=self.dilation_jitter(arr) |
|
|
570 |
y=(y if not self.segmentation else np.array(arr)) |
|
|
571 |
#print(y) |
|
|
572 |
arr=self.slides[ID] |
|
|
573 |
if not entire_image: |
|
|
574 |
arr=arr[xs:xs+patch_size,ys:ys+patch_size,:3] |
|
|
575 |
image, y = self.transform_fn(arr.compute().astype(np.uint8), y)#.unsqueeze(0) # transpose .transpose([1,0,2]) |
|
|
576 |
if not self.segmentation and not self.mt_bce and self.classify_annotations and use_long: |
|
|
577 |
y=y.long() |
|
|
578 |
#image_size=image.size() |
|
|
579 |
if self.gdl: |
|
|
580 |
y=class2one_hot(y, self.n_segmentation_classes) |
|
|
581 |
# y=one_hot2dist(y) |
|
|
582 |
return image, y |
|
|
583 |
|
|
|
584 |
def __len__(self): |
|
|
585 |
return self.length |
|
|
586 |
|
|
|
587 |
class NPYDataset(Dataset): |
|
|
588 |
def __init__(self, patch_info, patch_size, npy_file, transform, mmap=False): |
|
|
589 |
self.ID=os.path.basename(npy_file).split('.')[0] |
|
|
590 |
patch_info=patch_info=load_sql_df(patch_info,patch_size) |
|
|
591 |
self.patch_info=patch_info.loc[patch_info["ID"]==self.ID].reset_index() |
|
|
592 |
self.X=np.load(npy_file,mmap_mode=(None if not mmap else 'r+')) |
|
|
593 |
self.transform=transform |
|
|
594 |
|
|
|
595 |
def __getitem__(self,i): |
|
|
596 |
x,y,patch_size=self.patch_info.loc[i,["x","y","patch_size"]] |
|
|
597 |
return self.transform(self.X[x:x+patch_size,y:y+patch_size]) |
|
|
598 |
|
|
|
599 |
def __len__(self): |
|
|
600 |
return self.patch_info.shape[0] |
|
|
601 |
|
|
|
602 |
def embed(self,model,batch_size,out_dir): |
|
|
603 |
Z=[] |
|
|
604 |
dataloader=DataLoader(self,batch_size=batch_size,shuffle=False) |
|
|
605 |
n_batches=len(self)//batch_size |
|
|
606 |
with torch.no_grad(): |
|
|
607 |
for i,X in enumerate(dataloader): |
|
|
608 |
if torch.cuda.is_available(): |
|
|
609 |
X=X.cuda() |
|
|
610 |
z=model(X).detach().cpu().numpy() |
|
|
611 |
Z.append(z) |
|
|
612 |
print(f"Processed batch {i}/{n_batches}") |
|
|
613 |
Z=np.vstack(Z) |
|
|
614 |
torch.save(dict(embeddings=Z,patch_info=self.patch_info),os.path.join(out_dir,f"{self.ID}.pkl")) |
|
|
615 |
print("Embeddings saved") |
|
|
616 |
quit() |