Diff of /utils.py [000000] .. [dff9e0]

Switch to unified view

a b/utils.py
1
""" helper function
2
3
author junde
4
"""
5
6
import sys
7
8
import numpy
9
10
import torch
11
import torch.nn as nn
12
from torch.autograd import Function
13
from torch.optim.lr_scheduler import _LRScheduler
14
import torchvision
15
import torchvision.transforms as transforms
16
import torch.optim as optim
17
import torchvision.utils as vutils
18
from torch.utils.data import DataLoader
19
from torch.autograd import Variable
20
from torch import autograd
21
import random
22
import math
23
import PIL
24
import matplotlib.pyplot as plt
25
import seaborn as sns
26
27
import collections
28
import logging
29
import cv2
30
import math
31
import os
32
import time
33
from datetime import datetime
34
35
import dateutil.tz
36
37
from typing import Union, Optional, List, Tuple, Text, BinaryIO
38
import pathlib
39
import warnings
40
import numpy as np
41
from scipy.ndimage import label, find_objects
42
from PIL import Image, ImageDraw, ImageFont, ImageColor
43
# from lucent.optvis.param.spatial import pixel_image, fft_image, init_image
44
# from lucent.optvis.param.color import to_valid_rgb
45
# from lucent.optvis import objectives, transform, param
46
# from lucent.misc.io import show
47
from torchvision.models import vgg19
48
import torch.nn.functional as F
49
import cfg
50
51
import warnings
52
from collections import OrderedDict
53
import numpy as np
54
from tqdm import tqdm
55
from PIL import Image
56
import torch
57
58
# from precpt import run_precpt
59
from models.discriminator import Discriminator
60
# from siren_pytorch import SirenNet, SirenWrapper
61
62
import shutil
63
import tempfile
64
65
import matplotlib.pyplot as plt
66
from tqdm import tqdm
67
68
from monai.losses import DiceCELoss
69
from monai.inferers import sliding_window_inference
70
from monai.transforms import (
71
    AsDiscrete,
72
    Compose,
73
    CropForegroundd,
74
    LoadImaged,
75
    Orientationd,
76
    RandFlipd,
77
    RandCropByPosNegLabeld,
78
    RandShiftIntensityd,
79
    ScaleIntensityRanged,
80
    Spacingd,
81
    RandRotate90d,
82
    EnsureTyped,
83
)
84
85
from monai.config import print_config
86
from monai.metrics import DiceMetric
87
from monai.networks.nets import SwinUNETR
88
89
from monai.data import (
90
    ThreadDataLoader,
91
    CacheDataset,
92
    load_decathlon_datalist,
93
    decollate_batch,
94
    set_track_meta,
95
)
96
97
98
99
100
args = cfg.parse_args()
101
device = torch.device('cuda', args.gpu_device)
102
103
'''preparation of domain loss'''
104
# cnn = vgg19(pretrained=True).features.to(device).eval()
105
# cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
106
# cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
107
108
# netD = Discriminator(1).to(device)
109
# netD.apply(init_D)
110
# beta1 = 0.5
111
# dis_lr = 0.0002
112
# optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999))
113
'''end'''
114
115
def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True):
116
    """ return given network
117
    """
118
119
    if net == 'sam':
120
        from models.sam import SamPredictor, sam_model_registry
121
        from models.sam.utils.transforms import ResizeLongestSide
122
123
        net = sam_model_registry['vit_b'](args,checkpoint=args.sam_ckpt).to(device)
124
    else:
125
        print('the network name you have entered is not supported yet')
126
        sys.exit()
127
128
    if use_gpu:
129
        #net = net.cuda(device = gpu_device)
130
        if distribution != 'none':
131
            net = torch.nn.DataParallel(net,device_ids=[int(id) for id in args.distributed.split(',')])
132
            net = net.to(device=gpu_device)
133
        else:
134
            net = net.to(device=gpu_device)
135
136
    return net
137
138
139
def get_decath_loader(args):
140
141
    train_transforms = Compose(
142
        [   
143
            LoadImaged(keys=["image", "label"], ensure_channel_first=True),
144
            ScaleIntensityRanged(
145
                keys=["image"],
146
                a_min=-175,
147
                a_max=250,
148
                b_min=0.0,
149
                b_max=1.0,
150
                clip=True,
151
            ),
152
            CropForegroundd(keys=["image", "label"], source_key="image"),
153
            Orientationd(keys=["image", "label"], axcodes="RAS"),
154
            Spacingd(
155
                keys=["image", "label"],
156
                pixdim=(1.5, 1.5, 2.0),
157
                mode=("bilinear", "nearest"),
158
            ),
159
            EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
160
            RandCropByPosNegLabeld(
161
                keys=["image", "label"],
162
                label_key="label",
163
                spatial_size=(args.roi_size, args.roi_size, args.chunk),
164
                pos=1,
165
                neg=1,
166
                num_samples=args.num_sample,
167
                image_key="image",
168
                image_threshold=0,
169
            ),
170
            RandFlipd(
171
                keys=["image", "label"],
172
                spatial_axis=[0],
173
                prob=0.10,
174
            ),
175
            RandFlipd(
176
                keys=["image", "label"],
177
                spatial_axis=[1],
178
                prob=0.10,
179
            ),
180
            RandFlipd(
181
                keys=["image", "label"],
182
                spatial_axis=[2],
183
                prob=0.10,
184
            ),
185
            RandRotate90d(
186
                keys=["image", "label"],
187
                prob=0.10,
188
                max_k=3,
189
            ),
190
            RandShiftIntensityd(
191
                keys=["image"],
192
                offsets=0.10,
193
                prob=0.50,
194
            ),
195
        ]
196
    )
197
    val_transforms = Compose(
198
        [
199
            LoadImaged(keys=["image", "label"], ensure_channel_first=True),
200
            ScaleIntensityRanged(
201
                keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
202
            ),
203
            CropForegroundd(keys=["image", "label"], source_key="image"),
204
            Orientationd(keys=["image", "label"], axcodes="RAS"),
205
            Spacingd(
206
                keys=["image", "label"],
207
                pixdim=(1.5, 1.5, 2.0),
208
                mode=("bilinear", "nearest"),
209
            ),
210
            EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
211
        ]
212
    )
213
214
215
216
    data_dir = args.data_path
217
    split_JSON = "dataset_0.json"
218
219
    datasets = os.path.join(data_dir, split_JSON)
220
    datalist = load_decathlon_datalist(datasets, True, "training")
221
    val_files = load_decathlon_datalist(datasets, True, "validation")
222
    train_ds = CacheDataset(
223
        data=datalist,
224
        transform=train_transforms,
225
        cache_num=24,
226
        cache_rate=1.0,
227
        num_workers=8,
228
    )
229
    train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.b, shuffle=True)
230
    val_ds = CacheDataset(
231
        data=val_files, transform=val_transforms, cache_num=2, cache_rate=1.0, num_workers=0
232
    )
233
    val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
234
235
    set_track_meta(False)
236
237
    return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files
238
239
240
def cka_loss(gram_featureA, gram_featureB):
241
242
    scaled_hsic = torch.dot(torch.flatten(gram_featureA),torch.flatten(gram_featureB))
243
    normalization_x = gram_featureA.norm()
244
    normalization_y = gram_featureB.norm()
245
    return scaled_hsic / (normalization_x * normalization_y)
246
247
248
class WarmUpLR(_LRScheduler):
249
    """warmup_training learning rate scheduler
250
    Args:
251
        optimizer: optimzier(e.g. SGD)
252
        total_iters: totoal_iters of warmup phase
253
    """
254
    def __init__(self, optimizer, total_iters, last_epoch=-1):
255
256
        self.total_iters = total_iters
257
        super().__init__(optimizer, last_epoch)
258
259
    def get_lr(self):
260
        """we will use the first m batches, and set the learning
261
        rate to base_lr * m / total_iters
262
        """
263
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
264
265
def gram_matrix(input):
266
    a, b, c, d = input.size()  # a=batch size(=1)
267
    # b=number of feature maps
268
    # (c,d)=dimensions of a f. map (N=c*d)
269
270
    features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL
271
272
    G = torch.mm(features, features.t())  # compute the gram product
273
274
    # we 'normalize' the values of the gram matrix
275
    # by dividing by the number of element in each feature maps.
276
    return G.div(a * b * c * d)
277
278
279
280
@torch.no_grad()
281
def make_grid(
282
    tensor: Union[torch.Tensor, List[torch.Tensor]],
283
    nrow: int = 8,
284
    padding: int = 2,
285
    normalize: bool = False,
286
    value_range: Optional[Tuple[int, int]] = None,
287
    scale_each: bool = False,
288
    pad_value: int = 0,
289
    **kwargs
290
) -> torch.Tensor:
291
    if not (torch.is_tensor(tensor) or
292
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
293
        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
294
295
    if "range" in kwargs.keys():
296
        warning = "range will be deprecated, please use value_range instead."
297
        warnings.warn(warning)
298
        value_range = kwargs["range"]
299
300
    # if list of tensors, convert to a 4D mini-batch Tensor
301
    if isinstance(tensor, list):
302
        tensor = torch.stack(tensor, dim=0)
303
304
    if tensor.dim() == 2:  # single image H x W
305
        tensor = tensor.unsqueeze(0)
306
    if tensor.dim() == 3:  # single image
307
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
308
            tensor = torch.cat((tensor, tensor, tensor), 0)
309
        tensor = tensor.unsqueeze(0)
310
311
    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
312
        tensor = torch.cat((tensor, tensor, tensor), 1)
313
314
    if normalize is True:
315
        tensor = tensor.clone()  # avoid modifying tensor in-place
316
        if value_range is not None:
317
            assert isinstance(value_range, tuple), \
318
                "value_range has to be a tuple (min, max) if specified. min and max are numbers"
319
320
        def norm_ip(img, low, high):
321
            img.clamp(min=low, max=high)
322
            img.sub_(low).div_(max(high - low, 1e-5))
323
324
        def norm_range(t, value_range):
325
            if value_range is not None:
326
                norm_ip(t, value_range[0], value_range[1])
327
            else:
328
                norm_ip(t, float(t.min()), float(t.max()))
329
330
        if scale_each is True:
331
            for t in tensor:  # loop over mini-batch dimension
332
                norm_range(t, value_range)
333
        else:
334
            norm_range(tensor, value_range)
335
336
    if tensor.size(0) == 1:
337
        return tensor.squeeze(0)
338
339
    # make the mini-batch of images into a grid
340
    nmaps = tensor.size(0)
341
    xmaps = min(nrow, nmaps)
342
    ymaps = int(math.ceil(float(nmaps) / xmaps))
343
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
344
    num_channels = tensor.size(1)
345
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
346
    k = 0
347
    for y in range(ymaps):
348
        for x in range(xmaps):
349
            if k >= nmaps:
350
                break
351
            # Tensor.copy_() is a valid method but seems to be missing from the stubs
352
            # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
353
            grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
354
                2, x * width + padding, width - padding
355
            ).copy_(tensor[k])
356
            k = k + 1
357
    return grid
358
359
360
@torch.no_grad()
361
def save_image(
362
    tensor: Union[torch.Tensor, List[torch.Tensor]],
363
    fp: Union[Text, pathlib.Path, BinaryIO],
364
    format: Optional[str] = None,
365
    **kwargs
366
) -> None:
367
    """
368
    Save a given Tensor into an image file.
369
    Args:
370
        tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
371
            saves the tensor as a grid of images by calling ``make_grid``.
372
        fp (string or file object): A filename or a file object
373
        format(Optional):  If omitted, the format to use is determined from the filename extension.
374
            If a file object was used instead of a filename, this parameter should always be used.
375
        **kwargs: Other arguments are documented in ``make_grid``.
376
    """
377
378
    grid = make_grid(tensor, **kwargs)
379
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
380
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
381
    im = Image.fromarray(ndarr)
382
    im.save(fp, format=format)
383
    
384
385
def create_logger(log_dir, phase='train'):
386
    time_str = time.strftime('%Y-%m-%d-%H-%M')
387
    log_file = '{}_{}.log'.format(time_str, phase)
388
    final_log_file = os.path.join(log_dir, log_file)
389
    head = '%(asctime)-15s %(message)s'
390
    logging.basicConfig(filename=str(final_log_file),
391
                        format=head)
392
    logger = logging.getLogger()
393
    logger.setLevel(logging.INFO)
394
    console = logging.StreamHandler()
395
    logging.getLogger('').addHandler(console)
396
397
    return logger
398
399
400
def set_log_dir(root_dir, exp_name):
401
    path_dict = {}
402
    os.makedirs(root_dir, exist_ok=True)
403
404
    # set log path
405
    exp_path = os.path.join(root_dir, exp_name)
406
    now = datetime.now(dateutil.tz.tzlocal())
407
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
408
    prefix = exp_path + '_' + timestamp
409
    os.makedirs(prefix)
410
    path_dict['prefix'] = prefix
411
412
    # set checkpoint path
413
    ckpt_path = os.path.join(prefix, 'Model')
414
    os.makedirs(ckpt_path)
415
    path_dict['ckpt_path'] = ckpt_path
416
417
    log_path = os.path.join(prefix, 'Log')
418
    os.makedirs(log_path)
419
    path_dict['log_path'] = log_path
420
421
    # set sample image path for fid calculation
422
    sample_path = os.path.join(prefix, 'Samples')
423
    os.makedirs(sample_path)
424
    path_dict['sample_path'] = sample_path
425
426
    return path_dict
427
428
429
def save_checkpoint(states, is_best, output_dir,
430
                    filename='checkpoint.pth'):
431
    torch.save(states, os.path.join(output_dir, filename))
432
    if is_best:
433
        torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))
434
435
436
class RunningStats:
437
    def __init__(self, WIN_SIZE):
438
        self.mean = 0
439
        self.run_var = 0
440
        self.WIN_SIZE = WIN_SIZE
441
442
        self.window = collections.deque(maxlen=WIN_SIZE)
443
444
    def clear(self):
445
        self.window.clear()
446
        self.mean = 0
447
        self.run_var = 0
448
449
    def is_full(self):
450
        return len(self.window) == self.WIN_SIZE
451
452
    def push(self, x):
453
454
        if len(self.window) == self.WIN_SIZE:
455
            # Adjusting variance
456
            x_removed = self.window.popleft()
457
            self.window.append(x)
458
            old_m = self.mean
459
            self.mean += (x - x_removed) / self.WIN_SIZE
460
            self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed)
461
        else:
462
            # Calculating first variance
463
            self.window.append(x)
464
            delta = x - self.mean
465
            self.mean += delta / len(self.window)
466
            self.run_var += delta * (x - self.mean)
467
468
    def get_mean(self):
469
        return self.mean if len(self.window) else 0.0
470
471
    def get_var(self):
472
        return self.run_var / len(self.window) if len(self.window) > 1 else 0.0
473
474
    def get_std(self):
475
        return math.sqrt(self.get_var())
476
477
    def get_all(self):
478
        return list(self.window)
479
480
    def __str__(self):
481
        return "Current window values: {}".format(list(self.window))
482
483
def iou(outputs: np.array, labels: np.array):
484
    
485
    SMOOTH = 1e-6
486
    intersection = (outputs & labels).sum((1, 2))
487
    union = (outputs | labels).sum((1, 2))
488
489
    iou = (intersection + SMOOTH) / (union + SMOOTH)
490
491
492
    return iou.mean()
493
494
class DiceCoeff(Function):
495
    """Dice coeff for individual examples"""
496
497
    def forward(self, input, target):
498
        self.save_for_backward(input, target)
499
        eps = 0.0001
500
        self.inter = torch.dot(input.view(-1), target.view(-1))
501
        self.union = torch.sum(input) + torch.sum(target) + eps
502
503
        t = (2 * self.inter.float() + eps) / self.union.float()
504
        return t
505
506
    # This function has only a single output, so it gets only one gradient
507
    def backward(self, grad_output):
508
509
        input, target = self.saved_variables
510
        grad_input = grad_target = None
511
512
        if self.needs_input_grad[0]:
513
            grad_input = grad_output * 2 * (target * self.union - self.inter) \
514
                         / (self.union * self.union)
515
        if self.needs_input_grad[1]:
516
            grad_target = None
517
518
        return grad_input, grad_target
519
520
521
def dice_coeff(input, target):
522
    """Dice coeff for batches"""
523
    if input.is_cuda:
524
        s = torch.FloatTensor(1).to(device = input.device).zero_()
525
    else:
526
        s = torch.FloatTensor(1).zero_()
527
528
    for i, c in enumerate(zip(input, target)):
529
        s = s + DiceCoeff().forward(c[0], c[1])
530
531
    return s / (i + 1)
532
533
'''parameter'''
534
def para_image(w, h=None, img = None, mode = 'multi', seg = None, sd=None, batch=None,
535
          fft = False, channels=None, init = None):
536
    h = h or w
537
    batch = batch or 1
538
    ch = channels or 3
539
    shape = [batch, ch, h, w]
540
    param_f = fft_image if fft else pixel_image
541
    if init is not None:
542
        param_f = init_image
543
        params, maps_f = param_f(init)
544
    else:
545
        params, maps_f = param_f(shape, sd=sd)
546
    if mode == 'multi':
547
        output = to_valid_out(maps_f,img,seg)
548
    elif mode == 'seg':
549
        output = gene_out(maps_f,img)
550
    elif mode == 'raw':
551
        output = raw_out(maps_f,img)
552
    return params, output
553
554
def to_valid_out(maps_f,img,seg): #multi-rater
555
    def inner():
556
        maps = maps_f()
557
        maps = maps.to(device = img.device)
558
        maps = torch.nn.Softmax(dim = 1)(maps)
559
        final_seg = torch.multiply(seg,maps).sum(dim = 1, keepdim = True)
560
        return torch.cat((img,final_seg),1)
561
        # return torch.cat((img,maps),1)
562
    return inner
563
564
def gene_out(maps_f,img): #pure seg
565
    def inner():
566
        maps = maps_f()
567
        maps = maps.to(device = img.device)
568
        # maps = torch.nn.Sigmoid()(maps)
569
        return torch.cat((img,maps),1)
570
        # return torch.cat((img,maps),1)
571
    return inner
572
573
def raw_out(maps_f,img): #raw
574
    def inner():
575
        maps = maps_f()
576
        maps = maps.to(device = img.device)
577
        # maps = torch.nn.Sigmoid()(maps)
578
        return maps
579
        # return torch.cat((img,maps),1)
580
    return inner    
581
582
583
class CompositeActivation(torch.nn.Module):
584
585
    def forward(self, x):
586
        x = torch.atan(x)
587
        return torch.cat([x/0.67, (x*x)/0.6], 1)
588
        # return x
589
590
591
def cppn(args, size, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8,
592
         activation_fn=CompositeActivation, normalize=False, device = "cuda:0"):
593
594
    r = 3 ** 0.5
595
596
    coord_range = torch.linspace(-r, r, size)
597
    x = coord_range.view(-1, 1).repeat(1, coord_range.size(0))
598
    y = coord_range.view(1, -1).repeat(coord_range.size(0), 1)
599
600
    input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).repeat(batch,1,1,1).to(device)
601
602
    layers = []
603
    kernel_size = 1
604
    for i in range(num_layers):
605
        out_c = num_hidden_channels
606
        in_c = out_c * 2 # * 2 for composite activation
607
        if i == 0:
608
            in_c = 2
609
        if i == num_layers - 1:
610
            out_c = num_output_channels
611
        layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size)))
612
        if normalize:
613
            layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c)))
614
        if i < num_layers - 1:
615
            layers.append(('actv{}'.format(i), activation_fn()))
616
        else:
617
            layers.append(('output', torch.nn.Sigmoid()))
618
619
    # Initialize model
620
    net = torch.nn.Sequential(OrderedDict(layers)).to(device)
621
    # Initialize weights
622
    def weights_init(module):
623
        if isinstance(module, torch.nn.Conv2d):
624
            torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels))
625
            if module.bias is not None:
626
                torch.nn.init.zeros_(module.bias)
627
    net.apply(weights_init)
628
    # Set last conv2d layer's weights to 0
629
    torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight)
630
    outimg = raw_out(lambda: net(input_tensor),img) if args.netype == 'raw' else to_valid_out(lambda: net(input_tensor),img,seg)
631
    return net.parameters(), outimg
632
633
def get_siren(args):
634
    wrapper = get_network(args, 'siren', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed)
635
    '''load init weights'''
636
    checkpoint = torch.load('./logs/siren_train_init_2022_08_19_21_00_16/Model/checkpoint_best.pth')
637
    wrapper.load_state_dict(checkpoint['state_dict'],strict=False)
638
    '''end'''
639
640
    '''load prompt'''
641
    checkpoint = torch.load('./logs/vae_standard_refuge1_2022_08_21_17_56_49/Model/checkpoint500')
642
    vae = get_network(args, 'vae', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed)
643
    vae.load_state_dict(checkpoint['state_dict'],strict=False)
644
    '''end'''
645
646
    return wrapper, vae
647
648
649
def siren(args, wrapper, vae, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8,
650
         activation_fn=CompositeActivation, normalize=False, device = "cuda:0"):
651
    vae_img = torchvision.transforms.Resize(64)(img)
652
    latent = vae.encoder(vae_img).view(-1).detach()
653
    outimg = raw_out(lambda: wrapper(latent = latent),img) if args.netype == 'raw' else to_valid_out(lambda: wrapper(latent = latent),img,seg)
654
    # img = torch.randn(1, 3, 256, 256)
655
    # loss = wrapper(img)
656
    # loss.backward()
657
658
    # # after much training ...
659
    # # simply invoke the wrapper without passing in anything
660
661
    # pred_img = wrapper() # (1, 3, 256, 256)
662
    return wrapper.parameters(), outimg
663
        
664
665
'''adversary'''
666
def render_vis(
667
    args,
668
    model,
669
    objective_f,
670
    real_img,
671
    param_f=None,
672
    optimizer=None,
673
    transforms=None,
674
    thresholds=(256,),
675
    verbose=True,
676
    preprocess=True,
677
    progress=True,
678
    show_image=True,
679
    save_image=False,
680
    image_name=None,
681
    show_inline=False,
682
    fixed_image_size=None,
683
    label = 1,
684
    raw_img = None,
685
    prompt = None
686
):
687
    if label == 1:
688
        sign = 1
689
    elif label == 0:
690
        sign = -1
691
    else:
692
        print('label is wrong, label is',label)
693
    if args.reverse:
694
        sign = -sign
695
    if args.multilayer:
696
        sign = 1
697
698
    '''prepare'''
699
    now = datetime.now()
700
    date_time = now.strftime("%m-%d-%Y, %H:%M:%S")
701
702
    netD, optD = pre_d()
703
    '''end'''
704
705
    if param_f is None:
706
        param_f = lambda: param.image(128)
707
    # param_f is a function that should return two things
708
    # params - parameters to update, which we pass to the optimizer
709
    # image_f - a function that returns an image as a tensor
710
    params, image_f = param_f()
711
    
712
    if optimizer is None:
713
        optimizer = lambda params: torch.optim.Adam(params, lr=5e-1)
714
    optimizer = optimizer(params)
715
716
    if transforms is None:
717
        transforms = []
718
    transforms = transforms.copy()
719
720
    # Upsample images smaller than 224
721
    image_shape = image_f().shape
722
723
    if fixed_image_size is not None:
724
        new_size = fixed_image_size
725
    elif image_shape[2] < 224 or image_shape[3] < 224:
726
        new_size = 224
727
    else:
728
        new_size = None
729
    if new_size:
730
        transforms.append(
731
            torch.nn.Upsample(size=new_size, mode="bilinear", align_corners=True)
732
        )
733
734
    transform_f = transform.compose(transforms)
735
736
    hook = hook_model(model, image_f)
737
    objective_f = objectives.as_objective(objective_f)
738
739
    if verbose:
740
        model(transform_f(image_f()))
741
        print("Initial loss of ad: {:.3f}".format(objective_f(hook)))
742
743
    images = []
744
    try:
745
        for i in tqdm(range(1, max(thresholds) + 1), disable=(not progress)):
746
            optimizer.zero_grad()
747
            try:
748
                model(transform_f(image_f()))
749
            except RuntimeError as ex:
750
                if i == 1:
751
                    # Only display the warning message
752
                    # on the first iteration, no need to do that
753
                    # every iteration
754
                    warnings.warn(
755
                        "Some layers could not be computed because the size of the "
756
                        "image is not big enough. It is fine, as long as the non"
757
                        "computed layers are not used in the objective function"
758
                        f"(exception details: '{ex}')"
759
                    )
760
            if args.disc:
761
                '''dom loss part'''
762
                # content_img = raw_img
763
                # style_img = raw_img
764
                # precpt_loss = run_precpt(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, transform_f(image_f()))
765
                for p in netD.parameters():
766
                    p.requires_grad = True
767
                for _ in range(args.drec):
768
                    netD.zero_grad()
769
                    real = real_img
770
                    fake = image_f()
771
                    # for _ in range(6):
772
                    #     errD, D_x, D_G_z1 = update_d(args, netD, optD, real, fake)
773
774
                    # label = torch.full((args.b,), 1., dtype=torch.float, device=device)
775
                    # label.fill_(1.)
776
                    # output = netD(fake).view(-1)
777
                    # errG = nn.BCELoss()(output, label)
778
                    # D_G_z2 = output.mean().item()
779
                    # dom_loss = err
780
                    one = torch.tensor(1, dtype=torch.float)
781
                    mone = one * -1
782
                    one = one.cuda(args.gpu_device)
783
                    mone = mone.cuda(args.gpu_device)
784
785
                    d_loss_real = netD(real)
786
                    d_loss_real = d_loss_real.mean()
787
                    d_loss_real.backward(mone)
788
789
                    d_loss_fake = netD(fake)
790
                    d_loss_fake = d_loss_fake.mean()
791
                    d_loss_fake.backward(one)
792
793
                    # Train with gradient penalty
794
                    gradient_penalty = calculate_gradient_penalty(netD, real.data, fake.data)
795
                    gradient_penalty.backward()
796
797
798
                    d_loss = d_loss_fake - d_loss_real + gradient_penalty
799
                    Wasserstein_D = d_loss_real - d_loss_fake
800
                    optD.step()
801
802
                # Generator update
803
                for p in netD.parameters():
804
                    p.requires_grad = False  # to avoid computation
805
806
                fake_images = image_f()
807
                g_loss = netD(fake_images)
808
                g_loss = -g_loss.mean()
809
                dom_loss = g_loss
810
                g_cost = -g_loss
811
812
                if i% 5 == 0:
813
                    print(f' loss_fake: {d_loss_fake}, loss_real: {d_loss_real}')
814
                    print(f'Generator g_loss: {g_loss}')
815
                '''end'''
816
817
818
819
            '''ssim loss'''
820
821
            '''end'''
822
823
            if args.disc:
824
                loss = sign * objective_f(hook) + args.pw * dom_loss
825
                # loss = args.pw * dom_loss
826
            else:
827
                loss = sign * objective_f(hook)
828
                # loss = args.pw * dom_loss
829
830
            loss.backward()
831
832
            # #video the images
833
            # if i % 5 == 0:
834
            #     print('1')
835
            #     image_name = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png'
836
            #     img_path = os.path.join(args.path_helper['sample_path'], str(image_name))
837
            #     export(image_f(), img_path)
838
            # #end
839
            # if i % 50 == 0:
840
            #     print('Loss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
841
            #       % (errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
842
843
            optimizer.step()
844
            if i in thresholds:
845
                image = tensor_to_img_array(image_f())
846
                # if verbose:
847
                #     print("Loss at step {}: {:.3f}".format(i, objective_f(hook)))
848
                if save_image:
849
                    na = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png'
850
                    na = date_time + na
851
                    outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path']
852
                    img_path = os.path.join(outpath, str(na))
853
                    export(image_f(), img_path)
854
                
855
                images.append(image)
856
    except KeyboardInterrupt:
857
        print("Interrupted optimization at step {:d}.".format(i))
858
        if verbose:
859
            print("Loss at step {}: {:.3f}".format(i, objective_f(hook)))
860
        images.append(tensor_to_img_array(image_f()))
861
862
    if save_image:
863
        na = image_name[0].split('\\')[-1].split('.')[0] + '.png'
864
        na = date_time + na
865
        outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path']
866
        img_path = os.path.join(outpath, str(na))
867
        export(image_f(), img_path)
868
    if show_inline:
869
        show(tensor_to_img_array(image_f()))
870
    elif show_image:
871
        view(image_f())
872
    return image_f()
873
874
875
def tensor_to_img_array(tensor):
876
    image = tensor.cpu().detach().numpy()
877
    image = np.transpose(image, [0, 2, 3, 1])
878
    return image
879
880
881
def view(tensor):
882
    image = tensor_to_img_array(tensor)
883
    assert len(image.shape) in [
884
        3,
885
        4,
886
    ], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape)
887
    # Change dtype for PIL.Image
888
    image = (image * 255).astype(np.uint8)
889
    if len(image.shape) == 4:
890
        image = np.concatenate(image, axis=1)
891
    Image.fromarray(image).show()
892
893
894
def export(tensor, img_path=None):
895
    # image_name = image_name or "image.jpg"
896
    c = tensor.size(1)
897
    # if c == 7:
898
    #     for i in range(c):
899
    #         w_map = tensor[:,i,:,:].unsqueeze(1)
900
    #         w_map = tensor_to_img_array(w_map).squeeze()
901
    #         w_map = (w_map * 255).astype(np.uint8)
902
    #         image_name = image_name[0].split('/')[-1].split('.')[0] + str(i)+ '.png'
903
    #         wheat = sns.heatmap(w_map,cmap='coolwarm')
904
    #         figure = wheat.get_figure()    
905
    #         figure.savefig ('./fft_maps/weightheatmap/'+str(image_name), dpi=400)
906
    #         figure = 0
907
    # else:
908
    if c == 3:
909
        vutils.save_image(tensor, fp = img_path)
910
    else:
911
        image = tensor[:,0:3,:,:]
912
        w_map = tensor[:,-1,:,:].unsqueeze(1)
913
        image = tensor_to_img_array(image)
914
        w_map = 1 - tensor_to_img_array(w_map).squeeze()
915
        # w_map[w_map==1] = 0
916
        assert len(image.shape) in [
917
            3,
918
            4,
919
        ], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape)
920
        # Change dtype for PIL.Image
921
        image = (image * 255).astype(np.uint8)
922
        w_map = (w_map * 255).astype(np.uint8)
923
924
        Image.fromarray(w_map,'L').save(img_path)
925
926
927
class ModuleHook:
928
    def __init__(self, module):
929
        self.hook = module.register_forward_hook(self.hook_fn)
930
        self.module = None
931
        self.features = None
932
933
934
    def hook_fn(self, module, input, output):
935
        self.module = module
936
        self.features = output
937
938
939
    def close(self):
940
        self.hook.remove()
941
942
943
def hook_model(model, image_f):
944
    features = OrderedDict()
945
    # recursive hooking function
946
    def hook_layers(net, prefix=[]):
947
        if hasattr(net, "_modules"):
948
            for name, layer in net._modules.items():
949
                if layer is None:
950
                    # e.g. GoogLeNet's aux1 and aux2 layers
951
                    continue
952
                features["_".join(prefix + [name])] = ModuleHook(layer)
953
                hook_layers(layer, prefix=prefix + [name])
954
955
    hook_layers(model)
956
957
    def hook(layer):
958
        if layer == "input":
959
            out = image_f()
960
        elif layer == "labels":
961
            out = list(features.values())[-1].features
962
        else:
963
            assert layer in features, f"Invalid layer {layer}. Retrieve the list of layers with `lucent.modelzoo.util.get_model_layers(model)`."
964
            out = features[layer].features
965
        assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example."
966
        return out
967
968
    return hook
969
970
def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None,thre=0.5):
971
    
972
    b,c,h,w = pred_masks.size()
973
    dev = pred_masks.get_device()
974
    row_num = min(b, 4)
975
976
    if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0:
977
        pred_masks = torch.sigmoid(pred_masks)
978
    
979
    pred_masks = torch.tensor(pred_masks>thre)
980
981
    if reverse == True:
982
        pred_masks = 1 - pred_masks
983
        gt_masks = 1 - gt_masks
984
    if c == 2:
985
        pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
986
        gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
987
        tup = (imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:])
988
        # compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
989
        compose = torch.cat((pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
990
        vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10)
991
    else:
992
        imgs = torchvision.transforms.Resize((h,w))(imgs)
993
        if imgs.size(1) == 1:
994
            imgs = imgs[:,0,:,:].unsqueeze(1).expand(b,3,h,w)
995
        pred_masks = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w)
996
        gt_masks = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w)
997
        if points != None:
998
            for i in range(b):
999
                if args.thd:
1000
                    p = np.round(points.cpu()/args.roi_size * args.out_size).to(dtype = torch.int)
1001
                else:
1002
                    p = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int)
1003
                # gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev)))
1004
                for pmt_id in range(p.shape[1]):
1005
                    gt_masks[i,0,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 255
1006
                    gt_masks[i,1,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 0
1007
                    gt_masks[i,2,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 0
1008
        tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:])
1009
        # compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
1010
        compose = torch.cat(tup,0)
1011
        vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10)
1012
1013
    return
1014
1015
def eval_seg(pred,true_mask_p,threshold):
1016
    '''
1017
    threshold: a int or a tuple of int
1018
    masks: [b,2,h,w]
1019
    pred: [b,2,h,w]
1020
    '''
1021
    b, c, h, w = pred.size()
1022
    if c == 2:
1023
        iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0
1024
        for th in threshold:
1025
1026
            gt_vmask_p = (true_mask_p > th).float()
1027
            vpred = (pred > th).float()
1028
            vpred_cpu = vpred.cpu()
1029
            disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32')
1030
            cup_pred = vpred_cpu[:,1,:,:].numpy().astype('int32')
1031
1032
            disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32')
1033
            cup_mask = gt_vmask_p [:, 1, :, :].squeeze(1).cpu().numpy().astype('int32')
1034
    
1035
            '''iou for numpy'''
1036
            iou_d += iou(disc_pred,disc_mask)
1037
            iou_c += iou(cup_pred,cup_mask)
1038
1039
            '''dice for torch'''
1040
            disc_dice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item()
1041
            cup_dice += dice_coeff(vpred[:,1,:,:], gt_vmask_p[:,1,:,:]).item()
1042
            
1043
        return iou_d / len(threshold), iou_c / len(threshold), disc_dice / len(threshold), cup_dice / len(threshold)
1044
    else:
1045
        eiou, edice = 0,0
1046
        for th in threshold:
1047
1048
            gt_vmask_p = (true_mask_p > th).float()
1049
            vpred = (pred > th).float()
1050
            vpred_cpu = vpred.cpu()
1051
            disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32')
1052
1053
            disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32')
1054
    
1055
            '''iou for numpy'''
1056
            eiou += iou(disc_pred,disc_mask)
1057
1058
            '''dice for torch'''
1059
            edice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item()
1060
            
1061
        return eiou / len(threshold), edice / len(threshold)
1062
1063
# @objectives.wrap_objective()
1064
def dot_compare(layer, batch=1, cossim_pow=0):
1065
    def inner(T):
1066
        dot = (T(layer)[batch] * T(layer)[0]).sum()
1067
        mag = torch.sqrt(torch.sum(T(layer)[0]**2))
1068
        cossim = dot/(1e-6 + mag)
1069
        return -dot * cossim ** cossim_pow
1070
    return inner
1071
1072
def init_D(m):
1073
    classname = m.__class__.__name__
1074
    if classname.find('Conv') != -1:
1075
        nn.init.normal_(m.weight.data, 0.0, 0.02)
1076
    elif classname.find('BatchNorm') != -1:
1077
        nn.init.normal_(m.weight.data, 1.0, 0.02)
1078
        nn.init.constant_(m.bias.data, 0)
1079
1080
def pre_d():
1081
    netD = Discriminator(3).to(device)
1082
    # netD.apply(init_D)
1083
    beta1 = 0.5
1084
    dis_lr = 0.00002
1085
    optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999))
1086
    return netD, optimizerD
1087
1088
def update_d(args, netD, optimizerD, real, fake):
1089
    criterion = nn.BCELoss()
1090
1091
    label = torch.full((args.b,), 1., dtype=torch.float, device=device)
1092
    output = netD(real).view(-1)
1093
    # Calculate loss on all-real batch
1094
    errD_real = criterion(output, label)
1095
    # Calculate gradients for D in backward pass
1096
    errD_real.backward()
1097
    D_x = output.mean().item()
1098
1099
    label.fill_(0.)
1100
    # Classify all fake batch with D
1101
    output = netD(fake.detach()).view(-1)
1102
    # Calculate D's loss on the all-fake batch
1103
    errD_fake = criterion(output, label)
1104
    # Calculate the gradients for this batch, accumulated (summed) with previous gradients
1105
    errD_fake.backward()
1106
    D_G_z1 = output.mean().item()
1107
    # Compute error of D as sum over the fake and the real batches
1108
    errD = errD_real + errD_fake
1109
    # Update D
1110
    optimizerD.step()
1111
1112
    return errD, D_x, D_G_z1
1113
1114
def calculate_gradient_penalty(netD, real_images, fake_images):
1115
    eta = torch.FloatTensor(args.b,1,1,1).uniform_(0,1)
1116
    eta = eta.expand(args.b, real_images.size(1), real_images.size(2), real_images.size(3)).to(device = device)
1117
1118
    interpolated = (eta * real_images + ((1 - eta) * fake_images)).to(device = device)
1119
1120
    # define it to calculate gradient
1121
    interpolated = Variable(interpolated, requires_grad=True)
1122
1123
    # calculate probability of interpolated examples
1124
    prob_interpolated = netD(interpolated)
1125
1126
    # calculate gradients of probabilities with respect to examples
1127
    gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
1128
                            grad_outputs=torch.ones(
1129
                                prob_interpolated.size()).to(device = device),
1130
                            create_graph=True, retain_graph=True)[0]
1131
1132
    grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
1133
    return grad_penalty
1134
1135
1136
def random_click(mask, point_labels = 1, inout = 1):
1137
    indices = np.argwhere(mask == inout)
1138
    return indices[np.random.randint(len(indices))]
1139
1140
1141
def generate_click_prompt(img, msk, pt_label = 1):
1142
    # return: prompt, prompt mask
1143
    pt_list = []
1144
    msk_list = []
1145
    b, c, h, w, d = msk.size()
1146
    msk = msk[:,0,:,:,:]
1147
    for i in range(d):
1148
        pt_list_s = []
1149
        msk_list_s = []
1150
        for j in range(b):
1151
            msk_s = msk[j,:,:,i]
1152
            indices = torch.nonzero(msk_s)
1153
            if indices.size(0) == 0:
1154
                # generate a random array between [0-h, 0-h]:
1155
                random_index = torch.randint(0, h, (2,)).to(device = msk.device)
1156
                new_s = msk_s
1157
            else:
1158
                random_index = random.choice(indices)
1159
                label = msk_s[random_index[0], random_index[1]]
1160
                new_s = torch.zeros_like(msk_s)
1161
                # convert bool tensor to int
1162
                new_s = (msk_s == label).to(dtype = torch.float)
1163
                # new_s[msk_s == label] = 1
1164
            pt_list_s.append(random_index)
1165
            msk_list_s.append(new_s)
1166
        pts = torch.stack(pt_list_s, dim=0)
1167
        msks = torch.stack(msk_list_s, dim=0)
1168
        pt_list.append(pts)
1169
        msk_list.append(msks)
1170
    pt = torch.stack(pt_list, dim=-1)
1171
    msk = torch.stack(msk_list, dim=-1)
1172
1173
    msk = msk.unsqueeze(1)
1174
1175
    return img, pt, msk #[b, 2, d], [b, c, h, w, d]
1176
1177
1178
1179
def drawContour(m,s,RGB,size,a=0.8):
1180
    """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'"""
1181
    # Fill contour "c" with white, make all else black
1182
    
1183
    #ratio = int(255/np.max(s))
1184
    #s = np.uint(s*ratio)
1185
1186
    # Find edges of this contour and make into Numpy array
1187
    contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
1188
    m_old = m.copy()
1189
    # Paint locations of found edges in color "RGB" onto "main"
1190
    cv2.drawContours(m,contours,-1,RGB,size)
1191
    m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0)
1192
    return m
1193
1194
def IOU(pm, gt):
1195
    a = np.sum(np.bitwise_and(pm, gt))
1196
    b = np.sum(pm) + np.sum(gt) - a +1e-8
1197
    return a / b
1198
1199
1200
def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
1201
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
1202
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
1203
    if mean.ndim == 1:
1204
        mean = mean.view(-1, 1, 1)
1205
    if std.ndim == 1:
1206
        std = std.view(-1, 1, 1)
1207
    tensor.mul_(std).add_(mean)
1208
    return tensor
1209
1210
1211
1212
def remove_small_objects(array_2d, min_size=30):
1213
    """
1214
    Removes small objects from a 2D array using only NumPy.
1215
1216
    :param array_2d: Input 2D array.
1217
    :param min_size: Minimum size of objects to keep.
1218
    :return: 2D array with small objects removed.
1219
    """
1220
    # Label connected components
1221
    structure = np.ones((3, 3), dtype=int)  # Define connectivity
1222
    labeled, ncomponents = label(array_2d, structure)
1223
1224
    # Iterate through labeled components and remove small ones
1225
    for i in range(1, ncomponents + 1):
1226
        locations = np.where(labeled == i)
1227
        if len(locations[0]) < min_size:
1228
            array_2d[locations] = 0
1229
1230
    return array_2d
1231
1232
def create_box_mask(boxes,imgs):
1233
    b,_,w,h = imgs.shape
1234
    box_mask = torch.zeros((b,w,h))
1235
    for k in range(b):
1236
        k_box = boxes[k]
1237
        for box in k_box:
1238
            x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3])
1239
            box_mask[k,y1:y2,x1:x2] = 1
1240
    return box_mask