Diff of /demo.ipynb [000000] .. [dff9e0]

Switch to unified view

a b/demo.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "fbb3c1a2",
6
   "metadata": {},
7
   "source": [
8
    "## Package Installation"
9
   ]
10
  },
11
  {
12
   "cell_type": "code",
13
   "execution_count": null,
14
   "id": "f5678569",
15
   "metadata": {},
16
   "outputs": [],
17
   "source": [
18
    "%pip install -r requirements.txt"
19
   ]
20
  },
21
  {
22
   "cell_type": "code",
23
   "execution_count": 73,
24
   "id": "cef2a006-01b3-48ec-a631-ba22fcbec5a4",
25
   "metadata": {
26
    "tags": []
27
   },
28
   "outputs": [],
29
   "source": [
30
    "from models.sam import SamPredictor, sam_model_registry\n",
31
    "from models.sam.modeling.prompt_encoder import attention_fusion\n",
32
    "import numpy as np\n",
33
    "import os\n",
34
    "import torch\n",
35
    "import torchvision\n",
36
    "import matplotlib.pyplot as plt\n",
37
    "from torchvision import transforms\n",
38
    "from PIL import Image\n",
39
    "import matplotlib.pyplot as plt\n",
40
    "from pathlib import Path\n",
41
    "from dsc import dice_coeff\n",
42
    "import torchio as tio\n",
43
    "import nrrd\n",
44
    "import PIL\n",
45
    "import cfg\n",
46
    "from funcs import *\n",
47
    "from predict_funs import *\n",
48
    "args = cfg.parse_args()\n",
49
    "from monai.networks.nets import VNet\n",
50
    "args.if_mask_decoder_adapter=True\n",
51
    "args.if_encoder_adapter = True\n",
52
    "args.decoder_adapt_depth = 2\n",
53
    "%matplotlib inline"
54
   ]
55
  },
56
  {
57
   "cell_type": "markdown",
58
   "id": "34c4f647",
59
   "metadata": {},
60
   "source": [
61
    "## Load models"
62
   ]
63
  },
64
  {
65
   "cell_type": "code",
66
   "execution_count": null,
67
   "id": "9a578226-354b-4833-bd17-3f57ff143ee9",
68
   "metadata": {
69
    "tags": []
70
   },
71
   "outputs": [],
72
   "source": [
73
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
74
    "print(device)\n",
75
    "checkpoint_directory = './' # path to your checkpoint\n",
76
    "img_folder = os.path.join('images')\n",
77
    "gt_msk_folder = os.path.join('masks')\n",
78
    "predicted_msk_folder = os.path.join('predicted_masks')\n",
79
    "cls = 1\n",
80
    "\n",
81
    "sam_fine_tune = sam_model_registry[\"vit_t\"](args,checkpoint=os.path.join('mobile_sam.pt'),num_classes=2)\n",
82
    "sam_fine_tune.attention_fusion = attention_fusion()  \n",
83
    "sam_fine_tune.load_state_dict(torch.load(os.path.join(checkpoint_directory,'bone_sam.pth'),map_location=torch.device(device)), strict = True)\n",
84
    "sam_fine_tune = sam_fine_tune.to(device).eval()\n",
85
    "\n",
86
    "vnet = VNet().to(device)\n",
87
    "model_directory = \"./\"\n",
88
    "vnet.load_state_dict(torch.load(os.path.join(model_directory,'atten.pth'),map_location=torch.device(device)))"
89
   ]
90
  },
91
  {
92
   "cell_type": "markdown",
93
   "id": "755aea6f-151c-4872-81cb-a4ef66973a15",
94
   "metadata": {},
95
   "source": [
96
    "## 2D Slice Prediction & Evaluation"
97
   ]
98
  },
99
  {
100
   "cell_type": "code",
101
   "execution_count": 75,
102
   "id": "7f410bb5-b33f-4d98-a1be-caf578cfa7b7",
103
   "metadata": {
104
    "tags": []
105
   },
106
   "outputs": [],
107
   "source": [
108
    "def evaluateSlicePrediction(mask_pred, mask_name, slice_id):\n",
109
    "    voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n",
110
    "    mask_gt = voxels\n",
111
    "\n",
112
    "    msk = Image.fromarray(mask_gt[:,:,slice_id].astype(np.uint8), 'L')\n",
113
    "    msk = transforms.Resize((256,256))(msk)\n",
114
    "    msk_gt = (transforms.ToTensor()(msk)>0).float()\n",
115
    "\n",
116
    "    dsc_gt = dice_coeff(mask_pred.cpu(), msk_gt).item()\n",
117
    "    \n",
118
    "    print(\"dsc_gt:\", dsc_gt)\n",
119
    "    return msk_gt, dsc_gt\n",
120
    "\n",
121
    "def predictSlice(image_name, lower_percentile, upper_percentile, slice_id, attention_enabled):\n",
122
    "    \n",
123
    "    image1_vol = tio.ScalarImage(os.path.join(img_folder, image_name))\n",
124
    "    print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
125
    "\n",
126
    "    image_tensor = image1_vol.data\n",
127
    "    lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
128
    "    upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
129
    "\n",
130
    "    # Clip the data\n",
131
    "    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
132
    "\n",
133
    "    # Normalize the data to [0, 1] \n",
134
    "    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
135
    "\n",
136
    "    image1_vol.set_data(image_tensor)\n",
137
    "    atten_map= pred_attention(image1_vol,vnet,slice_id,device)\n",
138
    "    \n",
139
    "    atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
140
    "    print(atten_map.device)\n",
141
    "    if attention_enabled:\n",
142
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id,atten_map=atten_map)\n",
143
    "    else:\n",
144
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id)\n",
145
    "        \n",
146
    "    mask_pred = ((pred_1>0)==cls).float().cpu()\n",
147
    "\n",
148
    "    return ori_img, mask_pred, atten_map\n",
149
    "\n",
150
    "def visualizeSlicePrediction(ori_img, image_name, atten_map, msk_gt, mask_pred, dsc_gt):\n",
151
    "    image = np.rot90(torchvision.transforms.Resize((args.out_size,args.out_size))(ori_img)[0])\n",
152
    "    image_3d = np.repeat(np.array(image*255,dtype=np.uint8).copy()[:, :, np.newaxis], 3, axis=2)\n",
153
    "\n",
154
    "    pred_mask_auto = (mask_pred[0])*255\n",
155
    "    mask = (msk_gt.cpu()[0]>0)*255\n",
156
    "\n",
157
    "    target_prediction =  [103,169,237]   \n",
158
    "    image_pred_auto = drawContour(image_3d.copy(), np.rot90(pred_mask_auto),target_prediction,size=-1,a=0.6)\n",
159
    "\n",
160
    "    target_prediction =  [100,255,106] \n",
161
    "    image_mask = drawContour(image_3d.copy(),np.rot90(mask),target_prediction,size=-1,a=0.6)\n",
162
    "\n",
163
    "    fig, a = plt.subplots(1,4, figsize=(20,15))\n",
164
    "\n",
165
    "    a[0].imshow(image,cmap='gray',vmin=0, vmax=1)\n",
166
    "    a[0].set_title(image_name)\n",
167
    "    a[0].axis(False)\n",
168
    "\n",
169
    "    a[1].imshow(image_mask,cmap='gray',vmin=0, vmax=255)\n",
170
    "    a[1].set_title('gt_mask',fontsize=10)\n",
171
    "    a[1].axis(False)\n",
172
    "\n",
173
    "    a[2].imshow(image_pred_auto,cmap='gray',vmin=0, vmax=255)\n",
174
    "    a[2].set_title('pre_mask_auto, dsc %.2f'%(dsc_gt),fontsize=10)\n",
175
    "    a[2].axis(False)\n",
176
    "\n",
177
    "    a[3].imshow(np.rot90(atten_map.cpu()[0]),vmin=0, vmax=1,cmap='coolwarm')\n",
178
    "    a[3].set_title('atten_map',fontsize=10)\n",
179
    "    a[3].axis(False)\n",
180
    "\n",
181
    "    plt.tight_layout()"
182
   ]
183
  },
184
  {
185
   "cell_type": "code",
186
   "execution_count": null,
187
   "id": "2f5a3d21",
188
   "metadata": {},
189
   "outputs": [],
190
   "source": [
191
    "ori_img, predictedSliceMask, atten_map = predictSlice(\n",
192
    "    image_name = '2.nii.gz', \n",
193
    "    lower_percentile = 1,\n",
194
    "    upper_percentile = 99,\n",
195
    "    slice_id = 50, # slice number\n",
196
    "    attention_enabled = True, # if you want to use the depth attention\n",
197
    ")\n",
198
    "\n",
199
    "msk_gt, dsc_gt = evaluateSlicePrediction(\n",
200
    "    mask_pred = predictedSliceMask, \n",
201
    "    mask_name = '2.nrrd', \n",
202
    "    slice_id = 50\n",
203
    ")\n",
204
    "\n",
205
    "visualizeSlicePrediction(\n",
206
    "    ori_img=ori_img, \n",
207
    "    image_name='2.nii.gz', \n",
208
    "    atten_map=atten_map, \n",
209
    "    msk_gt=msk_gt, \n",
210
    "    mask_pred=predictedSliceMask, \n",
211
    "    dsc_gt=dsc_gt\n",
212
    ")"
213
   ]
214
  },
215
  {
216
   "cell_type": "markdown",
217
   "id": "58216ba5",
218
   "metadata": {},
219
   "source": [
220
    "## 3D Volume Prediction & Evaluation"
221
   ]
222
  },
223
  {
224
   "cell_type": "code",
225
   "execution_count": 77,
226
   "id": "d551caea",
227
   "metadata": {},
228
   "outputs": [],
229
   "source": [
230
    "def predictVolume(image_name, lower_percentile, upper_percentile):\n",
231
    "    dsc_gt = 0\n",
232
    "    image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n",
233
    "    print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
234
    "\n",
235
    "    # Define the percentiles\n",
236
    "    image_tensor = image1_vol.data\n",
237
    "    lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
238
    "    upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
239
    "\n",
240
    "    # Clip the data\n",
241
    "    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
242
    "    # Normalize the data to [0, 1] \n",
243
    "    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
244
    "    image1_vol.set_data(image_tensor)\n",
245
    "    \n",
246
    "    mask_vol_numpy = np.zeros(image1_vol.shape)\n",
247
    "    id_list = list(range(image1_vol.shape[3]))\n",
248
    "    for id in id_list:\n",
249
    "        atten_map = pred_attention(image1_vol,vnet,id,device)\n",
250
    "        atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
251
    "        \n",
252
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n",
253
    "        img1_size = Pil_img1.size\n",
254
    "        mask_pred = ((pred_1>0)==cls).float().cpu()\n",
255
    "        pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n",
256
    "        mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n",
257
    "    \n",
258
    "    mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n",
259
    "    mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n",
260
    "    Path(mask_save_folder).mkdir(parents=True, exist_ok = True)\n",
261
    "    mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n",
262
    "    return mask_vol"
263
   ]
264
  },
265
  {
266
   "cell_type": "code",
267
   "execution_count": 78,
268
   "id": "c48e328a",
269
   "metadata": {},
270
   "outputs": [],
271
   "source": [
272
    "def predictAndEvaluateVolume(image_name, mask_name, lower_percentile, upper_percentile):\n",
273
    "    dsc_gt = 0\n",
274
    "    image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n",
275
    "    print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
276
    "\n",
277
    "    # Define the percentiles\n",
278
    "    image_tensor = image1_vol.data\n",
279
    "    lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
280
    "    upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
281
    "\n",
282
    "    # Clip the data\n",
283
    "    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
284
    "    # Normalize the data to [0, 1] \n",
285
    "    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
286
    "    image1_vol.set_data(image_tensor)\n",
287
    "    \n",
288
    "    voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n",
289
    "    mask_gt = voxels\n",
290
    "    mask_vol_numpy = np.zeros(image1_vol.shape)\n",
291
    "    id_list = list(range(image1_vol.shape[3]))\n",
292
    "    for id in id_list:\n",
293
    "        atten_map = pred_attention(image1_vol,vnet,id,device)\n",
294
    "        atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
295
    "        \n",
296
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n",
297
    "        img1_size = Pil_img1.size\n",
298
    "\n",
299
    "        mask_pred = ((pred_1>0)==cls).float().cpu()\n",
300
    "        msk = Image.fromarray(mask_gt[:,:,id].astype(np.uint8), 'L')\n",
301
    "        msk = transforms.Resize((256,256))(msk)\n",
302
    "        msk_gt = (transforms.ToTensor()(msk)>0).float().cpu()\n",
303
    "        dsc_gt += dice_coeff(mask_pred.cpu(),msk_gt).item()\n",
304
    "        pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n",
305
    "        mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n",
306
    "    \n",
307
    "    mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n",
308
    "    mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n",
309
    "    Path(mask_save_folder).mkdir(parents=True,exist_ok = True)\n",
310
    "    mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n",
311
    "    dsc_gt /= len(id_list)\n",
312
    "    gt_vol = tio.LabelMap(tensor=torch.unsqueeze(torch.Tensor(mask_gt>0),0), affine=image1_vol.affine)\n",
313
    "    dsc_vol = dice_coeff(mask_vol.data.float().cpu(),gt_vol.data).item()\n",
314
    "    print('volume %s: slice_wise_dsc %.2f; vol_wise_dsc %.2f'%(image_name,dsc_gt,dsc_vol))"
315
   ]
316
  },
317
  {
318
   "cell_type": "code",
319
   "execution_count": null,
320
   "id": "4a2f4789",
321
   "metadata": {},
322
   "outputs": [],
323
   "source": [
324
    "mask = predictVolume(\n",
325
    "    image_name = '2.nii.gz', \n",
326
    "    lower_percentile = 1, \n",
327
    "    upper_percentile = 99\n",
328
    ")"
329
   ]
330
  },
331
  {
332
   "cell_type": "code",
333
   "execution_count": null,
334
   "id": "f5352a6c",
335
   "metadata": {},
336
   "outputs": [],
337
   "source": [
338
    "predictAndEvaluateVolume(\n",
339
    "    image_name = '2.nii.gz', \n",
340
    "    mask_name = '2.nrrd',\n",
341
    "    lower_percentile = 1, \n",
342
    "    upper_percentile = 99\n",
343
    ")"
344
   ]
345
  }
346
 ],
347
 "metadata": {
348
  "kernelspec": {
349
   "display_name": "Python 3",
350
   "language": "python",
351
   "name": "python3"
352
  },
353
  "language_info": {
354
   "codemirror_mode": {
355
    "name": "ipython",
356
    "version": 3
357
   },
358
   "file_extension": ".py",
359
   "mimetype": "text/x-python",
360
   "name": "python",
361
   "nbconvert_exporter": "python",
362
   "pygments_lexer": "ipython3",
363
   "version": "3.9.10"
364
  }
365
 },
366
 "nbformat": 4,
367
 "nbformat_minor": 5
368
}