Diff of /predict_funs.py [000000] .. [dff9e0]

Switch to unified view

a b/predict_funs.py
1
import sys
2
sys.path.append('../')
3
4
#from segment_anything import SamPredictor, sam_model_registry
5
from models.sam import SamPredictor, sam_model_registry
6
from models.sam.utils.transforms import ResizeLongestSide
7
from models.sam.modeling.prompt_encoder import attention_fusion
8
import pandas as pd
9
from skimage.measure import label
10
#Scientific computing 
11
import numpy as np
12
import os
13
#Pytorch packages
14
import torch
15
from torch import nn
16
import torch.optim as optim
17
import torchvision
18
from torchvision import datasets
19
#Visulization
20
import matplotlib.pyplot as plt
21
from torchvision import transforms
22
from PIL import Image
23
#Others
24
from torch.utils.data import DataLoader, Subset
25
from torch.autograd import Variable
26
import matplotlib.pyplot as plt
27
import copy
28
from dataset_bone import MRI_dataset
29
import torch.nn.functional as F
30
from torch.nn.functional import one_hot
31
from pathlib import Path
32
from tqdm import tqdm
33
from losses import DiceLoss
34
from dsc import dice_coeff
35
import cv2
36
import torchio as tio
37
import slicerio
38
import pickle
39
import nrrd
40
import PIL
41
import monai
42
import cfg
43
from funcs import *
44
args = cfg.parse_args()
45
from monai.networks.nets import VNet
46
47
def drawContour(m,s,RGB,size,a=0.8):
48
    """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'"""
49
    # Fill contour "c" with white, make all else black
50
    
51
    #ratio = int(255/np.max(s))
52
    #s = np.uint(s*ratio)
53
54
    # Find edges of this contour and make into Numpy array
55
    contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
56
    m_old = m.copy()
57
    # Paint locations of found edges in color "RGB" onto "main"
58
    cv2.drawContours(m,contours,-1,RGB,size)
59
    m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0)
60
    return m
61
62
def IOU(pm, gt):
63
    a = np.sum(np.bitwise_and(pm, gt))
64
    b = np.sum(pm) + np.sum(gt) - a +1e-8
65
    return a / b
66
67
68
def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
69
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
70
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
71
    if mean.ndim == 1:
72
        mean = mean.view(-1, 1, 1)
73
    if std.ndim == 1:
74
        std = std.view(-1, 1, 1)
75
    tensor.mul_(std).add_(mean)
76
    return tensor
77
78
79
80
def remove_small_objects(array_2d, min_size=30):
81
    """
82
    Removes small objects from a 2D array using only NumPy.
83
84
    :param array_2d: Input 2D array.
85
    :param min_size: Minimum size of objects to keep.
86
    :return: 2D array with small objects removed.
87
    """
88
    # Label connected components
89
    structure = np.ones((3, 3), dtype=int)  # Define connectivity
90
    labeled, ncomponents = label(array_2d, structure)
91
92
    # Iterate through labeled components and remove small ones
93
    for i in range(1, ncomponents + 1):
94
        locations = np.where(labeled == i)
95
        if len(locations[0]) < min_size:
96
            array_2d[locations] = 0
97
98
    return array_2d
99
100
def create_box_mask(boxes,imgs):
101
    b,_,w,h = imgs.shape
102
    box_mask = torch.zeros((b,w,h))
103
    for k in range(b):
104
        k_box = boxes[k]
105
        for box in k_box:
106
            x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3])
107
            box_mask[k,y1:y2,x1:x2] = 1
108
    return box_mask
109
110
111
112
# Calculate the percentile values
113
def torch_percentile(tensor, percentile):
114
    k = 1 + round(.01 * float(percentile) * (tensor.numel() - 1))
115
    return tensor.reshape(-1).kthvalue(k).values.item()
116
117
def pred_attention(image,vnet,slice_id,device):
118
    class Normalize3D:
119
        """Normalize a tensor to a specified mean and standard deviation."""
120
        def __init__(self, mean, std):
121
            self.mean = mean
122
            self.std = std
123
124
        def __call__(self, x):
125
            # Normalize x
126
            return (x - self.mean) / self.std
127
    def prob_rescale(prob, x_thres=0.05, y_thres=0.8,eps=1e-3):
128
        grad_1 = y_thres / x_thres
129
        grad_2 = (1 - y_thres) / (1 - x_thres)
130
131
        mask_eps = prob<=eps
132
        mask_1 =  (eps < prob) & (prob <= x_thres)
133
        mask_2 = prob > x_thres
134
        prob[mask_1] = prob[mask_1] * grad_1
135
        prob[mask_2] = (prob[mask_2] - x_thres) * grad_2 + y_thres
136
        prob[mask_eps]=0
137
        return prob
138
139
    def view_attention_2d(mask_volume, axis=2,eps=0.1):
140
        mask_eps = mask_volume<=eps
141
        mask_volume[mask_eps]=0
142
        attention = np.sum(mask_volume, axis=axis)
143
        return (attention) / (np.max(attention) +1e-8)
144
    
145
    norm_transform = transforms.Compose([
146
        Normalize3D(0.5, 0.5)
147
    ])
148
    depth_image = image.shape[3]
149
    resize = tio.Resize((64,64,64))
150
    image = resize(image)
151
    image_tensor = image.data
152
    image_tensor = torch.unsqueeze(image_tensor,0)
153
    image_tensor = norm_transform(image_tensor).float().to(device)
154
    with torch.set_grad_enabled(False):
155
        pred_mask = vnet(image_tensor)
156
    pred_mask = torch.sigmoid(pred_mask)
157
    pred_mask = pred_mask.detach().cpu().numpy()
158
    
159
    # the slice id after rescale to 64*64*64
160
    slice_id_reshape = int(slice_id*64/depth_image)
161
    slice_min = max(slice_id_reshape-8,0)
162
    slice_max = min(slice_id_reshape+8,64)
163
    return prob_rescale(view_attention_2d(np.squeeze(pred_mask[:,:,:,:,slice_min:slice_max])))
164
    
165
        
166
def evaluate_1_volume_withattention(image_vol,model,device,slice_id=None,target_spacing=None,atten_map=None):
167
    image_vol.data = image_vol.data / (image_vol.data.max()*1.0)
168
    voxel_spacing = image_vol.spacing
169
    if target_spacing and (voxel_spacing != target_spacing):
170
        resample = tio.Resample(target_spacing,image_interpolation='nearest') 
171
        image_vol = resample(image_vol)
172
    image_vol = image_vol.data[0]
173
    slice_num = image_vol.shape[2]
174
    if slice_id is not None:
175
        if slice_id>slice_num:
176
            slice_id = -1
177
    else:
178
        slice_id = slice_num//2
179
    img_arr = image_vol[:,:,slice_id]
180
    img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+0.00001)*255,dtype=np.uint8)
181
    img_3c = np.tile(img_arr[:, :,None], [1, 1, 3])
182
    img = Image.fromarray(img_3c, 'RGB')
183
    Pil_img = img.copy()
184
    img = transforms.Resize((1024,1024))(img)
185
    transform_img = transforms.Compose([
186
                 transforms.ToTensor()
187
                     ])
188
    img = transform_img(img)
189
    img = min_max_normalize(img)
190
    if img.mean()<0.1:
191
        img = monai.transforms.AdjustContrast(gamma=0.8)(img)
192
    imgs = torch.unsqueeze(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img),0).to(device)
193
194
    with torch.no_grad():
195
        img_emb= model.image_encoder(imgs)
196
        sparse_emb, dense_emb = model.prompt_encoder(points=None,boxes=None,masks=None)
197
        if not atten_map is None:
198
            # fuse the depth direction attention
199
            img_emb = model.attention_fusion(img_emb,atten_map)
200
        pred, _ = model.mask_decoder(
201
                        image_embeddings=img_emb,
202
                        image_pe=model.prompt_encoder.get_dense_pe(), 
203
                        sparse_prompt_embeddings=sparse_emb,
204
                        dense_prompt_embeddings=dense_emb, 
205
                        multimask_output=True,
206
                      )
207
        pred = pred[:,1,:,:]
208
    ori_img = inverse_normalize(imgs.cpu()[0])
209
    return ori_img,pred,voxel_spacing,Pil_img,slice_id