|
a |
|
b/predict_funs.py |
|
|
1 |
import sys |
|
|
2 |
sys.path.append('../') |
|
|
3 |
|
|
|
4 |
#from segment_anything import SamPredictor, sam_model_registry |
|
|
5 |
from models.sam import SamPredictor, sam_model_registry |
|
|
6 |
from models.sam.utils.transforms import ResizeLongestSide |
|
|
7 |
from models.sam.modeling.prompt_encoder import attention_fusion |
|
|
8 |
import pandas as pd |
|
|
9 |
from skimage.measure import label |
|
|
10 |
#Scientific computing |
|
|
11 |
import numpy as np |
|
|
12 |
import os |
|
|
13 |
#Pytorch packages |
|
|
14 |
import torch |
|
|
15 |
from torch import nn |
|
|
16 |
import torch.optim as optim |
|
|
17 |
import torchvision |
|
|
18 |
from torchvision import datasets |
|
|
19 |
#Visulization |
|
|
20 |
import matplotlib.pyplot as plt |
|
|
21 |
from torchvision import transforms |
|
|
22 |
from PIL import Image |
|
|
23 |
#Others |
|
|
24 |
from torch.utils.data import DataLoader, Subset |
|
|
25 |
from torch.autograd import Variable |
|
|
26 |
import matplotlib.pyplot as plt |
|
|
27 |
import copy |
|
|
28 |
from dataset_bone import MRI_dataset |
|
|
29 |
import torch.nn.functional as F |
|
|
30 |
from torch.nn.functional import one_hot |
|
|
31 |
from pathlib import Path |
|
|
32 |
from tqdm import tqdm |
|
|
33 |
from losses import DiceLoss |
|
|
34 |
from dsc import dice_coeff |
|
|
35 |
import cv2 |
|
|
36 |
import torchio as tio |
|
|
37 |
import slicerio |
|
|
38 |
import pickle |
|
|
39 |
import nrrd |
|
|
40 |
import PIL |
|
|
41 |
import monai |
|
|
42 |
import cfg |
|
|
43 |
from funcs import * |
|
|
44 |
args = cfg.parse_args() |
|
|
45 |
from monai.networks.nets import VNet |
|
|
46 |
|
|
|
47 |
def drawContour(m,s,RGB,size,a=0.8): |
|
|
48 |
"""Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'""" |
|
|
49 |
# Fill contour "c" with white, make all else black |
|
|
50 |
|
|
|
51 |
#ratio = int(255/np.max(s)) |
|
|
52 |
#s = np.uint(s*ratio) |
|
|
53 |
|
|
|
54 |
# Find edges of this contour and make into Numpy array |
|
|
55 |
contours, _ = cv2.findContours(np.uint8(s),cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE) |
|
|
56 |
m_old = m.copy() |
|
|
57 |
# Paint locations of found edges in color "RGB" onto "main" |
|
|
58 |
cv2.drawContours(m,contours,-1,RGB,size) |
|
|
59 |
m = cv2.addWeighted(np.uint8(m), a, np.uint8(m_old), 1-a,0) |
|
|
60 |
return m |
|
|
61 |
|
|
|
62 |
def IOU(pm, gt): |
|
|
63 |
a = np.sum(np.bitwise_and(pm, gt)) |
|
|
64 |
b = np.sum(pm) + np.sum(gt) - a +1e-8 |
|
|
65 |
return a / b |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): |
|
|
69 |
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) |
|
|
70 |
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) |
|
|
71 |
if mean.ndim == 1: |
|
|
72 |
mean = mean.view(-1, 1, 1) |
|
|
73 |
if std.ndim == 1: |
|
|
74 |
std = std.view(-1, 1, 1) |
|
|
75 |
tensor.mul_(std).add_(mean) |
|
|
76 |
return tensor |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def remove_small_objects(array_2d, min_size=30): |
|
|
81 |
""" |
|
|
82 |
Removes small objects from a 2D array using only NumPy. |
|
|
83 |
|
|
|
84 |
:param array_2d: Input 2D array. |
|
|
85 |
:param min_size: Minimum size of objects to keep. |
|
|
86 |
:return: 2D array with small objects removed. |
|
|
87 |
""" |
|
|
88 |
# Label connected components |
|
|
89 |
structure = np.ones((3, 3), dtype=int) # Define connectivity |
|
|
90 |
labeled, ncomponents = label(array_2d, structure) |
|
|
91 |
|
|
|
92 |
# Iterate through labeled components and remove small ones |
|
|
93 |
for i in range(1, ncomponents + 1): |
|
|
94 |
locations = np.where(labeled == i) |
|
|
95 |
if len(locations[0]) < min_size: |
|
|
96 |
array_2d[locations] = 0 |
|
|
97 |
|
|
|
98 |
return array_2d |
|
|
99 |
|
|
|
100 |
def create_box_mask(boxes,imgs): |
|
|
101 |
b,_,w,h = imgs.shape |
|
|
102 |
box_mask = torch.zeros((b,w,h)) |
|
|
103 |
for k in range(b): |
|
|
104 |
k_box = boxes[k] |
|
|
105 |
for box in k_box: |
|
|
106 |
x1,y1,x2,y2 = int(box[0]),int(box[1]),int(box[2]),int(box[3]) |
|
|
107 |
box_mask[k,y1:y2,x1:x2] = 1 |
|
|
108 |
return box_mask |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
|
|
|
112 |
# Calculate the percentile values |
|
|
113 |
def torch_percentile(tensor, percentile): |
|
|
114 |
k = 1 + round(.01 * float(percentile) * (tensor.numel() - 1)) |
|
|
115 |
return tensor.reshape(-1).kthvalue(k).values.item() |
|
|
116 |
|
|
|
117 |
def pred_attention(image,vnet,slice_id,device): |
|
|
118 |
class Normalize3D: |
|
|
119 |
"""Normalize a tensor to a specified mean and standard deviation.""" |
|
|
120 |
def __init__(self, mean, std): |
|
|
121 |
self.mean = mean |
|
|
122 |
self.std = std |
|
|
123 |
|
|
|
124 |
def __call__(self, x): |
|
|
125 |
# Normalize x |
|
|
126 |
return (x - self.mean) / self.std |
|
|
127 |
def prob_rescale(prob, x_thres=0.05, y_thres=0.8,eps=1e-3): |
|
|
128 |
grad_1 = y_thres / x_thres |
|
|
129 |
grad_2 = (1 - y_thres) / (1 - x_thres) |
|
|
130 |
|
|
|
131 |
mask_eps = prob<=eps |
|
|
132 |
mask_1 = (eps < prob) & (prob <= x_thres) |
|
|
133 |
mask_2 = prob > x_thres |
|
|
134 |
prob[mask_1] = prob[mask_1] * grad_1 |
|
|
135 |
prob[mask_2] = (prob[mask_2] - x_thres) * grad_2 + y_thres |
|
|
136 |
prob[mask_eps]=0 |
|
|
137 |
return prob |
|
|
138 |
|
|
|
139 |
def view_attention_2d(mask_volume, axis=2,eps=0.1): |
|
|
140 |
mask_eps = mask_volume<=eps |
|
|
141 |
mask_volume[mask_eps]=0 |
|
|
142 |
attention = np.sum(mask_volume, axis=axis) |
|
|
143 |
return (attention) / (np.max(attention) +1e-8) |
|
|
144 |
|
|
|
145 |
norm_transform = transforms.Compose([ |
|
|
146 |
Normalize3D(0.5, 0.5) |
|
|
147 |
]) |
|
|
148 |
depth_image = image.shape[3] |
|
|
149 |
resize = tio.Resize((64,64,64)) |
|
|
150 |
image = resize(image) |
|
|
151 |
image_tensor = image.data |
|
|
152 |
image_tensor = torch.unsqueeze(image_tensor,0) |
|
|
153 |
image_tensor = norm_transform(image_tensor).float().to(device) |
|
|
154 |
with torch.set_grad_enabled(False): |
|
|
155 |
pred_mask = vnet(image_tensor) |
|
|
156 |
pred_mask = torch.sigmoid(pred_mask) |
|
|
157 |
pred_mask = pred_mask.detach().cpu().numpy() |
|
|
158 |
|
|
|
159 |
# the slice id after rescale to 64*64*64 |
|
|
160 |
slice_id_reshape = int(slice_id*64/depth_image) |
|
|
161 |
slice_min = max(slice_id_reshape-8,0) |
|
|
162 |
slice_max = min(slice_id_reshape+8,64) |
|
|
163 |
return prob_rescale(view_attention_2d(np.squeeze(pred_mask[:,:,:,:,slice_min:slice_max]))) |
|
|
164 |
|
|
|
165 |
|
|
|
166 |
def evaluate_1_volume_withattention(image_vol,model,device,slice_id=None,target_spacing=None,atten_map=None): |
|
|
167 |
image_vol.data = image_vol.data / (image_vol.data.max()*1.0) |
|
|
168 |
voxel_spacing = image_vol.spacing |
|
|
169 |
if target_spacing and (voxel_spacing != target_spacing): |
|
|
170 |
resample = tio.Resample(target_spacing,image_interpolation='nearest') |
|
|
171 |
image_vol = resample(image_vol) |
|
|
172 |
image_vol = image_vol.data[0] |
|
|
173 |
slice_num = image_vol.shape[2] |
|
|
174 |
if slice_id is not None: |
|
|
175 |
if slice_id>slice_num: |
|
|
176 |
slice_id = -1 |
|
|
177 |
else: |
|
|
178 |
slice_id = slice_num//2 |
|
|
179 |
img_arr = image_vol[:,:,slice_id] |
|
|
180 |
img_arr = np.array((img_arr-img_arr.min())/(img_arr.max()-img_arr.min()+0.00001)*255,dtype=np.uint8) |
|
|
181 |
img_3c = np.tile(img_arr[:, :,None], [1, 1, 3]) |
|
|
182 |
img = Image.fromarray(img_3c, 'RGB') |
|
|
183 |
Pil_img = img.copy() |
|
|
184 |
img = transforms.Resize((1024,1024))(img) |
|
|
185 |
transform_img = transforms.Compose([ |
|
|
186 |
transforms.ToTensor() |
|
|
187 |
]) |
|
|
188 |
img = transform_img(img) |
|
|
189 |
img = min_max_normalize(img) |
|
|
190 |
if img.mean()<0.1: |
|
|
191 |
img = monai.transforms.AdjustContrast(gamma=0.8)(img) |
|
|
192 |
imgs = torch.unsqueeze(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img),0).to(device) |
|
|
193 |
|
|
|
194 |
with torch.no_grad(): |
|
|
195 |
img_emb= model.image_encoder(imgs) |
|
|
196 |
sparse_emb, dense_emb = model.prompt_encoder(points=None,boxes=None,masks=None) |
|
|
197 |
if not atten_map is None: |
|
|
198 |
# fuse the depth direction attention |
|
|
199 |
img_emb = model.attention_fusion(img_emb,atten_map) |
|
|
200 |
pred, _ = model.mask_decoder( |
|
|
201 |
image_embeddings=img_emb, |
|
|
202 |
image_pe=model.prompt_encoder.get_dense_pe(), |
|
|
203 |
sparse_prompt_embeddings=sparse_emb, |
|
|
204 |
dense_prompt_embeddings=dense_emb, |
|
|
205 |
multimask_output=True, |
|
|
206 |
) |
|
|
207 |
pred = pred[:,1,:,:] |
|
|
208 |
ori_img = inverse_normalize(imgs.cpu()[0]) |
|
|
209 |
return ori_img,pred,voxel_spacing,Pil_img,slice_id |