--- a +++ b/predict_funs.py @@ -0,0 +1,209 @@ +import sys +sys.path.append('../') + +#from segment_anything import SamPredictor, sam_model_registry +from models.sam import SamPredictor, sam_model_registry +from models.sam.utils.transforms import ResizeLongestSide +from models.sam.modeling.prompt_encoder import attention_fusion +import pandas as pd +from skimage.measure import label +#Scientific computing +import numpy as np +import os +#Pytorch packages +import torch +from torch import nn +import torch.optim as optim +import torchvision +from torchvision import datasets +#Visulization +import matplotlib.pyplot as plt +from torchvision import transforms +from PIL import Image +#Others +from torch.utils.data import DataLoader, Subset +from torch.autograd import Variable +import matplotlib.pyplot as plt +import copy +from dataset_bone import MRI_dataset +import torch.nn.functional as F +from torch.nn.functional import one_hot +from pathlib import Path +from tqdm import tqdm +from losses import DiceLoss +from dsc import dice_coeff +import cv2 +import torchio as tio +import slicerio +import pickle +import nrrd +import PIL +import monai +import cfg +from funcs import * +args = cfg.parse_args() +from monai.networks.nets import VNet + +def drawContour(m,s,RGB,size,a=0.8): + """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'""" + # Fill contour "c" with white, make all else black + + #ratio = int(255/np.max(s)) + #s = np.uint(s*ratio) + + # Find edges of this contour and make into Numpy array + contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE) + m_old = m.copy() + # Paint locations of found edges in color "RGB" onto "main" + cv2.drawContours(m,contours,-1,RGB,size) + m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0) + return m + +def IOU(pm, gt): + a = np.sum(np.bitwise_and(pm, gt)) + b = np.sum(pm) + np.sum(gt) - a +1e-8 + return a / b + + +def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + tensor.mul_(std).add_(mean) + return tensor + + + +def remove_small_objects(array_2d, min_size=30): + """ + Removes small objects from a 2D array using only NumPy. + + :param array_2d: Input 2D array. + :param min_size: Minimum size of objects to keep. + :return: 2D array with small objects removed. + """ + # Label connected components + structure = np.ones((3, 3), dtype=int) # Define connectivity + labeled, ncomponents = label(array_2d, structure) + + # Iterate through labeled components and remove small ones + for i in range(1, ncomponents + 1): + locations = np.where(labeled == i) + if len(locations[0]) < min_size: + array_2d[locations] = 0 + + return array_2d + +def create_box_mask(boxes,imgs): + b,_,w,h = imgs.shape + box_mask = torch.zeros((b,w,h)) + for k in range(b): + k_box = boxes[k] + for box in k_box: + x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3]) + box_mask[k,y1:y2,x1:x2] = 1 + return box_mask + + + +# Calculate the percentile values +def torch_percentile(tensor, percentile): + k = 1 + round(.01 * float(percentile) * (tensor.numel() - 1)) + return tensor.reshape(-1).kthvalue(k).values.item() + +def pred_attention(image,vnet,slice_id,device): + class Normalize3D: + """Normalize a tensor to a specified mean and standard deviation.""" + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, x): + # Normalize x + return (x - self.mean) / self.std + def prob_rescale(prob, x_thres=0.05, y_thres=0.8,eps=1e-3): + grad_1 = y_thres / x_thres + grad_2 = (1 - y_thres) / (1 - x_thres) + + mask_eps = prob<=eps + mask_1 = (eps < prob) & (prob <= x_thres) + mask_2 = prob > x_thres + prob[mask_1] = prob[mask_1] * grad_1 + prob[mask_2] = (prob[mask_2] - x_thres) * grad_2 + y_thres + prob[mask_eps]=0 + return prob + + def view_attention_2d(mask_volume, axis=2,eps=0.1): + mask_eps = mask_volume<=eps + mask_volume[mask_eps]=0 + attention = np.sum(mask_volume, axis=axis) + return (attention) / (np.max(attention) +1e-8) + + norm_transform = transforms.Compose([ + Normalize3D(0.5, 0.5) + ]) + depth_image = image.shape[3] + resize = tio.Resize((64,64,64)) + image = resize(image) + image_tensor = image.data + image_tensor = torch.unsqueeze(image_tensor,0) + image_tensor = norm_transform(image_tensor).float().to(device) + with torch.set_grad_enabled(False): + pred_mask = vnet(image_tensor) + pred_mask = torch.sigmoid(pred_mask) + pred_mask = pred_mask.detach().cpu().numpy() + + # the slice id after rescale to 64*64*64 + slice_id_reshape = int(slice_id*64/depth_image) + slice_min = max(slice_id_reshape-8,0) + slice_max = min(slice_id_reshape+8,64) + return prob_rescale(view_attention_2d(np.squeeze(pred_mask[:,:,:,:,slice_min:slice_max]))) + + +def evaluate_1_volume_withattention(image_vol,model,device,slice_id=None,target_spacing=None,atten_map=None): + image_vol.data = image_vol.data / (image_vol.data.max()*1.0) + voxel_spacing = image_vol.spacing + if target_spacing and (voxel_spacing != target_spacing): + resample = tio.Resample(target_spacing,image_interpolation='nearest') + image_vol = resample(image_vol) + image_vol = image_vol.data[0] + slice_num = image_vol.shape[2] + if slice_id is not None: + if slice_id>slice_num: + slice_id = -1 + else: + slice_id = slice_num//2 + img_arr = image_vol[:,:,slice_id] + img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+0.00001)*255,dtype=np.uint8) + img_3c = np.tile(img_arr[:, :,None], [1, 1, 3]) + img = Image.fromarray(img_3c, 'RGB') + Pil_img = img.copy() + img = transforms.Resize((1024,1024))(img) + transform_img = transforms.Compose([ + transforms.ToTensor() + ]) + img = transform_img(img) + img = min_max_normalize(img) + if img.mean()<0.1: + img = monai.transforms.AdjustContrast(gamma=0.8)(img) + imgs = torch.unsqueeze(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img),0).to(device) + + with torch.no_grad(): + img_emb= model.image_encoder(imgs) + sparse_emb, dense_emb = model.prompt_encoder(points=None,boxes=None,masks=None) + if not atten_map is None: + # fuse the depth direction attention + img_emb = model.attention_fusion(img_emb,atten_map) + pred, _ = model.mask_decoder( + image_embeddings=img_emb, + image_pe=model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_emb, + dense_prompt_embeddings=dense_emb, + multimask_output=True, + ) + pred = pred[:,1,:,:] + ori_img = inverse_normalize(imgs.cpu()[0]) + return ori_img,pred,voxel_spacing,Pil_img,slice_id \ No newline at end of file