--- a +++ b/demo.ipynb @@ -0,0 +1,368 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fbb3c1a2", + "metadata": {}, + "source": [ + "## Package Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5678569", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "cef2a006-01b3-48ec-a631-ba22fcbec5a4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from models.sam import SamPredictor, sam_model_registry\n", + "from models.sam.modeling.prompt_encoder import attention_fusion\n", + "import numpy as np\n", + "import os\n", + "import torch\n", + "import torchvision\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import transforms\n", + "from PIL import Image\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "from dsc import dice_coeff\n", + "import torchio as tio\n", + "import nrrd\n", + "import PIL\n", + "import cfg\n", + "from funcs import *\n", + "from predict_funs import *\n", + "args = cfg.parse_args()\n", + "from monai.networks.nets import VNet\n", + "args.if_mask_decoder_adapter=True\n", + "args.if_encoder_adapter = True\n", + "args.decoder_adapt_depth = 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "id": "34c4f647", + "metadata": {}, + "source": [ + "## Load models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a578226-354b-4833-bd17-3f57ff143ee9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(device)\n", + "checkpoint_directory = './' # path to your checkpoint\n", + "img_folder = os.path.join('images')\n", + "gt_msk_folder = os.path.join('masks')\n", + "predicted_msk_folder = os.path.join('predicted_masks')\n", + "cls = 1\n", + "\n", + "sam_fine_tune = sam_model_registry[\"vit_t\"](args,checkpoint=os.path.join('mobile_sam.pt'),num_classes=2)\n", + "sam_fine_tune.attention_fusion = attention_fusion() \n", + "sam_fine_tune.load_state_dict(torch.load(os.path.join(checkpoint_directory,'bone_sam.pth'),map_location=torch.device(device)), strict = True)\n", + "sam_fine_tune = sam_fine_tune.to(device).eval()\n", + "\n", + "vnet = VNet().to(device)\n", + "model_directory = \"./\"\n", + "vnet.load_state_dict(torch.load(os.path.join(model_directory,'atten.pth'),map_location=torch.device(device)))" + ] + }, + { + "cell_type": "markdown", + "id": "755aea6f-151c-4872-81cb-a4ef66973a15", + "metadata": {}, + "source": [ + "## 2D Slice Prediction & Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "7f410bb5-b33f-4d98-a1be-caf578cfa7b7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def evaluateSlicePrediction(mask_pred, mask_name, slice_id):\n", + " voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n", + " mask_gt = voxels\n", + "\n", + " msk = Image.fromarray(mask_gt[:,:,slice_id].astype(np.uint8), 'L')\n", + " msk = transforms.Resize((256,256))(msk)\n", + " msk_gt = (transforms.ToTensor()(msk)>0).float()\n", + "\n", + " dsc_gt = dice_coeff(mask_pred.cpu(), msk_gt).item()\n", + " \n", + " print(\"dsc_gt:\", dsc_gt)\n", + " return msk_gt, dsc_gt\n", + "\n", + "def predictSlice(image_name, lower_percentile, upper_percentile, slice_id, attention_enabled):\n", + " \n", + " image1_vol = tio.ScalarImage(os.path.join(img_folder, image_name))\n", + " print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n", + "\n", + " image_tensor = image1_vol.data\n", + " lower_bound = torch_percentile(image_tensor, lower_percentile)\n", + " upper_bound = torch_percentile(image_tensor, upper_percentile)\n", + "\n", + " # Clip the data\n", + " image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n", + "\n", + " # Normalize the data to [0, 1] \n", + " image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n", + "\n", + " image1_vol.set_data(image_tensor)\n", + " atten_map= pred_attention(image1_vol,vnet,slice_id,device)\n", + " \n", + " atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n", + " print(atten_map.device)\n", + " if attention_enabled:\n", + " ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id,atten_map=atten_map)\n", + " else:\n", + " ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id)\n", + " \n", + " mask_pred = ((pred_1>0)==cls).float().cpu()\n", + "\n", + " return ori_img, mask_pred, atten_map\n", + "\n", + "def visualizeSlicePrediction(ori_img, image_name, atten_map, msk_gt, mask_pred, dsc_gt):\n", + " image = np.rot90(torchvision.transforms.Resize((args.out_size,args.out_size))(ori_img)[0])\n", + " image_3d = np.repeat(np.array(image*255,dtype=np.uint8).copy()[:, :, np.newaxis], 3, axis=2)\n", + "\n", + " pred_mask_auto = (mask_pred[0])*255\n", + " mask = (msk_gt.cpu()[0]>0)*255\n", + "\n", + " target_prediction = [103,169,237] \n", + " image_pred_auto = drawContour(image_3d.copy(), np.rot90(pred_mask_auto),target_prediction,size=-1,a=0.6)\n", + "\n", + " target_prediction = [100,255,106] \n", + " image_mask = drawContour(image_3d.copy(),np.rot90(mask),target_prediction,size=-1,a=0.6)\n", + "\n", + " fig, a = plt.subplots(1,4, figsize=(20,15))\n", + "\n", + " a[0].imshow(image,cmap='gray',vmin=0, vmax=1)\n", + " a[0].set_title(image_name)\n", + " a[0].axis(False)\n", + "\n", + " a[1].imshow(image_mask,cmap='gray',vmin=0, vmax=255)\n", + " a[1].set_title('gt_mask',fontsize=10)\n", + " a[1].axis(False)\n", + "\n", + " a[2].imshow(image_pred_auto,cmap='gray',vmin=0, vmax=255)\n", + " a[2].set_title('pre_mask_auto, dsc %.2f'%(dsc_gt),fontsize=10)\n", + " a[2].axis(False)\n", + "\n", + " a[3].imshow(np.rot90(atten_map.cpu()[0]),vmin=0, vmax=1,cmap='coolwarm')\n", + " a[3].set_title('atten_map',fontsize=10)\n", + " a[3].axis(False)\n", + "\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f5a3d21", + "metadata": {}, + "outputs": [], + "source": [ + "ori_img, predictedSliceMask, atten_map = predictSlice(\n", + " image_name = '2.nii.gz', \n", + " lower_percentile = 1,\n", + " upper_percentile = 99,\n", + " slice_id = 50, # slice number\n", + " attention_enabled = True, # if you want to use the depth attention\n", + ")\n", + "\n", + "msk_gt, dsc_gt = evaluateSlicePrediction(\n", + " mask_pred = predictedSliceMask, \n", + " mask_name = '2.nrrd', \n", + " slice_id = 50\n", + ")\n", + "\n", + "visualizeSlicePrediction(\n", + " ori_img=ori_img, \n", + " image_name='2.nii.gz', \n", + " atten_map=atten_map, \n", + " msk_gt=msk_gt, \n", + " mask_pred=predictedSliceMask, \n", + " dsc_gt=dsc_gt\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "58216ba5", + "metadata": {}, + "source": [ + "## 3D Volume Prediction & Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "d551caea", + "metadata": {}, + "outputs": [], + "source": [ + "def predictVolume(image_name, lower_percentile, upper_percentile):\n", + " dsc_gt = 0\n", + " image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n", + " print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n", + "\n", + " # Define the percentiles\n", + " image_tensor = image1_vol.data\n", + " lower_bound = torch_percentile(image_tensor, lower_percentile)\n", + " upper_bound = torch_percentile(image_tensor, upper_percentile)\n", + "\n", + " # Clip the data\n", + " image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n", + " # Normalize the data to [0, 1] \n", + " image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n", + " image1_vol.set_data(image_tensor)\n", + " \n", + " mask_vol_numpy = np.zeros(image1_vol.shape)\n", + " id_list = list(range(image1_vol.shape[3]))\n", + " for id in id_list:\n", + " atten_map = pred_attention(image1_vol,vnet,id,device)\n", + " atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n", + " \n", + " ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n", + " img1_size = Pil_img1.size\n", + " mask_pred = ((pred_1>0)==cls).float().cpu()\n", + " pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n", + " mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n", + " \n", + " mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n", + " mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n", + " Path(mask_save_folder).mkdir(parents=True, exist_ok = True)\n", + " mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n", + " return mask_vol" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "c48e328a", + "metadata": {}, + "outputs": [], + "source": [ + "def predictAndEvaluateVolume(image_name, mask_name, lower_percentile, upper_percentile):\n", + " dsc_gt = 0\n", + " image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n", + " print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n", + "\n", + " # Define the percentiles\n", + " image_tensor = image1_vol.data\n", + " lower_bound = torch_percentile(image_tensor, lower_percentile)\n", + " upper_bound = torch_percentile(image_tensor, upper_percentile)\n", + "\n", + " # Clip the data\n", + " image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n", + " # Normalize the data to [0, 1] \n", + " image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n", + " image1_vol.set_data(image_tensor)\n", + " \n", + " voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n", + " mask_gt = voxels\n", + " mask_vol_numpy = np.zeros(image1_vol.shape)\n", + " id_list = list(range(image1_vol.shape[3]))\n", + " for id in id_list:\n", + " atten_map = pred_attention(image1_vol,vnet,id,device)\n", + " atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n", + " \n", + " ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n", + " img1_size = Pil_img1.size\n", + "\n", + " mask_pred = ((pred_1>0)==cls).float().cpu()\n", + " msk = Image.fromarray(mask_gt[:,:,id].astype(np.uint8), 'L')\n", + " msk = transforms.Resize((256,256))(msk)\n", + " msk_gt = (transforms.ToTensor()(msk)>0).float().cpu()\n", + " dsc_gt += dice_coeff(mask_pred.cpu(),msk_gt).item()\n", + " pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n", + " mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n", + " \n", + " mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n", + " mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n", + " Path(mask_save_folder).mkdir(parents=True,exist_ok = True)\n", + " mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n", + " dsc_gt /= len(id_list)\n", + " gt_vol = tio.LabelMap(tensor=torch.unsqueeze(torch.Tensor(mask_gt>0),0), affine=image1_vol.affine)\n", + " dsc_vol = dice_coeff(mask_vol.data.float().cpu(),gt_vol.data).item()\n", + " print('volume %s: slice_wise_dsc %.2f; vol_wise_dsc %.2f'%(image_name,dsc_gt,dsc_vol))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a2f4789", + "metadata": {}, + "outputs": [], + "source": [ + "mask = predictVolume(\n", + " image_name = '2.nii.gz', \n", + " lower_percentile = 1, \n", + " upper_percentile = 99\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5352a6c", + "metadata": {}, + "outputs": [], + "source": [ + "predictAndEvaluateVolume(\n", + " image_name = '2.nii.gz', \n", + " mask_name = '2.nrrd',\n", + " lower_percentile = 1, \n", + " upper_percentile = 99\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}