|
a |
|
b/utils.py |
|
|
1 |
""" helper function |
|
|
2 |
|
|
|
3 |
author junde |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import sys |
|
|
7 |
|
|
|
8 |
import numpy |
|
|
9 |
|
|
|
10 |
import torch |
|
|
11 |
import torch.nn as nn |
|
|
12 |
from torch.autograd import Function |
|
|
13 |
from torch.optim.lr_scheduler import _LRScheduler |
|
|
14 |
import torchvision |
|
|
15 |
import torchvision.transforms as transforms |
|
|
16 |
import torch.optim as optim |
|
|
17 |
import torchvision.utils as vutils |
|
|
18 |
from torch.utils.data import DataLoader |
|
|
19 |
from torch.autograd import Variable |
|
|
20 |
from torch import autograd |
|
|
21 |
import random |
|
|
22 |
import math |
|
|
23 |
import PIL |
|
|
24 |
import matplotlib.pyplot as plt |
|
|
25 |
import seaborn as sns |
|
|
26 |
|
|
|
27 |
import collections |
|
|
28 |
import logging |
|
|
29 |
import cv2 |
|
|
30 |
import math |
|
|
31 |
import os |
|
|
32 |
import time |
|
|
33 |
from datetime import datetime |
|
|
34 |
|
|
|
35 |
import dateutil.tz |
|
|
36 |
|
|
|
37 |
from typing import Union, Optional, List, Tuple, Text, BinaryIO |
|
|
38 |
import pathlib |
|
|
39 |
import warnings |
|
|
40 |
import numpy as np |
|
|
41 |
from scipy.ndimage import label, find_objects |
|
|
42 |
from PIL import Image, ImageDraw, ImageFont, ImageColor |
|
|
43 |
# from lucent.optvis.param.spatial import pixel_image, fft_image, init_image |
|
|
44 |
# from lucent.optvis.param.color import to_valid_rgb |
|
|
45 |
# from lucent.optvis import objectives, transform, param |
|
|
46 |
# from lucent.misc.io import show |
|
|
47 |
from torchvision.models import vgg19 |
|
|
48 |
import torch.nn.functional as F |
|
|
49 |
import cfg |
|
|
50 |
|
|
|
51 |
import warnings |
|
|
52 |
from collections import OrderedDict |
|
|
53 |
import numpy as np |
|
|
54 |
from tqdm import tqdm |
|
|
55 |
from PIL import Image |
|
|
56 |
import torch |
|
|
57 |
|
|
|
58 |
# from precpt import run_precpt |
|
|
59 |
from models.discriminator import Discriminator |
|
|
60 |
# from siren_pytorch import SirenNet, SirenWrapper |
|
|
61 |
|
|
|
62 |
import shutil |
|
|
63 |
import tempfile |
|
|
64 |
|
|
|
65 |
import matplotlib.pyplot as plt |
|
|
66 |
from tqdm import tqdm |
|
|
67 |
|
|
|
68 |
from monai.losses import DiceCELoss |
|
|
69 |
from monai.inferers import sliding_window_inference |
|
|
70 |
from monai.transforms import ( |
|
|
71 |
AsDiscrete, |
|
|
72 |
Compose, |
|
|
73 |
CropForegroundd, |
|
|
74 |
LoadImaged, |
|
|
75 |
Orientationd, |
|
|
76 |
RandFlipd, |
|
|
77 |
RandCropByPosNegLabeld, |
|
|
78 |
RandShiftIntensityd, |
|
|
79 |
ScaleIntensityRanged, |
|
|
80 |
Spacingd, |
|
|
81 |
RandRotate90d, |
|
|
82 |
EnsureTyped, |
|
|
83 |
) |
|
|
84 |
|
|
|
85 |
from monai.config import print_config |
|
|
86 |
from monai.metrics import DiceMetric |
|
|
87 |
from monai.networks.nets import SwinUNETR |
|
|
88 |
|
|
|
89 |
from monai.data import ( |
|
|
90 |
ThreadDataLoader, |
|
|
91 |
CacheDataset, |
|
|
92 |
load_decathlon_datalist, |
|
|
93 |
decollate_batch, |
|
|
94 |
set_track_meta, |
|
|
95 |
) |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
|
|
|
99 |
|
|
|
100 |
args = cfg.parse_args() |
|
|
101 |
device = torch.device('cuda', args.gpu_device) |
|
|
102 |
|
|
|
103 |
'''preparation of domain loss''' |
|
|
104 |
# cnn = vgg19(pretrained=True).features.to(device).eval() |
|
|
105 |
# cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) |
|
|
106 |
# cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) |
|
|
107 |
|
|
|
108 |
# netD = Discriminator(1).to(device) |
|
|
109 |
# netD.apply(init_D) |
|
|
110 |
# beta1 = 0.5 |
|
|
111 |
# dis_lr = 0.0002 |
|
|
112 |
# optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) |
|
|
113 |
'''end''' |
|
|
114 |
|
|
|
115 |
def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True): |
|
|
116 |
""" return given network |
|
|
117 |
""" |
|
|
118 |
|
|
|
119 |
if net == 'sam': |
|
|
120 |
from models.sam import SamPredictor, sam_model_registry |
|
|
121 |
from models.sam.utils.transforms import ResizeLongestSide |
|
|
122 |
|
|
|
123 |
net = sam_model_registry['vit_b'](args,checkpoint=args.sam_ckpt).to(device) |
|
|
124 |
else: |
|
|
125 |
print('the network name you have entered is not supported yet') |
|
|
126 |
sys.exit() |
|
|
127 |
|
|
|
128 |
if use_gpu: |
|
|
129 |
#net = net.cuda(device = gpu_device) |
|
|
130 |
if distribution != 'none': |
|
|
131 |
net = torch.nn.DataParallel(net,device_ids=[int(id) for id in args.distributed.split(',')]) |
|
|
132 |
net = net.to(device=gpu_device) |
|
|
133 |
else: |
|
|
134 |
net = net.to(device=gpu_device) |
|
|
135 |
|
|
|
136 |
return net |
|
|
137 |
|
|
|
138 |
|
|
|
139 |
def get_decath_loader(args): |
|
|
140 |
|
|
|
141 |
train_transforms = Compose( |
|
|
142 |
[ |
|
|
143 |
LoadImaged(keys=["image", "label"], ensure_channel_first=True), |
|
|
144 |
ScaleIntensityRanged( |
|
|
145 |
keys=["image"], |
|
|
146 |
a_min=-175, |
|
|
147 |
a_max=250, |
|
|
148 |
b_min=0.0, |
|
|
149 |
b_max=1.0, |
|
|
150 |
clip=True, |
|
|
151 |
), |
|
|
152 |
CropForegroundd(keys=["image", "label"], source_key="image"), |
|
|
153 |
Orientationd(keys=["image", "label"], axcodes="RAS"), |
|
|
154 |
Spacingd( |
|
|
155 |
keys=["image", "label"], |
|
|
156 |
pixdim=(1.5, 1.5, 2.0), |
|
|
157 |
mode=("bilinear", "nearest"), |
|
|
158 |
), |
|
|
159 |
EnsureTyped(keys=["image", "label"], device=device, track_meta=False), |
|
|
160 |
RandCropByPosNegLabeld( |
|
|
161 |
keys=["image", "label"], |
|
|
162 |
label_key="label", |
|
|
163 |
spatial_size=(args.roi_size, args.roi_size, args.chunk), |
|
|
164 |
pos=1, |
|
|
165 |
neg=1, |
|
|
166 |
num_samples=args.num_sample, |
|
|
167 |
image_key="image", |
|
|
168 |
image_threshold=0, |
|
|
169 |
), |
|
|
170 |
RandFlipd( |
|
|
171 |
keys=["image", "label"], |
|
|
172 |
spatial_axis=[0], |
|
|
173 |
prob=0.10, |
|
|
174 |
), |
|
|
175 |
RandFlipd( |
|
|
176 |
keys=["image", "label"], |
|
|
177 |
spatial_axis=[1], |
|
|
178 |
prob=0.10, |
|
|
179 |
), |
|
|
180 |
RandFlipd( |
|
|
181 |
keys=["image", "label"], |
|
|
182 |
spatial_axis=[2], |
|
|
183 |
prob=0.10, |
|
|
184 |
), |
|
|
185 |
RandRotate90d( |
|
|
186 |
keys=["image", "label"], |
|
|
187 |
prob=0.10, |
|
|
188 |
max_k=3, |
|
|
189 |
), |
|
|
190 |
RandShiftIntensityd( |
|
|
191 |
keys=["image"], |
|
|
192 |
offsets=0.10, |
|
|
193 |
prob=0.50, |
|
|
194 |
), |
|
|
195 |
] |
|
|
196 |
) |
|
|
197 |
val_transforms = Compose( |
|
|
198 |
[ |
|
|
199 |
LoadImaged(keys=["image", "label"], ensure_channel_first=True), |
|
|
200 |
ScaleIntensityRanged( |
|
|
201 |
keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True |
|
|
202 |
), |
|
|
203 |
CropForegroundd(keys=["image", "label"], source_key="image"), |
|
|
204 |
Orientationd(keys=["image", "label"], axcodes="RAS"), |
|
|
205 |
Spacingd( |
|
|
206 |
keys=["image", "label"], |
|
|
207 |
pixdim=(1.5, 1.5, 2.0), |
|
|
208 |
mode=("bilinear", "nearest"), |
|
|
209 |
), |
|
|
210 |
EnsureTyped(keys=["image", "label"], device=device, track_meta=True), |
|
|
211 |
] |
|
|
212 |
) |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
|
|
|
216 |
data_dir = args.data_path |
|
|
217 |
split_JSON = "dataset_0.json" |
|
|
218 |
|
|
|
219 |
datasets = os.path.join(data_dir, split_JSON) |
|
|
220 |
datalist = load_decathlon_datalist(datasets, True, "training") |
|
|
221 |
val_files = load_decathlon_datalist(datasets, True, "validation") |
|
|
222 |
train_ds = CacheDataset( |
|
|
223 |
data=datalist, |
|
|
224 |
transform=train_transforms, |
|
|
225 |
cache_num=24, |
|
|
226 |
cache_rate=1.0, |
|
|
227 |
num_workers=8, |
|
|
228 |
) |
|
|
229 |
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.b, shuffle=True) |
|
|
230 |
val_ds = CacheDataset( |
|
|
231 |
data=val_files, transform=val_transforms, cache_num=2, cache_rate=1.0, num_workers=0 |
|
|
232 |
) |
|
|
233 |
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) |
|
|
234 |
|
|
|
235 |
set_track_meta(False) |
|
|
236 |
|
|
|
237 |
return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files |
|
|
238 |
|
|
|
239 |
|
|
|
240 |
def cka_loss(gram_featureA, gram_featureB): |
|
|
241 |
|
|
|
242 |
scaled_hsic = torch.dot(torch.flatten(gram_featureA),torch.flatten(gram_featureB)) |
|
|
243 |
normalization_x = gram_featureA.norm() |
|
|
244 |
normalization_y = gram_featureB.norm() |
|
|
245 |
return scaled_hsic / (normalization_x * normalization_y) |
|
|
246 |
|
|
|
247 |
|
|
|
248 |
class WarmUpLR(_LRScheduler): |
|
|
249 |
"""warmup_training learning rate scheduler |
|
|
250 |
Args: |
|
|
251 |
optimizer: optimzier(e.g. SGD) |
|
|
252 |
total_iters: totoal_iters of warmup phase |
|
|
253 |
""" |
|
|
254 |
def __init__(self, optimizer, total_iters, last_epoch=-1): |
|
|
255 |
|
|
|
256 |
self.total_iters = total_iters |
|
|
257 |
super().__init__(optimizer, last_epoch) |
|
|
258 |
|
|
|
259 |
def get_lr(self): |
|
|
260 |
"""we will use the first m batches, and set the learning |
|
|
261 |
rate to base_lr * m / total_iters |
|
|
262 |
""" |
|
|
263 |
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] |
|
|
264 |
|
|
|
265 |
def gram_matrix(input): |
|
|
266 |
a, b, c, d = input.size() # a=batch size(=1) |
|
|
267 |
# b=number of feature maps |
|
|
268 |
# (c,d)=dimensions of a f. map (N=c*d) |
|
|
269 |
|
|
|
270 |
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL |
|
|
271 |
|
|
|
272 |
G = torch.mm(features, features.t()) # compute the gram product |
|
|
273 |
|
|
|
274 |
# we 'normalize' the values of the gram matrix |
|
|
275 |
# by dividing by the number of element in each feature maps. |
|
|
276 |
return G.div(a * b * c * d) |
|
|
277 |
|
|
|
278 |
|
|
|
279 |
|
|
|
280 |
@torch.no_grad() |
|
|
281 |
def make_grid( |
|
|
282 |
tensor: Union[torch.Tensor, List[torch.Tensor]], |
|
|
283 |
nrow: int = 8, |
|
|
284 |
padding: int = 2, |
|
|
285 |
normalize: bool = False, |
|
|
286 |
value_range: Optional[Tuple[int, int]] = None, |
|
|
287 |
scale_each: bool = False, |
|
|
288 |
pad_value: int = 0, |
|
|
289 |
**kwargs |
|
|
290 |
) -> torch.Tensor: |
|
|
291 |
if not (torch.is_tensor(tensor) or |
|
|
292 |
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): |
|
|
293 |
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') |
|
|
294 |
|
|
|
295 |
if "range" in kwargs.keys(): |
|
|
296 |
warning = "range will be deprecated, please use value_range instead." |
|
|
297 |
warnings.warn(warning) |
|
|
298 |
value_range = kwargs["range"] |
|
|
299 |
|
|
|
300 |
# if list of tensors, convert to a 4D mini-batch Tensor |
|
|
301 |
if isinstance(tensor, list): |
|
|
302 |
tensor = torch.stack(tensor, dim=0) |
|
|
303 |
|
|
|
304 |
if tensor.dim() == 2: # single image H x W |
|
|
305 |
tensor = tensor.unsqueeze(0) |
|
|
306 |
if tensor.dim() == 3: # single image |
|
|
307 |
if tensor.size(0) == 1: # if single-channel, convert to 3-channel |
|
|
308 |
tensor = torch.cat((tensor, tensor, tensor), 0) |
|
|
309 |
tensor = tensor.unsqueeze(0) |
|
|
310 |
|
|
|
311 |
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images |
|
|
312 |
tensor = torch.cat((tensor, tensor, tensor), 1) |
|
|
313 |
|
|
|
314 |
if normalize is True: |
|
|
315 |
tensor = tensor.clone() # avoid modifying tensor in-place |
|
|
316 |
if value_range is not None: |
|
|
317 |
assert isinstance(value_range, tuple), \ |
|
|
318 |
"value_range has to be a tuple (min, max) if specified. min and max are numbers" |
|
|
319 |
|
|
|
320 |
def norm_ip(img, low, high): |
|
|
321 |
img.clamp(min=low, max=high) |
|
|
322 |
img.sub_(low).div_(max(high - low, 1e-5)) |
|
|
323 |
|
|
|
324 |
def norm_range(t, value_range): |
|
|
325 |
if value_range is not None: |
|
|
326 |
norm_ip(t, value_range[0], value_range[1]) |
|
|
327 |
else: |
|
|
328 |
norm_ip(t, float(t.min()), float(t.max())) |
|
|
329 |
|
|
|
330 |
if scale_each is True: |
|
|
331 |
for t in tensor: # loop over mini-batch dimension |
|
|
332 |
norm_range(t, value_range) |
|
|
333 |
else: |
|
|
334 |
norm_range(tensor, value_range) |
|
|
335 |
|
|
|
336 |
if tensor.size(0) == 1: |
|
|
337 |
return tensor.squeeze(0) |
|
|
338 |
|
|
|
339 |
# make the mini-batch of images into a grid |
|
|
340 |
nmaps = tensor.size(0) |
|
|
341 |
xmaps = min(nrow, nmaps) |
|
|
342 |
ymaps = int(math.ceil(float(nmaps) / xmaps)) |
|
|
343 |
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) |
|
|
344 |
num_channels = tensor.size(1) |
|
|
345 |
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) |
|
|
346 |
k = 0 |
|
|
347 |
for y in range(ymaps): |
|
|
348 |
for x in range(xmaps): |
|
|
349 |
if k >= nmaps: |
|
|
350 |
break |
|
|
351 |
# Tensor.copy_() is a valid method but seems to be missing from the stubs |
|
|
352 |
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ |
|
|
353 |
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] |
|
|
354 |
2, x * width + padding, width - padding |
|
|
355 |
).copy_(tensor[k]) |
|
|
356 |
k = k + 1 |
|
|
357 |
return grid |
|
|
358 |
|
|
|
359 |
|
|
|
360 |
@torch.no_grad() |
|
|
361 |
def save_image( |
|
|
362 |
tensor: Union[torch.Tensor, List[torch.Tensor]], |
|
|
363 |
fp: Union[Text, pathlib.Path, BinaryIO], |
|
|
364 |
format: Optional[str] = None, |
|
|
365 |
**kwargs |
|
|
366 |
) -> None: |
|
|
367 |
""" |
|
|
368 |
Save a given Tensor into an image file. |
|
|
369 |
Args: |
|
|
370 |
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, |
|
|
371 |
saves the tensor as a grid of images by calling ``make_grid``. |
|
|
372 |
fp (string or file object): A filename or a file object |
|
|
373 |
format(Optional): If omitted, the format to use is determined from the filename extension. |
|
|
374 |
If a file object was used instead of a filename, this parameter should always be used. |
|
|
375 |
**kwargs: Other arguments are documented in ``make_grid``. |
|
|
376 |
""" |
|
|
377 |
|
|
|
378 |
grid = make_grid(tensor, **kwargs) |
|
|
379 |
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer |
|
|
380 |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() |
|
|
381 |
im = Image.fromarray(ndarr) |
|
|
382 |
im.save(fp, format=format) |
|
|
383 |
|
|
|
384 |
|
|
|
385 |
def create_logger(log_dir, phase='train'): |
|
|
386 |
time_str = time.strftime('%Y-%m-%d-%H-%M') |
|
|
387 |
log_file = '{}_{}.log'.format(time_str, phase) |
|
|
388 |
final_log_file = os.path.join(log_dir, log_file) |
|
|
389 |
head = '%(asctime)-15s %(message)s' |
|
|
390 |
logging.basicConfig(filename=str(final_log_file), |
|
|
391 |
format=head) |
|
|
392 |
logger = logging.getLogger() |
|
|
393 |
logger.setLevel(logging.INFO) |
|
|
394 |
console = logging.StreamHandler() |
|
|
395 |
logging.getLogger('').addHandler(console) |
|
|
396 |
|
|
|
397 |
return logger |
|
|
398 |
|
|
|
399 |
|
|
|
400 |
def set_log_dir(root_dir, exp_name): |
|
|
401 |
path_dict = {} |
|
|
402 |
os.makedirs(root_dir, exist_ok=True) |
|
|
403 |
|
|
|
404 |
# set log path |
|
|
405 |
exp_path = os.path.join(root_dir, exp_name) |
|
|
406 |
now = datetime.now(dateutil.tz.tzlocal()) |
|
|
407 |
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') |
|
|
408 |
prefix = exp_path + '_' + timestamp |
|
|
409 |
os.makedirs(prefix) |
|
|
410 |
path_dict['prefix'] = prefix |
|
|
411 |
|
|
|
412 |
# set checkpoint path |
|
|
413 |
ckpt_path = os.path.join(prefix, 'Model') |
|
|
414 |
os.makedirs(ckpt_path) |
|
|
415 |
path_dict['ckpt_path'] = ckpt_path |
|
|
416 |
|
|
|
417 |
log_path = os.path.join(prefix, 'Log') |
|
|
418 |
os.makedirs(log_path) |
|
|
419 |
path_dict['log_path'] = log_path |
|
|
420 |
|
|
|
421 |
# set sample image path for fid calculation |
|
|
422 |
sample_path = os.path.join(prefix, 'Samples') |
|
|
423 |
os.makedirs(sample_path) |
|
|
424 |
path_dict['sample_path'] = sample_path |
|
|
425 |
|
|
|
426 |
return path_dict |
|
|
427 |
|
|
|
428 |
|
|
|
429 |
def save_checkpoint(states, is_best, output_dir, |
|
|
430 |
filename='checkpoint.pth'): |
|
|
431 |
torch.save(states, os.path.join(output_dir, filename)) |
|
|
432 |
if is_best: |
|
|
433 |
torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) |
|
|
434 |
|
|
|
435 |
|
|
|
436 |
class RunningStats: |
|
|
437 |
def __init__(self, WIN_SIZE): |
|
|
438 |
self.mean = 0 |
|
|
439 |
self.run_var = 0 |
|
|
440 |
self.WIN_SIZE = WIN_SIZE |
|
|
441 |
|
|
|
442 |
self.window = collections.deque(maxlen=WIN_SIZE) |
|
|
443 |
|
|
|
444 |
def clear(self): |
|
|
445 |
self.window.clear() |
|
|
446 |
self.mean = 0 |
|
|
447 |
self.run_var = 0 |
|
|
448 |
|
|
|
449 |
def is_full(self): |
|
|
450 |
return len(self.window) == self.WIN_SIZE |
|
|
451 |
|
|
|
452 |
def push(self, x): |
|
|
453 |
|
|
|
454 |
if len(self.window) == self.WIN_SIZE: |
|
|
455 |
# Adjusting variance |
|
|
456 |
x_removed = self.window.popleft() |
|
|
457 |
self.window.append(x) |
|
|
458 |
old_m = self.mean |
|
|
459 |
self.mean += (x - x_removed) / self.WIN_SIZE |
|
|
460 |
self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed) |
|
|
461 |
else: |
|
|
462 |
# Calculating first variance |
|
|
463 |
self.window.append(x) |
|
|
464 |
delta = x - self.mean |
|
|
465 |
self.mean += delta / len(self.window) |
|
|
466 |
self.run_var += delta * (x - self.mean) |
|
|
467 |
|
|
|
468 |
def get_mean(self): |
|
|
469 |
return self.mean if len(self.window) else 0.0 |
|
|
470 |
|
|
|
471 |
def get_var(self): |
|
|
472 |
return self.run_var / len(self.window) if len(self.window) > 1 else 0.0 |
|
|
473 |
|
|
|
474 |
def get_std(self): |
|
|
475 |
return math.sqrt(self.get_var()) |
|
|
476 |
|
|
|
477 |
def get_all(self): |
|
|
478 |
return list(self.window) |
|
|
479 |
|
|
|
480 |
def __str__(self): |
|
|
481 |
return "Current window values: {}".format(list(self.window)) |
|
|
482 |
|
|
|
483 |
def iou(outputs: np.array, labels: np.array): |
|
|
484 |
|
|
|
485 |
SMOOTH = 1e-6 |
|
|
486 |
intersection = (outputs & labels).sum((1, 2)) |
|
|
487 |
union = (outputs | labels).sum((1, 2)) |
|
|
488 |
|
|
|
489 |
iou = (intersection + SMOOTH) / (union + SMOOTH) |
|
|
490 |
|
|
|
491 |
|
|
|
492 |
return iou.mean() |
|
|
493 |
|
|
|
494 |
class DiceCoeff(Function): |
|
|
495 |
"""Dice coeff for individual examples""" |
|
|
496 |
|
|
|
497 |
def forward(self, input, target): |
|
|
498 |
self.save_for_backward(input, target) |
|
|
499 |
eps = 0.0001 |
|
|
500 |
self.inter = torch.dot(input.view(-1), target.view(-1)) |
|
|
501 |
self.union = torch.sum(input) + torch.sum(target) + eps |
|
|
502 |
|
|
|
503 |
t = (2 * self.inter.float() + eps) / self.union.float() |
|
|
504 |
return t |
|
|
505 |
|
|
|
506 |
# This function has only a single output, so it gets only one gradient |
|
|
507 |
def backward(self, grad_output): |
|
|
508 |
|
|
|
509 |
input, target = self.saved_variables |
|
|
510 |
grad_input = grad_target = None |
|
|
511 |
|
|
|
512 |
if self.needs_input_grad[0]: |
|
|
513 |
grad_input = grad_output * 2 * (target * self.union - self.inter) \ |
|
|
514 |
/ (self.union * self.union) |
|
|
515 |
if self.needs_input_grad[1]: |
|
|
516 |
grad_target = None |
|
|
517 |
|
|
|
518 |
return grad_input, grad_target |
|
|
519 |
|
|
|
520 |
|
|
|
521 |
def dice_coeff(input, target): |
|
|
522 |
"""Dice coeff for batches""" |
|
|
523 |
if input.is_cuda: |
|
|
524 |
s = torch.FloatTensor(1).to(device = input.device).zero_() |
|
|
525 |
else: |
|
|
526 |
s = torch.FloatTensor(1).zero_() |
|
|
527 |
|
|
|
528 |
for i, c in enumerate(zip(input, target)): |
|
|
529 |
s = s + DiceCoeff().forward(c[0], c[1]) |
|
|
530 |
|
|
|
531 |
return s / (i + 1) |
|
|
532 |
|
|
|
533 |
'''parameter''' |
|
|
534 |
def para_image(w, h=None, img = None, mode = 'multi', seg = None, sd=None, batch=None, |
|
|
535 |
fft = False, channels=None, init = None): |
|
|
536 |
h = h or w |
|
|
537 |
batch = batch or 1 |
|
|
538 |
ch = channels or 3 |
|
|
539 |
shape = [batch, ch, h, w] |
|
|
540 |
param_f = fft_image if fft else pixel_image |
|
|
541 |
if init is not None: |
|
|
542 |
param_f = init_image |
|
|
543 |
params, maps_f = param_f(init) |
|
|
544 |
else: |
|
|
545 |
params, maps_f = param_f(shape, sd=sd) |
|
|
546 |
if mode == 'multi': |
|
|
547 |
output = to_valid_out(maps_f,img,seg) |
|
|
548 |
elif mode == 'seg': |
|
|
549 |
output = gene_out(maps_f,img) |
|
|
550 |
elif mode == 'raw': |
|
|
551 |
output = raw_out(maps_f,img) |
|
|
552 |
return params, output |
|
|
553 |
|
|
|
554 |
def to_valid_out(maps_f,img,seg): #multi-rater |
|
|
555 |
def inner(): |
|
|
556 |
maps = maps_f() |
|
|
557 |
maps = maps.to(device = img.device) |
|
|
558 |
maps = torch.nn.Softmax(dim = 1)(maps) |
|
|
559 |
final_seg = torch.multiply(seg,maps).sum(dim = 1, keepdim = True) |
|
|
560 |
return torch.cat((img,final_seg),1) |
|
|
561 |
# return torch.cat((img,maps),1) |
|
|
562 |
return inner |
|
|
563 |
|
|
|
564 |
def gene_out(maps_f,img): #pure seg |
|
|
565 |
def inner(): |
|
|
566 |
maps = maps_f() |
|
|
567 |
maps = maps.to(device = img.device) |
|
|
568 |
# maps = torch.nn.Sigmoid()(maps) |
|
|
569 |
return torch.cat((img,maps),1) |
|
|
570 |
# return torch.cat((img,maps),1) |
|
|
571 |
return inner |
|
|
572 |
|
|
|
573 |
def raw_out(maps_f,img): #raw |
|
|
574 |
def inner(): |
|
|
575 |
maps = maps_f() |
|
|
576 |
maps = maps.to(device = img.device) |
|
|
577 |
# maps = torch.nn.Sigmoid()(maps) |
|
|
578 |
return maps |
|
|
579 |
# return torch.cat((img,maps),1) |
|
|
580 |
return inner |
|
|
581 |
|
|
|
582 |
|
|
|
583 |
class CompositeActivation(torch.nn.Module): |
|
|
584 |
|
|
|
585 |
def forward(self, x): |
|
|
586 |
x = torch.atan(x) |
|
|
587 |
return torch.cat([x/0.67, (x*x)/0.6], 1) |
|
|
588 |
# return x |
|
|
589 |
|
|
|
590 |
|
|
|
591 |
def cppn(args, size, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8, |
|
|
592 |
activation_fn=CompositeActivation, normalize=False, device = "cuda:0"): |
|
|
593 |
|
|
|
594 |
r = 3 ** 0.5 |
|
|
595 |
|
|
|
596 |
coord_range = torch.linspace(-r, r, size) |
|
|
597 |
x = coord_range.view(-1, 1).repeat(1, coord_range.size(0)) |
|
|
598 |
y = coord_range.view(1, -1).repeat(coord_range.size(0), 1) |
|
|
599 |
|
|
|
600 |
input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).repeat(batch,1,1,1).to(device) |
|
|
601 |
|
|
|
602 |
layers = [] |
|
|
603 |
kernel_size = 1 |
|
|
604 |
for i in range(num_layers): |
|
|
605 |
out_c = num_hidden_channels |
|
|
606 |
in_c = out_c * 2 # * 2 for composite activation |
|
|
607 |
if i == 0: |
|
|
608 |
in_c = 2 |
|
|
609 |
if i == num_layers - 1: |
|
|
610 |
out_c = num_output_channels |
|
|
611 |
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size))) |
|
|
612 |
if normalize: |
|
|
613 |
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c))) |
|
|
614 |
if i < num_layers - 1: |
|
|
615 |
layers.append(('actv{}'.format(i), activation_fn())) |
|
|
616 |
else: |
|
|
617 |
layers.append(('output', torch.nn.Sigmoid())) |
|
|
618 |
|
|
|
619 |
# Initialize model |
|
|
620 |
net = torch.nn.Sequential(OrderedDict(layers)).to(device) |
|
|
621 |
# Initialize weights |
|
|
622 |
def weights_init(module): |
|
|
623 |
if isinstance(module, torch.nn.Conv2d): |
|
|
624 |
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels)) |
|
|
625 |
if module.bias is not None: |
|
|
626 |
torch.nn.init.zeros_(module.bias) |
|
|
627 |
net.apply(weights_init) |
|
|
628 |
# Set last conv2d layer's weights to 0 |
|
|
629 |
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight) |
|
|
630 |
outimg = raw_out(lambda: net(input_tensor),img) if args.netype == 'raw' else to_valid_out(lambda: net(input_tensor),img,seg) |
|
|
631 |
return net.parameters(), outimg |
|
|
632 |
|
|
|
633 |
def get_siren(args): |
|
|
634 |
wrapper = get_network(args, 'siren', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed) |
|
|
635 |
'''load init weights''' |
|
|
636 |
checkpoint = torch.load('./logs/siren_train_init_2022_08_19_21_00_16/Model/checkpoint_best.pth') |
|
|
637 |
wrapper.load_state_dict(checkpoint['state_dict'],strict=False) |
|
|
638 |
'''end''' |
|
|
639 |
|
|
|
640 |
'''load prompt''' |
|
|
641 |
checkpoint = torch.load('./logs/vae_standard_refuge1_2022_08_21_17_56_49/Model/checkpoint500') |
|
|
642 |
vae = get_network(args, 'vae', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed) |
|
|
643 |
vae.load_state_dict(checkpoint['state_dict'],strict=False) |
|
|
644 |
'''end''' |
|
|
645 |
|
|
|
646 |
return wrapper, vae |
|
|
647 |
|
|
|
648 |
|
|
|
649 |
def siren(args, wrapper, vae, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8, |
|
|
650 |
activation_fn=CompositeActivation, normalize=False, device = "cuda:0"): |
|
|
651 |
vae_img = torchvision.transforms.Resize(64)(img) |
|
|
652 |
latent = vae.encoder(vae_img).view(-1).detach() |
|
|
653 |
outimg = raw_out(lambda: wrapper(latent = latent),img) if args.netype == 'raw' else to_valid_out(lambda: wrapper(latent = latent),img,seg) |
|
|
654 |
# img = torch.randn(1, 3, 256, 256) |
|
|
655 |
# loss = wrapper(img) |
|
|
656 |
# loss.backward() |
|
|
657 |
|
|
|
658 |
# # after much training ... |
|
|
659 |
# # simply invoke the wrapper without passing in anything |
|
|
660 |
|
|
|
661 |
# pred_img = wrapper() # (1, 3, 256, 256) |
|
|
662 |
return wrapper.parameters(), outimg |
|
|
663 |
|
|
|
664 |
|
|
|
665 |
'''adversary''' |
|
|
666 |
def render_vis( |
|
|
667 |
args, |
|
|
668 |
model, |
|
|
669 |
objective_f, |
|
|
670 |
real_img, |
|
|
671 |
param_f=None, |
|
|
672 |
optimizer=None, |
|
|
673 |
transforms=None, |
|
|
674 |
thresholds=(256,), |
|
|
675 |
verbose=True, |
|
|
676 |
preprocess=True, |
|
|
677 |
progress=True, |
|
|
678 |
show_image=True, |
|
|
679 |
save_image=False, |
|
|
680 |
image_name=None, |
|
|
681 |
show_inline=False, |
|
|
682 |
fixed_image_size=None, |
|
|
683 |
label = 1, |
|
|
684 |
raw_img = None, |
|
|
685 |
prompt = None |
|
|
686 |
): |
|
|
687 |
if label == 1: |
|
|
688 |
sign = 1 |
|
|
689 |
elif label == 0: |
|
|
690 |
sign = -1 |
|
|
691 |
else: |
|
|
692 |
print('label is wrong, label is',label) |
|
|
693 |
if args.reverse: |
|
|
694 |
sign = -sign |
|
|
695 |
if args.multilayer: |
|
|
696 |
sign = 1 |
|
|
697 |
|
|
|
698 |
'''prepare''' |
|
|
699 |
now = datetime.now() |
|
|
700 |
date_time = now.strftime("%m-%d-%Y, %H:%M:%S") |
|
|
701 |
|
|
|
702 |
netD, optD = pre_d() |
|
|
703 |
'''end''' |
|
|
704 |
|
|
|
705 |
if param_f is None: |
|
|
706 |
param_f = lambda: param.image(128) |
|
|
707 |
# param_f is a function that should return two things |
|
|
708 |
# params - parameters to update, which we pass to the optimizer |
|
|
709 |
# image_f - a function that returns an image as a tensor |
|
|
710 |
params, image_f = param_f() |
|
|
711 |
|
|
|
712 |
if optimizer is None: |
|
|
713 |
optimizer = lambda params: torch.optim.Adam(params, lr=5e-1) |
|
|
714 |
optimizer = optimizer(params) |
|
|
715 |
|
|
|
716 |
if transforms is None: |
|
|
717 |
transforms = [] |
|
|
718 |
transforms = transforms.copy() |
|
|
719 |
|
|
|
720 |
# Upsample images smaller than 224 |
|
|
721 |
image_shape = image_f().shape |
|
|
722 |
|
|
|
723 |
if fixed_image_size is not None: |
|
|
724 |
new_size = fixed_image_size |
|
|
725 |
elif image_shape[2] < 224 or image_shape[3] < 224: |
|
|
726 |
new_size = 224 |
|
|
727 |
else: |
|
|
728 |
new_size = None |
|
|
729 |
if new_size: |
|
|
730 |
transforms.append( |
|
|
731 |
torch.nn.Upsample(size=new_size, mode="bilinear", align_corners=True) |
|
|
732 |
) |
|
|
733 |
|
|
|
734 |
transform_f = transform.compose(transforms) |
|
|
735 |
|
|
|
736 |
hook = hook_model(model, image_f) |
|
|
737 |
objective_f = objectives.as_objective(objective_f) |
|
|
738 |
|
|
|
739 |
if verbose: |
|
|
740 |
model(transform_f(image_f())) |
|
|
741 |
print("Initial loss of ad: {:.3f}".format(objective_f(hook))) |
|
|
742 |
|
|
|
743 |
images = [] |
|
|
744 |
try: |
|
|
745 |
for i in tqdm(range(1, max(thresholds) + 1), disable=(not progress)): |
|
|
746 |
optimizer.zero_grad() |
|
|
747 |
try: |
|
|
748 |
model(transform_f(image_f())) |
|
|
749 |
except RuntimeError as ex: |
|
|
750 |
if i == 1: |
|
|
751 |
# Only display the warning message |
|
|
752 |
# on the first iteration, no need to do that |
|
|
753 |
# every iteration |
|
|
754 |
warnings.warn( |
|
|
755 |
"Some layers could not be computed because the size of the " |
|
|
756 |
"image is not big enough. It is fine, as long as the non" |
|
|
757 |
"computed layers are not used in the objective function" |
|
|
758 |
f"(exception details: '{ex}')" |
|
|
759 |
) |
|
|
760 |
if args.disc: |
|
|
761 |
'''dom loss part''' |
|
|
762 |
# content_img = raw_img |
|
|
763 |
# style_img = raw_img |
|
|
764 |
# precpt_loss = run_precpt(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, transform_f(image_f())) |
|
|
765 |
for p in netD.parameters(): |
|
|
766 |
p.requires_grad = True |
|
|
767 |
for _ in range(args.drec): |
|
|
768 |
netD.zero_grad() |
|
|
769 |
real = real_img |
|
|
770 |
fake = image_f() |
|
|
771 |
# for _ in range(6): |
|
|
772 |
# errD, D_x, D_G_z1 = update_d(args, netD, optD, real, fake) |
|
|
773 |
|
|
|
774 |
# label = torch.full((args.b,), 1., dtype=torch.float, device=device) |
|
|
775 |
# label.fill_(1.) |
|
|
776 |
# output = netD(fake).view(-1) |
|
|
777 |
# errG = nn.BCELoss()(output, label) |
|
|
778 |
# D_G_z2 = output.mean().item() |
|
|
779 |
# dom_loss = err |
|
|
780 |
one = torch.tensor(1, dtype=torch.float) |
|
|
781 |
mone = one * -1 |
|
|
782 |
one = one.cuda(args.gpu_device) |
|
|
783 |
mone = mone.cuda(args.gpu_device) |
|
|
784 |
|
|
|
785 |
d_loss_real = netD(real) |
|
|
786 |
d_loss_real = d_loss_real.mean() |
|
|
787 |
d_loss_real.backward(mone) |
|
|
788 |
|
|
|
789 |
d_loss_fake = netD(fake) |
|
|
790 |
d_loss_fake = d_loss_fake.mean() |
|
|
791 |
d_loss_fake.backward(one) |
|
|
792 |
|
|
|
793 |
# Train with gradient penalty |
|
|
794 |
gradient_penalty = calculate_gradient_penalty(netD, real.data, fake.data) |
|
|
795 |
gradient_penalty.backward() |
|
|
796 |
|
|
|
797 |
|
|
|
798 |
d_loss = d_loss_fake - d_loss_real + gradient_penalty |
|
|
799 |
Wasserstein_D = d_loss_real - d_loss_fake |
|
|
800 |
optD.step() |
|
|
801 |
|
|
|
802 |
# Generator update |
|
|
803 |
for p in netD.parameters(): |
|
|
804 |
p.requires_grad = False # to avoid computation |
|
|
805 |
|
|
|
806 |
fake_images = image_f() |
|
|
807 |
g_loss = netD(fake_images) |
|
|
808 |
g_loss = -g_loss.mean() |
|
|
809 |
dom_loss = g_loss |
|
|
810 |
g_cost = -g_loss |
|
|
811 |
|
|
|
812 |
if i% 5 == 0: |
|
|
813 |
print(f' loss_fake: {d_loss_fake}, loss_real: {d_loss_real}') |
|
|
814 |
print(f'Generator g_loss: {g_loss}') |
|
|
815 |
'''end''' |
|
|
816 |
|
|
|
817 |
|
|
|
818 |
|
|
|
819 |
'''ssim loss''' |
|
|
820 |
|
|
|
821 |
'''end''' |
|
|
822 |
|
|
|
823 |
if args.disc: |
|
|
824 |
loss = sign * objective_f(hook) + args.pw * dom_loss |
|
|
825 |
# loss = args.pw * dom_loss |
|
|
826 |
else: |
|
|
827 |
loss = sign * objective_f(hook) |
|
|
828 |
# loss = args.pw * dom_loss |
|
|
829 |
|
|
|
830 |
loss.backward() |
|
|
831 |
|
|
|
832 |
# #video the images |
|
|
833 |
# if i % 5 == 0: |
|
|
834 |
# print('1') |
|
|
835 |
# image_name = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png' |
|
|
836 |
# img_path = os.path.join(args.path_helper['sample_path'], str(image_name)) |
|
|
837 |
# export(image_f(), img_path) |
|
|
838 |
# #end |
|
|
839 |
# if i % 50 == 0: |
|
|
840 |
# print('Loss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' |
|
|
841 |
# % (errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) |
|
|
842 |
|
|
|
843 |
optimizer.step() |
|
|
844 |
if i in thresholds: |
|
|
845 |
image = tensor_to_img_array(image_f()) |
|
|
846 |
# if verbose: |
|
|
847 |
# print("Loss at step {}: {:.3f}".format(i, objective_f(hook))) |
|
|
848 |
if save_image: |
|
|
849 |
na = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png' |
|
|
850 |
na = date_time + na |
|
|
851 |
outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path'] |
|
|
852 |
img_path = os.path.join(outpath, str(na)) |
|
|
853 |
export(image_f(), img_path) |
|
|
854 |
|
|
|
855 |
images.append(image) |
|
|
856 |
except KeyboardInterrupt: |
|
|
857 |
print("Interrupted optimization at step {:d}.".format(i)) |
|
|
858 |
if verbose: |
|
|
859 |
print("Loss at step {}: {:.3f}".format(i, objective_f(hook))) |
|
|
860 |
images.append(tensor_to_img_array(image_f())) |
|
|
861 |
|
|
|
862 |
if save_image: |
|
|
863 |
na = image_name[0].split('\\')[-1].split('.')[0] + '.png' |
|
|
864 |
na = date_time + na |
|
|
865 |
outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path'] |
|
|
866 |
img_path = os.path.join(outpath, str(na)) |
|
|
867 |
export(image_f(), img_path) |
|
|
868 |
if show_inline: |
|
|
869 |
show(tensor_to_img_array(image_f())) |
|
|
870 |
elif show_image: |
|
|
871 |
view(image_f()) |
|
|
872 |
return image_f() |
|
|
873 |
|
|
|
874 |
|
|
|
875 |
def tensor_to_img_array(tensor): |
|
|
876 |
image = tensor.cpu().detach().numpy() |
|
|
877 |
image = np.transpose(image, [0, 2, 3, 1]) |
|
|
878 |
return image |
|
|
879 |
|
|
|
880 |
|
|
|
881 |
def view(tensor): |
|
|
882 |
image = tensor_to_img_array(tensor) |
|
|
883 |
assert len(image.shape) in [ |
|
|
884 |
3, |
|
|
885 |
4, |
|
|
886 |
], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape) |
|
|
887 |
# Change dtype for PIL.Image |
|
|
888 |
image = (image * 255).astype(np.uint8) |
|
|
889 |
if len(image.shape) == 4: |
|
|
890 |
image = np.concatenate(image, axis=1) |
|
|
891 |
Image.fromarray(image).show() |
|
|
892 |
|
|
|
893 |
|
|
|
894 |
def export(tensor, img_path=None): |
|
|
895 |
# image_name = image_name or "image.jpg" |
|
|
896 |
c = tensor.size(1) |
|
|
897 |
# if c == 7: |
|
|
898 |
# for i in range(c): |
|
|
899 |
# w_map = tensor[:,i,:,:].unsqueeze(1) |
|
|
900 |
# w_map = tensor_to_img_array(w_map).squeeze() |
|
|
901 |
# w_map = (w_map * 255).astype(np.uint8) |
|
|
902 |
# image_name = image_name[0].split('/')[-1].split('.')[0] + str(i)+ '.png' |
|
|
903 |
# wheat = sns.heatmap(w_map,cmap='coolwarm') |
|
|
904 |
# figure = wheat.get_figure() |
|
|
905 |
# figure.savefig ('./fft_maps/weightheatmap/'+str(image_name), dpi=400) |
|
|
906 |
# figure = 0 |
|
|
907 |
# else: |
|
|
908 |
if c == 3: |
|
|
909 |
vutils.save_image(tensor, fp = img_path) |
|
|
910 |
else: |
|
|
911 |
image = tensor[:,0:3,:,:] |
|
|
912 |
w_map = tensor[:,-1,:,:].unsqueeze(1) |
|
|
913 |
image = tensor_to_img_array(image) |
|
|
914 |
w_map = 1 - tensor_to_img_array(w_map).squeeze() |
|
|
915 |
# w_map[w_map==1] = 0 |
|
|
916 |
assert len(image.shape) in [ |
|
|
917 |
3, |
|
|
918 |
4, |
|
|
919 |
], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape) |
|
|
920 |
# Change dtype for PIL.Image |
|
|
921 |
image = (image * 255).astype(np.uint8) |
|
|
922 |
w_map = (w_map * 255).astype(np.uint8) |
|
|
923 |
|
|
|
924 |
Image.fromarray(w_map,'L').save(img_path) |
|
|
925 |
|
|
|
926 |
|
|
|
927 |
class ModuleHook: |
|
|
928 |
def __init__(self, module): |
|
|
929 |
self.hook = module.register_forward_hook(self.hook_fn) |
|
|
930 |
self.module = None |
|
|
931 |
self.features = None |
|
|
932 |
|
|
|
933 |
|
|
|
934 |
def hook_fn(self, module, input, output): |
|
|
935 |
self.module = module |
|
|
936 |
self.features = output |
|
|
937 |
|
|
|
938 |
|
|
|
939 |
def close(self): |
|
|
940 |
self.hook.remove() |
|
|
941 |
|
|
|
942 |
|
|
|
943 |
def hook_model(model, image_f): |
|
|
944 |
features = OrderedDict() |
|
|
945 |
# recursive hooking function |
|
|
946 |
def hook_layers(net, prefix=[]): |
|
|
947 |
if hasattr(net, "_modules"): |
|
|
948 |
for name, layer in net._modules.items(): |
|
|
949 |
if layer is None: |
|
|
950 |
# e.g. GoogLeNet's aux1 and aux2 layers |
|
|
951 |
continue |
|
|
952 |
features["_".join(prefix + [name])] = ModuleHook(layer) |
|
|
953 |
hook_layers(layer, prefix=prefix + [name]) |
|
|
954 |
|
|
|
955 |
hook_layers(model) |
|
|
956 |
|
|
|
957 |
def hook(layer): |
|
|
958 |
if layer == "input": |
|
|
959 |
out = image_f() |
|
|
960 |
elif layer == "labels": |
|
|
961 |
out = list(features.values())[-1].features |
|
|
962 |
else: |
|
|
963 |
assert layer in features, f"Invalid layer {layer}. Retrieve the list of layers with `lucent.modelzoo.util.get_model_layers(model)`." |
|
|
964 |
out = features[layer].features |
|
|
965 |
assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example." |
|
|
966 |
return out |
|
|
967 |
|
|
|
968 |
return hook |
|
|
969 |
|
|
|
970 |
def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None,thre=0.5): |
|
|
971 |
|
|
|
972 |
b,c,h,w = pred_masks.size() |
|
|
973 |
dev = pred_masks.get_device() |
|
|
974 |
row_num = min(b, 4) |
|
|
975 |
|
|
|
976 |
if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0: |
|
|
977 |
pred_masks = torch.sigmoid(pred_masks) |
|
|
978 |
|
|
|
979 |
pred_masks = torch.tensor(pred_masks>thre) |
|
|
980 |
|
|
|
981 |
if reverse == True: |
|
|
982 |
pred_masks = 1 - pred_masks |
|
|
983 |
gt_masks = 1 - gt_masks |
|
|
984 |
if c == 2: |
|
|
985 |
pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) |
|
|
986 |
gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) |
|
|
987 |
tup = (imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]) |
|
|
988 |
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) |
|
|
989 |
compose = torch.cat((pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) |
|
|
990 |
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) |
|
|
991 |
else: |
|
|
992 |
imgs = torchvision.transforms.Resize((h,w))(imgs) |
|
|
993 |
if imgs.size(1) == 1: |
|
|
994 |
imgs = imgs[:,0,:,:].unsqueeze(1).expand(b,3,h,w) |
|
|
995 |
pred_masks = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w) |
|
|
996 |
gt_masks = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w) |
|
|
997 |
if points != None: |
|
|
998 |
for i in range(b): |
|
|
999 |
if args.thd: |
|
|
1000 |
p = np.round(points.cpu()/args.roi_size * args.out_size).to(dtype = torch.int) |
|
|
1001 |
else: |
|
|
1002 |
p = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int) |
|
|
1003 |
# gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev))) |
|
|
1004 |
for pmt_id in range(p.shape[1]): |
|
|
1005 |
gt_masks[i,0,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 255 |
|
|
1006 |
gt_masks[i,1,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 0 |
|
|
1007 |
gt_masks[i,2,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 0 |
|
|
1008 |
tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:]) |
|
|
1009 |
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) |
|
|
1010 |
compose = torch.cat(tup,0) |
|
|
1011 |
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) |
|
|
1012 |
|
|
|
1013 |
return |
|
|
1014 |
|
|
|
1015 |
def eval_seg(pred,true_mask_p,threshold): |
|
|
1016 |
''' |
|
|
1017 |
threshold: a int or a tuple of int |
|
|
1018 |
masks: [b,2,h,w] |
|
|
1019 |
pred: [b,2,h,w] |
|
|
1020 |
''' |
|
|
1021 |
b, c, h, w = pred.size() |
|
|
1022 |
if c == 2: |
|
|
1023 |
iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0 |
|
|
1024 |
for th in threshold: |
|
|
1025 |
|
|
|
1026 |
gt_vmask_p = (true_mask_p > th).float() |
|
|
1027 |
vpred = (pred > th).float() |
|
|
1028 |
vpred_cpu = vpred.cpu() |
|
|
1029 |
disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32') |
|
|
1030 |
cup_pred = vpred_cpu[:,1,:,:].numpy().astype('int32') |
|
|
1031 |
|
|
|
1032 |
disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32') |
|
|
1033 |
cup_mask = gt_vmask_p [:, 1, :, :].squeeze(1).cpu().numpy().astype('int32') |
|
|
1034 |
|
|
|
1035 |
'''iou for numpy''' |
|
|
1036 |
iou_d += iou(disc_pred,disc_mask) |
|
|
1037 |
iou_c += iou(cup_pred,cup_mask) |
|
|
1038 |
|
|
|
1039 |
'''dice for torch''' |
|
|
1040 |
disc_dice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item() |
|
|
1041 |
cup_dice += dice_coeff(vpred[:,1,:,:], gt_vmask_p[:,1,:,:]).item() |
|
|
1042 |
|
|
|
1043 |
return iou_d / len(threshold), iou_c / len(threshold), disc_dice / len(threshold), cup_dice / len(threshold) |
|
|
1044 |
else: |
|
|
1045 |
eiou, edice = 0,0 |
|
|
1046 |
for th in threshold: |
|
|
1047 |
|
|
|
1048 |
gt_vmask_p = (true_mask_p > th).float() |
|
|
1049 |
vpred = (pred > th).float() |
|
|
1050 |
vpred_cpu = vpred.cpu() |
|
|
1051 |
disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32') |
|
|
1052 |
|
|
|
1053 |
disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32') |
|
|
1054 |
|
|
|
1055 |
'''iou for numpy''' |
|
|
1056 |
eiou += iou(disc_pred,disc_mask) |
|
|
1057 |
|
|
|
1058 |
'''dice for torch''' |
|
|
1059 |
edice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item() |
|
|
1060 |
|
|
|
1061 |
return eiou / len(threshold), edice / len(threshold) |
|
|
1062 |
|
|
|
1063 |
# @objectives.wrap_objective() |
|
|
1064 |
def dot_compare(layer, batch=1, cossim_pow=0): |
|
|
1065 |
def inner(T): |
|
|
1066 |
dot = (T(layer)[batch] * T(layer)[0]).sum() |
|
|
1067 |
mag = torch.sqrt(torch.sum(T(layer)[0]**2)) |
|
|
1068 |
cossim = dot/(1e-6 + mag) |
|
|
1069 |
return -dot * cossim ** cossim_pow |
|
|
1070 |
return inner |
|
|
1071 |
|
|
|
1072 |
def init_D(m): |
|
|
1073 |
classname = m.__class__.__name__ |
|
|
1074 |
if classname.find('Conv') != -1: |
|
|
1075 |
nn.init.normal_(m.weight.data, 0.0, 0.02) |
|
|
1076 |
elif classname.find('BatchNorm') != -1: |
|
|
1077 |
nn.init.normal_(m.weight.data, 1.0, 0.02) |
|
|
1078 |
nn.init.constant_(m.bias.data, 0) |
|
|
1079 |
|
|
|
1080 |
def pre_d(): |
|
|
1081 |
netD = Discriminator(3).to(device) |
|
|
1082 |
# netD.apply(init_D) |
|
|
1083 |
beta1 = 0.5 |
|
|
1084 |
dis_lr = 0.00002 |
|
|
1085 |
optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) |
|
|
1086 |
return netD, optimizerD |
|
|
1087 |
|
|
|
1088 |
def update_d(args, netD, optimizerD, real, fake): |
|
|
1089 |
criterion = nn.BCELoss() |
|
|
1090 |
|
|
|
1091 |
label = torch.full((args.b,), 1., dtype=torch.float, device=device) |
|
|
1092 |
output = netD(real).view(-1) |
|
|
1093 |
# Calculate loss on all-real batch |
|
|
1094 |
errD_real = criterion(output, label) |
|
|
1095 |
# Calculate gradients for D in backward pass |
|
|
1096 |
errD_real.backward() |
|
|
1097 |
D_x = output.mean().item() |
|
|
1098 |
|
|
|
1099 |
label.fill_(0.) |
|
|
1100 |
# Classify all fake batch with D |
|
|
1101 |
output = netD(fake.detach()).view(-1) |
|
|
1102 |
# Calculate D's loss on the all-fake batch |
|
|
1103 |
errD_fake = criterion(output, label) |
|
|
1104 |
# Calculate the gradients for this batch, accumulated (summed) with previous gradients |
|
|
1105 |
errD_fake.backward() |
|
|
1106 |
D_G_z1 = output.mean().item() |
|
|
1107 |
# Compute error of D as sum over the fake and the real batches |
|
|
1108 |
errD = errD_real + errD_fake |
|
|
1109 |
# Update D |
|
|
1110 |
optimizerD.step() |
|
|
1111 |
|
|
|
1112 |
return errD, D_x, D_G_z1 |
|
|
1113 |
|
|
|
1114 |
def calculate_gradient_penalty(netD, real_images, fake_images): |
|
|
1115 |
eta = torch.FloatTensor(args.b,1,1,1).uniform_(0,1) |
|
|
1116 |
eta = eta.expand(args.b, real_images.size(1), real_images.size(2), real_images.size(3)).to(device = device) |
|
|
1117 |
|
|
|
1118 |
interpolated = (eta * real_images + ((1 - eta) * fake_images)).to(device = device) |
|
|
1119 |
|
|
|
1120 |
# define it to calculate gradient |
|
|
1121 |
interpolated = Variable(interpolated, requires_grad=True) |
|
|
1122 |
|
|
|
1123 |
# calculate probability of interpolated examples |
|
|
1124 |
prob_interpolated = netD(interpolated) |
|
|
1125 |
|
|
|
1126 |
# calculate gradients of probabilities with respect to examples |
|
|
1127 |
gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated, |
|
|
1128 |
grad_outputs=torch.ones( |
|
|
1129 |
prob_interpolated.size()).to(device = device), |
|
|
1130 |
create_graph=True, retain_graph=True)[0] |
|
|
1131 |
|
|
|
1132 |
grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 |
|
|
1133 |
return grad_penalty |
|
|
1134 |
|
|
|
1135 |
|
|
|
1136 |
def random_click(mask, point_labels = 1, inout = 1): |
|
|
1137 |
indices = np.argwhere(mask == inout) |
|
|
1138 |
return indices[np.random.randint(len(indices))] |
|
|
1139 |
|
|
|
1140 |
|
|
|
1141 |
def generate_click_prompt(img, msk, pt_label = 1): |
|
|
1142 |
# return: prompt, prompt mask |
|
|
1143 |
pt_list = [] |
|
|
1144 |
msk_list = [] |
|
|
1145 |
b, c, h, w, d = msk.size() |
|
|
1146 |
msk = msk[:,0,:,:,:] |
|
|
1147 |
for i in range(d): |
|
|
1148 |
pt_list_s = [] |
|
|
1149 |
msk_list_s = [] |
|
|
1150 |
for j in range(b): |
|
|
1151 |
msk_s = msk[j,:,:,i] |
|
|
1152 |
indices = torch.nonzero(msk_s) |
|
|
1153 |
if indices.size(0) == 0: |
|
|
1154 |
# generate a random array between [0-h, 0-h]: |
|
|
1155 |
random_index = torch.randint(0, h, (2,)).to(device = msk.device) |
|
|
1156 |
new_s = msk_s |
|
|
1157 |
else: |
|
|
1158 |
random_index = random.choice(indices) |
|
|
1159 |
label = msk_s[random_index[0], random_index[1]] |
|
|
1160 |
new_s = torch.zeros_like(msk_s) |
|
|
1161 |
# convert bool tensor to int |
|
|
1162 |
new_s = (msk_s == label).to(dtype = torch.float) |
|
|
1163 |
# new_s[msk_s == label] = 1 |
|
|
1164 |
pt_list_s.append(random_index) |
|
|
1165 |
msk_list_s.append(new_s) |
|
|
1166 |
pts = torch.stack(pt_list_s, dim=0) |
|
|
1167 |
msks = torch.stack(msk_list_s, dim=0) |
|
|
1168 |
pt_list.append(pts) |
|
|
1169 |
msk_list.append(msks) |
|
|
1170 |
pt = torch.stack(pt_list, dim=-1) |
|
|
1171 |
msk = torch.stack(msk_list, dim=-1) |
|
|
1172 |
|
|
|
1173 |
msk = msk.unsqueeze(1) |
|
|
1174 |
|
|
|
1175 |
return img, pt, msk #[b, 2, d], [b, c, h, w, d] |
|
|
1176 |
|
|
|
1177 |
|
|
|
1178 |
|
|
|
1179 |
def drawContour(m,s,RGB,size,a=0.8): |
|
|
1180 |
"""Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'""" |
|
|
1181 |
# Fill contour "c" with white, make all else black |
|
|
1182 |
|
|
|
1183 |
#ratio = int(255/np.max(s)) |
|
|
1184 |
#s = np.uint(s*ratio) |
|
|
1185 |
|
|
|
1186 |
# Find edges of this contour and make into Numpy array |
|
|
1187 |
contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE) |
|
|
1188 |
m_old = m.copy() |
|
|
1189 |
# Paint locations of found edges in color "RGB" onto "main" |
|
|
1190 |
cv2.drawContours(m,contours,-1,RGB,size) |
|
|
1191 |
m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0) |
|
|
1192 |
return m |
|
|
1193 |
|
|
|
1194 |
def IOU(pm, gt): |
|
|
1195 |
a = np.sum(np.bitwise_and(pm, gt)) |
|
|
1196 |
b = np.sum(pm) + np.sum(gt) - a +1e-8 |
|
|
1197 |
return a / b |
|
|
1198 |
|
|
|
1199 |
|
|
|
1200 |
def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): |
|
|
1201 |
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) |
|
|
1202 |
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) |
|
|
1203 |
if mean.ndim == 1: |
|
|
1204 |
mean = mean.view(-1, 1, 1) |
|
|
1205 |
if std.ndim == 1: |
|
|
1206 |
std = std.view(-1, 1, 1) |
|
|
1207 |
tensor.mul_(std).add_(mean) |
|
|
1208 |
return tensor |
|
|
1209 |
|
|
|
1210 |
|
|
|
1211 |
|
|
|
1212 |
def remove_small_objects(array_2d, min_size=30): |
|
|
1213 |
""" |
|
|
1214 |
Removes small objects from a 2D array using only NumPy. |
|
|
1215 |
|
|
|
1216 |
:param array_2d: Input 2D array. |
|
|
1217 |
:param min_size: Minimum size of objects to keep. |
|
|
1218 |
:return: 2D array with small objects removed. |
|
|
1219 |
""" |
|
|
1220 |
# Label connected components |
|
|
1221 |
structure = np.ones((3, 3), dtype=int) # Define connectivity |
|
|
1222 |
labeled, ncomponents = label(array_2d, structure) |
|
|
1223 |
|
|
|
1224 |
# Iterate through labeled components and remove small ones |
|
|
1225 |
for i in range(1, ncomponents + 1): |
|
|
1226 |
locations = np.where(labeled == i) |
|
|
1227 |
if len(locations[0]) < min_size: |
|
|
1228 |
array_2d[locations] = 0 |
|
|
1229 |
|
|
|
1230 |
return array_2d |
|
|
1231 |
|
|
|
1232 |
def create_box_mask(boxes,imgs): |
|
|
1233 |
b,_,w,h = imgs.shape |
|
|
1234 |
box_mask = torch.zeros((b,w,h)) |
|
|
1235 |
for k in range(b): |
|
|
1236 |
k_box = boxes[k] |
|
|
1237 |
for box in k_box: |
|
|
1238 |
x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3]) |
|
|
1239 |
box_mask[k,y1:y2,x1:x2] = 1 |
|
|
1240 |
return box_mask |