--- a
+++ b/predict_funs.py
@@ -0,0 +1,209 @@
+import sys
+sys.path.append('../')
+
+#from segment_anything import SamPredictor, sam_model_registry
+from models.sam import SamPredictor, sam_model_registry
+from models.sam.utils.transforms import ResizeLongestSide
+from models.sam.modeling.prompt_encoder import attention_fusion
+import pandas as pd
+from skimage.measure import label
+#Scientific computing 
+import numpy as np
+import os
+#Pytorch packages
+import torch
+from torch import nn
+import torch.optim as optim
+import torchvision
+from torchvision import datasets
+#Visulization
+import matplotlib.pyplot as plt
+from torchvision import transforms
+from PIL import Image
+#Others
+from torch.utils.data import DataLoader, Subset
+from torch.autograd import Variable
+import matplotlib.pyplot as plt
+import copy
+from dataset_bone import MRI_dataset
+import torch.nn.functional as F
+from torch.nn.functional import one_hot
+from pathlib import Path
+from tqdm import tqdm
+from losses import DiceLoss
+from dsc import dice_coeff
+import cv2
+import torchio as tio
+import slicerio
+import pickle
+import nrrd
+import PIL
+import monai
+import cfg
+from funcs import *
+args = cfg.parse_args()
+from monai.networks.nets import VNet
+
+def drawContour(m,s,RGB,size,a=0.8):
+    """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'"""
+    # Fill contour "c" with white, make all else black
+    
+    #ratio = int(255/np.max(s))
+    #s = np.uint(s*ratio)
+
+    # Find edges of this contour and make into Numpy array
+    contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
+    m_old = m.copy()
+    # Paint locations of found edges in color "RGB" onto "main"
+    cv2.drawContours(m,contours,-1,RGB,size)
+    m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0)
+    return m
+
+def IOU(pm, gt):
+    a = np.sum(np.bitwise_and(pm, gt))
+    b = np.sum(pm) + np.sum(gt) - a +1e-8
+    return a / b
+
+
+def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
+    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
+    if mean.ndim == 1:
+        mean = mean.view(-1, 1, 1)
+    if std.ndim == 1:
+        std = std.view(-1, 1, 1)
+    tensor.mul_(std).add_(mean)
+    return tensor
+
+
+
+def remove_small_objects(array_2d, min_size=30):
+    """
+    Removes small objects from a 2D array using only NumPy.
+
+    :param array_2d: Input 2D array.
+    :param min_size: Minimum size of objects to keep.
+    :return: 2D array with small objects removed.
+    """
+    # Label connected components
+    structure = np.ones((3, 3), dtype=int)  # Define connectivity
+    labeled, ncomponents = label(array_2d, structure)
+
+    # Iterate through labeled components and remove small ones
+    for i in range(1, ncomponents + 1):
+        locations = np.where(labeled == i)
+        if len(locations[0]) < min_size:
+            array_2d[locations] = 0
+
+    return array_2d
+
+def create_box_mask(boxes,imgs):
+    b,_,w,h = imgs.shape
+    box_mask = torch.zeros((b,w,h))
+    for k in range(b):
+        k_box = boxes[k]
+        for box in k_box:
+            x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3])
+            box_mask[k,y1:y2,x1:x2] = 1
+    return box_mask
+
+
+
+# Calculate the percentile values
+def torch_percentile(tensor, percentile):
+    k = 1 + round(.01 * float(percentile) * (tensor.numel() - 1))
+    return tensor.reshape(-1).kthvalue(k).values.item()
+
+def pred_attention(image,vnet,slice_id,device):
+    class Normalize3D:
+        """Normalize a tensor to a specified mean and standard deviation."""
+        def __init__(self, mean, std):
+            self.mean = mean
+            self.std = std
+
+        def __call__(self, x):
+            # Normalize x
+            return (x - self.mean) / self.std
+    def prob_rescale(prob, x_thres=0.05, y_thres=0.8,eps=1e-3):
+        grad_1 = y_thres / x_thres
+        grad_2 = (1 - y_thres) / (1 - x_thres)
+
+        mask_eps = prob<=eps
+        mask_1 =  (eps < prob) & (prob <= x_thres)
+        mask_2 = prob > x_thres
+        prob[mask_1] = prob[mask_1] * grad_1
+        prob[mask_2] = (prob[mask_2] - x_thres) * grad_2 + y_thres
+        prob[mask_eps]=0
+        return prob
+
+    def view_attention_2d(mask_volume, axis=2,eps=0.1):
+        mask_eps = mask_volume<=eps
+        mask_volume[mask_eps]=0
+        attention = np.sum(mask_volume, axis=axis)
+        return (attention) / (np.max(attention) +1e-8)
+    
+    norm_transform = transforms.Compose([
+        Normalize3D(0.5, 0.5)
+    ])
+    depth_image = image.shape[3]
+    resize = tio.Resize((64,64,64))
+    image = resize(image)
+    image_tensor = image.data
+    image_tensor = torch.unsqueeze(image_tensor,0)
+    image_tensor = norm_transform(image_tensor).float().to(device)
+    with torch.set_grad_enabled(False):
+        pred_mask = vnet(image_tensor)
+    pred_mask = torch.sigmoid(pred_mask)
+    pred_mask = pred_mask.detach().cpu().numpy()
+    
+    # the slice id after rescale to 64*64*64
+    slice_id_reshape = int(slice_id*64/depth_image)
+    slice_min = max(slice_id_reshape-8,0)
+    slice_max = min(slice_id_reshape+8,64)
+    return prob_rescale(view_attention_2d(np.squeeze(pred_mask[:,:,:,:,slice_min:slice_max])))
+    
+        
+def evaluate_1_volume_withattention(image_vol,model,device,slice_id=None,target_spacing=None,atten_map=None):
+    image_vol.data = image_vol.data / (image_vol.data.max()*1.0)
+    voxel_spacing = image_vol.spacing
+    if target_spacing and (voxel_spacing != target_spacing):
+        resample = tio.Resample(target_spacing,image_interpolation='nearest') 
+        image_vol = resample(image_vol)
+    image_vol = image_vol.data[0]
+    slice_num = image_vol.shape[2]
+    if slice_id is not None:
+        if slice_id>slice_num:
+            slice_id = -1
+    else:
+        slice_id = slice_num//2
+    img_arr = image_vol[:,:,slice_id]
+    img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+0.00001)*255,dtype=np.uint8)
+    img_3c = np.tile(img_arr[:, :,None], [1, 1, 3])
+    img = Image.fromarray(img_3c, 'RGB')
+    Pil_img = img.copy()
+    img = transforms.Resize((1024,1024))(img)
+    transform_img = transforms.Compose([
+                 transforms.ToTensor()
+                     ])
+    img = transform_img(img)
+    img = min_max_normalize(img)
+    if img.mean()<0.1:
+        img = monai.transforms.AdjustContrast(gamma=0.8)(img)
+    imgs = torch.unsqueeze(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img),0).to(device)
+
+    with torch.no_grad():
+        img_emb= model.image_encoder(imgs)
+        sparse_emb, dense_emb = model.prompt_encoder(points=None,boxes=None,masks=None)
+        if not atten_map is None:
+            # fuse the depth direction attention
+            img_emb = model.attention_fusion(img_emb,atten_map)
+        pred, _ = model.mask_decoder(
+                        image_embeddings=img_emb,
+                        image_pe=model.prompt_encoder.get_dense_pe(), 
+                        sparse_prompt_embeddings=sparse_emb,
+                        dense_prompt_embeddings=dense_emb, 
+                        multimask_output=True,
+                      )
+        pred = pred[:,1,:,:]
+    ori_img = inverse_normalize(imgs.cpu()[0])
+    return ori_img,pred,voxel_spacing,Pil_img,slice_id
\ No newline at end of file