Switch to unified view

a b/Serialized/helper/mydata.py
1
'''
2
Data loaders and manipulation
3
BY: Yuval
4
'''
5
6
import os
7
from skimage.io import imread
8
import numpy as np
9
import pandas as pd
10
import torch
11
from skimage import io, transform
12
from torch.utils.data import Dataset, DataLoader
13
from torchvision import transforms, utils
14
from .myio import load_one_image
15
from itertools import product
16
from scipy.ndimage import zoom
17
from tqdm import tqdm_notebook
18
from multiprocessing import Pool
19
20
def narrow(arr,axis,start,end):
21
    # A numpy implementation to torch.narrow
22
    if axis<0:
23
        axis = arr.ndim-axis
24
    return arr[(slice(None),)*axis+(slice(start,end,1),)+(slice(None),)*(arr.ndim-axis-1)]
25
26
def _shift(arr, shift, axis, fill_value=None):
27
    ''' Shits an image. Fill the empty space
28
        Arges:
29
            arr  : numpy array with the image to shift
30
            shift: shift step
31
            axis : axis to shift
32
            fill_value: a value to fill empty spaces - default - None => Use min value in array
33
        Return:
34
            new shifted array
35
        Update: Yuval 12/10/2019
36
    '''
37
            
38
    if shift == 0:
39
        return arr
40
    if fill_value is None:
41
        fill_value = arr.min()
42
    if axis < 0:
43
        axis += arr.ndim
44
45
    dim_size = arr.shape[axis]
46
    after_start = dim_size - shift
47
    slice_shape=list(arr.shape)
48
    slice_shape[axis]=abs(shift)
49
    if shift < 0:
50
        after_start = -shift
51
        shift = dim_size - abs(shift)
52
        before = np.ones(slice_shape)*fill_value
53
        after = narrow(arr,axis, after_start, shift)
54
    else:
55
        before = narrow(arr,axis, 0, dim_size - shift)
56
        after = np.ones(slice_shape)*fill_value
57
    return np.concatenate([after, before], axis)
58
59
class MyTransform():
60
    '''
61
    My implementation for image transformation class
62
    Args:
63
     flip        : do a right/left mirroring.  default - False
64
     mean_change : mean change.                default - 0
65
     std_change  : std change.                 default - 0
66
     crop=None   : crop image to size. If tuple - x,y. the crop position will chabge randomly.  
67
                   default -None, keep image. the crop position will
68
     seed        : seed to use for random.     default - None
69
     zoom=0.0    : Zoom images, by             default - 0 (stay the same)
70
     rotate=0    : Rotation Angle, Deg         default - 0 
71
     shift=0     : Shift image                 default - 0
72
     out_size    : Output image size, int (x=y) or tuple 
73
                                               default - None => keep input                                           
74
   Methods:
75
       random: random transform using the init parameters
76
       Args:
77
           imags - numpy array with one or more images, for multiple images the first dim should be the channle
78
       Returns 
79
           numpy array with randomly transformed images of size out_size
80
           
81
   Updated: Yuval 12/10/19
82
   '''
83
    
84
    def __init__(self,
85
                 flip=False,
86
                 mean_change=0,
87
                 std_change=0,
88
                 crop=None,
89
                 seed=None,
90
                 zoom=0.0,
91
                 rotate=0,
92
                 shift=0,
93
                 out_size=None,
94
                 normal=True,
95
                 anti_aliasing=True
96
                ):
97
        
98
        self.do_flip = flip
99
        self.rotate_angle=rotate
100
        self.mean_change = mean_change
101
        self.std_change = std_change
102
        np.random.seed(seed)
103
        self.zoom_factor=zoom
104
        self.anti_aliasing=anti_aliasing
105
106
        if isinstance (crop,tuple):
107
            self.cropx=crop[0]
108
            self.cropy=crop[1]
109
        else:
110
            self.cropx=crop
111
            self.cropy=crop
112
        if isinstance (shift,tuple):
113
            self.shiftx=shift[0]
114
            self.shifty=shift[1]
115
        else:
116
            self.shiftx=shift
117
            self.shifty=shift
118
        if isinstance (out_size,tuple):
119
            self.out_sizex=out_size[0]
120
            self.out_sizey=out_size[1]
121
        else:
122
            self.out_sizex=out_size
123
            self.out_sizey=out_size
124
        if normal:
125
            self.randf = lambda n: torch.randn(n).numpy() 
126
        else:        
127
            self.randf = lambda n: 2.0*torch.rand(n).numpy()-1.0
128
        
129
    def random(self,imgs):
130
        sqz=False
131
        imgs=imgs.copy()
132
        if len(imgs.shape)==2:
133
            imgs=np.expand_dims(imgs, axis=0)
134
            sqz=True
135
        cropx,cropy = imgs.shape[1:3] if self.cropx is None else (self.cropx,self.cropy)
136
        out_sizex,out_sizey = imgs.shape[1:3] if self.out_sizex is None else (self.out_sizex,self.out_sizey)
137
        imgs=imgs.transpose(1,2,0)
138
        if (self.std_change>0) or (self.mean_change>0):
139
#            for i,ix in enumerate(self.channels):
140
#                imgs[i]=imgs[i]*np.random.normal(loc=1,scale=self.std_change)+np.random.normal(loc=0,scale=self.mean_change)
141
            imgs=self.change_mean_std(imgs,self.randf(1)[0]*self.mean_change,1+self.randf(1)[0]*self.std_change)
142
        if self.do_flip:
143
            if (torch.randint(low=0,high=2,size=(1,))[0]>0):
144
                imgs = self.flip(imgs)
145
        if self.rotate_angle>0:
146
            angle=int(torch.randint(-self.rotate_angle,self.rotate_angle,(1,))[0])
147
            imgs=self.rotate(imgs,angle)       
148
        if self.shiftx>0:
149
            imgs=self.img_shift(imgs,np.random.randint(-self.shiftx,self.shiftx),np.random.randint(-self.shifty,self.shifty))
150
        if self.zoom_factor!=0:
151
            if isinstance(self.zoom_factor,tuple):
152
                factor_x=1+self.randf(1)[0]*self.zoom_factor[0]
153
                factor_y=(1+self.randf(1)[0]*self.zoom_factor[1])*factor_x
154
                factor=(factor_x,factor_y)
155
            else:
156
                factor=1+np.random.randn(1)[0]*self.zoom_factor
157
            imgs=self.zoom(imgs,factor)
158
        x0=max(imgs.shape[1]//2-cropx//2,0)
159
        y0=max(imgs.shape[0]//2-cropy//2,0)
160
        imgs=self.crop(imgs,x0,y0,cropx,cropy)
161
        if (imgs.shape[0]!=out_sizey) or (imgs.shape[1]!=out_sizex):
162
            imgs=self.resize(imgs,out_sizex, out_sizey)
163
        imgs=imgs.transpose(2,0,1)        
164
        if sqz:
165
            imgs=imgs.squeeze(0)
166
        return imgs
167
    
168
    def flip(self,img,axis=1):
169
        return np.flip(img,axis=axis)
170
    
171
    def img_shift(self,img,x,y):
172
        return _shift(_shift(img,x,1),y,0)
173
    
174
    def crop(self,img,x,y,width,hight):
175
        if width>img.shape[1]:
176
            img=np.concatenate([np.ones((img.shape[0],(width-img.shape[1])//2+1,img.shape[-1]))*img.min(),
177
                                img,
178
                                np.ones((img.shape[0],(width-img.shape[1])//2+1,img.shape[-1]))*img.min()],1)
179
        if hight>img.shape[0]:
180
            img=np.concatenate([np.ones(((hight-img.shape[0])//2+1,img.shape[1],img.shape[-1]))*img.min(),
181
                                img,
182
                                np.ones(((hight-img.shape[0])//2+1,img.shape[1],img.shape[-1]))*img.min()],0)
183
        
184
        return img[x:x+width,y:y+hight,:]
185
    
186
    def change_mean_std(self,img,mean,std):
187
        if (isinstance(mean,list)):
188
            for i,(m,s) in zip(mean,std):
189
                img[...,i] = img[...,i]*s+m
190
        else:
191
            img = img*std + mean
192
        return img
193
    
194
    def resize(self,img,width,hight):
195
        return transform.resize(img,(hight,width),anti_aliasing=self.anti_aliasing)
196
    
197
    def zoom(self,img,factor):
198
#        timg=transform.rescale(img,1.0+factor,multichannel=True,mode='constant',cval=float(img.min()))
199
        return transform.rescale(img,factor,multichannel=True,mode='constant',cval=float(img.min()))
200
    
201
    def rotate(self,img,angle,resize=True):
202
        return transform.rotate(img, angle, resize=resize, center=None, order=1, 
203
                                mode='constant', cval=img.min(), clip=True, preserve_range=False)   
204
    
205
206
    
207
    
208
class sampler():
209
    '''
210
    sampler class for RSNA 2019. sample the images according to the tagets vector
211
    
212
    Args:
213
        arr:            numpy array with the target vectors
214
        norm_ratio:     float - the ratio of sampling for all zero target vector
215
        sampled ratios: a numpy vector,len: arr.shape[-1], sampling ratio by target value.
216
        unique_col:     numpy vector length arr.shape[0], 
217
                        with values which will be uniqued (no 2 samples would have the same value in this column
218
                        default: None (don't use)
219
    Methods:
220
        __call__:
221
        Args:
222
            index_arr: index vector, sample only from this index. default: None
223
            
224
    Update: Yuval 12/10/19
225
    '''
226
    
227
    def __init__(self,arr,norm_ratio,sampled_ratios,unique_col=None):
228
        self.arr=arr
229
        self.norm_ratio = norm_ratio
230
        self.sampled_ratios = sampled_ratios
231
        self.unique_col=unique_col
232
        
233
    def do_unique(self,indxes):
234
        if self.unique_col is not None:
235
            u,ind = np.unique(self.unique_col[indxes],return_index=True)
236
            return indxes[ind]
237
        else:
238
            return indxes
239
        
240
    def __call__(self,index_arr=None):
241
        if index_arr is None:
242
            index_arr = Ellipsis
243
        sampled = []
244
        indxes=np.argwhere(~self.arr[index_arr].any(axis=1)>0).squeeze()
245
        np.random.shuffle(indxes)
246
        indxes = self.do_unique(indxes)    
247
        sampled.append(indxes[:int(indxes.shape[0]*(self.norm_ratio-np.floor(self.norm_ratio)))])
248
        for i in range(int(self.norm_ratio)):
249
            indxes=np.argwhere(~self.arr[index_arr].any(axis=1)>0).squeeze()
250
            np.random.shuffle(indxes)  
251
            sampled.append(self.do_unique(indxes))
252
           
253
        for i,s in enumerate(self.sampled_ratios):
254
            s_=s
255
            if s_>1:
256
                s_=np.floor(s_)
257
                for j in range(int(s_)):
258
                    indxes=np.argwhere(self.arr[index_arr][...,i]>0).squeeze()
259
                    np.random.shuffle(indxes) 
260
                    sampled.append(self.do_unique(indxes))
261
                s_=s-s_
262
            if s_>0:
263
                    indxes=np.argwhere(self.arr[index_arr][...,i]>0).squeeze()
264
                    np.random.shuffle(indxes)
265
                    indxes = self.do_unique(indxes)
266
                    sampled.append(indxes[:int(indxes.shape[0]*s_)])
267
        return np.concatenate(sampled)
268
269
    
270
class simple_sampler():
271
    '''
272
    Simple sampler, will shuffle and sample a part of an array (dim - 0)
273
    Args:
274
        arr   :numpy array
275
        ratio :float, the sampling ratio, if>1 the same as =1
276
    Methods:
277
        __call__:
278
            Return numpy vector, type long, with sampled indexes
279
    Update: Yuval 12/10/19
280
    '''
281
    
282
    def __init__(self,arr,ratio):
283
        self.arr=arr
284
        self.ratio = ratio
285
        
286
    def __call__(self):
287
        indxes=np.arange(self.arr.shape[0])
288
        np.random.shuffle(indxes)
289
        return indxes[:int(self.arr.shape[0]*self.ratio)]
290
291
class Mixup():
292
    '''
293
    Method for mixup augmentation - TODO doc
294
    '''
295
    def __init__(self,alpha=0.4,device='gpu'):
296
        self.alpha=alpha
297
        self.device=device
298
        
299
    def __call__(self,images,targets):
300
        lambd = np.random.beta(self.alpha, self.alpha, targets.size(0))
301
        lambd = np.abs(lambd-0.5)+0.5 #np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
302
        shuffle = torch.randperm(targets.size(0)).to(self.device)
303
        lambd=torch.tensor(lambd,dtype=torch.float).to(self.device)
304
        out_images = (lambd*images.transpose(0,-1)+(1-lambd)*images[shuffle].transpose(0,-1)).transpose(0,-1)
305
        out_targets = torch.cat([targets.unsqueeze(-1),
306
                                 targets[shuffle].unsqueeze(-1),
307
                                 lambd.expand_as(targets.transpose(0,-1)).transpose(0,-1).unsqueeze(-1)],-1)
308
#        out_targets = (lambd*targets.transpose(0,-1)+(1-lambd)*targets[shuffle].transpose(0,-1)).transpose(0,-1)
309
        return out_images, out_targets
310
311
class ImageDataset(Dataset):
312
    '''
313
        RSNA 2019 Image (DICOM) dataset to use in Pytorch dataloader.
314
        Base class: Dataset
315
        Args:
316
            df              : Data frame with the image ids
317
            base_path       : File path for the images 
318
            transform=None  : Transfor method. to perform after the images are loaded. default: None - no transform
319
            out_shape=None  : Expected output shape - used only for sanity check.      default: None - no check
320
            window_eq=False : Do window equaliztion: (for backward competability, don't use it anymore use WSO)
321
                              False - No equalization
322
                              True  - Do equalization with window = [(40,80),(80,200),(600,2800)]
323
                              tuple/list shaped as above 
324
            equalize         : Equalize - return (image-mean)/std
325
            rescale=False    : Use DICOM parameters for rescale, done automaticaly if windows_eq!=False
326
       Update:Yuval 12/10/19       
327
    '''
328
329
    def __init__(self, df, base_path, transform=None,out_shape=None,window_eq=False,equalize=True,rescale=False):
330
        super(ImageDataset, self).__init__()
331
        self.df = df
332
        self.pids = df.PatientID.values
333
        self.transform = transform
334
        self.base_path = base_path
335
        self.out_shape=out_shape
336
        self.window_eq=window_eq
337
        self.equalize = equalize
338
        self.rescale=rescale
339
340
    def __len__(self):
341
        return self.pids.shape[0] 
342
343
    def __getitem__(self, idx):
344
        sample=load_one_image(self.pids[idx],equalize=self.equalize,base_path=self.base_path,file_path='',
345
                              window_eq=self.window_eq,rescale=self.rescale)
346
        sample = torch.tensor(sample,dtype=torch.float) \
347
            if self.transform is None else torch.tensor(self.transform(sample),dtype=torch.float)
348
        if len(sample.shape)==2:
349
            sample = sample.unsqueeze(0)
350
        if self.out_shape is not None:
351
            if sample.shape != self.out_shape:
352
                print ("Error in idx {}".format(idx))
353
                print (sample.shape,sample)
354
                sample = torch.randn(self.out_shape)*1e-5
355
        return sample
356
357
class FeatursDataset(Dataset):
358
    '''
359
        RSNA 2019 features dataset to use in Pytorch dataloader.
360
        Base class: Dataset
361
        Args:
362
            df              : Data frame
363
            features        : pytorch tensor with features.
364
                              Shape:
365
                                  option1 - (df.shape[0],num_of_features) - normal mode
366
                                  option2 - Not Implemented here yet
367
                                            (df.shape[0],N,num_of_features) - TTA mode, 
368
                                            will select random feature vector from same raw. 
369
                                      
370
            num_neighbors   : int, Number of neighbor to return, output will be shape (1+2*num_neighbors,features.shape[-1])
371
            ref_column      : string, The name of the column in df with the series id
372
            order_column    : string, The name of the column in df with the data which will determine the neighbors
373
            target_columns  : list of strings/None, names of column in df with the target data, 
374
                              default None - no target data will be returned
375
        Methods:
376
            __calls__:
377
                return: sample - tensor size (1+2*num_neighbors,features.shape[-1])
378
                        if target column is defined, return tuple with the 2nd valiable: 
379
                            targets - tensor size (1+2*num_neighbors,len(target_columns)), dtype=torch.float
380
        Update:Yuval 12/10/19       
381
    '''
382
    def __init__(self, df, features,num_neighbors, ref_column,order_column,target_columns=None):
383
        """
384
        Args:
385
            Todo
386
        """
387
        super(FeatursDataset, self).__init__()
388
        self.df = df.sort_values([ref_column,order_column])
389
        self.num_neighbors = num_neighbors
390
        self.ref_column = ref_column
391
        self.target_columns=target_columns
392
        self.target_tensor=None if target_columns is None else torch.tensor(df[self.target_columns].values,dtype=torch.float)
393
        self.features=features
394
        self.ref_arr=np.zeros((self.df.shape[0],1+2*self.num_neighbors),dtype=np.long)
395
        for i in range(-self.num_neighbors,self.num_neighbors+1):
396
            self.ref_arr[:,i+self.num_neighbors]=np.where(self.df[ref_column]==self.df[ref_column].shift(i),
397
                                                          np.roll(self.df.index.values,i),
398
                                                          self.df.index.values)
399
        self.ref_arr=torch.tensor(self.ref_arr[np.argsort(self.ref_arr[:,self.num_neighbors])])
400
                
401
                              
402
403
    def __len__(self):
404
        return self.ref_arr.shape[0] 
405
406
    def __getitem__(self, idx):
407
        sample=self.features[self.ref_arr[idx]]
408
        return sample if self.target_tensor is None else (sample, self.target_tensor[idx])
409
410
411
class FeatursDatasetCor(Dataset):
412
    """Not Used, like FeatursDataset but determine neighbors according to feature distance"""
413
414
    def __init__(self, df, features,num_neighbors, ref_column,target_columns=None):
415
        """
416
        Args:
417
            Todo
418
        """
419
        super(FeatursDatasetCor, self).__init__()
420
        self.num_neighbors = num_neighbors
421
        self.ref_column = ref_column
422
        self.target_columns=target_columns
423
        self.target_tensor=None if target_columns is None else torch.tensor(df[self.target_columns].values,dtype=torch.float)
424
        self.features=features
425
        self.ref_arr=np.zeros((df.shape[0],1+2*self.num_neighbors),dtype=np.long)
426
        unq,si=np.unique(df[self.ref_column].values,return_inverse=True)
427
        for i in tqdm_notebook(range(unq.shape[0]), leave=False):
428
            sinx = np.where(si==i)[0]
429
            r=np.corrcoef(self.features[sinx].numpy())
430
            self.ref_arr[sinx]=sinx[np.argsort(-r)][:,:1+2*self.num_neighbors]
431
432
    def __len__(self):
433
        return self.ref_arr.shape[0] 
434
435
    def __getitem__(self, idx):
436
        sample=self.features[self.ref_arr[idx]]
437
        return sample if self.target_tensor is None else (sample, self.target_tensor[idx])
438
439
class FullHeadImageDataset(Dataset):
440
    '''
441
        RSNA 2019  full head dataset to use in Pytorch dataloader.
442
        return all the slices from a scan in the right order.
443
        Base class: Dataset
444
        Args:
445
            df              : Data frame
446
            base_path       : File path for the images 
447
            SeriesIDs       : numpy array with scan series ids, each call the mathod will return one full series
448
            transform       : Transfor method. to perform after the images are loaded. 
449
                              The same transformation is done for all images in a series
450
                              default: None - no transform
451
            out_shape       : Expected output shape - used only for sanity check.
452
                              default: None - no check
453
            window_eq       : Do window equaliztion: (for backward competability, don't use it anymore use WSO)
454
                              False - No equalization [default]
455
                              True  - Do equalization with window = [(40,80),(80,200),(600,2800)]
456
                              tuple/list shaped as above 
457
            equalize        : Equalize - return (image-mean)/std [default - False]
458
            rescale         : Use DICOM parameters for rescale, done automaticaly if windows_eq!=False
459
                               default - True
460
            ref_column      : string, The name of the column in df with the series id
461
            order_column    : string, The name of the column in df with the data which will determine the neighbors
462
            target_columns  : list of strings/None, names of column in df with the target data, 
463
                              default None - no target data will be returned
464
        Methods:
465
            __calls__:
466
                return: sample - tensor size (# of images in series,image shape)
467
                        if target column is defined, return tuple with the 2nd valiable: 
468
                            targets - tensor size (# of images in series,len(target_columns)), dtype=torch.float
469
        Update:Yuval 12/10/19       
470
    '''
471
    
472
    def __init__(self, df,
473
                 base_path,
474
                 SeriesIDs,
475
                 ref_column,
476
                 order_column,
477
                 transform=None,
478
                 window_eq=False,
479
                 equalize=False,
480
                 rescale=True, 
481
                 target_columns=None,
482
                 full_transform=True):
483
        super(FullHeadImageDataset, self).__init__()
484
        self.df = df
485
        self.SeriesIDs=SeriesIDs
486
        self.ref_column=ref_column
487
        self.order_column=order_column
488
        self.target_columns=target_columns
489
        self.pids = df.PatientID.values
490
        self.transform = transform
491
        self.base_path = base_path
492
        self.window_eq=window_eq
493
        self.equalize = equalize
494
        self.rescale=rescale
495
        self.full_transform=full_transform
496
        self.ref_arr=df[ref_column].values
497
        self.order_arr=df[order_column].values
498
        self.target_tensor=None if target_columns is None else torch.tensor(df[self.target_columns].values,dtype=torch.float)
499
500
501
    def __len__(self):
502
        return self.SeriesIDs.shape[0]
503
504
    def __getitem__(self, idx):
505
        head_idx=np.where(self.ref_arr==self.SeriesIDs[idx])[0]
506
        sorted_head_idx=head_idx[np.argsort(self.order_arr[head_idx])]
507
        samples=[]
508
        for i in sorted_head_idx:
509
            sample=load_one_image(self.pids[i],equalize=self.equalize,base_path=self.base_path,file_path='',
510
                                  window_eq=self.window_eq,rescale=self.rescale)[None]
511
            if (not self.full_transform) and (self.transform is not None):
512
                sample = self.transform(sample)
513
            samples.append(sample)
514
        headimages=np.concatenate(samples,0)
515
        headimages = torch.tensor(headimages,dtype=torch.float) \
516
                if ((self.transform is None) or (not self.full_transform)) else torch.tensor(self.transform(headimages),dtype=torch.float)
517
        headimages=headimages[:,None]  # lat's make a batch out of it.
518
        if self.target_tensor is not None:
519
            targets=self.target_tensor[sorted_head_idx]
520
521
        return headimages if self.target_tensor is None else (headimages, targets)
522
523
class FullHeadDataset(Dataset):
524
    '''
525
        RSNA 2019 full head scan features dataset to use in Pytorch dataloader.
526
        Base class: Dataset
527
        Args:
528
            df              : Data frame
529
            SeriesIDs       : numpy array with scan series ids, each call the mathod will return one full series
530
            features        : pytorch tensor with features.
531
                              Shape:
532
                                  option1 - (df.shape[0],num_of_features) - normal mode
533
                                  option2 - (df.shape[0],N,num_of_features) - TTA mode, 
534
                                            will select random feature vector from same raw. 
535
                                      
536
            ref_column      : string, The name of the column in df with the series id
537
            order_column    : string, The name of the column in df with the data which will determine the neighbors
538
            target_columns  : list of strings/None, names of column in df with the target data, 
539
                              default None - no target data will be returned
540
        Methods:
541
            __calls__:
542
                return: sample - tensor size (# of images in series,features.shape[-1])
543
                        if target column is defined, return tuple with the 2nd valiable: 
544
                            targets - tensor size (# of images in series,len(target_columns)), dtype=torch.float
545
        Update:Yuval 12/10/19       
546
    '''
547
    def __init__(self, df, SeriesIDs,features, ref_column,order_column,target_columns=None,max_len=60,multi=1):
548
        """
549
        Args:
550
            Todo
551
        """
552
        super(FullHeadDataset, self).__init__()
553
        self.ref_column = ref_column
554
        self.target_columns=target_columns
555
        self.target_tensor=None if target_columns is None else torch.tensor(df[self.target_columns].values,dtype=torch.float)
556
        self.features=features
557
        self.ref_arr=df[ref_column].values
558
        self.order_arr=df[order_column].values
559
        self.max_len=max_len
560
        self.SeriesIDs=SeriesIDs
561
        self.multi=multi
562
                
563
                              
564
565
    def __len__(self):
566
        return self.SeriesIDs.shape[0] 
567
568
    def __getitem__(self, idx):
569
        sample = torch.zeros((self.max_len,self.features.shape[-1]*self.multi),dtype=torch.float)
570
        head_idx=np.where(self.ref_arr==self.SeriesIDs[idx])[0]
571
        sorted_head_idx=head_idx[np.argsort(self.order_arr[head_idx])]
572
        if self.features.dim()==3:         
573
            if self.multi>1:
574
                tta_idx=torch.zeros((head_idx.shape[0],self.multi),dtype=torch.long)
575
                for i in range(head_idx.shape[0]):
576
                    tta_idx[i]=torch.randperm(self.features.shape[1],dtype=torch.long)[:self.multi]
577
#                tta_idx2=(tta_idx+torch.LongTensor(head_idx.shape[0]).random_(1, self.features.shape[1]))%self.features.shape[1]
578
                sample[:head_idx.shape[0]]=torch.cat([self.features[sorted_head_idx,tta_idx[:,i]] for i in range(self.multi)],-1)
579
580
            else:
581
                tta_idx=torch.LongTensor(head_idx.shape[0]).random_(0, self.features.shape[1])
582
                sample[:head_idx.shape[0]]=self.features[sorted_head_idx,tta_idx]
583
584
585
        else:
586
            sample[:head_idx.shape[0]]=self.features[sorted_head_idx]
587
        if self.target_tensor is not None:
588
            targets = -1*torch.ones((self.max_len,self.target_tensor.shape[-1]),dtype=torch.float)
589
            targets[:head_idx.shape[0]]=self.target_tensor[sorted_head_idx]
590
        return sample if self.target_tensor is None else (sample, targets)
591
592
593
594
class DatasetCat(Dataset):
595
    '''
596
    Concatenate datasets for Pytorch dataloader
597
    The normal pytorch implementation does it only for raws. this is a "column" implementation
598
    Arges:
599
        datasets: list of datasets, of the same length
600
    Updated: Yuval 12/10/2019
601
    '''
602
    
603
    def __init__(self,datasets):
604
        '''
605
        Args: datasets - an iterable containing the datasets
606
        '''
607
        super(DatasetCat, self).__init__()
608
        self.datasets=datasets
609
        assert len(self.datasets)>0
610
        for dataset in datasets:
611
            assert len(self.datasets[0])==len(dataset),"Datasets length should be equal"
612
            
613
    def __len__(self):
614
        return len(self.datasets[0])
615
    
616
    def __getitem__(self, idx):
617
        outputs = tuple(dataset.__getitem__(idx) for i in self.datasets for dataset in (i if isinstance(i, tuple) else (i,)))
618
        return tuple(output for i in outputs for output in (i if isinstance(i, tuple) else (i,)))