Diff of /pathflowai/datasets.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/datasets.py
1
"""
2
datasets.py
3
=======================
4
Houses the DynamicImageDataset class, also functions to help with image color channel normalization, transformers, etc..
5
"""
6
7
import torch
8
from torchvision import transforms
9
import os
10
import dask
11
#from dask.distributed import Client; Client()
12
import dask.array as da, pandas as pd, numpy as np
13
from pathflowai.utils import *
14
import pysnooper
15
import nonechucks as nc
16
from torch.utils.data import Dataset, DataLoader
17
import random
18
import albumentations as alb
19
import copy
20
from albumentations import pytorch as albtorch
21
from sklearn.preprocessing import LabelBinarizer
22
from sklearn.utils.class_weight import compute_class_weight
23
from pathflowai.losses import class2one_hot
24
import cv2
25
from scipy.ndimage.morphology import generate_binary_structure
26
from dask_image.ndmorph import binary_dilation
27
cv2.setNumThreads(0)
28
cv2.ocl.setUseOpenCL(False)
29
30
31
def RandomRotate90():
32
    """Transformer for random 90 degree rotation image.
33
34
    Returns
35
    -------
36
    function
37
        Transformer function for operation.
38
39
    """
40
    return (lambda img: img.rotate(random.sample([0, 90, 180, 270], k=1)[0]))
41
42
def get_data_transforms(patch_size = None, mean=[], std=[], resize=False, transform_platform='torch', elastic=True, user_transforms=dict()):
43
    """Get data transformers for training test and validation sets.
44
45
    Parameters
46
    ----------
47
    patch_size:int
48
        Original patch size being transformed.
49
    mean:list of float
50
        Mean RGB
51
    std:list of float
52
        Std RGB
53
    resize:int
54
        Which patch size to resize to.
55
    transform_platform:str
56
        Use pytorch or albumentation transforms.
57
    elastic:bool
58
        Whether to add elastic deformations from albumentations.
59
60
    Returns
61
    -------
62
    dict
63
        Transformers.
64
65
    """
66
    transform_dict=dict(torch=dict(
67
                                    colorjitter=lambda kargs: transforms.ColorJitter(**kargs),
68
                                    hflip=lambda kargs: transforms.RandomHorizontalFlip(),
69
                                    vflip=lambda kargs: transforms.RandomVerticalFlip(),
70
                                    r90= lambda kargs: RandomRotate90()
71
                                    ),
72
                        albumentations=dict(
73
                            huesaturation=lambda kargs: alb.augmentations.transforms.HueSaturationValue(**kargs),
74
                            flip=lambda kargs: alb.augmentations.transforms.Flip(**kargs),
75
                            transpose=lambda kargs: alb.augmentations.transforms.Transpose(**kargs),
76
                            affine=lambda kargs: alb.augmentations.transforms.ShiftScaleRotate(**kargs),
77
                            r90=lambda kargs: alb.augmentations.transforms.RandomRotate90(**kargs),
78
                            elastic=lambda kargs: alb.augmentations.transforms.ElasticTransform(**kargs)
79
                        ))
80
    if 'normalization' in user_transforms:
81
        mean=user_transforms['normalization'].pop('mean')
82
        std=user_transforms['normalization'].pop('std')
83
        del user_transforms['normalization']
84
    default_transforms=dict() # add normalization custom
85
    default_transforms['torch']=dict(
86
                            colorjitter=dict(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.5),
87
                            hflip=dict(),
88
                            vflip=dict(),
89
                            r90=dict())
90
    default_transforms['albumentations']=dict(
91
                            huesaturation=dict(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
92
                            r90=dict(p=0.5),
93
                            elastic=dict(p=0.5))
94
    main_transforms = default_transforms[transform_platform] if not user_transforms else user_transforms
95
    print(main_transforms)
96
    train_transforms=[transform_dict[transform_platform][k](v) for k,v in main_transforms.items()]
97
    torch_init=[transforms.ToPILImage(),transforms.Resize((patch_size,patch_size)),transforms.CenterCrop(patch_size)]
98
    albu_init=[alb.augmentations.transforms.Resize(patch_size, patch_size),
99
                alb.augmentations.transforms.CenterCrop(patch_size, patch_size)]
100
    tensor_norm=[transforms.ToTensor(),transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])] #mean and standard deviations for lung adenocarcinoma resection slides
101
    data_transforms = { 'torch': {
102
        'train': transforms.Compose(torch_init+train_transforms+tensor_norm),
103
        'val': transforms.Compose([
104
            transforms.ToPILImage(),
105
            transforms.Resize((patch_size,patch_size)),
106
            transforms.CenterCrop(patch_size),
107
            transforms.ToTensor(),
108
            transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])
109
        ]),
110
        'test': transforms.Compose([
111
            transforms.ToPILImage(),
112
            transforms.Resize((patch_size,patch_size)),
113
            transforms.CenterCrop(patch_size),
114
            transforms.ToTensor(),
115
            transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])
116
        ]),
117
        'pass': transforms.Compose([
118
            transforms.ToPILImage(),
119
            transforms.CenterCrop(patch_size),
120
            transforms.ToTensor(),
121
        ])
122
    },
123
    'albumentations':{
124
    'train':alb.core.composition.Compose(albu_init+train_transforms),
125
    'val':alb.core.composition.Compose([
126
        alb.augmentations.transforms.Resize(patch_size, patch_size),
127
        alb.augmentations.transforms.CenterCrop(patch_size, patch_size)
128
    ]),
129
    'test':alb.core.composition.Compose([
130
        alb.augmentations.transforms.Resize(patch_size, patch_size),
131
        alb.augmentations.transforms.CenterCrop(patch_size, patch_size)
132
    ]),
133
    'normalize':transforms.Compose([transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])])
134
    }}
135
136
    return data_transforms[transform_platform]
137
138
def create_transforms(mean, std):
139
    """Create transformers.
140
141
    Parameters
142
    ----------
143
    mean:list
144
        See get_data_transforms.
145
    std:list
146
        See get_data_transforms.
147
148
    Returns
149
    -------
150
    dict
151
        Transformers.
152
153
    """
154
    return get_data_transforms(patch_size = 224, mean=mean, std=std, resize=True)
155
156
157
158
def get_normalizer(normalization_file, dataset_opts):
159
    """Find mean and standard deviation of images in batches.
160
161
    Parameters
162
    ----------
163
    normalization_file:str
164
        File to store normalization information.
165
    dataset_opts:type
166
        Dictionary storing information to create DynamicDataset class.
167
168
    Returns
169
    -------
170
    dict
171
        Stores RGB mean, stdev.
172
173
    """
174
    if os.path.exists(normalization_file):
175
        norm_dict = torch.load(normalization_file)
176
    else:
177
        norm_dict = {'normalization_file':normalization_file}
178
179
    if 'normalization_file' in norm_dict:
180
181
        transformers = get_data_transforms(patch_size = 224, mean=[], std=[], resize=True, transform_platform='torch')
182
183
        dataset_opts['transformers']=transformers
184
        #print(dict(pos_annotation_class=pos_annotation_class, segmentation=segmentation, patch_size=patch_size, fix_names=fix_names, other_annotations=other_annotations))
185
186
        dataset = DynamicImageDataset(**dataset_opts)#nc.SafeDataset(DynamicImageDataset(**dataset_opts))
187
188
        if dataset_opts['classify_annotations']:
189
            dataset.binarize_annotations()
190
191
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
192
193
        all_mean = torch.tensor([0.,0.,0.],dtype=torch.float)#[]
194
195
        all_std = torch.tensor([0.,0.,0.],dtype=torch.float)
196
197
        if torch.cuda.is_available():
198
            all_mean=all_mean.cuda()
199
            all_std=all_std.cuda()
200
201
        with torch.no_grad():
202
            for i,(X,_) in enumerate(dataloader): # x,3,224,224
203
                if torch.cuda.is_available():
204
                    X=X.cuda()
205
                all_mean += torch.mean(X, (0,2,3))
206
                all_std += torch.std(X, (0,2,3))
207
208
        N=i+1
209
210
        all_mean /= float(N) #(np.array(all_mean).mean(axis=0)).tolist()
211
        all_std /= float(N) #(np.array(all_std).mean(axis=0)).tolist()
212
213
        all_mean = all_mean.detach().cpu().numpy().tolist()
214
        all_std = all_std.detach().cpu().numpy().tolist()
215
216
        torch.save(dict(mean=all_mean,std=all_std),norm_dict['normalization_file'])
217
218
        norm_dict = torch.load(norm_dict['normalization_file'])
219
    return norm_dict
220
221
def segmentation_transform(img,mask, transformer, normalizer, alb_reduction):
222
    """Run albumentations and return an image and its segmentation mask.
223
224
    Parameters
225
    ----------
226
    img:array
227
        Image as array
228
    mask:array
229
        Categorical pixel by pixel.
230
    transformer :
231
        Transformation object.
232
233
    Returns
234
    -------
235
    tuple arrays
236
        Image and mask array.
237
238
    """
239
    res=transformer(True, image=img, mask=mask)
240
    #res_mask_shape = res['mask'].size()
241
    return normalizer(torch.tensor(np.transpose(res['image']/alb_reduction,axes=(2,0,1)),dtype=torch.float)).float(), torch.tensor(res['mask']).long()#.view(res_mask_shape[0],res_mask_shape[1],res_mask_shape[2])
242
243
class DilationJitter:
244
    def __init__(self, dilation_jitter=dict(), segmentation=True, train_set=False):
245
        if dilation_jitter and segmentation and train_set:
246
            self.run_jitter=True
247
            self.dilation_jitter=dilation_jitter
248
            self.struct=generate_binary_structure(2,1) #structure=self.struct,
249
        else:
250
            self.run_jitter=False
251
252
253
    def __call__(self, mask):
254
        if self.run_jitter:
255
            for k in self.dilation_jitter:
256
                amount_jitter=int(round(max(np.random.normal(self.dilation_jitter[k]['mean'],
257
                                                        self.dilation_jitter[k]['std']),1)))
258
                #print((mask==k).compute())
259
                mask[binary_dilation(mask==k,structure=self.struct,iterations=amount_jitter)]=k
260
261
        return mask
262
263
264
class DynamicImageDataset(Dataset):
265
    """Generate image dataset that accesses images and annotations via dask.
266
267
    Parameters
268
    ----------
269
    dataset_df:dataframe
270
        Dataframe with WSI, which set it is in (train/test/val) and corresponding WSI labels if applicable.
271
    set:str
272
        Whether train, test, val or pass (normalization) set.
273
    patch_info_file:str
274
        SQL db with positional and annotation information on each slide.
275
    transformers:dict
276
        Contains transformers to apply on images.
277
    input_dir:str
278
        Directory where images comes from.
279
    target_names:list/str
280
        Names of initial targets, which may be modified.
281
    pos_annotation_class:str
282
        If selected and predicting on WSI, this class is labeled as a positive from the WSI, while the other classes are not.
283
    other_annotations:list
284
        Other annotations to consider from patch info db.
285
    segmentation:bool
286
        Conducting segmentation task?
287
    patch_size:int
288
        Patch size.
289
    fix_names:bool
290
        Whether to change the names of dataset_df.
291
    target_segmentation_class:list
292
        Now can be used for classification as well, matched with two below options, samples images only from this class. Can specify this and below two options multiple times.
293
    target_threshold:list
294
        Sampled only if above this threshold of occurence in the patches.
295
    oversampling_factor:list
296
        Over sample them at this amount.
297
    n_segmentation_classes:int
298
        Number classes to segment.
299
    gdl:bool
300
        Using generalized dice loss?
301
    mt_bce:bool
302
        For multi-target prediction tasks.
303
    classify_annotations:bool
304
        For classifying annotations.
305
306
    """
307
    # when building transformers, need a resize patch size to make patches 224 by 224
308
    #@pysnooper.snoop('init_data.log')
309
    def __init__(self,dataset_df, set, patch_info_file, transformers, input_dir, target_names, pos_annotation_class, other_annotations=[], segmentation=False, patch_size=224, fix_names=True, target_segmentation_class=-1, target_threshold=0., oversampling_factor=1., n_segmentation_classes=4, gdl=False, mt_bce=False, classify_annotations=False, dilation_jitter=dict(), modify_patches=True):
310
311
        #print('check',classify_annotations)
312
        reduce_alb=True
313
        self.patch_size=patch_size
314
        self.input_dir = input_dir
315
        self.alb_reduction=255. if reduce_alb else 1.
316
        self.transformer=transformers[set]
317
        original_set = copy.deepcopy(set)
318
        if set=='pass':
319
            set='train'
320
        self.targets = target_names
321
        self.mt_bce=mt_bce
322
        self.set = set
323
        self.segmentation = segmentation
324
        self.alb_normalizer=None
325
        if 'normalize' in transformers:
326
            self.alb_normalizer = transformers['normalize']
327
        if len(self.targets)==1:
328
            self.targets = self.targets[0]
329
        if original_set == 'pass':
330
            self.transform_fn = lambda x,y: (self.transformer(x), torch.tensor(1.,dtype=torch.float))
331
        else:
332
            if self.segmentation:
333
                self.transform_fn = lambda x,y: segmentation_transform(x,y, self.transformer, self.alb_normalizer, self.alb_reduction)
334
            else:
335
                if 'p' in dir(self.transformer):
336
                    self.transform_fn = lambda x,y: (self.alb_normalizer(torch.tensor(np.transpose(self.transformer(True, image=x)['image']/self.alb_reduction,axes=(2,0,1)),dtype=torch.float)), torch.from_numpy(y).float())
337
                else:
338
                    self.transform_fn = lambda x,y: (self.transformer(x), torch.from_numpy(y).float())
339
        self.image_set = dataset_df[dataset_df['set']==set]
340
        if self.segmentation:
341
            self.targets='target'
342
            self.image_set[self.targets] = 1.
343
        if not self.segmentation and fix_names:
344
            self.image_set.loc[:,'ID'] = self.image_set['ID'].map(fix_name)
345
        self.slide_info = pd.DataFrame(self.image_set.set_index('ID').loc[:,self.targets])
346
        if self.mt_bce and not self.segmentation:
347
            if pos_annotation_class:
348
                self.targets = [pos_annotation_class]+list(other_annotations)
349
            else:
350
                self.targets = None
351
        print(self.targets)
352
        IDs = self.slide_info.index.tolist()
353
        pi_dict=dict(input_info_db=patch_info_file,
354
                    slide_labels=self.slide_info,
355
                    pos_annotation_class=pos_annotation_class,
356
                    patch_size=patch_size,
357
                    segmentation=self.segmentation,
358
                    other_annotations=other_annotations,
359
                    target_segmentation_class=target_segmentation_class,
360
                    target_threshold=target_threshold,
361
                    classify_annotations=classify_annotations,
362
                    modify_patches=modify_patches)
363
        self.patch_info = modify_patch_info(**pi_dict)
364
365
        if self.segmentation and original_set!='pass':
366
            #IDs = self.patch_info['ID'].unique()
367
            self.segmentation_maps = {slide:npy2da(join(input_dir,'{}_mask.npy'.format(slide))) for slide in IDs}
368
        self.slides = {slide:load_preprocessed_img(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
369
        #print(self.slide_info)
370
        if original_set =='pass':
371
            self.segmentation=False
372
        #print(self.patch_info[self.targets].unique())
373
        if oversampling_factor > 1:
374
            self.patch_info = pd.concat([self.patch_info]*int(oversampling_factor),axis=0).reset_index(drop=True)
375
        elif oversampling_factor < 1:
376
            self.patch_info = self.patch_info.sample(frac=oversampling_factor).reset_index(drop=True)
377
        self.length = self.patch_info.shape[0]
378
        self.n_segmentation_classes = n_segmentation_classes
379
        self.gdl=gdl if self.segmentation else False
380
        self.binarized=False
381
        self.classify_annotations=classify_annotations
382
        print(self.targets)
383
        self.dilation_jitter=DilationJitter(dilation_jitter,self.segmentation,(original_set=='train'))
384
        if not self.targets:
385
            self.targets = [pos_annotation_class]+list(other_annotations)
386
387
    def concat(self, other_dataset):
388
        """Concatenate this dataset with others. Updates its own internal attributes.
389
390
        Parameters
391
        ----------
392
        other_dataset:DynamicImageDataset
393
            Other image dataset.
394
395
        """
396
        self.patch_info = pd.concat([self.patch_info, other_dataset.patch_info],axis=0).reset_index(drop=True)
397
        self.length = self.patch_info.shape[0]
398
        if self.segmentation:
399
            self.segmentation_maps.update(other_dataset.segmentation_maps)
400
            #print(self.segmentation_maps.keys())
401
402
    def retain_ID(self, ID):
403
        """Reduce the sample set to just images from one ID.
404
405
        Parameters
406
        ----------
407
        ID:str
408
            Basename/ID to predict on.
409
410
        Returns
411
        -------
412
        self
413
414
        """
415
        self.patch_info=self.patch_info.loc[self.patch_info['ID']==ID]
416
        self.length = self.patch_info.shape[0]
417
        self.segmentation_maps={ID:self.segmentation_maps[ID]}
418
        return self
419
420
    def split_by_ID(self):
421
        """Generator similar to groupby, but splits up by ID, generates (ID,data) using retain_ID.
422
423
        Returns
424
        -------
425
        generator
426
            ID, DynamicDataset
427
428
        """
429
        for ID in self.patch_info['ID'].unique():
430
            new_dataset = copy.deepcopy(self)
431
            yield ID, new_dataset.retain_ID(ID)
432
433
    def select_IDs(self, IDs):
434
        for ID in IDs:
435
            if ID in self.patch_info['ID'].unique():
436
                new_dataset = copy.deepcopy(self)
437
                yield ID, new_dataset.retain_ID(ID)
438
439
440
    def get_class_weights(self, i=0):#[0,1]
441
        """Weight loss function with weights inversely proportional to the class appearence.
442
443
        Parameters
444
        ----------
445
        i:int
446
            If multi-target, class used for weighting.
447
448
        Returns
449
        -------
450
        self
451
            Dataset.
452
453
        """
454
        if self.segmentation:
455
            label_counts=self.patch_info[list(map(str,list(range(self.n_segmentation_classes))))].sum(axis=0).values
456
            freq = label_counts/sum(label_counts)
457
            weights=1./(freq)
458
        elif self.mt_bce:
459
            weights=1./(self.patch_info.loc[:,self.targets].sum(axis=0).values)
460
            weights=weights/sum(weights)
461
        else:
462
            if self.binarized and len(self.targets)>1:
463
                y=np.argmax(self.patch_info.loc[:,self.targets].values,axis=1)
464
            elif (type(self.targets)==type('')):
465
                y=self.patch_info.loc[:,self.targets]
466
            else:
467
                y=self.patch_info.loc[:,self.targets[i]]
468
            y=y.values.astype(int).flatten()
469
            weights=compute_class_weight(class_weight='balanced',classes=np.unique(y),y=y)
470
        return weights
471
472
    def binarize_annotations(self, binarizer=None, num_targets=1, binary_threshold=0.):
473
        """Label binarize some annotations or threshold them if classifying slide annotations.
474
475
        Parameters
476
        ----------
477
        binarizer:LabelBinarizer
478
            Binarizes the labels of a column(s)
479
        num_targets:int
480
            Number of desired targets to preidict on.
481
        binary_threshold:float
482
            Amount of annotation in patch before positive annotation.
483
484
        Returns
485
        -------
486
        binarizer
487
488
        """
489
490
        annotations = self.patch_info['annotation']
491
        annots=[annot for annot in list(self.patch_info.iloc[:,6:]) if annot !='area']
492
        if not self.mt_bce and num_targets > 1:
493
            if binarizer == None:
494
                self.binarizer = LabelBinarizer().fit(annotations)
495
            else:
496
                self.binarizer = copy.deepcopy(binarizer)
497
            self.targets = self.binarizer.classes_
498
            annotation_labels = pd.DataFrame(self.binarizer.transform(annotations),index=self.patch_info.index,columns=self.targets).astype(float)
499
            for col in list(annotation_labels):
500
                if col in list(self.patch_info):
501
                    self.patch_info.loc[:,col]=annotation_labels[col].values
502
                else:
503
                    self.patch_info[col]=annotation_labels[col].values
504
        else:
505
            self.binarizer=None
506
            self.targets=annots
507
            if num_targets == 1:
508
                self.targets = [self.targets[-1]]
509
            if binary_threshold>0.:
510
                self.patch_info.loc[:,self.targets]=(self.patch_info[self.targets]>=binary_threshold).values.astype(np.float32)
511
            print(self.targets)
512
            #self.patch_info = pd.concat([self.patch_info,annotation_labels],axis=1)
513
        self.binarized=True
514
        return self.binarizer
515
516
    def subsample(self, p):
517
        """Sample subset of dataset.
518
519
        Parameters
520
        ----------
521
        p:float
522
            Fraction to subsample.
523
524
        """
525
        np.random.seed(42)
526
        self.patch_info = self.patch_info.sample(frac=p)
527
        self.length = self.patch_info.shape[0]
528
529
    def update_dataset(self, input_dir, new_db, prediction_basename=[]):
530
        """Experimental. Only use for segmentation for now."""
531
        self.input_dir=input_dir
532
        self.patch_info=load_sql_df(new_db, self.patch_size)
533
        IDs = self.patch_info['ID'].unique()
534
        self.slides = {slide:load_preprocessed_img(join(self.input_dir,'{}.zarr'.format(slide))) for slide in IDs}
535
        if self.segmentation:
536
            if prediction_basename:
537
                self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs if slide in prediction_basename}
538
            else:
539
                self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs}
540
        self.length = self.patch_info.shape[0]
541
542
    #@pysnooper.snoop("getitem.log")
543
    def __getitem__(self, i):
544
        patch_info = self.patch_info.iloc[i]
545
        ID = patch_info['ID']
546
        xs = patch_info['x']
547
        ys = patch_info['y']
548
        patch_size = patch_info['patch_size']
549
        if xs==np.nan:
550
            entire_image=True
551
        else:
552
            entire_image=False
553
        targets=self.targets
554
        use_long=False
555
        if not self.segmentation:
556
            y = patch_info.loc[list(self.targets) if not isinstance(self.targets,str) else self.targets]
557
            if isinstance(y,pd.Series):
558
                y=y.values.astype(float)
559
                if self.binarized and not self.mt_bce and len(y)>1:
560
                    y=np.array(y.argmax())
561
                    use_long=True
562
            y=np.array(y)
563
            if not y.shape:
564
                y=y.reshape(1)
565
        if self.segmentation:
566
            arr=self.segmentation_maps[ID]
567
            if not entire_image:
568
                arr=arr[xs:xs+patch_size,ys:ys+patch_size]
569
            arr=self.dilation_jitter(arr)
570
        y=(y if not self.segmentation else np.array(arr))
571
        #print(y)
572
        arr=self.slides[ID]
573
        if not entire_image:
574
            arr=arr[xs:xs+patch_size,ys:ys+patch_size,:3]
575
        image, y = self.transform_fn(arr.compute().astype(np.uint8), y)#.unsqueeze(0) # transpose .transpose([1,0,2])
576
        if not self.segmentation and not self.mt_bce and self.classify_annotations and use_long:
577
            y=y.long()
578
        #image_size=image.size()
579
        if self.gdl:
580
            y=class2one_hot(y, self.n_segmentation_classes)
581
        #   y=one_hot2dist(y)
582
        return image, y
583
584
    def __len__(self):
585
        return self.length
586
587
class NPYDataset(Dataset):
588
    def __init__(self, patch_info, patch_size, npy_file, transform, mmap=False):
589
        self.ID=os.path.basename(npy_file).split('.')[0]
590
        patch_info=patch_info=load_sql_df(patch_info,patch_size)
591
        self.patch_info=patch_info.loc[patch_info["ID"]==self.ID].reset_index()
592
        self.X=np.load(npy_file,mmap_mode=(None if not mmap else 'r+'))
593
        self.transform=transform
594
595
    def __getitem__(self,i):
596
        x,y,patch_size=self.patch_info.loc[i,["x","y","patch_size"]]
597
        return self.transform(self.X[x:x+patch_size,y:y+patch_size])
598
599
    def __len__(self):
600
        return self.patch_info.shape[0]
601
602
    def embed(self,model,batch_size,out_dir):
603
        Z=[]
604
        dataloader=DataLoader(self,batch_size=batch_size,shuffle=False)
605
        n_batches=len(self)//batch_size
606
        with torch.no_grad():
607
            for i,X in enumerate(dataloader):
608
                if torch.cuda.is_available():
609
                    X=X.cuda()
610
                z=model(X).detach().cpu().numpy()
611
                Z.append(z)
612
                print(f"Processed batch {i}/{n_batches}")
613
        Z=np.vstack(Z)
614
        torch.save(dict(embeddings=Z,patch_info=self.patch_info),os.path.join(out_dir,f"{self.ID}.pkl"))
615
        print("Embeddings saved")
616
        quit()