import os
import sys
import argparse
from datetime import datetime
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import DataLoader
#from dataset import *
from torch.autograd import Variable
from PIL import Image
from tensorboardX import SummaryWriter
#from models.discriminatorlayer import discriminator
from conf import settings
import time
import cfg
from conf import settings
from tqdm import tqdm
from utils import *
import torch.nn.functional as F
import torch
from einops import rearrange
import pytorch_ssim
import models.sam.utils.transforms as samtrans
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# from lucent.modelzoo import inceptionv1
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
AsDiscrete,
)
import torch
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1,11,(args.b,7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_sam(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 50):
hard = 0
epoch_loss = 0
ind = 0
# train mode
net.train()
optimizer.zero_grad()
epoch_loss = 0
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
device = GPUdevice
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
if point_labels[0] != -1:
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''Train'''
for n, value in net.image_encoder.named_parameters():
if "Adapter" not in n:
value.requires_grad = False
imge= net.image_encoder(imgs)
with torch.no_grad():
# imge= net.image_encoder(imgs)
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
loss.backward()
# nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()
'''vis images'''
if vis:
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
pbar.update()
return loss
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0,0,0,0), (0,0,0,0)
rater_res = [(0,0,0,0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
device = GPUdevice
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice)
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack:
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
else:
ptw = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
if args.thd:
pt = ptw[:,:,buoy: buoy + evl_ch]
else:
pt = ptw
imgs = imgsw[...,buoy:buoy + evl_ch]
masks = masksw[...,buoy:buoy + evl_ch]
buoy += evl_ch
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
if point_labels[0] != -1:
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''test'''
with torch.no_grad():
imge= net.image_encoder(imgs)
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
tot += lossfunc(pred, masks)
'''vis images'''
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
pbar.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot/ n_val , tuple([a/n_val for a in mix_res])