Diff of /function.py [000000] .. [dff9e0]

Switch to side-by-side view

--- a
+++ b/function.py
@@ -0,0 +1,294 @@
+
+import os
+import sys
+import argparse
+from datetime import datetime
+from collections import OrderedDict
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
+import torchvision
+import torchvision.transforms as transforms
+from skimage import io
+from torch.utils.data import DataLoader
+#from dataset import *
+from torch.autograd import Variable
+from PIL import Image
+from tensorboardX import SummaryWriter
+#from models.discriminatorlayer import discriminator
+from conf import settings
+import time
+import cfg
+from conf import settings
+from tqdm import tqdm
+from utils import *
+import torch.nn.functional as F
+import torch
+from einops import rearrange
+import pytorch_ssim
+import models.sam.utils.transforms as samtrans
+
+# from lucent.modelzoo.util import get_model_layers
+# from lucent.optvis import render, param, transform, objectives
+# from lucent.modelzoo import inceptionv1
+
+import shutil
+import tempfile
+
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+
+from monai.losses import DiceCELoss
+from monai.inferers import sliding_window_inference
+from monai.transforms import (
+    AsDiscrete,
+)
+
+
+import torch
+
+
+args = cfg.parse_args()
+
+GPUdevice = torch.device('cuda', args.gpu_device)
+pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
+criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
+seed = torch.randint(1,11,(args.b,7))
+
+torch.backends.cudnn.benchmark = True
+loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
+scaler = torch.cuda.amp.GradScaler()
+max_iterations = settings.EPOCH
+post_label = AsDiscrete(to_onehot=14)
+post_pred = AsDiscrete(argmax=True, to_onehot=14)
+dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
+dice_val_best = 0.0
+global_step_best = 0
+epoch_loss_values = []
+metric_values = []
+
+def train_sam(args, net: nn.Module, optimizer, train_loader,
+          epoch, writer, schedulers=None, vis = 50):
+    hard = 0
+    epoch_loss = 0
+    ind = 0
+    # train mode
+    net.train()
+    optimizer.zero_grad()
+
+    epoch_loss = 0
+    GPUdevice = torch.device('cuda:' + str(args.gpu_device))
+    device = GPUdevice
+
+    if args.thd:
+        lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
+    else:
+        lossfunc = criterion_G
+
+    with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
+        for pack in train_loader:
+            imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
+            masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
+            # for k,v in pack['image_meta_dict'].items():
+            #     print(k)
+            if 'pt' not in pack:
+                imgs, pt, masks = generate_click_prompt(imgs, masks)
+            else:
+                pt = pack['pt']
+                point_labels = pack['p_label']
+            name = pack['image_meta_dict']['filename_or_obj']
+
+            if args.thd:
+                pt = rearrange(pt, 'b n d -> (b d) n')
+                imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
+                masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
+
+                imgs = imgs.repeat(1,3,1,1)
+                point_labels = torch.ones(imgs.size(0))
+
+                imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
+                masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
+            
+            showp = pt
+
+            mask_type = torch.float32
+            ind += 1
+            b_size,c,w,h = imgs.size()
+            longsize = w if w >=h else h
+
+            if point_labels[0] != -1:
+                # point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
+                point_coords = pt
+                coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
+                labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
+                coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+                pt = (coords_torch, labels_torch)
+
+            '''init'''
+            if hard:
+                true_mask_ave = (true_mask_ave > 0.5).float()
+                #true_mask_ave = cons_tensor(true_mask_ave)
+            imgs = imgs.to(dtype = mask_type,device = GPUdevice)
+            
+            '''Train'''
+            for n, value in net.image_encoder.named_parameters():
+                if "Adapter" not in n:
+                    value.requires_grad = False
+            imge= net.image_encoder(imgs)
+
+            with torch.no_grad():
+                # imge= net.image_encoder(imgs)
+                se, de = net.prompt_encoder(
+                    points=pt,
+                    boxes=None,
+                    masks=None,
+                )
+            pred, _ = net.mask_decoder(
+                image_embeddings=imge,
+                image_pe=net.prompt_encoder.get_dense_pe(), 
+                sparse_prompt_embeddings=se,
+                dense_prompt_embeddings=de, 
+                multimask_output=False,
+              )
+
+            loss = lossfunc(pred, masks)
+
+            pbar.set_postfix(**{'loss (batch)': loss.item()})
+            epoch_loss += loss.item()
+            loss.backward()
+
+            # nn.utils.clip_grad_value_(net.parameters(), 0.1)
+            optimizer.step()
+            optimizer.zero_grad()
+
+            '''vis images'''
+            if vis:
+                if ind % vis == 0:
+                    namecat = 'Train'
+                    for na in name:
+                        namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
+                    vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
+
+            pbar.update()
+
+    return loss
+
+def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
+     # eval mode
+    net.eval()
+
+    mask_type = torch.float32
+    n_val = len(val_loader)  # the number of batch
+    ave_res, mix_res = (0,0,0,0), (0,0,0,0)
+    rater_res = [(0,0,0,0) for _ in range(6)]
+    tot = 0
+    hard = 0
+    threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
+    GPUdevice = torch.device('cuda:' + str(args.gpu_device))
+    device = GPUdevice
+
+    if args.thd:
+        lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
+    else:
+        lossfunc = criterion_G
+
+    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
+        for ind, pack in enumerate(val_loader):
+            imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice)
+            masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice)
+            # for k,v in pack['image_meta_dict'].items():
+            #     print(k)
+            if 'pt' not in pack:
+                imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
+            else:
+                ptw = pack['pt']
+                point_labels = pack['p_label']
+            name = pack['image_meta_dict']['filename_or_obj']
+            
+            buoy = 0
+            if args.evl_chunk:
+                evl_ch = int(args.evl_chunk)
+            else:
+                evl_ch = int(imgsw.size(-1))
+
+            while (buoy + evl_ch) <= imgsw.size(-1):
+                if args.thd:
+                    pt = ptw[:,:,buoy: buoy + evl_ch]
+                else:
+                    pt = ptw
+
+                imgs = imgsw[...,buoy:buoy + evl_ch]
+                masks = masksw[...,buoy:buoy + evl_ch]
+                buoy += evl_ch
+
+                if args.thd:
+                    pt = rearrange(pt, 'b n d -> (b d) n')
+                    imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
+                    masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
+                    imgs = imgs.repeat(1,3,1,1)
+                    point_labels = torch.ones(imgs.size(0))
+
+                    imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
+                    masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
+                
+                showp = pt
+
+                mask_type = torch.float32
+                ind += 1
+                b_size,c,w,h = imgs.size()
+                longsize = w if w >=h else h
+
+                if point_labels[0] != -1:
+                    # point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
+                    point_coords = pt
+                    coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
+                    labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
+                    coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+                    pt = (coords_torch, labels_torch)
+
+                '''init'''
+                if hard:
+                    true_mask_ave = (true_mask_ave > 0.5).float()
+                    #true_mask_ave = cons_tensor(true_mask_ave)
+                imgs = imgs.to(dtype = mask_type,device = GPUdevice)
+                
+                '''test'''
+                with torch.no_grad():
+                    imge= net.image_encoder(imgs)
+
+                    se, de = net.prompt_encoder(
+                        points=pt,
+                        boxes=None,
+                        masks=None,
+                    )
+
+                    pred, _ = net.mask_decoder(
+                        image_embeddings=imge,
+                        image_pe=net.prompt_encoder.get_dense_pe(),
+                        sparse_prompt_embeddings=se,
+                        dense_prompt_embeddings=de, 
+                        multimask_output=False,
+                    )
+                
+                    tot += lossfunc(pred, masks)
+
+                    '''vis images'''
+                    if ind % args.vis == 0:
+                        namecat = 'Test'
+                        for na in name:
+                            img_name = na.split('/')[-1].split('.')[0]
+                            namecat = namecat + img_name + '+'
+                        vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
+                    
+
+                    temp = eval_seg(pred, masks, threshold)
+                    mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
+
+            pbar.update()
+
+    if args.evl_chunk:
+        n_val = n_val * (imgsw.size(-1) // evl_ch)
+
+    return tot/ n_val , tuple([a/n_val for a in mix_res])