[dff9e0]: / demo.ipynb

Download this file

369 lines (368 with data), 13.8 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fbb3c1a2",
   "metadata": {},
   "source": [
    "## Package Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5678569",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "cef2a006-01b3-48ec-a631-ba22fcbec5a4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from models.sam import SamPredictor, sam_model_registry\n",
    "from models.sam.modeling.prompt_encoder import attention_fusion\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "from dsc import dice_coeff\n",
    "import torchio as tio\n",
    "import nrrd\n",
    "import PIL\n",
    "import cfg\n",
    "from funcs import *\n",
    "from predict_funs import *\n",
    "args = cfg.parse_args()\n",
    "from monai.networks.nets import VNet\n",
    "args.if_mask_decoder_adapter=True\n",
    "args.if_encoder_adapter = True\n",
    "args.decoder_adapt_depth = 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34c4f647",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a578226-354b-4833-bd17-3f57ff143ee9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(device)\n",
    "checkpoint_directory = './' # path to your checkpoint\n",
    "img_folder = os.path.join('images')\n",
    "gt_msk_folder = os.path.join('masks')\n",
    "predicted_msk_folder = os.path.join('predicted_masks')\n",
    "cls = 1\n",
    "\n",
    "sam_fine_tune = sam_model_registry[\"vit_t\"](args,checkpoint=os.path.join('mobile_sam.pt'),num_classes=2)\n",
    "sam_fine_tune.attention_fusion = attention_fusion()  \n",
    "sam_fine_tune.load_state_dict(torch.load(os.path.join(checkpoint_directory,'bone_sam.pth'),map_location=torch.device(device)), strict = True)\n",
    "sam_fine_tune = sam_fine_tune.to(device).eval()\n",
    "\n",
    "vnet = VNet().to(device)\n",
    "model_directory = \"./\"\n",
    "vnet.load_state_dict(torch.load(os.path.join(model_directory,'atten.pth'),map_location=torch.device(device)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "755aea6f-151c-4872-81cb-a4ef66973a15",
   "metadata": {},
   "source": [
    "## 2D Slice Prediction & Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "7f410bb5-b33f-4d98-a1be-caf578cfa7b7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def evaluateSlicePrediction(mask_pred, mask_name, slice_id):\n",
    "    voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n",
    "    mask_gt = voxels\n",
    "\n",
    "    msk = Image.fromarray(mask_gt[:,:,slice_id].astype(np.uint8), 'L')\n",
    "    msk = transforms.Resize((256,256))(msk)\n",
    "    msk_gt = (transforms.ToTensor()(msk)>0).float()\n",
    "\n",
    "    dsc_gt = dice_coeff(mask_pred.cpu(), msk_gt).item()\n",
    "    \n",
    "    print(\"dsc_gt:\", dsc_gt)\n",
    "    return msk_gt, dsc_gt\n",
    "\n",
    "def predictSlice(image_name, lower_percentile, upper_percentile, slice_id, attention_enabled):\n",
    "    \n",
    "    image1_vol = tio.ScalarImage(os.path.join(img_folder, image_name))\n",
    "    print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
    "\n",
    "    image_tensor = image1_vol.data\n",
    "    lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
    "    upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
    "\n",
    "    # Clip the data\n",
    "    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
    "\n",
    "    # Normalize the data to [0, 1] \n",
    "    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
    "\n",
    "    image1_vol.set_data(image_tensor)\n",
    "    atten_map= pred_attention(image1_vol,vnet,slice_id,device)\n",
    "    \n",
    "    atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
    "    print(atten_map.device)\n",
    "    if attention_enabled:\n",
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id,atten_map=atten_map)\n",
    "    else:\n",
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=slice_id)\n",
    "        \n",
    "    mask_pred = ((pred_1>0)==cls).float().cpu()\n",
    "\n",
    "    return ori_img, mask_pred, atten_map\n",
    "\n",
    "def visualizeSlicePrediction(ori_img, image_name, atten_map, msk_gt, mask_pred, dsc_gt):\n",
    "    image = np.rot90(torchvision.transforms.Resize((args.out_size,args.out_size))(ori_img)[0])\n",
    "    image_3d = np.repeat(np.array(image*255,dtype=np.uint8).copy()[:, :, np.newaxis], 3, axis=2)\n",
    "\n",
    "    pred_mask_auto = (mask_pred[0])*255\n",
    "    mask = (msk_gt.cpu()[0]>0)*255\n",
    "\n",
    "    target_prediction =  [103,169,237]   \n",
    "    image_pred_auto = drawContour(image_3d.copy(), np.rot90(pred_mask_auto),target_prediction,size=-1,a=0.6)\n",
    "\n",
    "    target_prediction =  [100,255,106] \n",
    "    image_mask = drawContour(image_3d.copy(),np.rot90(mask),target_prediction,size=-1,a=0.6)\n",
    "\n",
    "    fig, a = plt.subplots(1,4, figsize=(20,15))\n",
    "\n",
    "    a[0].imshow(image,cmap='gray',vmin=0, vmax=1)\n",
    "    a[0].set_title(image_name)\n",
    "    a[0].axis(False)\n",
    "\n",
    "    a[1].imshow(image_mask,cmap='gray',vmin=0, vmax=255)\n",
    "    a[1].set_title('gt_mask',fontsize=10)\n",
    "    a[1].axis(False)\n",
    "\n",
    "    a[2].imshow(image_pred_auto,cmap='gray',vmin=0, vmax=255)\n",
    "    a[2].set_title('pre_mask_auto, dsc %.2f'%(dsc_gt),fontsize=10)\n",
    "    a[2].axis(False)\n",
    "\n",
    "    a[3].imshow(np.rot90(atten_map.cpu()[0]),vmin=0, vmax=1,cmap='coolwarm')\n",
    "    a[3].set_title('atten_map',fontsize=10)\n",
    "    a[3].axis(False)\n",
    "\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f5a3d21",
   "metadata": {},
   "outputs": [],
   "source": [
    "ori_img, predictedSliceMask, atten_map = predictSlice(\n",
    "    image_name = '2.nii.gz', \n",
    "    lower_percentile = 1,\n",
    "    upper_percentile = 99,\n",
    "    slice_id = 50, # slice number\n",
    "    attention_enabled = True, # if you want to use the depth attention\n",
    ")\n",
    "\n",
    "msk_gt, dsc_gt = evaluateSlicePrediction(\n",
    "    mask_pred = predictedSliceMask, \n",
    "    mask_name = '2.nrrd', \n",
    "    slice_id = 50\n",
    ")\n",
    "\n",
    "visualizeSlicePrediction(\n",
    "    ori_img=ori_img, \n",
    "    image_name='2.nii.gz', \n",
    "    atten_map=atten_map, \n",
    "    msk_gt=msk_gt, \n",
    "    mask_pred=predictedSliceMask, \n",
    "    dsc_gt=dsc_gt\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58216ba5",
   "metadata": {},
   "source": [
    "## 3D Volume Prediction & Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "d551caea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictVolume(image_name, lower_percentile, upper_percentile):\n",
    "    dsc_gt = 0\n",
    "    image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n",
    "    print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
    "\n",
    "    # Define the percentiles\n",
    "    image_tensor = image1_vol.data\n",
    "    lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
    "    upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
    "\n",
    "    # Clip the data\n",
    "    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
    "    # Normalize the data to [0, 1] \n",
    "    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
    "    image1_vol.set_data(image_tensor)\n",
    "    \n",
    "    mask_vol_numpy = np.zeros(image1_vol.shape)\n",
    "    id_list = list(range(image1_vol.shape[3]))\n",
    "    for id in id_list:\n",
    "        atten_map = pred_attention(image1_vol,vnet,id,device)\n",
    "        atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
    "        \n",
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n",
    "        img1_size = Pil_img1.size\n",
    "        mask_pred = ((pred_1>0)==cls).float().cpu()\n",
    "        pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n",
    "        mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n",
    "    \n",
    "    mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n",
    "    mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n",
    "    Path(mask_save_folder).mkdir(parents=True, exist_ok = True)\n",
    "    mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n",
    "    return mask_vol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "c48e328a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictAndEvaluateVolume(image_name, mask_name, lower_percentile, upper_percentile):\n",
    "    dsc_gt = 0\n",
    "    image1_vol = tio.ScalarImage(os.path.join(img_folder,image_name))\n",
    "    print('vol shape: %s vol spacing %s' %(image1_vol.shape,image1_vol.spacing))\n",
    "\n",
    "    # Define the percentiles\n",
    "    image_tensor = image1_vol.data\n",
    "    lower_bound = torch_percentile(image_tensor, lower_percentile)\n",
    "    upper_bound = torch_percentile(image_tensor, upper_percentile)\n",
    "\n",
    "    # Clip the data\n",
    "    image_tensor = torch.clamp(image_tensor, lower_bound, upper_bound)\n",
    "    # Normalize the data to [0, 1] \n",
    "    image_tensor = (image_tensor - lower_bound) / (upper_bound - lower_bound)\n",
    "    image1_vol.set_data(image_tensor)\n",
    "    \n",
    "    voxels, header = nrrd.read(os.path.join(gt_msk_folder,mask_name))\n",
    "    mask_gt = voxels\n",
    "    mask_vol_numpy = np.zeros(image1_vol.shape)\n",
    "    id_list = list(range(image1_vol.shape[3]))\n",
    "    for id in id_list:\n",
    "        atten_map = pred_attention(image1_vol,vnet,id,device)\n",
    "        atten_map = torch.unsqueeze(torch.tensor(atten_map),0).float().to(device)\n",
    "        \n",
    "        ori_img,pred_1,voxel_spacing1,Pil_img1,slice_id1 = evaluate_1_volume_withattention(image1_vol,sam_fine_tune,device,slice_id=id,atten_map=atten_map)\n",
    "        img1_size = Pil_img1.size\n",
    "\n",
    "        mask_pred = ((pred_1>0)==cls).float().cpu()\n",
    "        msk = Image.fromarray(mask_gt[:,:,id].astype(np.uint8), 'L')\n",
    "        msk = transforms.Resize((256,256))(msk)\n",
    "        msk_gt = (transforms.ToTensor()(msk)>0).float().cpu()\n",
    "        dsc_gt += dice_coeff(mask_pred.cpu(),msk_gt).item()\n",
    "        pil_mask1 = Image.fromarray(np.array(mask_pred[0],dtype=np.uint8),'L').resize(img1_size,resample= PIL.Image.NEAREST)\n",
    "        mask_vol_numpy[0,:,:,id] = np.asarray(pil_mask1)\n",
    "    \n",
    "    mask_vol = tio.LabelMap(tensor=torch.tensor(mask_vol_numpy,dtype=torch.int), affine=image1_vol.affine)\n",
    "    mask_save_folder = os.path.join(predicted_msk_folder,'/'.join(image_name.split('/')[:-1]))\n",
    "    Path(mask_save_folder).mkdir(parents=True,exist_ok = True)\n",
    "    mask_vol.save(os.path.join(mask_save_folder,image_name.split('/')[-1].replace('.nii.gz','_predicted_SAMatten_paired.nrrd')))\n",
    "    dsc_gt /= len(id_list)\n",
    "    gt_vol = tio.LabelMap(tensor=torch.unsqueeze(torch.Tensor(mask_gt>0),0), affine=image1_vol.affine)\n",
    "    dsc_vol = dice_coeff(mask_vol.data.float().cpu(),gt_vol.data).item()\n",
    "    print('volume %s: slice_wise_dsc %.2f; vol_wise_dsc %.2f'%(image_name,dsc_gt,dsc_vol))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a2f4789",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = predictVolume(\n",
    "    image_name = '2.nii.gz', \n",
    "    lower_percentile = 1, \n",
    "    upper_percentile = 99\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5352a6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictAndEvaluateVolume(\n",
    "    image_name = '2.nii.gz', \n",
    "    mask_name = '2.nrrd',\n",
    "    lower_percentile = 1, \n",
    "    upper_percentile = 99\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}