|
a |
|
b/dataset_bone.py |
|
|
1 |
import os, torch |
|
|
2 |
import numpy as np |
|
|
3 |
from PIL import Image |
|
|
4 |
from torch.utils.data import Dataset |
|
|
5 |
from torchvision import transforms |
|
|
6 |
import cv2 |
|
|
7 |
import random |
|
|
8 |
import torchio as tio |
|
|
9 |
import slicerio |
|
|
10 |
import nrrd |
|
|
11 |
import monai |
|
|
12 |
import pickle |
|
|
13 |
import nibabel as nib |
|
|
14 |
from scipy.ndimage import zoom |
|
|
15 |
from monai.transforms import OneOf |
|
|
16 |
import einops |
|
|
17 |
from funcs import * |
|
|
18 |
from torchvision.transforms import InterpolationMode |
|
|
19 |
#from .utils.transforms import ResizeLongestSide |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
class MRI_dataset(Dataset): |
|
|
23 |
def __init__(self,args, img_folder, mask_folder, img_list,phase='train',sample_num=50,channel_num=1,crop=False,crop_size=1024,targets=['femur','hip'],part_list=['all'],cls=1,if_prompt=True,prompt_type='point',region_type='largest_15',prompt_num=15,delete_empty_masks=False,if_attention_map=None): |
|
|
24 |
super(MRI_dataset, self).__init__() |
|
|
25 |
self.img_folder = img_folder |
|
|
26 |
self.mask_folder = mask_folder |
|
|
27 |
self.crop = crop |
|
|
28 |
self.crop_size = crop_size |
|
|
29 |
self.phase = phase |
|
|
30 |
self.channel_num=channel_num |
|
|
31 |
self.targets = targets |
|
|
32 |
self.segment_names_to_labels = [] |
|
|
33 |
self.args = args |
|
|
34 |
self.cls = cls |
|
|
35 |
self.if_prompt = if_prompt |
|
|
36 |
self.region_type = region_type |
|
|
37 |
self.prompt_type = prompt_type |
|
|
38 |
self.prompt_num = prompt_num |
|
|
39 |
self.if_attention_map = if_attention_map |
|
|
40 |
|
|
|
41 |
for i,tag in enumerate(targets): |
|
|
42 |
self.segment_names_to_labels.append((tag,i)) |
|
|
43 |
|
|
|
44 |
namefiles = open(img_list,'r') |
|
|
45 |
self.data_list = namefiles.read().split('\n')[:-1] |
|
|
46 |
|
|
|
47 |
if delete_empty_masks=='delete' or delete_empty_masks=='subsample': |
|
|
48 |
keep_idx = [] |
|
|
49 |
for idx,data in enumerate(self.data_list): |
|
|
50 |
mask_path = data.split(' ')[1] |
|
|
51 |
if os.path.exists(os.path.join(self.mask_folder,mask_path)): |
|
|
52 |
msk = Image.open(os.path.join(self.mask_folder,mask_path)).convert('L') |
|
|
53 |
else: |
|
|
54 |
msk = Image.open(os.path.join(self.mask_folder.replace('2D-slices','2D-slices-generated'),mask_path)).convert('L') |
|
|
55 |
if 'all' in self.targets: # combine all targets as single target |
|
|
56 |
mask_cls = np.array(np.array(msk,dtype=int)>0,dtype=int) |
|
|
57 |
else: |
|
|
58 |
mask_cls = np.array(msk==self.cls,dtype=int) |
|
|
59 |
if part_list[0]=='all' and np.sum(mask_cls)>0: |
|
|
60 |
keep_idx.append(idx) |
|
|
61 |
elif np.sum(mask_cls)>0: |
|
|
62 |
if_keep = False |
|
|
63 |
for part in part_list: |
|
|
64 |
if mask_path.find(part)>=0: |
|
|
65 |
if_keep = True |
|
|
66 |
if if_keep: |
|
|
67 |
keep_idx.append(idx) |
|
|
68 |
print('num with non-empty masks',len(keep_idx),'num with all masks',len(self.data_list)) |
|
|
69 |
if delete_empty_masks=='subsample': |
|
|
70 |
empty_idx = list(set(range(len(self.data_list)))-set(keep_idx)) |
|
|
71 |
keep_empty_idx = random.sample(empty_idx, int(len(empty_idx)*0.1)) |
|
|
72 |
keep_idx = empty_idx + keep_idx |
|
|
73 |
self.data_list = [self.data_list[i] for i in keep_idx] # keep the slices that contains target mask |
|
|
74 |
|
|
|
75 |
if phase == 'train': |
|
|
76 |
self.aug_img = [transforms.RandomEqualize(p=0.1), |
|
|
77 |
transforms.ColorJitter(brightness=0.3, contrast=0.3,saturation=0.3,hue=0.3), |
|
|
78 |
transforms.RandomAdjustSharpness(0.5, p=0.5), |
|
|
79 |
] |
|
|
80 |
self.transform_spatial = transforms.Compose([transforms.RandomResizedCrop(crop_size, scale=(0.8, 1.2)), |
|
|
81 |
transforms.RandomRotation(45)]) |
|
|
82 |
transform_img = [transforms.ToTensor()] |
|
|
83 |
else: |
|
|
84 |
transform_img = [ |
|
|
85 |
transforms.ToTensor(), |
|
|
86 |
] |
|
|
87 |
self.transform_img = transforms.Compose(transform_img) |
|
|
88 |
|
|
|
89 |
def __len__(self): |
|
|
90 |
return len(self.data_list) |
|
|
91 |
|
|
|
92 |
def __getitem__(self,index): |
|
|
93 |
# load image and the mask |
|
|
94 |
data = self.data_list[index] |
|
|
95 |
img_path = data.split(' ')[0] |
|
|
96 |
mask_path = data.split(' ')[1] |
|
|
97 |
slice_num = data.split(' ')[3] # total slice num for this object |
|
|
98 |
#print(img_path,mask_path) |
|
|
99 |
try: |
|
|
100 |
if os.path.exists(os.path.join(self.img_folder,img_path)): |
|
|
101 |
img = Image.open(os.path.join(self.img_folder,img_path)).convert('RGB') |
|
|
102 |
else: |
|
|
103 |
img = Image.open(os.path.join(self.img_folder.replace('2D-slices','2D-slices-generated'),img_path)).convert('RGB') |
|
|
104 |
except: |
|
|
105 |
# try to load image as numpy file |
|
|
106 |
img_arr = np.load(os.path.join(self.img_folder,img_path)) |
|
|
107 |
img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+1e-8)*255,dtype=np.uint8) |
|
|
108 |
img_3c = np.tile(img_arr[:, :,None], [1, 1, 3]) |
|
|
109 |
img = Image.fromarray(img_3c, 'RGB') |
|
|
110 |
if os.path.exists(os.path.join(self.mask_folder,mask_path)): |
|
|
111 |
msk = Image.open(os.path.join(self.mask_folder,mask_path)).convert('L') |
|
|
112 |
else: |
|
|
113 |
msk = Image.open(os.path.join(self.mask_folder.replace('2D-slices','2D-slices-generated'),mask_path)).convert('L') |
|
|
114 |
|
|
|
115 |
if self.if_attention_map: |
|
|
116 |
slice_id = int(img_path.split('-')[-1].split('.')[0]) |
|
|
117 |
slice_fraction = int(slice_id/int(slice_num)*4) |
|
|
118 |
img_id = '/'.join(img_path.split('-')[:-1]) +'_'+str(slice_fraction) + '.npy' |
|
|
119 |
attention_map = torch.tensor(np.load(os.path.join(self.if_attention_map,img_id))) |
|
|
120 |
else: |
|
|
121 |
attention_map = torch.zeros((64,64)) |
|
|
122 |
|
|
|
123 |
img = transforms.Resize((self.args.image_size,self.args.image_size))(img) |
|
|
124 |
msk = transforms.Resize((self.args.image_size,self.args.image_size),InterpolationMode.NEAREST)(msk) |
|
|
125 |
|
|
|
126 |
state = torch.get_rng_state() |
|
|
127 |
if self.crop: |
|
|
128 |
im_w, im_h = img.size |
|
|
129 |
diff_w = max(0,self.crop_size-im_w) |
|
|
130 |
diff_h = max(0,self.crop_size-im_h) |
|
|
131 |
padding = (diff_w//2, diff_h//2, diff_w-diff_w//2, diff_h-diff_h//2) |
|
|
132 |
img = transforms.functional.pad(img, padding, 0, 'constant') |
|
|
133 |
torch.set_rng_state(state) |
|
|
134 |
t,l,h,w=transforms.RandomCrop.get_params(img,(self.crop_size,self.crop_size)) |
|
|
135 |
img = transforms.functional.crop(img, t, l, h,w) |
|
|
136 |
msk = transforms.functional.pad(msk, padding, 0, 'constant') |
|
|
137 |
msk = transforms.functional.crop(msk, t, l, h,w) |
|
|
138 |
if self.phase =='train': |
|
|
139 |
# add random optimazition |
|
|
140 |
aug_img_fuc = transforms.RandomChoice(self.aug_img) |
|
|
141 |
img = aug_img_fuc(img) |
|
|
142 |
|
|
|
143 |
img = self.transform_img(img) |
|
|
144 |
if self.phase == 'train': |
|
|
145 |
# It will randomly choose one |
|
|
146 |
random_transform = OneOf([monai.transforms.RandGaussianNoise(prob=0.5, mean=0.0, std=0.1),\ |
|
|
147 |
monai.transforms.RandKSpaceSpikeNoise(prob=0.5, intensity_range=None, channel_wise=True),\ |
|
|
148 |
monai.transforms.RandBiasField(degree=3),\ |
|
|
149 |
monai.transforms.RandGibbsNoise(prob=0.5, alpha=(0.0, 1.0)) |
|
|
150 |
],weights=[0.3,0.3,0.2,0.2]) |
|
|
151 |
img = random_transform(img).as_tensor() |
|
|
152 |
else: |
|
|
153 |
if img.mean()<0.05: |
|
|
154 |
img = min_max_normalize(img) |
|
|
155 |
img = monai.transforms.AdjustContrast(gamma=0.8)(img) |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
if 'all' in self.targets: # combine all targets as single target |
|
|
159 |
msk = np.array(np.array(msk,dtype=int)>0,dtype=int) |
|
|
160 |
else: |
|
|
161 |
msk = np.array(msk,dtype=int) |
|
|
162 |
|
|
|
163 |
mask_cls = np.array(msk==self.cls,dtype=int) |
|
|
164 |
|
|
|
165 |
if self.phase=='train' and (not self.if_attention_map==None): |
|
|
166 |
mask_cls = np.repeat(mask_cls[np.newaxis,:, :], 3, axis=0) |
|
|
167 |
both_targets = torch.cat((img.unsqueeze(0), torch.tensor(mask_cls).unsqueeze(0)),0) |
|
|
168 |
transformed_targets = self.transform_spatial(both_targets) |
|
|
169 |
img = transformed_targets[0] |
|
|
170 |
mask_cls = np.array(transformed_targets[1][0].detach(),dtype=int) |
|
|
171 |
|
|
|
172 |
img = (img-img.min())/(img.max()-img.min()+1e-8) |
|
|
173 |
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) |
|
|
174 |
|
|
|
175 |
# generate mask and prompt |
|
|
176 |
if self.if_prompt: |
|
|
177 |
if self.prompt_type =='point': |
|
|
178 |
prompt,mask_now = get_first_prompt(mask_cls,region_type=self.region_type,prompt_num=self.prompt_num) |
|
|
179 |
pc = torch.as_tensor(prompt[:,:2], dtype=torch.float) |
|
|
180 |
pl = torch.as_tensor(prompt[:, -1], dtype=torch.float) |
|
|
181 |
msk = torch.unsqueeze(torch.tensor(mask_now,dtype=torch.long),0) |
|
|
182 |
return {'image':img, |
|
|
183 |
'mask':msk, |
|
|
184 |
'point_coords': pc, |
|
|
185 |
'point_labels':pl, |
|
|
186 |
'img_name':img_path, |
|
|
187 |
'atten_map':attention_map, |
|
|
188 |
} |
|
|
189 |
elif self.prompt_type =='box': |
|
|
190 |
prompt,mask_now = get_top_boxes(mask_cls,region_type=self.region_type,prompt_num=self.prompt_num) |
|
|
191 |
box = torch.as_tensor(prompt, dtype=torch.float) |
|
|
192 |
msk = torch.unsqueeze(torch.tensor(mask_now,dtype=torch.long),0) |
|
|
193 |
return {'image':img, |
|
|
194 |
'mask':msk, |
|
|
195 |
'boxes':box, |
|
|
196 |
'img_name':img_path, |
|
|
197 |
'atten_map':attention_map, |
|
|
198 |
} |
|
|
199 |
else: |
|
|
200 |
msk = torch.unsqueeze(torch.tensor(mask_cls,dtype=torch.long),0) |
|
|
201 |
return {'image':img, |
|
|
202 |
'mask':msk, |
|
|
203 |
'img_name':img_path, |
|
|
204 |
'atten_map':attention_map, |
|
|
205 |
} |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
class MRI_dataset_multicls(Dataset): |
|
|
209 |
def __init__(self, args, img_folder, mask_folder, img_list, phase='train', sample_num=50, channel_num=1, |
|
|
210 |
crop=False, crop_size=1024, targets=['combine_all'], part_list=['all'], if_prompt=True, |
|
|
211 |
prompt_type='point', if_spatial = True, region_type='largest_20', prompt_num=20, delete_empty_masks=False, |
|
|
212 |
label_mapping=None, reference_slice_num=0, if_attention_map=None,label_frequency_path=None): |
|
|
213 |
super(MRI_dataset_multicls, self).__init__() |
|
|
214 |
self.initialize_parameters(args, img_folder, mask_folder, img_list, phase, sample_num, channel_num, |
|
|
215 |
crop, crop_size, targets, part_list, if_prompt, prompt_type, if_spatial, region_type, |
|
|
216 |
prompt_num, delete_empty_masks, label_mapping, reference_slice_num, if_attention_map,label_frequency_path) |
|
|
217 |
self.load_label_mapping() |
|
|
218 |
self.prepare_data_list() |
|
|
219 |
self.filter_data_list() |
|
|
220 |
if phase == 'train': |
|
|
221 |
self.setup_transformations_train(crop_size) |
|
|
222 |
else: |
|
|
223 |
self.setup_transformations_other() |
|
|
224 |
|
|
|
225 |
def initialize_parameters(self, args, img_folder, mask_folder, img_list, phase, sample_num, channel_num, |
|
|
226 |
crop, crop_size, targets, part_list, if_prompt, prompt_type, if_spatial, region_type, |
|
|
227 |
prompt_num, delete_empty_masks, label_mapping, reference_slice_num, if_attention_map,label_frequency_path): |
|
|
228 |
self.args = args |
|
|
229 |
self.img_folder = img_folder |
|
|
230 |
self.mask_folder = mask_folder |
|
|
231 |
self.img_list = img_list |
|
|
232 |
self.phase = phase |
|
|
233 |
self.sample_num = sample_num |
|
|
234 |
self.channel_num = channel_num |
|
|
235 |
self.crop = crop |
|
|
236 |
self.crop_size = crop_size |
|
|
237 |
self.targets = targets |
|
|
238 |
self.part_list = part_list |
|
|
239 |
self.if_prompt = if_prompt |
|
|
240 |
self.prompt_type = prompt_type |
|
|
241 |
self.if_spatial = if_spatial |
|
|
242 |
self.region_type = region_type |
|
|
243 |
self.prompt_num = prompt_num |
|
|
244 |
self.delete_empty_masks = delete_empty_masks |
|
|
245 |
self.label_mapping = label_mapping |
|
|
246 |
self.reference_slice_num = reference_slice_num |
|
|
247 |
self.if_attention_map = if_attention_map |
|
|
248 |
self.label_dic = {} |
|
|
249 |
self.label_frequency_path = label_frequency_path |
|
|
250 |
|
|
|
251 |
def load_label_mapping(self): |
|
|
252 |
# Load the basic label mappings from a pickle file |
|
|
253 |
if self.label_mapping: |
|
|
254 |
with open(self.label_mapping, 'rb') as handle: |
|
|
255 |
self.segment_names_to_labels = pickle.load(handle) |
|
|
256 |
self.label_dic = {seg[1]: seg[0] for seg in self.segment_names_to_labels} |
|
|
257 |
self.label_name_list = [seg[0] for seg in self.segment_names_to_labels] |
|
|
258 |
print(self.label_dic) |
|
|
259 |
else: |
|
|
260 |
self.label_dic = {value: 'all' for value in range(1, 256)} |
|
|
261 |
|
|
|
262 |
# Load frequency data and remap classes if required |
|
|
263 |
if 'remap_frequency' in self.targets: |
|
|
264 |
self.load_and_remap_classes_based_on_frequency() |
|
|
265 |
|
|
|
266 |
def load_and_remap_classes_based_on_frequency(self): |
|
|
267 |
if self.label_frequency_path: |
|
|
268 |
with open(self.label_frequency_path, 'r') as file: |
|
|
269 |
all_label_frequencies = json.load(file) |
|
|
270 |
all_label_frequencies = all_label_frequencies['train'] |
|
|
271 |
|
|
|
272 |
|
|
|
273 |
# Example to select the target region dynamically based on some condition or configuration |
|
|
274 |
target_region = self.part_list[0] |
|
|
275 |
if target_region in all_label_frequencies: |
|
|
276 |
label_frequencies = all_label_frequencies[target_region] |
|
|
277 |
self.label_frequencies = label_frequencies |
|
|
278 |
#print(label_frequencies) |
|
|
279 |
self.remap_classes_based_on_frequency(label_frequencies) |
|
|
280 |
else: |
|
|
281 |
print(f"Warning: No frequency data found for the target region '{target_region}'. No remapping applied.") |
|
|
282 |
|
|
|
283 |
def remap_classes_based_on_frequency(self, label_frequencies): |
|
|
284 |
# Determine the frequency threshold for high vs. low frequency classes |
|
|
285 |
total = max(label_frequencies.values()) |
|
|
286 |
high_freq_threshold = total * 0.5 # Adjust this threshold as needed |
|
|
287 |
|
|
|
288 |
# Initialize dictionaries to hold new class mappings |
|
|
289 |
high_freq_classes = {} |
|
|
290 |
low_freq_classes = {} |
|
|
291 |
|
|
|
292 |
# Assign classes to high or low frequency based on the threshold |
|
|
293 |
for label, freq in label_frequencies.items(): |
|
|
294 |
if freq >= high_freq_threshold: |
|
|
295 |
high_freq_classes[label] = freq |
|
|
296 |
else: |
|
|
297 |
low_freq_classes[label] = freq |
|
|
298 |
|
|
|
299 |
# Update label dictionary based on the frequency classification |
|
|
300 |
#self.label_dic: {old_cls: old_name} |
|
|
301 |
new_label_dic = {} |
|
|
302 |
for cls, name in self.label_dic.items(): |
|
|
303 |
if name in high_freq_classes: |
|
|
304 |
new_label_dic[cls] = name # Retain original name for high frequency classes |
|
|
305 |
elif name in low_freq_classes: |
|
|
306 |
new_label_dic[cls] = 'combined_low_freq' # Combine low frequency classes into one |
|
|
307 |
|
|
|
308 |
self.updated_label_dic = new_label_dic |
|
|
309 |
#new_label_dic: {old_cls: new_name} |
|
|
310 |
#print("Updated label dictionary with frequency remapping:", new_label_dic) |
|
|
311 |
|
|
|
312 |
#print('new_label_dic:',new_label_dic) |
|
|
313 |
|
|
|
314 |
# Sort high frequency keys by their frequency in descending order |
|
|
315 |
sorted_high_freq_labels = sorted(high_freq_classes.items(), key=lambda item: item[1], reverse=True) |
|
|
316 |
|
|
|
317 |
# Create a mapping for high frequency classes based on the sorted order |
|
|
318 |
original_to_new = {label: idx + 1 for idx, (label, _) in enumerate(sorted_high_freq_labels)} |
|
|
319 |
|
|
|
320 |
|
|
|
321 |
combined_low_freq_class_id = len(original_to_new) + 1 |
|
|
322 |
# Ensure combined low frequency class is mapped correctly |
|
|
323 |
if 'combined_low_freq' in new_label_dic.values(): |
|
|
324 |
for cls in low_freq_classes.keys(): |
|
|
325 |
original_to_new[cls] = combined_low_freq_class_id |
|
|
326 |
|
|
|
327 |
# orignal_to_new {old_name:new_cls} |
|
|
328 |
#print('original_to_new:',original_to_new) |
|
|
329 |
|
|
|
330 |
|
|
|
331 |
# Create additional dictionaries |
|
|
332 |
self.old_name_to_new_name = {self.label_dic[cls]: new_label for cls, new_label in new_label_dic.items()} |
|
|
333 |
self.old_cls_to_new_cls = {cls: original_to_new[self.label_dic[cls]] for cls in self.label_dic.keys() if self.label_dic[cls] in original_to_new} |
|
|
334 |
|
|
|
335 |
print('remapped label dic:',self.old_name_to_new_name) |
|
|
336 |
print('remapped cls dic:',self.old_cls_to_new_cls) |
|
|
337 |
|
|
|
338 |
def prepare_data_list(self): |
|
|
339 |
with open(self.img_list, 'r') as namefiles: |
|
|
340 |
self.data_list = namefiles.read().split('\n')[:-1] |
|
|
341 |
self.sp_symbol = ',' if ',' in self.data_list[0] else ' ' |
|
|
342 |
|
|
|
343 |
def filter_data_list(self): |
|
|
344 |
keep_idx = [] |
|
|
345 |
for idx, data in enumerate(self.data_list): |
|
|
346 |
img_path, mask_path = self.extract_paths(data) |
|
|
347 |
msk = Image.open(os.path.join(self.mask_folder, mask_path)).convert('L') |
|
|
348 |
mask_cls = self.determine_mask_class(msk) |
|
|
349 |
|
|
|
350 |
if self.should_keep(mask_cls, mask_path): |
|
|
351 |
keep_idx.append(idx) |
|
|
352 |
if self.reference_slice_num > 1: |
|
|
353 |
self.add_reference_slice(img_path, mask_path, data) |
|
|
354 |
|
|
|
355 |
self.data_list = [self.data_list[i] for i in keep_idx] |
|
|
356 |
print('num with non-empty masks', len(keep_idx), 'num with all masks', len(self.data_list)) |
|
|
357 |
|
|
|
358 |
def extract_paths(self, data): |
|
|
359 |
img_path = data.split(self.sp_symbol)[0] |
|
|
360 |
mask_path = data.split(self.sp_symbol)[1] |
|
|
361 |
return img_path.lstrip('/'), mask_path.lstrip('/') |
|
|
362 |
|
|
|
363 |
def determine_mask_class(self, msk): |
|
|
364 |
if 'combine_all' in self.targets: |
|
|
365 |
return np.array(msk, dtype=int) > 0 |
|
|
366 |
elif self.targets[0] in self.label_name_list: |
|
|
367 |
return np.array(msk, dtype=int) == self.cls |
|
|
368 |
return np.array(msk, dtype=int) |
|
|
369 |
|
|
|
370 |
def should_keep(self, mask_cls, mask_path): |
|
|
371 |
if self.delete_empty_masks: |
|
|
372 |
has_mask = np.any(mask_cls > 0) |
|
|
373 |
if has_mask: |
|
|
374 |
if self.part_list[0] == 'all': |
|
|
375 |
return True |
|
|
376 |
return any(mask_path.find(part) >= 0 for part in self.part_list) |
|
|
377 |
return False |
|
|
378 |
return True |
|
|
379 |
|
|
|
380 |
|
|
|
381 |
def add_reference_slice(self, img_path, mask_path, data): |
|
|
382 |
volume_name = ''.join(img_path.split('-')[:-1]) # get volume name |
|
|
383 |
slice_num = data.split(self.sp_symbol)[2] |
|
|
384 |
if volume_name not in self.reference_slices: |
|
|
385 |
self.reference_slices[volume_name] = [] |
|
|
386 |
self.reference_slices[volume_name].append((img_path, mask_path, slice_num)) |
|
|
387 |
|
|
|
388 |
def setup_transformations_train(self, crop_size): |
|
|
389 |
self.transform_img = transforms.Compose([ |
|
|
390 |
transforms.ToTensor(), |
|
|
391 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
392 |
]) |
|
|
393 |
self.aug_img = transforms.RandomChoice([ |
|
|
394 |
transforms.RandomEqualize(p=0.1), |
|
|
395 |
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3), |
|
|
396 |
transforms.RandomAdjustSharpness(0.5, p=0.5), |
|
|
397 |
]) |
|
|
398 |
if self.if_spatial: |
|
|
399 |
self.transform_spatial = transforms.Compose([transforms.RandomResizedCrop(self.crop_size, scale=(0.5, 1.5), interpolation=InterpolationMode.NEAREST), |
|
|
400 |
transforms.RandomRotation(45, interpolation=InterpolationMode.NEAREST)]) |
|
|
401 |
|
|
|
402 |
def setup_transformations_other(self): |
|
|
403 |
self.transform_img = transforms.Compose([ |
|
|
404 |
transforms.ToTensor(), |
|
|
405 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
406 |
]) |
|
|
407 |
|
|
|
408 |
def __len__(self): |
|
|
409 |
return len(self.data_list) |
|
|
410 |
|
|
|
411 |
def __getitem__(self, index): |
|
|
412 |
# Load image and mask, handle missing files |
|
|
413 |
data = self.data_list[index] |
|
|
414 |
img, msk, img_path, mask_path, slice_num = self.load_image_and_mask(data) |
|
|
415 |
|
|
|
416 |
# Optional: Load attention map |
|
|
417 |
attention_map = self.load_attention_map(img_path, slice_num) if self.if_attention_map else torch.zeros((64, 64)) |
|
|
418 |
|
|
|
419 |
# Handle reference slices if necessary |
|
|
420 |
if self.reference_slice_num > 1: |
|
|
421 |
img, msk = self.handle_reference_slices(img_path, mask_path, slice_num) |
|
|
422 |
|
|
|
423 |
# Apply transformations |
|
|
424 |
img, msk = self.apply_transformations(img, msk) |
|
|
425 |
|
|
|
426 |
# Generate and process masks and prompts |
|
|
427 |
output_dict = self.prepare_output(img, msk, img_path, mask_path,attention_map) |
|
|
428 |
|
|
|
429 |
|
|
|
430 |
return output_dict |
|
|
431 |
|
|
|
432 |
def load_image_and_mask(self, data): |
|
|
433 |
img_path, mask_path = self.extract_paths(data) |
|
|
434 |
slice_num = data.split(self.sp_symbol)[3] # Extract total slice number for this object |
|
|
435 |
|
|
|
436 |
img_folder = self.img_folder |
|
|
437 |
msk_folder = self.mask_folder |
|
|
438 |
|
|
|
439 |
img = Image.open(os.path.join(img_folder, img_path)).convert('RGB') |
|
|
440 |
msk = Image.open(os.path.join(msk_folder, mask_path)).convert('L') |
|
|
441 |
|
|
|
442 |
# Resize images for processing |
|
|
443 |
img = transforms.Resize((self.args.image_size, self.args.image_size))(img) |
|
|
444 |
msk = transforms.Resize((self.args.image_size, self.args.image_size), InterpolationMode.NEAREST)(msk) |
|
|
445 |
|
|
|
446 |
return img, msk, img_path, mask_path, int(slice_num) |
|
|
447 |
|
|
|
448 |
def load_attention_map(self, img_path, slice_num): |
|
|
449 |
slice_id = int(img_path.split('-')[-1].split('.')[0]) |
|
|
450 |
slice_fraction = int(slice_id / slice_num * 4) |
|
|
451 |
img_id = '/'.join(img_path.split('-')[:-1]) + '_' + str(slice_fraction) + '.npy' |
|
|
452 |
attention_map = torch.tensor(np.load(os.path.join(self.if_attention_map, img_id))) |
|
|
453 |
return attention_map |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
def apply_crop(self, img, msk): |
|
|
457 |
im_w, im_h = img.size |
|
|
458 |
diff_w = max(0, self.crop_size - im_w) |
|
|
459 |
diff_h = max(0, self.crop_size - im_h) |
|
|
460 |
padding = (diff_w // 2, diff_h // 2, diff_w - diff_w // 2, diff_h - diff_h // 2) |
|
|
461 |
img = transforms.functional.pad(img, padding, 0, 'constant') |
|
|
462 |
msk = transforms.functional.pad(msk, padding, 0, 'constant') |
|
|
463 |
t, l, h, w = transforms.RandomCrop.get_params(img, (self.crop_size, self.crop_size)) |
|
|
464 |
img = transforms.functional.crop(img, t, l, h, w) |
|
|
465 |
msk = transforms.functional.crop(msk, t, l, h, w) |
|
|
466 |
return img, msk |
|
|
467 |
|
|
|
468 |
def apply_transformations(self, img, msk): |
|
|
469 |
if self.crop: |
|
|
470 |
img, msk = self.apply_crop(img, msk) |
|
|
471 |
if self.phase == 'train': |
|
|
472 |
img = self.aug_img(img) |
|
|
473 |
img = self.transform_img(img) |
|
|
474 |
if self.phase =='train' and self.if_spatial: |
|
|
475 |
mask_cls = np.array(msk,dtype=int) |
|
|
476 |
mask_cls = np.repeat(mask_cls[np.newaxis,:, :], 3, axis=0) |
|
|
477 |
both_targets = torch.cat((img.unsqueeze(0), torch.tensor(mask_cls).unsqueeze(0)),0) |
|
|
478 |
transformed_targets = self.transform_spatial(both_targets) |
|
|
479 |
img = transformed_targets[0] |
|
|
480 |
mask_cls = np.array(transformed_targets[1][0].detach(),dtype=int) |
|
|
481 |
msk = torch.tensor(mask_cls) |
|
|
482 |
return img, msk |
|
|
483 |
|
|
|
484 |
def handle_reference_slices(self, img_path, mask_path, slice_num): |
|
|
485 |
volume_name = ''.join(img_path.split('-')[:-1]) |
|
|
486 |
ref_slices, ref_msks = [], [] |
|
|
487 |
reference_slices = self.reference_slices.get(volume_name, []) |
|
|
488 |
for ref_slice in reference_slices: |
|
|
489 |
ref_img_path, ref_msk_path, _ = ref_slice |
|
|
490 |
ref_img = Image.open(os.path.join(self.img_folder, ref_img_path)).convert('RGB') |
|
|
491 |
ref_img = transforms.Resize((self.args.image_size, self.args.image_size))(ref_img) |
|
|
492 |
ref_img = self.transform_img(ref_img) |
|
|
493 |
ref_img = torch.unsqueeze(ref_img, 0) |
|
|
494 |
|
|
|
495 |
ref_msk = Image.open(os.path.join(self.mask_folder, ref_msk_path)).convert('L') |
|
|
496 |
ref_msk = transforms.Resize((self.args.image_size, self.args.image_size), InterpolationMode.NEAREST)(ref_msk) |
|
|
497 |
ref_msk = torch.tensor(ref_msk, dtype=torch.long) |
|
|
498 |
ref_msks.append(torch.unsqueeze(ref_msk, 0)) |
|
|
499 |
|
|
|
500 |
img = torch.cat(ref_slices, dim=0) |
|
|
501 |
msk = torch.cat(ref_msks, dim=0) |
|
|
502 |
return img, msk |
|
|
503 |
|
|
|
504 |
def remap_classes_sequentially(self, mask, label_frequencies): |
|
|
505 |
# Apply the mapping to the mask |
|
|
506 |
remapped_mask = mask.copy() |
|
|
507 |
for old_cls, new_cls in self.old_cls_to_new_cls.items(): |
|
|
508 |
remapped_mask[mask == old_cls] = new_cls |
|
|
509 |
return remapped_mask |
|
|
510 |
|
|
|
511 |
|
|
|
512 |
def prepare_output(self, img, msk, img_path, mask_path, attention_map): |
|
|
513 |
# Normalize the image |
|
|
514 |
img = (img - img.min()) / (img.max() - img.min() + 1e-8) |
|
|
515 |
img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) |
|
|
516 |
|
|
|
517 |
msk = np.array(msk, dtype=int) |
|
|
518 |
#print('ori_msk:',np.unique(msk)) |
|
|
519 |
if self.label_frequency_path: |
|
|
520 |
msk = self.remap_classes_sequentially(msk,self.label_frequencies) # Assuming msk is already using updated IDs |
|
|
521 |
#print('new_msk------------------------:',self.old_cls_to_new_cls) |
|
|
522 |
# Prepare one-hot encoding for the remapped classes |
|
|
523 |
|
|
|
524 |
unique_classes = np.unique(msk).tolist() |
|
|
525 |
if 0 in unique_classes: |
|
|
526 |
unique_classes.remove(0) |
|
|
527 |
|
|
|
528 |
if len(unique_classes) > 0: |
|
|
529 |
selected_dic = {k: self.label_dic[k] for k in unique_classes if k in self.label_dic} |
|
|
530 |
else: |
|
|
531 |
selected_dic = {} |
|
|
532 |
|
|
|
533 |
if self.targets[0] == 'random': |
|
|
534 |
mask_cls, selected_label, cls_one_hot = self.handle_random_target(msk, unique_classes, selected_dic) |
|
|
535 |
elif self.targets[0] in self.label_name_list: |
|
|
536 |
selected_label = self.targets[0] |
|
|
537 |
mask_cls = np.array(msk == self.cls, dtype=int) |
|
|
538 |
cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) |
|
|
539 |
cls_one_hot[self.cls - 1] = 1 |
|
|
540 |
else: |
|
|
541 |
selected_label = self.targets[0] |
|
|
542 |
mask_cls = msk |
|
|
543 |
cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) |
|
|
544 |
|
|
|
545 |
# Handle prompts |
|
|
546 |
if self.if_prompt: |
|
|
547 |
prompt, mask_now, mask_cls = self.generate_prompt(mask_cls) |
|
|
548 |
ref_msk,_ = torch.max(mask_now>0,dim=0) |
|
|
549 |
return_dict = {'image': img, 'mask': mask_now, 'selected_label_name': selected_label, |
|
|
550 |
'cls_one_hot': cls_one_hot, 'prompt': prompt, 'img_name': img_path, |
|
|
551 |
'mask_ori': msk, 'mask_cls': mask_cls, 'all_label_dic': selected_dic,'ref_mask':ref_msk} |
|
|
552 |
else: |
|
|
553 |
if len(mask_cls.shape)==2: |
|
|
554 |
msk = torch.unsqueeze(torch.tensor(mask_cls,dtype=torch.long),0) |
|
|
555 |
elif len(mask_cls.shape)==4: |
|
|
556 |
msk = torch.squeeze(torch.tensor(mask_cls,dtype=torch.long)) |
|
|
557 |
else: |
|
|
558 |
msk = torch.tensor(mask_cls,dtype=torch.long) |
|
|
559 |
ref_msk,_ = torch.max(msk>0,dim=0) |
|
|
560 |
#print('unique mask values:',msk.unique()) |
|
|
561 |
return_dict = {'image': img, 'mask': msk, 'selected_label_name': selected_label, |
|
|
562 |
'cls_one_hot': cls_one_hot, 'img_name': img_path, 'mask_ori': msk,'ref_mask':ref_msk} |
|
|
563 |
|
|
|
564 |
return return_dict |
|
|
565 |
|
|
|
566 |
def generate_prompt(self, mask_cls): |
|
|
567 |
if self.prompt_type == 'point': |
|
|
568 |
prompt, mask_now = get_first_prompt(mask_cls, region_type=self.region_type, prompt_num=self.prompt_num) |
|
|
569 |
elif self.prompt_type == 'box': |
|
|
570 |
prompt, mask_now = get_top_boxes(mask_cls, region_type=self.region_type, prompt_num=self.prompt_num) |
|
|
571 |
else: |
|
|
572 |
prompt = mask_now = None |
|
|
573 |
|
|
|
574 |
# Handling the shape of mask_now for return |
|
|
575 |
if mask_now is not None: |
|
|
576 |
if len(mask_now.shape) == 2: |
|
|
577 |
mask_now = torch.unsqueeze(torch.tensor(mask_now, dtype=torch.long), 0) |
|
|
578 |
mask_cls = torch.unsqueeze(torch.tensor(mask_cls, dtype=torch.long), 0) |
|
|
579 |
elif len(mask_now.shape) == 4: |
|
|
580 |
mask_now = torch.squeeze(torch.tensor(mask_now, dtype=torch.long)) |
|
|
581 |
else: |
|
|
582 |
mask_now = torch.tensor(mask_now, dtype=torch.long) |
|
|
583 |
mask_cls = torch.tensor(mask_cls, dtype=torch.long) |
|
|
584 |
|
|
|
585 |
return prompt, mask_now, mask_cls |
|
|
586 |
|
|
|
587 |
|
|
|
588 |
def handle_random_target(self, msk, unique_classes, selected_dic): |
|
|
589 |
if len(unique_classes) > 0: |
|
|
590 |
random_selected_cls = random.choice(unique_classes) |
|
|
591 |
selected_label = selected_dic[random_selected_cls] |
|
|
592 |
mask_cls = np.array(msk == random_selected_cls, dtype=int) |
|
|
593 |
|
|
|
594 |
cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) |
|
|
595 |
cls_one_hot[random_selected_cls - 1] = 1 |
|
|
596 |
else: |
|
|
597 |
selected_label = None |
|
|
598 |
mask_cls = torch.zeros_like(msk) # assuming msk is already a numpy array |
|
|
599 |
cls_one_hot = torch.zeros(len(self.label_dic), dtype=torch.long) |
|
|
600 |
|
|
|
601 |
return mask_cls, selected_label, cls_one_hot |