--- a +++ b/utils.py @@ -0,0 +1,1240 @@ +""" helper function + +author junde +""" + +import sys + +import numpy + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.optim.lr_scheduler import _LRScheduler +import torchvision +import torchvision.transforms as transforms +import torch.optim as optim +import torchvision.utils as vutils +from torch.utils.data import DataLoader +from torch.autograd import Variable +from torch import autograd +import random +import math +import PIL +import matplotlib.pyplot as plt +import seaborn as sns + +import collections +import logging +import cv2 +import math +import os +import time +from datetime import datetime + +import dateutil.tz + +from typing import Union, Optional, List, Tuple, Text, BinaryIO +import pathlib +import warnings +import numpy as np +from scipy.ndimage import label, find_objects +from PIL import Image, ImageDraw, ImageFont, ImageColor +# from lucent.optvis.param.spatial import pixel_image, fft_image, init_image +# from lucent.optvis.param.color import to_valid_rgb +# from lucent.optvis import objectives, transform, param +# from lucent.misc.io import show +from torchvision.models import vgg19 +import torch.nn.functional as F +import cfg + +import warnings +from collections import OrderedDict +import numpy as np +from tqdm import tqdm +from PIL import Image +import torch + +# from precpt import run_precpt +from models.discriminator import Discriminator +# from siren_pytorch import SirenNet, SirenWrapper + +import shutil +import tempfile + +import matplotlib.pyplot as plt +from tqdm import tqdm + +from monai.losses import DiceCELoss +from monai.inferers import sliding_window_inference +from monai.transforms import ( + AsDiscrete, + Compose, + CropForegroundd, + LoadImaged, + Orientationd, + RandFlipd, + RandCropByPosNegLabeld, + RandShiftIntensityd, + ScaleIntensityRanged, + Spacingd, + RandRotate90d, + EnsureTyped, +) + +from monai.config import print_config +from monai.metrics import DiceMetric +from monai.networks.nets import SwinUNETR + +from monai.data import ( + ThreadDataLoader, + CacheDataset, + load_decathlon_datalist, + decollate_batch, + set_track_meta, +) + + + + +args = cfg.parse_args() +device = torch.device('cuda', args.gpu_device) + +'''preparation of domain loss''' +# cnn = vgg19(pretrained=True).features.to(device).eval() +# cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) +# cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) + +# netD = Discriminator(1).to(device) +# netD.apply(init_D) +# beta1 = 0.5 +# dis_lr = 0.0002 +# optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) +'''end''' + +def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True): + """ return given network + """ + + if net == 'sam': + from models.sam import SamPredictor, sam_model_registry + from models.sam.utils.transforms import ResizeLongestSide + + net = sam_model_registry['vit_b'](args,checkpoint=args.sam_ckpt).to(device) + else: + print('the network name you have entered is not supported yet') + sys.exit() + + if use_gpu: + #net = net.cuda(device = gpu_device) + if distribution != 'none': + net = torch.nn.DataParallel(net,device_ids=[int(id) for id in args.distributed.split(',')]) + net = net.to(device=gpu_device) + else: + net = net.to(device=gpu_device) + + return net + + +def get_decath_loader(args): + + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], ensure_channel_first=True), + ScaleIntensityRanged( + keys=["image"], + a_min=-175, + a_max=250, + b_min=0.0, + b_max=1.0, + clip=True, + ), + CropForegroundd(keys=["image", "label"], source_key="image"), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd( + keys=["image", "label"], + pixdim=(1.5, 1.5, 2.0), + mode=("bilinear", "nearest"), + ), + EnsureTyped(keys=["image", "label"], device=device, track_meta=False), + RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_size, args.roi_size, args.chunk), + pos=1, + neg=1, + num_samples=args.num_sample, + image_key="image", + image_threshold=0, + ), + RandFlipd( + keys=["image", "label"], + spatial_axis=[0], + prob=0.10, + ), + RandFlipd( + keys=["image", "label"], + spatial_axis=[1], + prob=0.10, + ), + RandFlipd( + keys=["image", "label"], + spatial_axis=[2], + prob=0.10, + ), + RandRotate90d( + keys=["image", "label"], + prob=0.10, + max_k=3, + ), + RandShiftIntensityd( + keys=["image"], + offsets=0.10, + prob=0.50, + ), + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"], ensure_channel_first=True), + ScaleIntensityRanged( + keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True + ), + CropForegroundd(keys=["image", "label"], source_key="image"), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd( + keys=["image", "label"], + pixdim=(1.5, 1.5, 2.0), + mode=("bilinear", "nearest"), + ), + EnsureTyped(keys=["image", "label"], device=device, track_meta=True), + ] + ) + + + + data_dir = args.data_path + split_JSON = "dataset_0.json" + + datasets = os.path.join(data_dir, split_JSON) + datalist = load_decathlon_datalist(datasets, True, "training") + val_files = load_decathlon_datalist(datasets, True, "validation") + train_ds = CacheDataset( + data=datalist, + transform=train_transforms, + cache_num=24, + cache_rate=1.0, + num_workers=8, + ) + train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.b, shuffle=True) + val_ds = CacheDataset( + data=val_files, transform=val_transforms, cache_num=2, cache_rate=1.0, num_workers=0 + ) + val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) + + set_track_meta(False) + + return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files + + +def cka_loss(gram_featureA, gram_featureB): + + scaled_hsic = torch.dot(torch.flatten(gram_featureA),torch.flatten(gram_featureB)) + normalization_x = gram_featureA.norm() + normalization_y = gram_featureB.norm() + return scaled_hsic / (normalization_x * normalization_y) + + +class WarmUpLR(_LRScheduler): + """warmup_training learning rate scheduler + Args: + optimizer: optimzier(e.g. SGD) + total_iters: totoal_iters of warmup phase + """ + def __init__(self, optimizer, total_iters, last_epoch=-1): + + self.total_iters = total_iters + super().__init__(optimizer, last_epoch) + + def get_lr(self): + """we will use the first m batches, and set the learning + rate to base_lr * m / total_iters + """ + return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] + +def gram_matrix(input): + a, b, c, d = input.size() # a=batch size(=1) + # b=number of feature maps + # (c,d)=dimensions of a f. map (N=c*d) + + features = input.view(a * b, c * d) # resise F_XL into \hat F_XL + + G = torch.mm(features, features.t()) # compute the gram product + + # we 'normalize' the values of the gram matrix + # by dividing by the number of element in each feature maps. + return G.div(a * b * c * d) + + + +@torch.no_grad() +def make_grid( + tensor: Union[torch.Tensor, List[torch.Tensor]], + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: int = 0, + **kwargs +) -> torch.Tensor: + if not (torch.is_tensor(tensor) or + (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if "range" in kwargs.keys(): + warning = "range will be deprecated, please use value_range instead." + warnings.warn(warning) + value_range = kwargs["range"] + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None: + assert isinstance(value_range, tuple), \ + "value_range has to be a tuple (min, max) if specified. min and max are numbers" + + def norm_ip(img, low, high): + img.clamp(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + if tensor.size(0) == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + num_channels = tensor.size(1) + grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # Tensor.copy_() is a valid method but seems to be missing from the stubs + # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ + grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +@torch.no_grad() +def save_image( + tensor: Union[torch.Tensor, List[torch.Tensor]], + fp: Union[Text, pathlib.Path, BinaryIO], + format: Optional[str] = None, + **kwargs +) -> None: + """ + Save a given Tensor into an image file. + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(fp, format=format) + + +def create_logger(log_dir, phase='train'): + time_str = time.strftime('%Y-%m-%d-%H-%M') + log_file = '{}_{}.log'.format(time_str, phase) + final_log_file = os.path.join(log_dir, log_file) + head = '%(asctime)-15s %(message)s' + logging.basicConfig(filename=str(final_log_file), + format=head) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + console = logging.StreamHandler() + logging.getLogger('').addHandler(console) + + return logger + + +def set_log_dir(root_dir, exp_name): + path_dict = {} + os.makedirs(root_dir, exist_ok=True) + + # set log path + exp_path = os.path.join(root_dir, exp_name) + now = datetime.now(dateutil.tz.tzlocal()) + timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') + prefix = exp_path + '_' + timestamp + os.makedirs(prefix) + path_dict['prefix'] = prefix + + # set checkpoint path + ckpt_path = os.path.join(prefix, 'Model') + os.makedirs(ckpt_path) + path_dict['ckpt_path'] = ckpt_path + + log_path = os.path.join(prefix, 'Log') + os.makedirs(log_path) + path_dict['log_path'] = log_path + + # set sample image path for fid calculation + sample_path = os.path.join(prefix, 'Samples') + os.makedirs(sample_path) + path_dict['sample_path'] = sample_path + + return path_dict + + +def save_checkpoint(states, is_best, output_dir, + filename='checkpoint.pth'): + torch.save(states, os.path.join(output_dir, filename)) + if is_best: + torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) + + +class RunningStats: + def __init__(self, WIN_SIZE): + self.mean = 0 + self.run_var = 0 + self.WIN_SIZE = WIN_SIZE + + self.window = collections.deque(maxlen=WIN_SIZE) + + def clear(self): + self.window.clear() + self.mean = 0 + self.run_var = 0 + + def is_full(self): + return len(self.window) == self.WIN_SIZE + + def push(self, x): + + if len(self.window) == self.WIN_SIZE: + # Adjusting variance + x_removed = self.window.popleft() + self.window.append(x) + old_m = self.mean + self.mean += (x - x_removed) / self.WIN_SIZE + self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed) + else: + # Calculating first variance + self.window.append(x) + delta = x - self.mean + self.mean += delta / len(self.window) + self.run_var += delta * (x - self.mean) + + def get_mean(self): + return self.mean if len(self.window) else 0.0 + + def get_var(self): + return self.run_var / len(self.window) if len(self.window) > 1 else 0.0 + + def get_std(self): + return math.sqrt(self.get_var()) + + def get_all(self): + return list(self.window) + + def __str__(self): + return "Current window values: {}".format(list(self.window)) + +def iou(outputs: np.array, labels: np.array): + + SMOOTH = 1e-6 + intersection = (outputs & labels).sum((1, 2)) + union = (outputs | labels).sum((1, 2)) + + iou = (intersection + SMOOTH) / (union + SMOOTH) + + + return iou.mean() + +class DiceCoeff(Function): + """Dice coeff for individual examples""" + + def forward(self, input, target): + self.save_for_backward(input, target) + eps = 0.0001 + self.inter = torch.dot(input.view(-1), target.view(-1)) + self.union = torch.sum(input) + torch.sum(target) + eps + + t = (2 * self.inter.float() + eps) / self.union.float() + return t + + # This function has only a single output, so it gets only one gradient + def backward(self, grad_output): + + input, target = self.saved_variables + grad_input = grad_target = None + + if self.needs_input_grad[0]: + grad_input = grad_output * 2 * (target * self.union - self.inter) \ + / (self.union * self.union) + if self.needs_input_grad[1]: + grad_target = None + + return grad_input, grad_target + + +def dice_coeff(input, target): + """Dice coeff for batches""" + if input.is_cuda: + s = torch.FloatTensor(1).to(device = input.device).zero_() + else: + s = torch.FloatTensor(1).zero_() + + for i, c in enumerate(zip(input, target)): + s = s + DiceCoeff().forward(c[0], c[1]) + + return s / (i + 1) + +'''parameter''' +def para_image(w, h=None, img = None, mode = 'multi', seg = None, sd=None, batch=None, + fft = False, channels=None, init = None): + h = h or w + batch = batch or 1 + ch = channels or 3 + shape = [batch, ch, h, w] + param_f = fft_image if fft else pixel_image + if init is not None: + param_f = init_image + params, maps_f = param_f(init) + else: + params, maps_f = param_f(shape, sd=sd) + if mode == 'multi': + output = to_valid_out(maps_f,img,seg) + elif mode == 'seg': + output = gene_out(maps_f,img) + elif mode == 'raw': + output = raw_out(maps_f,img) + return params, output + +def to_valid_out(maps_f,img,seg): #multi-rater + def inner(): + maps = maps_f() + maps = maps.to(device = img.device) + maps = torch.nn.Softmax(dim = 1)(maps) + final_seg = torch.multiply(seg,maps).sum(dim = 1, keepdim = True) + return torch.cat((img,final_seg),1) + # return torch.cat((img,maps),1) + return inner + +def gene_out(maps_f,img): #pure seg + def inner(): + maps = maps_f() + maps = maps.to(device = img.device) + # maps = torch.nn.Sigmoid()(maps) + return torch.cat((img,maps),1) + # return torch.cat((img,maps),1) + return inner + +def raw_out(maps_f,img): #raw + def inner(): + maps = maps_f() + maps = maps.to(device = img.device) + # maps = torch.nn.Sigmoid()(maps) + return maps + # return torch.cat((img,maps),1) + return inner + + +class CompositeActivation(torch.nn.Module): + + def forward(self, x): + x = torch.atan(x) + return torch.cat([x/0.67, (x*x)/0.6], 1) + # return x + + +def cppn(args, size, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8, + activation_fn=CompositeActivation, normalize=False, device = "cuda:0"): + + r = 3 ** 0.5 + + coord_range = torch.linspace(-r, r, size) + x = coord_range.view(-1, 1).repeat(1, coord_range.size(0)) + y = coord_range.view(1, -1).repeat(coord_range.size(0), 1) + + input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).repeat(batch,1,1,1).to(device) + + layers = [] + kernel_size = 1 + for i in range(num_layers): + out_c = num_hidden_channels + in_c = out_c * 2 # * 2 for composite activation + if i == 0: + in_c = 2 + if i == num_layers - 1: + out_c = num_output_channels + layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size))) + if normalize: + layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c))) + if i < num_layers - 1: + layers.append(('actv{}'.format(i), activation_fn())) + else: + layers.append(('output', torch.nn.Sigmoid())) + + # Initialize model + net = torch.nn.Sequential(OrderedDict(layers)).to(device) + # Initialize weights + def weights_init(module): + if isinstance(module, torch.nn.Conv2d): + torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels)) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + net.apply(weights_init) + # Set last conv2d layer's weights to 0 + torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight) + outimg = raw_out(lambda: net(input_tensor),img) if args.netype == 'raw' else to_valid_out(lambda: net(input_tensor),img,seg) + return net.parameters(), outimg + +def get_siren(args): + wrapper = get_network(args, 'siren', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed) + '''load init weights''' + checkpoint = torch.load('./logs/siren_train_init_2022_08_19_21_00_16/Model/checkpoint_best.pth') + wrapper.load_state_dict(checkpoint['state_dict'],strict=False) + '''end''' + + '''load prompt''' + checkpoint = torch.load('./logs/vae_standard_refuge1_2022_08_21_17_56_49/Model/checkpoint500') + vae = get_network(args, 'vae', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed) + vae.load_state_dict(checkpoint['state_dict'],strict=False) + '''end''' + + return wrapper, vae + + +def siren(args, wrapper, vae, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8, + activation_fn=CompositeActivation, normalize=False, device = "cuda:0"): + vae_img = torchvision.transforms.Resize(64)(img) + latent = vae.encoder(vae_img).view(-1).detach() + outimg = raw_out(lambda: wrapper(latent = latent),img) if args.netype == 'raw' else to_valid_out(lambda: wrapper(latent = latent),img,seg) + # img = torch.randn(1, 3, 256, 256) + # loss = wrapper(img) + # loss.backward() + + # # after much training ... + # # simply invoke the wrapper without passing in anything + + # pred_img = wrapper() # (1, 3, 256, 256) + return wrapper.parameters(), outimg + + +'''adversary''' +def render_vis( + args, + model, + objective_f, + real_img, + param_f=None, + optimizer=None, + transforms=None, + thresholds=(256,), + verbose=True, + preprocess=True, + progress=True, + show_image=True, + save_image=False, + image_name=None, + show_inline=False, + fixed_image_size=None, + label = 1, + raw_img = None, + prompt = None +): + if label == 1: + sign = 1 + elif label == 0: + sign = -1 + else: + print('label is wrong, label is',label) + if args.reverse: + sign = -sign + if args.multilayer: + sign = 1 + + '''prepare''' + now = datetime.now() + date_time = now.strftime("%m-%d-%Y, %H:%M:%S") + + netD, optD = pre_d() + '''end''' + + if param_f is None: + param_f = lambda: param.image(128) + # param_f is a function that should return two things + # params - parameters to update, which we pass to the optimizer + # image_f - a function that returns an image as a tensor + params, image_f = param_f() + + if optimizer is None: + optimizer = lambda params: torch.optim.Adam(params, lr=5e-1) + optimizer = optimizer(params) + + if transforms is None: + transforms = [] + transforms = transforms.copy() + + # Upsample images smaller than 224 + image_shape = image_f().shape + + if fixed_image_size is not None: + new_size = fixed_image_size + elif image_shape[2] < 224 or image_shape[3] < 224: + new_size = 224 + else: + new_size = None + if new_size: + transforms.append( + torch.nn.Upsample(size=new_size, mode="bilinear", align_corners=True) + ) + + transform_f = transform.compose(transforms) + + hook = hook_model(model, image_f) + objective_f = objectives.as_objective(objective_f) + + if verbose: + model(transform_f(image_f())) + print("Initial loss of ad: {:.3f}".format(objective_f(hook))) + + images = [] + try: + for i in tqdm(range(1, max(thresholds) + 1), disable=(not progress)): + optimizer.zero_grad() + try: + model(transform_f(image_f())) + except RuntimeError as ex: + if i == 1: + # Only display the warning message + # on the first iteration, no need to do that + # every iteration + warnings.warn( + "Some layers could not be computed because the size of the " + "image is not big enough. It is fine, as long as the non" + "computed layers are not used in the objective function" + f"(exception details: '{ex}')" + ) + if args.disc: + '''dom loss part''' + # content_img = raw_img + # style_img = raw_img + # precpt_loss = run_precpt(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, transform_f(image_f())) + for p in netD.parameters(): + p.requires_grad = True + for _ in range(args.drec): + netD.zero_grad() + real = real_img + fake = image_f() + # for _ in range(6): + # errD, D_x, D_G_z1 = update_d(args, netD, optD, real, fake) + + # label = torch.full((args.b,), 1., dtype=torch.float, device=device) + # label.fill_(1.) + # output = netD(fake).view(-1) + # errG = nn.BCELoss()(output, label) + # D_G_z2 = output.mean().item() + # dom_loss = err + one = torch.tensor(1, dtype=torch.float) + mone = one * -1 + one = one.cuda(args.gpu_device) + mone = mone.cuda(args.gpu_device) + + d_loss_real = netD(real) + d_loss_real = d_loss_real.mean() + d_loss_real.backward(mone) + + d_loss_fake = netD(fake) + d_loss_fake = d_loss_fake.mean() + d_loss_fake.backward(one) + + # Train with gradient penalty + gradient_penalty = calculate_gradient_penalty(netD, real.data, fake.data) + gradient_penalty.backward() + + + d_loss = d_loss_fake - d_loss_real + gradient_penalty + Wasserstein_D = d_loss_real - d_loss_fake + optD.step() + + # Generator update + for p in netD.parameters(): + p.requires_grad = False # to avoid computation + + fake_images = image_f() + g_loss = netD(fake_images) + g_loss = -g_loss.mean() + dom_loss = g_loss + g_cost = -g_loss + + if i% 5 == 0: + print(f' loss_fake: {d_loss_fake}, loss_real: {d_loss_real}') + print(f'Generator g_loss: {g_loss}') + '''end''' + + + + '''ssim loss''' + + '''end''' + + if args.disc: + loss = sign * objective_f(hook) + args.pw * dom_loss + # loss = args.pw * dom_loss + else: + loss = sign * objective_f(hook) + # loss = args.pw * dom_loss + + loss.backward() + + # #video the images + # if i % 5 == 0: + # print('1') + # image_name = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png' + # img_path = os.path.join(args.path_helper['sample_path'], str(image_name)) + # export(image_f(), img_path) + # #end + # if i % 50 == 0: + # print('Loss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' + # % (errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) + + optimizer.step() + if i in thresholds: + image = tensor_to_img_array(image_f()) + # if verbose: + # print("Loss at step {}: {:.3f}".format(i, objective_f(hook))) + if save_image: + na = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png' + na = date_time + na + outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path'] + img_path = os.path.join(outpath, str(na)) + export(image_f(), img_path) + + images.append(image) + except KeyboardInterrupt: + print("Interrupted optimization at step {:d}.".format(i)) + if verbose: + print("Loss at step {}: {:.3f}".format(i, objective_f(hook))) + images.append(tensor_to_img_array(image_f())) + + if save_image: + na = image_name[0].split('\\')[-1].split('.')[0] + '.png' + na = date_time + na + outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path'] + img_path = os.path.join(outpath, str(na)) + export(image_f(), img_path) + if show_inline: + show(tensor_to_img_array(image_f())) + elif show_image: + view(image_f()) + return image_f() + + +def tensor_to_img_array(tensor): + image = tensor.cpu().detach().numpy() + image = np.transpose(image, [0, 2, 3, 1]) + return image + + +def view(tensor): + image = tensor_to_img_array(tensor) + assert len(image.shape) in [ + 3, + 4, + ], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape) + # Change dtype for PIL.Image + image = (image * 255).astype(np.uint8) + if len(image.shape) == 4: + image = np.concatenate(image, axis=1) + Image.fromarray(image).show() + + +def export(tensor, img_path=None): + # image_name = image_name or "image.jpg" + c = tensor.size(1) + # if c == 7: + # for i in range(c): + # w_map = tensor[:,i,:,:].unsqueeze(1) + # w_map = tensor_to_img_array(w_map).squeeze() + # w_map = (w_map * 255).astype(np.uint8) + # image_name = image_name[0].split('/')[-1].split('.')[0] + str(i)+ '.png' + # wheat = sns.heatmap(w_map,cmap='coolwarm') + # figure = wheat.get_figure() + # figure.savefig ('./fft_maps/weightheatmap/'+str(image_name), dpi=400) + # figure = 0 + # else: + if c == 3: + vutils.save_image(tensor, fp = img_path) + else: + image = tensor[:,0:3,:,:] + w_map = tensor[:,-1,:,:].unsqueeze(1) + image = tensor_to_img_array(image) + w_map = 1 - tensor_to_img_array(w_map).squeeze() + # w_map[w_map==1] = 0 + assert len(image.shape) in [ + 3, + 4, + ], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape) + # Change dtype for PIL.Image + image = (image * 255).astype(np.uint8) + w_map = (w_map * 255).astype(np.uint8) + + Image.fromarray(w_map,'L').save(img_path) + + +class ModuleHook: + def __init__(self, module): + self.hook = module.register_forward_hook(self.hook_fn) + self.module = None + self.features = None + + + def hook_fn(self, module, input, output): + self.module = module + self.features = output + + + def close(self): + self.hook.remove() + + +def hook_model(model, image_f): + features = OrderedDict() + # recursive hooking function + def hook_layers(net, prefix=[]): + if hasattr(net, "_modules"): + for name, layer in net._modules.items(): + if layer is None: + # e.g. GoogLeNet's aux1 and aux2 layers + continue + features["_".join(prefix + [name])] = ModuleHook(layer) + hook_layers(layer, prefix=prefix + [name]) + + hook_layers(model) + + def hook(layer): + if layer == "input": + out = image_f() + elif layer == "labels": + out = list(features.values())[-1].features + else: + assert layer in features, f"Invalid layer {layer}. Retrieve the list of layers with `lucent.modelzoo.util.get_model_layers(model)`." + out = features[layer].features + assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example." + return out + + return hook + +def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None,thre=0.5): + + b,c,h,w = pred_masks.size() + dev = pred_masks.get_device() + row_num = min(b, 4) + + if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0: + pred_masks = torch.sigmoid(pred_masks) + + pred_masks = torch.tensor(pred_masks>thre) + + if reverse == True: + pred_masks = 1 - pred_masks + gt_masks = 1 - gt_masks + if c == 2: + pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) + gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w) + tup = (imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]) + # compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) + compose = torch.cat((pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) + vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) + else: + imgs = torchvision.transforms.Resize((h,w))(imgs) + if imgs.size(1) == 1: + imgs = imgs[:,0,:,:].unsqueeze(1).expand(b,3,h,w) + pred_masks = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w) + gt_masks = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w) + if points != None: + for i in range(b): + if args.thd: + p = np.round(points.cpu()/args.roi_size * args.out_size).to(dtype = torch.int) + else: + p = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int) + # gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev))) + for pmt_id in range(p.shape[1]): + gt_masks[i,0,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 255 + gt_masks[i,1,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 0 + gt_masks[i,2,p[i,pmt_id,0]-3:p[i,pmt_id,0]+3,p[i,pmt_id,1]-3:p[i,pmt_id,1]+3] = 0 + tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:]) + # compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0) + compose = torch.cat(tup,0) + vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10) + + return + +def eval_seg(pred,true_mask_p,threshold): + ''' + threshold: a int or a tuple of int + masks: [b,2,h,w] + pred: [b,2,h,w] + ''' + b, c, h, w = pred.size() + if c == 2: + iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0 + for th in threshold: + + gt_vmask_p = (true_mask_p > th).float() + vpred = (pred > th).float() + vpred_cpu = vpred.cpu() + disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32') + cup_pred = vpred_cpu[:,1,:,:].numpy().astype('int32') + + disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32') + cup_mask = gt_vmask_p [:, 1, :, :].squeeze(1).cpu().numpy().astype('int32') + + '''iou for numpy''' + iou_d += iou(disc_pred,disc_mask) + iou_c += iou(cup_pred,cup_mask) + + '''dice for torch''' + disc_dice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item() + cup_dice += dice_coeff(vpred[:,1,:,:], gt_vmask_p[:,1,:,:]).item() + + return iou_d / len(threshold), iou_c / len(threshold), disc_dice / len(threshold), cup_dice / len(threshold) + else: + eiou, edice = 0,0 + for th in threshold: + + gt_vmask_p = (true_mask_p > th).float() + vpred = (pred > th).float() + vpred_cpu = vpred.cpu() + disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32') + + disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32') + + '''iou for numpy''' + eiou += iou(disc_pred,disc_mask) + + '''dice for torch''' + edice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item() + + return eiou / len(threshold), edice / len(threshold) + +# @objectives.wrap_objective() +def dot_compare(layer, batch=1, cossim_pow=0): + def inner(T): + dot = (T(layer)[batch] * T(layer)[0]).sum() + mag = torch.sqrt(torch.sum(T(layer)[0]**2)) + cossim = dot/(1e-6 + mag) + return -dot * cossim ** cossim_pow + return inner + +def init_D(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + +def pre_d(): + netD = Discriminator(3).to(device) + # netD.apply(init_D) + beta1 = 0.5 + dis_lr = 0.00002 + optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999)) + return netD, optimizerD + +def update_d(args, netD, optimizerD, real, fake): + criterion = nn.BCELoss() + + label = torch.full((args.b,), 1., dtype=torch.float, device=device) + output = netD(real).view(-1) + # Calculate loss on all-real batch + errD_real = criterion(output, label) + # Calculate gradients for D in backward pass + errD_real.backward() + D_x = output.mean().item() + + label.fill_(0.) + # Classify all fake batch with D + output = netD(fake.detach()).view(-1) + # Calculate D's loss on the all-fake batch + errD_fake = criterion(output, label) + # Calculate the gradients for this batch, accumulated (summed) with previous gradients + errD_fake.backward() + D_G_z1 = output.mean().item() + # Compute error of D as sum over the fake and the real batches + errD = errD_real + errD_fake + # Update D + optimizerD.step() + + return errD, D_x, D_G_z1 + +def calculate_gradient_penalty(netD, real_images, fake_images): + eta = torch.FloatTensor(args.b,1,1,1).uniform_(0,1) + eta = eta.expand(args.b, real_images.size(1), real_images.size(2), real_images.size(3)).to(device = device) + + interpolated = (eta * real_images + ((1 - eta) * fake_images)).to(device = device) + + # define it to calculate gradient + interpolated = Variable(interpolated, requires_grad=True) + + # calculate probability of interpolated examples + prob_interpolated = netD(interpolated) + + # calculate gradients of probabilities with respect to examples + gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated, + grad_outputs=torch.ones( + prob_interpolated.size()).to(device = device), + create_graph=True, retain_graph=True)[0] + + grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 + return grad_penalty + + +def random_click(mask, point_labels = 1, inout = 1): + indices = np.argwhere(mask == inout) + return indices[np.random.randint(len(indices))] + + +def generate_click_prompt(img, msk, pt_label = 1): + # return: prompt, prompt mask + pt_list = [] + msk_list = [] + b, c, h, w, d = msk.size() + msk = msk[:,0,:,:,:] + for i in range(d): + pt_list_s = [] + msk_list_s = [] + for j in range(b): + msk_s = msk[j,:,:,i] + indices = torch.nonzero(msk_s) + if indices.size(0) == 0: + # generate a random array between [0-h, 0-h]: + random_index = torch.randint(0, h, (2,)).to(device = msk.device) + new_s = msk_s + else: + random_index = random.choice(indices) + label = msk_s[random_index[0], random_index[1]] + new_s = torch.zeros_like(msk_s) + # convert bool tensor to int + new_s = (msk_s == label).to(dtype = torch.float) + # new_s[msk_s == label] = 1 + pt_list_s.append(random_index) + msk_list_s.append(new_s) + pts = torch.stack(pt_list_s, dim=0) + msks = torch.stack(msk_list_s, dim=0) + pt_list.append(pts) + msk_list.append(msks) + pt = torch.stack(pt_list, dim=-1) + msk = torch.stack(msk_list, dim=-1) + + msk = msk.unsqueeze(1) + + return img, pt, msk #[b, 2, d], [b, c, h, w, d] + + + +def drawContour(m,s,RGB,size,a=0.8): + """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'""" + # Fill contour "c" with white, make all else black + + #ratio = int(255/np.max(s)) + #s = np.uint(s*ratio) + + # Find edges of this contour and make into Numpy array + contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE) + m_old = m.copy() + # Paint locations of found edges in color "RGB" onto "main" + cv2.drawContours(m,contours,-1,RGB,size) + m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0) + return m + +def IOU(pm, gt): + a = np.sum(np.bitwise_and(pm, gt)) + b = np.sum(pm) + np.sum(gt) - a +1e-8 + return a / b + + +def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + tensor.mul_(std).add_(mean) + return tensor + + + +def remove_small_objects(array_2d, min_size=30): + """ + Removes small objects from a 2D array using only NumPy. + + :param array_2d: Input 2D array. + :param min_size: Minimum size of objects to keep. + :return: 2D array with small objects removed. + """ + # Label connected components + structure = np.ones((3, 3), dtype=int) # Define connectivity + labeled, ncomponents = label(array_2d, structure) + + # Iterate through labeled components and remove small ones + for i in range(1, ncomponents + 1): + locations = np.where(labeled == i) + if len(locations[0]) < min_size: + array_2d[locations] = 0 + + return array_2d + +def create_box_mask(boxes,imgs): + b,_,w,h = imgs.shape + box_mask = torch.zeros((b,w,h)) + for k in range(b): + k_box = boxes[k] + for box in k_box: + x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3]) + box_mask[k,y1:y2,x1:x2] = 1 + return box_mask