--- a +++ b/dataset_bone.py @@ -0,0 +1,601 @@ +import os, torch +import numpy as np +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +import cv2 +import random +import torchio as tio +import slicerio +import nrrd +import monai +import pickle +import nibabel as nib +from scipy.ndimage import zoom +from monai.transforms import OneOf +import einops +from funcs import * +from torchvision.transforms import InterpolationMode +#from .utils.transforms import ResizeLongestSide + + +class MRI_dataset(Dataset): + def __init__(self,args, img_folder, mask_folder, img_list,phase='train',sample_num=50,channel_num=1,crop=False,crop_size=1024,targets=['femur','hip'],part_list=['all'],cls=1,if_prompt=True,prompt_type='point',region_type='largest_15',prompt_num=15,delete_empty_masks=False,if_attention_map=None): + super(MRI_dataset, self).__init__() + self.img_folder = img_folder + self.mask_folder = mask_folder + self.crop = crop + self.crop_size = crop_size + self.phase = phase + self.channel_num=channel_num + self.targets = targets + self.segment_names_to_labels = [] + self.args = args + self.cls = cls + self.if_prompt = if_prompt + self.region_type = region_type + self.prompt_type = prompt_type + self.prompt_num = prompt_num + self.if_attention_map = if_attention_map + + for i,tag in enumerate(targets): + self.segment_names_to_labels.append((tag,i)) + + namefiles = open(img_list,'r') + self.data_list = namefiles.read().split('\n')[:-1] + + if delete_empty_masks=='delete' or delete_empty_masks=='subsample': + keep_idx = [] + for idx,data in enumerate(self.data_list): + mask_path = data.split(' ')[1] + if os.path.exists(os.path.join(self.mask_folder,mask_path)): + msk = Image.open(os.path.join(self.mask_folder,mask_path)).convert('L') + else: + msk = Image.open(os.path.join(self.mask_folder.replace('2D-slices','2D-slices-generated'),mask_path)).convert('L') + if 'all' in self.targets: # combine all targets as single target + mask_cls = np.array(np.array(msk,dtype=int)>0,dtype=int) + else: + mask_cls = np.array(msk==self.cls,dtype=int) + if part_list[0]=='all' and np.sum(mask_cls)>0: + keep_idx.append(idx) + elif np.sum(mask_cls)>0: + if_keep = False + for part in part_list: + if mask_path.find(part)>=0: + if_keep = True + if if_keep: + keep_idx.append(idx) + print('num with non-empty masks',len(keep_idx),'num with all masks',len(self.data_list)) + if delete_empty_masks=='subsample': + empty_idx = list(set(range(len(self.data_list)))-set(keep_idx)) + keep_empty_idx = random.sample(empty_idx, int(len(empty_idx)*0.1)) + keep_idx = empty_idx + keep_idx + self.data_list = [self.data_list[i] for i in keep_idx] # keep the slices that contains target mask + + if phase == 'train': + self.aug_img = [transforms.RandomEqualize(p=0.1), + transforms.ColorJitter(brightness=0.3, contrast=0.3,saturation=0.3,hue=0.3), + transforms.RandomAdjustSharpness(0.5, p=0.5), + ] + self.transform_spatial = transforms.Compose([transforms.RandomResizedCrop(crop_size, scale=(0.8, 1.2)), + transforms.RandomRotation(45)]) + transform_img = [transforms.ToTensor()] + else: + transform_img = [ + transforms.ToTensor(), + ] + self.transform_img = transforms.Compose(transform_img) + + def __len__(self): + return len(self.data_list) + + def __getitem__(self,index): + # load image and the mask + data = self.data_list[index] + img_path = data.split(' ')[0] + mask_path = data.split(' ')[1] + slice_num = data.split(' ')[3] # total slice num for this object + #print(img_path,mask_path) + try: + if os.path.exists(os.path.join(self.img_folder,img_path)): + img = Image.open(os.path.join(self.img_folder,img_path)).convert('RGB') + else: + img = Image.open(os.path.join(self.img_folder.replace('2D-slices','2D-slices-generated'),img_path)).convert('RGB') + except: + # try to load image as numpy file + img_arr = np.load(os.path.join(self.img_folder,img_path)) + img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+1e-8)*255,dtype=np.uint8) + img_3c = np.tile(img_arr[:, :,None], [1, 1, 3]) + img = Image.fromarray(img_3c, 'RGB') + if os.path.exists(os.path.join(self.mask_folder,mask_path)): + msk = Image.open(os.path.join(self.mask_folder,mask_path)).convert('L') + else: + msk = Image.open(os.path.join(self.mask_folder.replace('2D-slices','2D-slices-generated'),mask_path)).convert('L') + + if self.if_attention_map: + slice_id = int(img_path.split('-')[-1].split('.')[0]) + slice_fraction = int(slice_id/int(slice_num)*4) + img_id = '/'.join(img_path.split('-')[:-1]) +'_'+str(slice_fraction) + '.npy' + attention_map = torch.tensor(np.load(os.path.join(self.if_attention_map,img_id))) + else: + attention_map = torch.zeros((64,64)) + + img = transforms.Resize((self.args.image_size,self.args.image_size))(img) + msk = transforms.Resize((self.args.image_size,self.args.image_size),InterpolationMode.NEAREST)(msk) + + state = torch.get_rng_state() + if self.crop: + im_w, im_h = img.size + diff_w = max(0,self.crop_size-im_w) + diff_h = max(0,self.crop_size-im_h) + padding = (diff_w//2, diff_h//2, diff_w-diff_w//2, diff_h-diff_h//2) + img = transforms.functional.pad(img, padding, 0, 'constant') + torch.set_rng_state(state) + t,l,h,w=transforms.RandomCrop.get_params(img,(self.crop_size,self.crop_size)) + img = transforms.functional.crop(img, t, l, h,w) + msk = transforms.functional.pad(msk, padding, 0, 'constant') + msk = transforms.functional.crop(msk, t, l, h,w) + if self.phase =='train': + # add random optimazition + aug_img_fuc = transforms.RandomChoice(self.aug_img) + img = aug_img_fuc(img) + + img = self.transform_img(img) + if self.phase == 'train': + # It will randomly choose one + random_transform = OneOf([monai.transforms.RandGaussianNoise(prob=0.5, mean=0.0, std=0.1),\ + monai.transforms.RandKSpaceSpikeNoise(prob=0.5, intensity_range=None, channel_wise=True),\ + monai.transforms.RandBiasField(degree=3),\ + monai.transforms.RandGibbsNoise(prob=0.5, alpha=(0.0, 1.0)) + ],weights=[0.3,0.3,0.2,0.2]) + img = random_transform(img).as_tensor() + else: + if img.mean()<0.05: + img = min_max_normalize(img) + img = monai.transforms.AdjustContrast(gamma=0.8)(img) + + + if 'all' in self.targets: # combine all targets as single target + msk = np.array(np.array(msk,dtype=int)>0,dtype=int) + else: + msk = np.array(msk,dtype=int) + + mask_cls = np.array(msk==self.cls,dtype=int) + + if self.phase=='train' and (not self.if_attention_map==None): + mask_cls = np.repeat(mask_cls[np.newaxis,:, :], 3, axis=0) + both_targets = torch.cat((img.unsqueeze(0), torch.tensor(mask_cls).unsqueeze(0)),0) + transformed_targets = self.transform_spatial(both_targets) + img = transformed_targets[0] + mask_cls = np.array(transformed_targets[1][0].detach(),dtype=int) + + img = (img-img.min())/(img.max()-img.min()+1e-8) + img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) + + # generate mask and prompt + if self.if_prompt: + if self.prompt_type =='point': + prompt,mask_now = get_first_prompt(mask_cls,region_type=self.region_type,prompt_num=self.prompt_num) + pc = torch.as_tensor(prompt[:,:2], dtype=torch.float) + pl = torch.as_tensor(prompt[:, -1], dtype=torch.float) + msk = torch.unsqueeze(torch.tensor(mask_now,dtype=torch.long),0) + return {'image':img, + 'mask':msk, + 'point_coords': pc, + 'point_labels':pl, + 'img_name':img_path, + 'atten_map':attention_map, + } + elif self.prompt_type =='box': + prompt,mask_now = get_top_boxes(mask_cls,region_type=self.region_type,prompt_num=self.prompt_num) + box = torch.as_tensor(prompt, dtype=torch.float) + msk = torch.unsqueeze(torch.tensor(mask_now,dtype=torch.long),0) + return {'image':img, + 'mask':msk, + 'boxes':box, + 'img_name':img_path, + 'atten_map':attention_map, + } + else: + msk = torch.unsqueeze(torch.tensor(mask_cls,dtype=torch.long),0) + return {'image':img, + 'mask':msk, + 'img_name':img_path, + 'atten_map':attention_map, + } + + +class MRI_dataset_multicls(Dataset): + def __init__(self, args, img_folder, mask_folder, img_list, phase='train', sample_num=50, channel_num=1, + crop=False, crop_size=1024, targets=['combine_all'], part_list=['all'], if_prompt=True, + prompt_type='point', if_spatial = True, region_type='largest_20', prompt_num=20, delete_empty_masks=False, + label_mapping=None, reference_slice_num=0, if_attention_map=None,label_frequency_path=None): + super(MRI_dataset_multicls, self).__init__() + self.initialize_parameters(args, img_folder, mask_folder, img_list, phase, sample_num, channel_num, + crop, crop_size, targets, part_list, if_prompt, prompt_type, if_spatial, region_type, + prompt_num, delete_empty_masks, label_mapping, reference_slice_num, if_attention_map,label_frequency_path) + self.load_label_mapping() + self.prepare_data_list() + self.filter_data_list() + if phase == 'train': + self.setup_transformations_train(crop_size) + else: + self.setup_transformations_other() + + def initialize_parameters(self, args, img_folder, mask_folder, img_list, phase, sample_num, channel_num, + crop, crop_size, targets, part_list, if_prompt, prompt_type, if_spatial, region_type, + prompt_num, delete_empty_masks, label_mapping, reference_slice_num, if_attention_map,label_frequency_path): + self.args = args + self.img_folder = img_folder + self.mask_folder = mask_folder + self.img_list = img_list + self.phase = phase + self.sample_num = sample_num + self.channel_num = channel_num + self.crop = crop + self.crop_size = crop_size + self.targets = targets + self.part_list = part_list + self.if_prompt = if_prompt + self.prompt_type = prompt_type + self.if_spatial = if_spatial + self.region_type = region_type + self.prompt_num = prompt_num + self.delete_empty_masks = delete_empty_masks + self.label_mapping = label_mapping + self.reference_slice_num = reference_slice_num + self.if_attention_map = if_attention_map + self.label_dic = {} + self.label_frequency_path = label_frequency_path + + def load_label_mapping(self): + # Load the basic label mappings from a pickle file + if self.label_mapping: + with open(self.label_mapping, 'rb') as handle: + self.segment_names_to_labels = pickle.load(handle) + self.label_dic = {seg[1]: seg[0] for seg in self.segment_names_to_labels} + self.label_name_list = [seg[0] for seg in self.segment_names_to_labels] + print(self.label_dic) + else: + self.label_dic = {value: 'all' for value in range(1, 256)} + + # Load frequency data and remap classes if required + if 'remap_frequency' in self.targets: + self.load_and_remap_classes_based_on_frequency() + + def load_and_remap_classes_based_on_frequency(self): + if self.label_frequency_path: + with open(self.label_frequency_path, 'r') as file: + all_label_frequencies = json.load(file) + all_label_frequencies = all_label_frequencies['train'] + + + # Example to select the target region dynamically based on some condition or configuration + target_region = self.part_list[0] + if target_region in all_label_frequencies: + label_frequencies = all_label_frequencies[target_region] + self.label_frequencies = label_frequencies + #print(label_frequencies) + self.remap_classes_based_on_frequency(label_frequencies) + else: + print(f"Warning: No frequency data found for the target region '{target_region}'. No remapping applied.") + + def remap_classes_based_on_frequency(self, label_frequencies): + # Determine the frequency threshold for high vs. low frequency classes + total = max(label_frequencies.values()) + high_freq_threshold = total * 0.5 # Adjust this threshold as needed + + # Initialize dictionaries to hold new class mappings + high_freq_classes = {} + low_freq_classes = {} + + # Assign classes to high or low frequency based on the threshold + for label, freq in label_frequencies.items(): + if freq >= high_freq_threshold: + high_freq_classes[label] = freq + else: + low_freq_classes[label] = freq + + # Update label dictionary based on the frequency classification + #self.label_dic: {old_cls: old_name} + new_label_dic = {} + for cls, name in self.label_dic.items(): + if name in high_freq_classes: + new_label_dic[cls] = name # Retain original name for high frequency classes + elif name in low_freq_classes: + new_label_dic[cls] = 'combined_low_freq' # Combine low frequency classes into one + + self.updated_label_dic = new_label_dic + #new_label_dic: {old_cls: new_name} + #print("Updated label dictionary with frequency remapping:", new_label_dic) + + #print('new_label_dic:',new_label_dic) + + # Sort high frequency keys by their frequency in descending order + sorted_high_freq_labels = sorted(high_freq_classes.items(), key=lambda item: item[1], reverse=True) + + # Create a mapping for high frequency classes based on the sorted order + original_to_new = {label: idx + 1 for idx, (label, _) in enumerate(sorted_high_freq_labels)} + + + combined_low_freq_class_id = len(original_to_new) + 1 + # Ensure combined low frequency class is mapped correctly + if 'combined_low_freq' in new_label_dic.values(): + for cls in low_freq_classes.keys(): + original_to_new[cls] = combined_low_freq_class_id + + # orignal_to_new {old_name:new_cls} + #print('original_to_new:',original_to_new) + + + # Create additional dictionaries + self.old_name_to_new_name = {self.label_dic[cls]: new_label for cls, new_label in new_label_dic.items()} + self.old_cls_to_new_cls = {cls: original_to_new[self.label_dic[cls]] for cls in self.label_dic.keys() if self.label_dic[cls] in original_to_new} + + print('remapped label dic:',self.old_name_to_new_name) + print('remapped cls dic:',self.old_cls_to_new_cls) + + def prepare_data_list(self): + with open(self.img_list, 'r') as namefiles: + self.data_list = namefiles.read().split('\n')[:-1] + self.sp_symbol = ',' if ',' in self.data_list[0] else ' ' + + def filter_data_list(self): + keep_idx = [] + for idx, data in enumerate(self.data_list): + img_path, mask_path = self.extract_paths(data) + msk = Image.open(os.path.join(self.mask_folder, mask_path)).convert('L') + mask_cls = self.determine_mask_class(msk) + + if self.should_keep(mask_cls, mask_path): + keep_idx.append(idx) + if self.reference_slice_num > 1: + self.add_reference_slice(img_path, mask_path, data) + + self.data_list = [self.data_list[i] for i in keep_idx] + print('num with non-empty masks', len(keep_idx), 'num with all masks', len(self.data_list)) + + def extract_paths(self, data): + img_path = data.split(self.sp_symbol)[0] + mask_path = data.split(self.sp_symbol)[1] + return img_path.lstrip('/'), mask_path.lstrip('/') + + def determine_mask_class(self, msk): + if 'combine_all' in self.targets: + return np.array(msk, dtype=int) > 0 + elif self.targets[0] in self.label_name_list: + return np.array(msk, dtype=int) == self.cls + return np.array(msk, dtype=int) + + def should_keep(self, mask_cls, mask_path): + if self.delete_empty_masks: + has_mask = np.any(mask_cls > 0) + if has_mask: + if self.part_list[0] == 'all': + return True + return any(mask_path.find(part) >= 0 for part in self.part_list) + return False + return True + + + def add_reference_slice(self, img_path, mask_path, data): + volume_name = ''.join(img_path.split('-')[:-1]) # get volume name + slice_num = data.split(self.sp_symbol)[2] + if volume_name not in self.reference_slices: + self.reference_slices[volume_name] = [] + self.reference_slices[volume_name].append((img_path, mask_path, slice_num)) + + def setup_transformations_train(self, crop_size): + self.transform_img = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.aug_img = transforms.RandomChoice([ + transforms.RandomEqualize(p=0.1), + transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3), + transforms.RandomAdjustSharpness(0.5, p=0.5), + ]) + if self.if_spatial: + self.transform_spatial = transforms.Compose([transforms.RandomResizedCrop(self.crop_size, scale=(0.5, 1.5), interpolation=InterpolationMode.NEAREST), + transforms.RandomRotation(45, interpolation=InterpolationMode.NEAREST)]) + + def setup_transformations_other(self): + self.transform_img = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + # Load image and mask, handle missing files + data = self.data_list[index] + img, msk, img_path, mask_path, slice_num = self.load_image_and_mask(data) + + # Optional: Load attention map + attention_map = self.load_attention_map(img_path, slice_num) if self.if_attention_map else torch.zeros((64, 64)) + + # Handle reference slices if necessary + if self.reference_slice_num > 1: + img, msk = self.handle_reference_slices(img_path, mask_path, slice_num) + + # Apply transformations + img, msk = self.apply_transformations(img, msk) + + # Generate and process masks and prompts + output_dict = self.prepare_output(img, msk, img_path, mask_path,attention_map) + + + return output_dict + + def load_image_and_mask(self, data): + img_path, mask_path = self.extract_paths(data) + slice_num = data.split(self.sp_symbol)[3] # Extract total slice number for this object + + img_folder = self.img_folder + msk_folder = self.mask_folder + + img = Image.open(os.path.join(img_folder, img_path)).convert('RGB') + msk = Image.open(os.path.join(msk_folder, mask_path)).convert('L') + + # Resize images for processing + img = transforms.Resize((self.args.image_size, self.args.image_size))(img) + msk = transforms.Resize((self.args.image_size, self.args.image_size), InterpolationMode.NEAREST)(msk) + + return img, msk, img_path, mask_path, int(slice_num) + + def load_attention_map(self, img_path, slice_num): + slice_id = int(img_path.split('-')[-1].split('.')[0]) + slice_fraction = int(slice_id / slice_num * 4) + img_id = '/'.join(img_path.split('-')[:-1]) + '_' + str(slice_fraction) + '.npy' + attention_map = torch.tensor(np.load(os.path.join(self.if_attention_map, img_id))) + return attention_map + + + def apply_crop(self, img, msk): + im_w, im_h = img.size + diff_w = max(0, self.crop_size - im_w) + diff_h = max(0, self.crop_size - im_h) + padding = (diff_w // 2, diff_h // 2, diff_w - diff_w // 2, diff_h - diff_h // 2) + img = transforms.functional.pad(img, padding, 0, 'constant') + msk = transforms.functional.pad(msk, padding, 0, 'constant') + t, l, h, w = transforms.RandomCrop.get_params(img, (self.crop_size, self.crop_size)) + img = transforms.functional.crop(img, t, l, h, w) + msk = transforms.functional.crop(msk, t, l, h, w) + return img, msk + + def apply_transformations(self, img, msk): + if self.crop: + img, msk = self.apply_crop(img, msk) + if self.phase == 'train': + img = self.aug_img(img) + img = self.transform_img(img) + if self.phase =='train' and self.if_spatial: + mask_cls = np.array(msk,dtype=int) + mask_cls = np.repeat(mask_cls[np.newaxis,:, :], 3, axis=0) + both_targets = torch.cat((img.unsqueeze(0), torch.tensor(mask_cls).unsqueeze(0)),0) + transformed_targets = self.transform_spatial(both_targets) + img = transformed_targets[0] + mask_cls = np.array(transformed_targets[1][0].detach(),dtype=int) + msk = torch.tensor(mask_cls) + return img, msk + + def handle_reference_slices(self, img_path, mask_path, slice_num): + volume_name = ''.join(img_path.split('-')[:-1]) + ref_slices, ref_msks = [], [] + reference_slices = self.reference_slices.get(volume_name, []) + for ref_slice in reference_slices: + ref_img_path, ref_msk_path, _ = ref_slice + ref_img = Image.open(os.path.join(self.img_folder, ref_img_path)).convert('RGB') + ref_img = transforms.Resize((self.args.image_size, self.args.image_size))(ref_img) + ref_img = self.transform_img(ref_img) + ref_img = torch.unsqueeze(ref_img, 0) + + ref_msk = Image.open(os.path.join(self.mask_folder, ref_msk_path)).convert('L') + ref_msk = transforms.Resize((self.args.image_size, self.args.image_size), InterpolationMode.NEAREST)(ref_msk) + ref_msk = torch.tensor(ref_msk, dtype=torch.long) + ref_msks.append(torch.unsqueeze(ref_msk, 0)) + + img = torch.cat(ref_slices, dim=0) + msk = torch.cat(ref_msks, dim=0) + return img, msk + + def remap_classes_sequentially(self, mask, label_frequencies): + # Apply the mapping to the mask + remapped_mask = mask.copy() + for old_cls, new_cls in self.old_cls_to_new_cls.items(): + remapped_mask[mask == old_cls] = new_cls + return remapped_mask + + + def prepare_output(self, img, msk, img_path, mask_path, attention_map): + # Normalize the image + img = (img - img.min()) / (img.max() - img.min() + 1e-8) + img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) + + msk = np.array(msk, dtype=int) + #print('ori_msk:',np.unique(msk)) + if self.label_frequency_path: + msk = self.remap_classes_sequentially(msk,self.label_frequencies) # Assuming msk is already using updated IDs + #print('new_msk------------------------:',self.old_cls_to_new_cls) + # Prepare one-hot encoding for the remapped classes + + unique_classes = np.unique(msk).tolist() + if 0 in unique_classes: + unique_classes.remove(0) + + if len(unique_classes) > 0: + selected_dic = {k: self.label_dic[k] for k in unique_classes if k in self.label_dic} + else: + selected_dic = {} + + if self.targets[0] == 'random': + mask_cls, selected_label, cls_one_hot = self.handle_random_target(msk, unique_classes, selected_dic) + elif self.targets[0] in self.label_name_list: + selected_label = self.targets[0] + mask_cls = np.array(msk == self.cls, dtype=int) + cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) + cls_one_hot[self.cls - 1] = 1 + else: + selected_label = self.targets[0] + mask_cls = msk + cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) + + # Handle prompts + if self.if_prompt: + prompt, mask_now, mask_cls = self.generate_prompt(mask_cls) + ref_msk,_ = torch.max(mask_now>0,dim=0) + return_dict = {'image': img, 'mask': mask_now, 'selected_label_name': selected_label, + 'cls_one_hot': cls_one_hot, 'prompt': prompt, 'img_name': img_path, + 'mask_ori': msk, 'mask_cls': mask_cls, 'all_label_dic': selected_dic,'ref_mask':ref_msk} + else: + if len(mask_cls.shape)==2: + msk = torch.unsqueeze(torch.tensor(mask_cls,dtype=torch.long),0) + elif len(mask_cls.shape)==4: + msk = torch.squeeze(torch.tensor(mask_cls,dtype=torch.long)) + else: + msk = torch.tensor(mask_cls,dtype=torch.long) + ref_msk,_ = torch.max(msk>0,dim=0) + #print('unique mask values:',msk.unique()) + return_dict = {'image': img, 'mask': msk, 'selected_label_name': selected_label, + 'cls_one_hot': cls_one_hot, 'img_name': img_path, 'mask_ori': msk,'ref_mask':ref_msk} + + return return_dict + + def generate_prompt(self, mask_cls): + if self.prompt_type == 'point': + prompt, mask_now = get_first_prompt(mask_cls, region_type=self.region_type, prompt_num=self.prompt_num) + elif self.prompt_type == 'box': + prompt, mask_now = get_top_boxes(mask_cls, region_type=self.region_type, prompt_num=self.prompt_num) + else: + prompt = mask_now = None + + # Handling the shape of mask_now for return + if mask_now is not None: + if len(mask_now.shape) == 2: + mask_now = torch.unsqueeze(torch.tensor(mask_now, dtype=torch.long), 0) + mask_cls = torch.unsqueeze(torch.tensor(mask_cls, dtype=torch.long), 0) + elif len(mask_now.shape) == 4: + mask_now = torch.squeeze(torch.tensor(mask_now, dtype=torch.long)) + else: + mask_now = torch.tensor(mask_now, dtype=torch.long) + mask_cls = torch.tensor(mask_cls, dtype=torch.long) + + return prompt, mask_now, mask_cls + + + def handle_random_target(self, msk, unique_classes, selected_dic): + if len(unique_classes) > 0: + random_selected_cls = random.choice(unique_classes) + selected_label = selected_dic[random_selected_cls] + mask_cls = np.array(msk == random_selected_cls, dtype=int) + + cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) + cls_one_hot[random_selected_cls - 1] = 1 + else: + selected_label = None + mask_cls = torch.zeros_like(msk) # assuming msk is already a numpy array + cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) + + return mask_cls, selected_label, cls_one_hot \ No newline at end of file