Diff of /dataset_bone.py [000000] .. [dff9e0]

Switch to unified view

a b/dataset_bone.py
1
import os, torch
2
import numpy as np
3
from PIL import Image
4
from torch.utils.data import Dataset
5
from torchvision import transforms
6
import cv2
7
import random
8
import torchio as tio
9
import slicerio
10
import nrrd
11
import monai
12
import pickle
13
import nibabel as nib
14
from scipy.ndimage import zoom
15
from monai.transforms import OneOf
16
import einops
17
from funcs import *
18
from torchvision.transforms import InterpolationMode
19
#from .utils.transforms import ResizeLongestSide
20
21
22
class MRI_dataset(Dataset):
23
    def __init__(self,args, img_folder, mask_folder, img_list,phase='train',sample_num=50,channel_num=1,crop=False,crop_size=1024,targets=['femur','hip'],part_list=['all'],cls=1,if_prompt=True,prompt_type='point',region_type='largest_15',prompt_num=15,delete_empty_masks=False,if_attention_map=None):
24
        super(MRI_dataset, self).__init__()
25
        self.img_folder = img_folder
26
        self.mask_folder = mask_folder
27
        self.crop = crop
28
        self.crop_size = crop_size
29
        self.phase = phase
30
        self.channel_num=channel_num
31
        self.targets = targets
32
        self.segment_names_to_labels = []
33
        self.args = args
34
        self.cls = cls
35
        self.if_prompt = if_prompt
36
        self.region_type = region_type
37
        self.prompt_type = prompt_type
38
        self.prompt_num = prompt_num
39
        self.if_attention_map = if_attention_map
40
        
41
        for i,tag in enumerate(targets):
42
            self.segment_names_to_labels.append((tag,i))
43
            
44
        namefiles = open(img_list,'r')
45
        self.data_list = namefiles.read().split('\n')[:-1]
46
47
        if delete_empty_masks=='delete' or delete_empty_masks=='subsample':
48
            keep_idx = []
49
            for idx,data in enumerate(self.data_list):
50
                mask_path = data.split(' ')[1]
51
                if os.path.exists(os.path.join(self.mask_folder,mask_path)):
52
                    msk = Image.open(os.path.join(self.mask_folder,mask_path)).convert('L')
53
                else:
54
                    msk = Image.open(os.path.join(self.mask_folder.replace('2D-slices','2D-slices-generated'),mask_path)).convert('L')
55
                if 'all' in self.targets: # combine all targets as single target
56
                    mask_cls = np.array(np.array(msk,dtype=int)>0,dtype=int)
57
                else:
58
                    mask_cls = np.array(msk==self.cls,dtype=int)
59
                if part_list[0]=='all' and np.sum(mask_cls)>0:
60
                    keep_idx.append(idx) 
61
                elif np.sum(mask_cls)>0:
62
                    if_keep = False
63
                    for part in part_list:
64
                        if mask_path.find(part)>=0:
65
                            if_keep = True
66
                    if if_keep:
67
                        keep_idx.append(idx) 
68
            print('num with non-empty masks',len(keep_idx),'num with all masks',len(self.data_list))  
69
            if delete_empty_masks=='subsample':
70
                empty_idx = list(set(range(len(self.data_list)))-set(keep_idx))
71
                keep_empty_idx = random.sample(empty_idx, int(len(empty_idx)*0.1))
72
                keep_idx = empty_idx + keep_idx
73
            self.data_list = [self.data_list[i] for i in keep_idx] # keep the slices that contains target mask
74
  
75
        if phase == 'train':
76
            self.aug_img = [transforms.RandomEqualize(p=0.1),
77
                             transforms.ColorJitter(brightness=0.3, contrast=0.3,saturation=0.3,hue=0.3),
78
                             transforms.RandomAdjustSharpness(0.5, p=0.5),
79
                             ]
80
            self.transform_spatial = transforms.Compose([transforms.RandomResizedCrop(crop_size, scale=(0.8, 1.2)),
81
                     transforms.RandomRotation(45)])
82
            transform_img = [transforms.ToTensor()]
83
        else:
84
            transform_img = [
85
                         transforms.ToTensor(),
86
                             ]
87
        self.transform_img = transforms.Compose(transform_img)
88
            
89
    def __len__(self):
90
        return len(self.data_list)
91
        
92
    def __getitem__(self,index):
93
        # load image and the mask
94
        data = self.data_list[index]
95
        img_path = data.split(' ')[0]
96
        mask_path = data.split(' ')[1]
97
        slice_num = data.split(' ')[3] # total slice num for this object
98
        #print(img_path,mask_path)
99
        try:
100
            if os.path.exists(os.path.join(self.img_folder,img_path)):
101
                img = Image.open(os.path.join(self.img_folder,img_path)).convert('RGB')
102
            else:
103
                img = Image.open(os.path.join(self.img_folder.replace('2D-slices','2D-slices-generated'),img_path)).convert('RGB')
104
        except:
105
            # try to load image as numpy file
106
            img_arr = np.load(os.path.join(self.img_folder,img_path)) 
107
            img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+1e-8)*255,dtype=np.uint8)
108
            img_3c = np.tile(img_arr[:, :,None], [1, 1, 3])
109
            img = Image.fromarray(img_3c, 'RGB')
110
        if os.path.exists(os.path.join(self.mask_folder,mask_path)):
111
            msk = Image.open(os.path.join(self.mask_folder,mask_path)).convert('L')
112
        else:
113
            msk = Image.open(os.path.join(self.mask_folder.replace('2D-slices','2D-slices-generated'),mask_path)).convert('L')
114
                    
115
        if self.if_attention_map:
116
            slice_id = int(img_path.split('-')[-1].split('.')[0])
117
            slice_fraction = int(slice_id/int(slice_num)*4)
118
            img_id = '/'.join(img_path.split('-')[:-1]) +'_'+str(slice_fraction) + '.npy'
119
            attention_map = torch.tensor(np.load(os.path.join(self.if_attention_map,img_id)))
120
        else:
121
            attention_map = torch.zeros((64,64))
122
        
123
        img = transforms.Resize((self.args.image_size,self.args.image_size))(img)
124
        msk = transforms.Resize((self.args.image_size,self.args.image_size),InterpolationMode.NEAREST)(msk)
125
        
126
        state = torch.get_rng_state()
127
        if self.crop:
128
            im_w, im_h = img.size
129
            diff_w = max(0,self.crop_size-im_w)
130
            diff_h = max(0,self.crop_size-im_h)
131
            padding = (diff_w//2, diff_h//2, diff_w-diff_w//2, diff_h-diff_h//2)
132
            img = transforms.functional.pad(img, padding, 0, 'constant')
133
            torch.set_rng_state(state)
134
            t,l,h,w=transforms.RandomCrop.get_params(img,(self.crop_size,self.crop_size))
135
            img = transforms.functional.crop(img, t, l, h,w) 
136
            msk = transforms.functional.pad(msk, padding, 0, 'constant')
137
            msk = transforms.functional.crop(msk, t, l, h,w)
138
        if self.phase =='train':
139
            # add random optimazition
140
            aug_img_fuc = transforms.RandomChoice(self.aug_img)
141
            img = aug_img_fuc(img)
142
143
        img = self.transform_img(img)
144
        if self.phase == 'train':
145
            # It will randomly choose one
146
            random_transform = OneOf([monai.transforms.RandGaussianNoise(prob=0.5, mean=0.0, std=0.1),\
147
                                      monai.transforms.RandKSpaceSpikeNoise(prob=0.5, intensity_range=None, channel_wise=True),\
148
                                      monai.transforms.RandBiasField(degree=3),\
149
                                      monai.transforms.RandGibbsNoise(prob=0.5, alpha=(0.0, 1.0))
150
                                     ],weights=[0.3,0.3,0.2,0.2])
151
            img = random_transform(img).as_tensor()
152
        else:
153
            if img.mean()<0.05:
154
                img = min_max_normalize(img)
155
                img = monai.transforms.AdjustContrast(gamma=0.8)(img)
156
157
        
158
        if 'all' in self.targets: # combine all targets as single target
159
            msk = np.array(np.array(msk,dtype=int)>0,dtype=int)
160
        else:
161
            msk = np.array(msk,dtype=int)
162
            
163
        mask_cls = np.array(msk==self.cls,dtype=int)
164
165
        if self.phase=='train' and (not self.if_attention_map==None):
166
            mask_cls = np.repeat(mask_cls[np.newaxis,:, :], 3, axis=0)
167
            both_targets = torch.cat((img.unsqueeze(0), torch.tensor(mask_cls).unsqueeze(0)),0)
168
            transformed_targets = self.transform_spatial(both_targets)
169
            img = transformed_targets[0]
170
            mask_cls = np.array(transformed_targets[1][0].detach(),dtype=int)
171
        
172
        img = (img-img.min())/(img.max()-img.min()+1e-8)
173
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
174
        
175
        # generate mask and prompt
176
        if self.if_prompt:
177
            if self.prompt_type =='point':
178
                prompt,mask_now = get_first_prompt(mask_cls,region_type=self.region_type,prompt_num=self.prompt_num)
179
                pc = torch.as_tensor(prompt[:,:2], dtype=torch.float)
180
                pl = torch.as_tensor(prompt[:, -1], dtype=torch.float)
181
                msk = torch.unsqueeze(torch.tensor(mask_now,dtype=torch.long),0)
182
                return {'image':img,
183
                    'mask':msk,
184
                    'point_coords': pc,
185
                    'point_labels':pl,
186
                    'img_name':img_path,
187
                    'atten_map':attention_map,
188
            }
189
            elif self.prompt_type =='box':
190
                prompt,mask_now = get_top_boxes(mask_cls,region_type=self.region_type,prompt_num=self.prompt_num)
191
                box = torch.as_tensor(prompt, dtype=torch.float)
192
                msk = torch.unsqueeze(torch.tensor(mask_now,dtype=torch.long),0)
193
                return {'image':img,
194
                    'mask':msk,
195
                    'boxes':box,
196
                    'img_name':img_path,
197
                    'atten_map':attention_map,
198
            }
199
        else:
200
            msk = torch.unsqueeze(torch.tensor(mask_cls,dtype=torch.long),0)
201
            return {'image':img,
202
                'mask':msk,
203
                'img_name':img_path,
204
                'atten_map':attention_map,
205
        }
206
207
208
class MRI_dataset_multicls(Dataset):
209
    def __init__(self, args, img_folder, mask_folder, img_list, phase='train', sample_num=50, channel_num=1,
210
                 crop=False, crop_size=1024, targets=['combine_all'], part_list=['all'], if_prompt=True, 
211
                 prompt_type='point', if_spatial = True, region_type='largest_20', prompt_num=20, delete_empty_masks=False, 
212
                 label_mapping=None, reference_slice_num=0, if_attention_map=None,label_frequency_path=None):
213
        super(MRI_dataset_multicls, self).__init__()
214
        self.initialize_parameters(args, img_folder, mask_folder, img_list, phase, sample_num, channel_num,
215
                                   crop, crop_size, targets, part_list, if_prompt, prompt_type, if_spatial, region_type,
216
                                   prompt_num, delete_empty_masks, label_mapping, reference_slice_num, if_attention_map,label_frequency_path)
217
        self.load_label_mapping()
218
        self.prepare_data_list()
219
        self.filter_data_list()
220
        if phase == 'train':
221
            self.setup_transformations_train(crop_size)
222
        else:
223
            self.setup_transformations_other()
224
225
    def initialize_parameters(self, args, img_folder, mask_folder, img_list, phase, sample_num, channel_num,
226
                              crop, crop_size, targets, part_list, if_prompt, prompt_type, if_spatial, region_type,
227
                              prompt_num, delete_empty_masks, label_mapping, reference_slice_num, if_attention_map,label_frequency_path):
228
        self.args = args
229
        self.img_folder = img_folder
230
        self.mask_folder = mask_folder
231
        self.img_list = img_list
232
        self.phase = phase
233
        self.sample_num = sample_num
234
        self.channel_num = channel_num
235
        self.crop = crop
236
        self.crop_size = crop_size
237
        self.targets = targets
238
        self.part_list = part_list
239
        self.if_prompt = if_prompt
240
        self.prompt_type = prompt_type
241
        self.if_spatial = if_spatial
242
        self.region_type = region_type
243
        self.prompt_num = prompt_num
244
        self.delete_empty_masks = delete_empty_masks
245
        self.label_mapping = label_mapping
246
        self.reference_slice_num = reference_slice_num
247
        self.if_attention_map = if_attention_map
248
        self.label_dic = {}
249
        self.label_frequency_path = label_frequency_path
250
251
    def load_label_mapping(self):
252
        # Load the basic label mappings from a pickle file
253
        if self.label_mapping:
254
            with open(self.label_mapping, 'rb') as handle:
255
                self.segment_names_to_labels = pickle.load(handle)
256
            self.label_dic = {seg[1]: seg[0] for seg in self.segment_names_to_labels}
257
            self.label_name_list = [seg[0] for seg in self.segment_names_to_labels]
258
            print(self.label_dic)
259
        else:
260
            self.label_dic = {value: 'all' for value in range(1, 256)}
261
        
262
        # Load frequency data and remap classes if required
263
        if 'remap_frequency' in self.targets:
264
            self.load_and_remap_classes_based_on_frequency()
265
266
    def load_and_remap_classes_based_on_frequency(self):
267
        if self.label_frequency_path:
268
            with open(self.label_frequency_path, 'r') as file:
269
                all_label_frequencies = json.load(file)
270
            all_label_frequencies = all_label_frequencies['train']
271
            
272
            
273
            # Example to select the target region dynamically based on some condition or configuration
274
            target_region = self.part_list[0] 
275
            if target_region in all_label_frequencies:
276
                label_frequencies = all_label_frequencies[target_region]
277
                self.label_frequencies = label_frequencies
278
                #print(label_frequencies)
279
                self.remap_classes_based_on_frequency(label_frequencies)
280
            else:
281
                print(f"Warning: No frequency data found for the target region '{target_region}'. No remapping applied.")
282
    
283
    def remap_classes_based_on_frequency(self, label_frequencies):
284
        # Determine the frequency threshold for high vs. low frequency classes
285
        total = max(label_frequencies.values())
286
        high_freq_threshold = total * 0.5  # Adjust this threshold as needed
287
        
288
        # Initialize dictionaries to hold new class mappings
289
        high_freq_classes = {}
290
        low_freq_classes = {}
291
        
292
        # Assign classes to high or low frequency based on the threshold
293
        for label, freq in label_frequencies.items():
294
            if freq >= high_freq_threshold:
295
                high_freq_classes[label] = freq
296
            else:
297
                low_freq_classes[label] = freq
298
    
299
        # Update label dictionary based on the frequency classification
300
        #self.label_dic: {old_cls: old_name}
301
        new_label_dic = {}
302
        for cls, name in self.label_dic.items():
303
            if name in high_freq_classes:
304
                new_label_dic[cls] = name  # Retain original name for high frequency classes
305
            elif name in low_freq_classes:
306
                new_label_dic[cls] = 'combined_low_freq'  # Combine low frequency classes into one
307
    
308
        self.updated_label_dic = new_label_dic
309
        #new_label_dic: {old_cls: new_name}
310
        #print("Updated label dictionary with frequency remapping:", new_label_dic)
311
        
312
        #print('new_label_dic:',new_label_dic)
313
    
314
        # Sort high frequency keys by their frequency in descending order
315
        sorted_high_freq_labels = sorted(high_freq_classes.items(), key=lambda item: item[1], reverse=True)
316
        
317
        # Create a mapping for high frequency classes based on the sorted order
318
        original_to_new = {label: idx + 1 for idx, (label, _) in enumerate(sorted_high_freq_labels)}
319
320
        
321
        combined_low_freq_class_id = len(original_to_new) + 1
322
        # Ensure combined low frequency class is mapped correctly
323
        if 'combined_low_freq' in new_label_dic.values():
324
            for cls in low_freq_classes.keys():
325
                original_to_new[cls] = combined_low_freq_class_id
326
                
327
        # orignal_to_new {old_name:new_cls} 
328
        #print('original_to_new:',original_to_new)
329
330
        
331
        # Create additional dictionaries
332
        self.old_name_to_new_name = {self.label_dic[cls]: new_label for cls, new_label in new_label_dic.items()}
333
        self.old_cls_to_new_cls = {cls: original_to_new[self.label_dic[cls]] for cls in self.label_dic.keys() if self.label_dic[cls] in original_to_new}
334
335
        print('remapped label dic:',self.old_name_to_new_name)
336
        print('remapped cls dic:',self.old_cls_to_new_cls)
337
            
338
    def prepare_data_list(self):
339
        with open(self.img_list, 'r') as namefiles:
340
            self.data_list = namefiles.read().split('\n')[:-1]
341
        self.sp_symbol = ',' if ',' in self.data_list[0] else ' '
342
343
    def filter_data_list(self):
344
        keep_idx = []
345
        for idx, data in enumerate(self.data_list):
346
            img_path, mask_path = self.extract_paths(data)
347
            msk = Image.open(os.path.join(self.mask_folder, mask_path)).convert('L')
348
            mask_cls = self.determine_mask_class(msk)
349
350
            if self.should_keep(mask_cls, mask_path):
351
                keep_idx.append(idx)
352
                if self.reference_slice_num > 1:
353
                    self.add_reference_slice(img_path, mask_path, data)
354
355
        self.data_list = [self.data_list[i] for i in keep_idx]
356
        print('num with non-empty masks', len(keep_idx), 'num with all masks', len(self.data_list))
357
358
    def extract_paths(self, data):
359
        img_path = data.split(self.sp_symbol)[0]
360
        mask_path = data.split(self.sp_symbol)[1]
361
        return img_path.lstrip('/'), mask_path.lstrip('/')
362
363
    def determine_mask_class(self, msk):
364
        if 'combine_all' in self.targets:
365
            return np.array(msk, dtype=int) > 0
366
        elif self.targets[0] in self.label_name_list:
367
            return np.array(msk, dtype=int) == self.cls
368
        return np.array(msk, dtype=int)
369
370
    def should_keep(self, mask_cls, mask_path):
371
        if self.delete_empty_masks:
372
            has_mask = np.any(mask_cls > 0)
373
            if has_mask:
374
                if self.part_list[0] == 'all':
375
                    return True
376
                return any(mask_path.find(part) >= 0 for part in self.part_list)
377
            return False
378
        return True
379
380
381
    def add_reference_slice(self, img_path, mask_path, data):
382
        volume_name = ''.join(img_path.split('-')[:-1])  # get volume name
383
        slice_num = data.split(self.sp_symbol)[2]
384
        if volume_name not in self.reference_slices:
385
            self.reference_slices[volume_name] = []
386
        self.reference_slices[volume_name].append((img_path, mask_path, slice_num))
387
388
    def setup_transformations_train(self, crop_size):
389
        self.transform_img = transforms.Compose([
390
            transforms.ToTensor(),
391
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
392
        ])
393
        self.aug_img = transforms.RandomChoice([
394
            transforms.RandomEqualize(p=0.1),
395
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
396
            transforms.RandomAdjustSharpness(0.5, p=0.5),
397
        ])
398
        if self.if_spatial:
399
                self.transform_spatial = transforms.Compose([transforms.RandomResizedCrop(self.crop_size, scale=(0.5, 1.5), interpolation=InterpolationMode.NEAREST),
400
                         transforms.RandomRotation(45, interpolation=InterpolationMode.NEAREST)])
401
402
    def setup_transformations_other(self):
403
        self.transform_img = transforms.Compose([
404
            transforms.ToTensor(),
405
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
406
        ])
407
408
    def __len__(self):
409
        return len(self.data_list)
410
411
    def __getitem__(self, index):
412
        # Load image and mask, handle missing files
413
        data = self.data_list[index]
414
        img, msk, img_path, mask_path, slice_num = self.load_image_and_mask(data)
415
        
416
        # Optional: Load attention map
417
        attention_map = self.load_attention_map(img_path, slice_num) if self.if_attention_map else torch.zeros((64, 64))
418
    
419
        # Handle reference slices if necessary
420
        if self.reference_slice_num > 1:
421
            img, msk = self.handle_reference_slices(img_path, mask_path, slice_num)
422
        
423
        # Apply transformations
424
        img, msk = self.apply_transformations(img, msk)
425
    
426
        # Generate and process masks and prompts
427
        output_dict = self.prepare_output(img, msk, img_path, mask_path,attention_map)
428
        
429
    
430
        return output_dict
431
        
432
    def load_image_and_mask(self, data):
433
        img_path, mask_path = self.extract_paths(data)
434
        slice_num = data.split(self.sp_symbol)[3]  # Extract total slice number for this object
435
        
436
        img_folder = self.img_folder
437
        msk_folder = self.mask_folder
438
        
439
        img = Image.open(os.path.join(img_folder, img_path)).convert('RGB')
440
        msk = Image.open(os.path.join(msk_folder, mask_path)).convert('L')
441
    
442
        # Resize images for processing
443
        img = transforms.Resize((self.args.image_size, self.args.image_size))(img)
444
        msk = transforms.Resize((self.args.image_size, self.args.image_size), InterpolationMode.NEAREST)(msk)
445
    
446
        return img, msk, img_path, mask_path, int(slice_num)
447
448
    def load_attention_map(self, img_path, slice_num):
449
        slice_id = int(img_path.split('-')[-1].split('.')[0])
450
        slice_fraction = int(slice_id / slice_num * 4)
451
        img_id = '/'.join(img_path.split('-')[:-1]) + '_' + str(slice_fraction) + '.npy'
452
        attention_map = torch.tensor(np.load(os.path.join(self.if_attention_map, img_id)))
453
        return attention_map
454
455
456
    def apply_crop(self, img, msk):
457
        im_w, im_h = img.size
458
        diff_w = max(0, self.crop_size - im_w)
459
        diff_h = max(0, self.crop_size - im_h)
460
        padding = (diff_w // 2, diff_h // 2, diff_w - diff_w // 2, diff_h - diff_h // 2)
461
        img = transforms.functional.pad(img, padding, 0, 'constant')
462
        msk = transforms.functional.pad(msk, padding, 0, 'constant')
463
        t, l, h, w = transforms.RandomCrop.get_params(img, (self.crop_size, self.crop_size))
464
        img = transforms.functional.crop(img, t, l, h, w)
465
        msk = transforms.functional.crop(msk, t, l, h, w)
466
        return img, msk
467
        
468
    def apply_transformations(self, img, msk):
469
        if self.crop:
470
            img, msk = self.apply_crop(img, msk)
471
        if self.phase == 'train':
472
            img = self.aug_img(img)
473
        img = self.transform_img(img)
474
        if self.phase =='train' and self.if_spatial:
475
            mask_cls = np.array(msk,dtype=int)
476
            mask_cls = np.repeat(mask_cls[np.newaxis,:, :], 3, axis=0)
477
            both_targets = torch.cat((img.unsqueeze(0), torch.tensor(mask_cls).unsqueeze(0)),0)
478
            transformed_targets = self.transform_spatial(both_targets)
479
            img = transformed_targets[0]
480
            mask_cls = np.array(transformed_targets[1][0].detach(),dtype=int)
481
            msk = torch.tensor(mask_cls)
482
        return img, msk
483
484
    def handle_reference_slices(self, img_path, mask_path, slice_num):
485
        volume_name = ''.join(img_path.split('-')[:-1])
486
        ref_slices, ref_msks = [], []
487
        reference_slices = self.reference_slices.get(volume_name, [])
488
        for ref_slice in reference_slices:
489
            ref_img_path, ref_msk_path, _ = ref_slice
490
            ref_img = Image.open(os.path.join(self.img_folder, ref_img_path)).convert('RGB')
491
            ref_img = transforms.Resize((self.args.image_size, self.args.image_size))(ref_img)
492
            ref_img = self.transform_img(ref_img)
493
            ref_img = torch.unsqueeze(ref_img, 0)
494
            
495
            ref_msk = Image.open(os.path.join(self.mask_folder, ref_msk_path)).convert('L')
496
            ref_msk = transforms.Resize((self.args.image_size, self.args.image_size), InterpolationMode.NEAREST)(ref_msk)
497
            ref_msk = torch.tensor(ref_msk, dtype=torch.long)
498
            ref_msks.append(torch.unsqueeze(ref_msk, 0))
499
    
500
        img = torch.cat(ref_slices, dim=0)
501
        msk = torch.cat(ref_msks, dim=0)
502
        return img, msk
503
    
504
    def remap_classes_sequentially(self, mask, label_frequencies):
505
        # Apply the mapping to the mask
506
        remapped_mask = mask.copy()
507
        for old_cls, new_cls in  self.old_cls_to_new_cls.items():
508
            remapped_mask[mask == old_cls] = new_cls
509
        return remapped_mask
510
511
512
    def prepare_output(self, img, msk, img_path, mask_path, attention_map):
513
        # Normalize the image
514
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
515
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
516
    
517
        msk = np.array(msk, dtype=int)
518
        #print('ori_msk:',np.unique(msk))
519
        if self.label_frequency_path:
520
            msk = self.remap_classes_sequentially(msk,self.label_frequencies)  # Assuming msk is already using updated IDs
521
            #print('new_msk------------------------:',self.old_cls_to_new_cls)
522
            # Prepare one-hot encoding for the remapped classes
523
        
524
        unique_classes = np.unique(msk).tolist()
525
        if 0 in unique_classes:
526
            unique_classes.remove(0)
527
    
528
        if len(unique_classes) > 0:
529
            selected_dic = {k: self.label_dic[k] for k in unique_classes if k in self.label_dic}
530
        else:
531
            selected_dic = {}
532
    
533
        if self.targets[0] == 'random':
534
            mask_cls, selected_label, cls_one_hot = self.handle_random_target(msk, unique_classes, selected_dic)
535
        elif self.targets[0] in self.label_name_list:
536
            selected_label = self.targets[0]
537
            mask_cls = np.array(msk == self.cls, dtype=int)
538
            cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long)
539
            cls_one_hot[self.cls - 1] = 1
540
        else:
541
            selected_label = self.targets[0]
542
            mask_cls = msk
543
            cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long)
544
    
545
        # Handle prompts
546
        if self.if_prompt:
547
            prompt, mask_now, mask_cls = self.generate_prompt(mask_cls)
548
            ref_msk,_ = torch.max(mask_now>0,dim=0)
549
            return_dict = {'image': img, 'mask': mask_now, 'selected_label_name': selected_label,
550
                           'cls_one_hot': cls_one_hot, 'prompt': prompt, 'img_name': img_path,
551
                           'mask_ori': msk, 'mask_cls': mask_cls, 'all_label_dic': selected_dic,'ref_mask':ref_msk}
552
        else:
553
            if len(mask_cls.shape)==2:
554
                msk = torch.unsqueeze(torch.tensor(mask_cls,dtype=torch.long),0)
555
            elif len(mask_cls.shape)==4:
556
                msk = torch.squeeze(torch.tensor(mask_cls,dtype=torch.long))
557
            else:
558
                msk = torch.tensor(mask_cls,dtype=torch.long)
559
            ref_msk,_ = torch.max(msk>0,dim=0)
560
            #print('unique mask values:',msk.unique())
561
            return_dict = {'image': img, 'mask': msk, 'selected_label_name': selected_label,
562
                           'cls_one_hot': cls_one_hot, 'img_name': img_path, 'mask_ori': msk,'ref_mask':ref_msk}
563
    
564
        return return_dict
565
        
566
    def generate_prompt(self, mask_cls):
567
        if self.prompt_type == 'point':
568
            prompt, mask_now = get_first_prompt(mask_cls, region_type=self.region_type, prompt_num=self.prompt_num)
569
        elif self.prompt_type == 'box':
570
            prompt, mask_now = get_top_boxes(mask_cls, region_type=self.region_type, prompt_num=self.prompt_num)
571
        else:
572
            prompt = mask_now = None
573
        
574
        # Handling the shape of mask_now for return
575
        if mask_now is not None:
576
            if len(mask_now.shape) == 2:
577
                mask_now = torch.unsqueeze(torch.tensor(mask_now, dtype=torch.long), 0)
578
                mask_cls = torch.unsqueeze(torch.tensor(mask_cls, dtype=torch.long), 0)
579
            elif len(mask_now.shape) == 4:
580
                mask_now = torch.squeeze(torch.tensor(mask_now, dtype=torch.long))
581
            else:
582
                mask_now = torch.tensor(mask_now, dtype=torch.long)
583
                mask_cls = torch.tensor(mask_cls, dtype=torch.long)
584
    
585
        return prompt, mask_now, mask_cls
586
587
588
    def handle_random_target(self, msk, unique_classes, selected_dic):
589
        if len(unique_classes) > 0:
590
            random_selected_cls = random.choice(unique_classes)
591
            selected_label = selected_dic[random_selected_cls]
592
            mask_cls = np.array(msk == random_selected_cls, dtype=int)
593
            
594
            cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long)
595
            cls_one_hot[random_selected_cls - 1] = 1
596
        else:
597
            selected_label = None
598
            mask_cls = torch.zeros_like(msk)  # assuming msk is already a numpy array
599
            cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long)
600
    
601
        return mask_cls, selected_label, cls_one_hot