369 lines (368 with data), 13.8 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "fbb3c1a2",
"metadata": {},
"source": [
"## Package Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5678569",
"metadata": {},
"outputs": [],
"source": [
"%pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "cef2a006-01b3-48ec-a631-ba22fcbec5a4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from models.sam import SamPredictor, sam_model_registry\n",
"from models.sam.modeling.prompt_encoder import attention_fusion\n",
"import numpy as np\n",
"import os\n",
"import torch\n",
"import torchvision\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"from dsc import dice_coeff\n",
"import torchio as tio\n",
"import nrrd\n",
"import PIL\n",
"import cfg\n",
"from funcs import *\n",
"from predict_funs import *\n",
"args = cfg.parse_args()\n",
"from monai.networks.nets import VNet\n",
"args.if_mask_decoder_adapter=True\n",
"args.if_encoder_adapter = True\n",
"args.decoder_adapt_depth = 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "34c4f647",
"metadata": {},
"source": [
"## Load models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a578226-354b-4833-bd17-3f57ff143ee9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(device)\n",
"checkpoint_directory = './' # path to your checkpoint\n",
"img_folder = os.path.join('images')\n",
"gt_msk_folder = os.path.join('masks')\n",
"predicted_msk_folder = os.path.join('predicted_masks')\n",
"cls = 1\n",
"\n",
"sam_fine_tune = sam_model_registry[\"vit_t\"](args,checkpoint=os.path.join('mobile_sam.pt'),num_classes=2)\n",
"sam_fine_tune.attention_fusion = attention_fusion() \n",
"sam_fine_tune.load_state_dict(torch.load(os.path.join(checkpoint_directory,'bone_sam.pth'),map_location=torch.device(device)), strict = True)\n",
"sam_fine_tune = sam_fine_tune.to(device).eval()\n",
"\n",
"vnet = VNet().to(device)\n",
"model_directory = \"./\"\n",
"vnet.load_state_dict(torch.load(os.path.join(model_directory,'atten.pth'),map_location=torch.device(device)))"
]
},
{
"cell_type": "markdown",
"id": "755aea6f-151c-4872-81cb-a4ef66973a15",
"metadata": {},
"source": [
"## 2D Slice Prediction & Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "7f410bb5-b33f-4d98-a1be-caf578cfa7b7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def evaluateSlicePrediction(mask_pred, mask_name, slice_id):\n",
" voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n",
" mask_gt = voxels\n",
"\n",
" msk = Image.fromarray(mask_gt[:,:,slice_id].astype(np.uint8), 'L')\n",
" msk = transforms.Resize((256,256))(msk)\n",
" msk_gt = (transforms.ToTensor()(msk)>0).float()\n",
"\n",
" dsc_gt = dice_coeff(mask_pred.cpu(), msk_gt).item()\n",
" \n",
" print(\"dsc_gt:\", dsc_gt)\n",
" return msk_gt, dsc_gt\n",
"\n",
"def predictSlice(image_name, lower_percentile, upper_percentile, slice_id, attention_enabled):\n",
" \n",
" image1_vol = tio.ScalarImage(os.path.join(img_folder, image_name))\n",
" print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
"\n",
" image_tensor = image1_vol.data\n",
" lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
" upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
"\n",
" # Clip the data\n",
" image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
"\n",
" # Normalize the data to [0, 1] \n",
" image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
"\n",
" image1_vol.set_data(image_tensor)\n",
" atten_map= pred_attention(image1_vol,vnet,slice_id,device)\n",
" \n",
" atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
" print(atten_map.device)\n",
" if attention_enabled:\n",
" ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id,atten_map=atten_map)\n",
" else:\n",
" ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id)\n",
" \n",
" mask_pred = ((pred_1>0)==cls).float().cpu()\n",
"\n",
" return ori_img, mask_pred, atten_map\n",
"\n",
"def visualizeSlicePrediction(ori_img, image_name, atten_map, msk_gt, mask_pred, dsc_gt):\n",
" image = np.rot90(torchvision.transforms.Resize((args.out_size,args.out_size))(ori_img)[0])\n",
" image_3d = np.repeat(np.array(image*255,dtype=np.uint8).copy()[:, :, np.newaxis], 3, axis=2)\n",
"\n",
" pred_mask_auto = (mask_pred[0])*255\n",
" mask = (msk_gt.cpu()[0]>0)*255\n",
"\n",
" target_prediction = [103,169,237] \n",
" image_pred_auto = drawContour(image_3d.copy(), np.rot90(pred_mask_auto),target_prediction,size=-1,a=0.6)\n",
"\n",
" target_prediction = [100,255,106] \n",
" image_mask = drawContour(image_3d.copy(),np.rot90(mask),target_prediction,size=-1,a=0.6)\n",
"\n",
" fig, a = plt.subplots(1,4, figsize=(20,15))\n",
"\n",
" a[0].imshow(image,cmap='gray',vmin=0, vmax=1)\n",
" a[0].set_title(image_name)\n",
" a[0].axis(False)\n",
"\n",
" a[1].imshow(image_mask,cmap='gray',vmin=0, vmax=255)\n",
" a[1].set_title('gt_mask',fontsize=10)\n",
" a[1].axis(False)\n",
"\n",
" a[2].imshow(image_pred_auto,cmap='gray',vmin=0, vmax=255)\n",
" a[2].set_title('pre_mask_auto, dsc %.2f'%(dsc_gt),fontsize=10)\n",
" a[2].axis(False)\n",
"\n",
" a[3].imshow(np.rot90(atten_map.cpu()[0]),vmin=0, vmax=1,cmap='coolwarm')\n",
" a[3].set_title('atten_map',fontsize=10)\n",
" a[3].axis(False)\n",
"\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f5a3d21",
"metadata": {},
"outputs": [],
"source": [
"ori_img, predictedSliceMask, atten_map = predictSlice(\n",
" image_name = '2.nii.gz', \n",
" lower_percentile = 1,\n",
" upper_percentile = 99,\n",
" slice_id = 50, # slice number\n",
" attention_enabled = True, # if you want to use the depth attention\n",
")\n",
"\n",
"msk_gt, dsc_gt = evaluateSlicePrediction(\n",
" mask_pred = predictedSliceMask, \n",
" mask_name = '2.nrrd', \n",
" slice_id = 50\n",
")\n",
"\n",
"visualizeSlicePrediction(\n",
" ori_img=ori_img, \n",
" image_name='2.nii.gz', \n",
" atten_map=atten_map, \n",
" msk_gt=msk_gt, \n",
" mask_pred=predictedSliceMask, \n",
" dsc_gt=dsc_gt\n",
")"
]
},
{
"cell_type": "markdown",
"id": "58216ba5",
"metadata": {},
"source": [
"## 3D Volume Prediction & Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "d551caea",
"metadata": {},
"outputs": [],
"source": [
"def predictVolume(image_name, lower_percentile, upper_percentile):\n",
" dsc_gt = 0\n",
" image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n",
" print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
"\n",
" # Define the percentiles\n",
" image_tensor = image1_vol.data\n",
" lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
" upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
"\n",
" # Clip the data\n",
" image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
" # Normalize the data to [0, 1] \n",
" image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
" image1_vol.set_data(image_tensor)\n",
" \n",
" mask_vol_numpy = np.zeros(image1_vol.shape)\n",
" id_list = list(range(image1_vol.shape[3]))\n",
" for id in id_list:\n",
" atten_map = pred_attention(image1_vol,vnet,id,device)\n",
" atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
" \n",
" ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n",
" img1_size = Pil_img1.size\n",
" mask_pred = ((pred_1>0)==cls).float().cpu()\n",
" pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n",
" mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n",
" \n",
" mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n",
" mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n",
" Path(mask_save_folder).mkdir(parents=True, exist_ok = True)\n",
" mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n",
" return mask_vol"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "c48e328a",
"metadata": {},
"outputs": [],
"source": [
"def predictAndEvaluateVolume(image_name, mask_name, lower_percentile, upper_percentile):\n",
" dsc_gt = 0\n",
" image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n",
" print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
"\n",
" # Define the percentiles\n",
" image_tensor = image1_vol.data\n",
" lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
" upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
"\n",
" # Clip the data\n",
" image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
" # Normalize the data to [0, 1] \n",
" image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
" image1_vol.set_data(image_tensor)\n",
" \n",
" voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n",
" mask_gt = voxels\n",
" mask_vol_numpy = np.zeros(image1_vol.shape)\n",
" id_list = list(range(image1_vol.shape[3]))\n",
" for id in id_list:\n",
" atten_map = pred_attention(image1_vol,vnet,id,device)\n",
" atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
" \n",
" ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n",
" img1_size = Pil_img1.size\n",
"\n",
" mask_pred = ((pred_1>0)==cls).float().cpu()\n",
" msk = Image.fromarray(mask_gt[:,:,id].astype(np.uint8), 'L')\n",
" msk = transforms.Resize((256,256))(msk)\n",
" msk_gt = (transforms.ToTensor()(msk)>0).float().cpu()\n",
" dsc_gt += dice_coeff(mask_pred.cpu(),msk_gt).item()\n",
" pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n",
" mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n",
" \n",
" mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n",
" mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n",
" Path(mask_save_folder).mkdir(parents=True,exist_ok = True)\n",
" mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n",
" dsc_gt /= len(id_list)\n",
" gt_vol = tio.LabelMap(tensor=torch.unsqueeze(torch.Tensor(mask_gt>0),0), affine=image1_vol.affine)\n",
" dsc_vol = dice_coeff(mask_vol.data.float().cpu(),gt_vol.data).item()\n",
" print('volume %s: slice_wise_dsc %.2f; vol_wise_dsc %.2f'%(image_name,dsc_gt,dsc_vol))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a2f4789",
"metadata": {},
"outputs": [],
"source": [
"mask = predictVolume(\n",
" image_name = '2.nii.gz', \n",
" lower_percentile = 1, \n",
" upper_percentile = 99\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5352a6c",
"metadata": {},
"outputs": [],
"source": [
"predictAndEvaluateVolume(\n",
" image_name = '2.nii.gz', \n",
" mask_name = '2.nrrd',\n",
" lower_percentile = 1, \n",
" upper_percentile = 99\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}