[3e357e]: / hands-on-session-2 / hands-on-session-2.ipynb

Download this file

701 lines (700 with data), 27.8 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/maragraziani/interpretAI_DigiPath/blob/main/hands-on-session-2/hands-on-session-2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <center> Hands-on Session 2</center>\n",
    "## <center> Explainable Graph Representations in Digital Pathology</center>\n",
    "\n",
    "**Presented by:**\n",
    "- Guillaume Jaume\n",
    "    - Pre-doc researcher with EPFL & IBM Research \n",
    "    - gja@zurich.ibm.com  \n",
    "<br/>\n",
    "- Pushpak Pati \n",
    "    - Pre-doc researcher with ETH & IBM Research\n",
    "    - pus@zurich.ibm.com\n",
    "    \n",
    "#### Content\n",
    "\n",
    "* [Introduction & Motivation](#Intro)\n",
    "* [Installation & Data](#Section0)\n",
    "* [(1) Cell Graph construction](#Section1)\n",
    "* [(2) Cell Graph classification](#Section2)\n",
    "* [(3) Cell Graph explanation](#Section3)\n",
    "* [(4) Nuclei concept analysis](#Section4)\n",
    "\n",
    "#### Take-away\n",
    "\n",
    "* Motivation of entity-graph modeling for model explainability\n",
    "* Getting familiar with the histocartography library and BRACS dataset\n",
    "* Tools to construct and analyze cell-graphs \n",
    "* Understand and use post-hoc graph explainability techniques"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction & Motivation:\n",
    "\n",
    "The first part of this tutorial will guide you to build **interpretable entity-based representations** of tissue regions.\n",
    "\n",
    "The motivation for shifting from pixel- to entity-based analysis is as follows:\n",
    "\n",
    "- Cancer diagnosis and prognosis from tissue specimens highly depend on the phenotype and topological distribution of constituting histological entities, *e.g.,* cells, nuclei, tissue regions. To adequately characterize the tissue composition and utilize the tissue structure-to-function relationship, an entity-paradigm is imperative.\n",
    "### <center> \"*Tissue composition matters for analyzing tissue functionality.*\" </center> \n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/fig1_1.png\" width=\"750\">\n",
    "  <img src=\"Figures/fig1_2.png\" width=\"750\">\n",
    "</figure>\n",
    "\n",
    "- The entity-based processing enables to delineate the diagnostically relevant and irrelevant histopathological entities. The set of entities and corresponding inter- and intra-entity interactions can be customized by using task-specific prior pathological knowledge.\n",
    "### <center> \"*Entity-paradigm enables to incorporate pathological prior during diagnosis.*\" </center> \n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/fig2.png\" width=\"750\">\n",
    "</figure>\n",
    "\n",
    "- Unlike most of the deep learning techniques operating at pixel-level, the entity-based analysis preserves the notion of histopathological entities, which the pathologists can relate to and reason with. Thus, explainability of the entity-graph based methodologies can be interpreted by pathologists, which can potentially lead to build trust and adoption of AI in clinical practice. Notably, the produced explanations in the entity-space are better localized, and therefore better discernible.\n",
    "### <center> \"*Pathologically comprehensible and localized explanations in the entity-space.*\" </center> \n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/fig3.png\" width=\"750\">\n",
    "</figure>\n",
    "\n",
    "- Further, the light-weight and flexible graph representation allows to scale to large and arbitrary tissue regions by including arbitrary number of nodes and edges.\n",
    "### <center> \"*Context vs Resolution trade-off.*\" </center> \n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/context.png\" width=\"550\">\n",
    "</figure>\n",
    "\n",
    "In this tutorial, we will focus on nuclei as entities to build **Cell-graphs**. A similar approach can naturally be extended to other histopathological entities, such as tissue regions, glands.\n",
    "\n",
    "**References:**\n",
    "\n",
    "- [Hierarchical Graph Representations in Digital Pathology.](https://arxiv.org/pdf/2102.11057.pdf) Pati et al., \tarXiv:2102.11057, 2021.\n",
    "- [CGC-Net: Cell Graph Convolutional Network for Grading of Colorectal Cancer Histology Images.](https://arxiv.org/pdf/1909.01068.pdf) Zhou et al., IEEE CVPR Workshops, 2019."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"Section0\"></div> \n",
    "\n",
    "## Installation and Data \n",
    "\n",
    "- Running on **Colab**: this tutorial requires a GPU. Colab allows you to use a K80 GPU for 12h. Please do the following steps:\n",
    "    - Open the tab *Runtime*\n",
    "    - Click on *Change Runtime Type*\n",
    "    - Set the hardware to *GPU* and *Save*\n",
    "\n",
    "\n",
    "- Installation of the **histocartography** library, a Python-based library to facilitate entity-graph analysis and explainability in Computational Pathology. Documentation and examples can be checked [here](https://github.com/histocartography/histocartography).\n",
    "\n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/hcg_logo.png\" width=\"450\">\n",
    "</figure>\n",
    "\n",
    "- Downloading samples from the **BRACS** dataset, a large cohort of H&E stained breast carcinoma tissue regions. More information and download link to the dataset can be found [here](https://www.bracs.icar.cnr.it/). \n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/bracs_logo.png\" width=\"450\">\n",
    "</figure>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# installing missing packages \n",
    "!pip install histocartography\n",
    "!pip install mpld3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Required only if you run this code on Colab:\n",
    "# Get dependent files\n",
    "!wget https://raw.githubusercontent.com/maragraziani/interpretAI_DigiPath/main/hands-on-session-2/cg_bracs_cggnn_3_classes_gin.yml\n",
    "!wget https://raw.githubusercontent.com/maragraziani/interpretAI_DigiPath/main/hands-on-session-2/utils.py\n",
    "    \n",
    "# Get images\n",
    "import os\n",
    "!mkdir images\n",
    "os.chdir('images')\n",
    "!wget --content-disposition https://ibm.box.com/shared/static/6320wnhxsjte9tjlqb02zn0jaxlca5vb.png\n",
    "!wget --content-disposition https://ibm.box.com/shared/static/d8rdupnzbo9ufcnc4qaluh0s2w7jt8mh.png\n",
    "!wget --content-disposition https://ibm.box.com/shared/static/yj6kho8j5ovypafnheoju7y18bvtk32h.png\n",
    "os.chdir('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "from glob import glob \n",
    "from PIL import Image \n",
    "\n",
    "# 1. set up inline show\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import mpld3\n",
    "mpld3.enable_notebook() \n",
    "\n",
    "# 2. visualize the images: We will work with these 3 samples throughout the tutorial\n",
    "images = [(Image.open(path), os.path.basename(path).split('.')[0]) \n",
    "          for path in glob(os.path.join('images', '*.png'))]\n",
    "for image, image_name in images:\n",
    "    print('Image:', image_name)\n",
    "    display(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"Section1\"></div> \n",
    "\n",
    "## 1) Image-to-Graph: Cell-Graph construction\n",
    "\n",
    "This code enables to build a cell-graph for an input H&E image. The step-by-step procedure to define a cell-graph is as follows,\n",
    "\n",
    "- **Nodes**: Detecting nuclei using HoverNet\n",
    "- **Node features**: Extracting features to characterize the nuclei\n",
    "- **Edges**: Constructing k-NN graph to denote the intter-nuclei interactions\n",
    "\n",
    "**References:**\n",
    "\n",
    "- [Hierarchical Graph Representations in Digital Pathology.](https://arxiv.org/pdf/2102.11057.pdf) Pati et al., \tarXiv:2102.11057, 2021.\n",
    "- [Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images.](https://arxiv.org/pdf/1812.06499.pdf) Graham et al., Medical Image Analysis, 2019.\n",
    "- [PanNuke Dataset Extension, Insights and Baselines.](https://arxiv.org/abs/2003.10778) Gamper et al., arXiv:2003.10778, 2020."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import os \n",
    "from glob import glob\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import torch \n",
    "from tqdm import tqdm  \n",
    "from dgl.data.utils import save_graphs\n",
    "\n",
    "from histocartography.preprocessing import NucleiExtractor, DeepFeatureExtractor, KNNGraphBuilder, NucleiConceptExtractor\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Define nuclei extractor: HoverNet pre-trained on the PanNuke dataset. \n",
    "nuclei_detector = NucleiExtractor()\n",
    "\n",
    "# Define a deep feature extractor with ResNet34 and patches 72 resized to 224 to match ResNet input\n",
    "feature_extractor = DeepFeatureExtractor(architecture='resnet34', patch_size=72, resize_size=224)\n",
    "\n",
    "# Define a graph builder to build a DGLGraph object \n",
    "graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True)\n",
    "\n",
    "# Define nuclei concept extractor: extract nuclei-level attributes - will be useful later for understanding the model\n",
    "nuclei_concept_extractor = NucleiConceptExtractor(\n",
    "  concept_names='area,eccentricity,roundness,roughness,shape_factor,mean_crowdedness,glcm_entropy,glcm_contrast'\n",
    ")\n",
    "\n",
    "# Load image fnames to process \n",
    "image_fnames = glob(os.path.join('images', '*.png'))\n",
    "\n",
    "# Create output directories \n",
    "os.makedirs('cell_graphs', exist_ok=True)\n",
    "os.makedirs('nuclei_concepts', exist_ok=True)\n",
    "\n",
    "for image_name in tqdm(image_fnames):\n",
    "    print('Processing...', image_name)\n",
    "    \n",
    "    # 1. load image\n",
    "    image = np.array(Image.open(image_name))\n",
    "\n",
    "    # 2. nuclei detection \n",
    "    nuclei_map, nuclei_centroids = nuclei_detector.process(image)\n",
    "\n",
    "    # 3. nuclei feature extraction \n",
    "    features = feature_extractor.process(image, nuclei_map)\n",
    "\n",
    "    # 4. build the cell graph\n",
    "    cell_graph = graph_builder.process(\n",
    "        instance_map=nuclei_map,\n",
    "        features=features\n",
    "    )\n",
    "    \n",
    "    # 5. extract the nuclei-level concept, i.e., properties: shape, size, etc.\n",
    "    concepts = nuclei_concept_extractor.process(image, nuclei_map)\n",
    "\n",
    "    # 6. print graph properties\n",
    "    print('Number of nodes:', cell_graph.number_of_nodes())\n",
    "    print('Number of edges:', cell_graph.number_of_edges())\n",
    "    print('Number of features per node:', cell_graph.ndata['feat'].shape[1])\n",
    "\n",
    "    # 7. save graph with DGL library and concepts \n",
    "    image_id = os.path.basename(image_name).split('.')[0]\n",
    "    save_graphs(os.path.join('cell_graphs', image_id + '.bin'), [cell_graph])\n",
    "    with open(os.path.join('nuclei_concepts', image_id + '.npy'), 'wb') as f:\n",
    "        np.save(f, concepts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from histocartography.visualization import OverlayGraphVisualization, InstanceImageVisualization\n",
    "from utils import *\n",
    "\n",
    "# Visualize the nuclei detection \n",
    "visualizer = InstanceImageVisualization()\n",
    "viz_nuclei = visualizer.process(image, instance_map=nuclei_map)\n",
    "show_inline(viz_nuclei)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the resulting cell graph \n",
    "visualizer = OverlayGraphVisualization(\n",
    "    instance_visualizer=InstanceImageVisualization(\n",
    "        instance_style=\"filled+outline\"\n",
    "    )\n",
    ")\n",
    "viz_cg = visualizer.process(\n",
    "  canvas=image,\n",
    "  graph=cell_graph,\n",
    "  instance_map=nuclei_map\n",
    ")\n",
    "\n",
    "show_inline(viz_cg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"Section2\"></div> \n",
    "\n",
    "## 2) Cell-graph classification\n",
    "\n",
    "Given the set of cell graphs generated for the 4000 H&E images in the BRACS dataset, a Graph Neural Network (GNN) is trained to classify each sample as either *Benign*, *Atypical* or *Malignant*. \n",
    "\n",
    "A GNN is an artifical neural network designed to operate on graph-structured data. They work in an analogous way as Convolutional Neural Networks (CNNs). For each node, a GNN layer is aggregating and updating information from its neighbors to contextualize the node feature representation. More information about GNNs can be found [here](https://github.com/guillaumejaume/graph-neural-networks-roadmap).\n",
    "\n",
    "\n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/gnn.png\" width=\"650\">\n",
    "</figure>\n",
    "\n",
    "**References:**\n",
    "\n",
    "- [Hierarchical Graph Representations in Digital Pathology.](https://arxiv.org/pdf/2102.11057.pdf) Pati et al., \tarXiv:2102.11057, 2021.\n",
    "- [Benchmarking Graph Neural Networks.](https://arxiv.org/pdf/2003.00982.pdf) Dwivedi et al., NeurIPS, 2020. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os \n",
    "import yaml \n",
    "\n",
    "from histocartography.ml import CellGraphModel \n",
    "\n",
    "# 1. load CG-GNN config \n",
    "config_fname = 'cg_bracs_cggnn_3_classes_gin.yml'\n",
    "with open(config_fname, 'r') as file:\n",
    "    config = yaml.load(file)\n",
    "\n",
    "# 2. declare cell graph model: A pytorch model for predicting the tumor type given an input cell-graph\n",
    "model = CellGraphModel(\n",
    "    gnn_params=config['gnn_params'],\n",
    "    classification_params=config['classification_params'],\n",
    "    node_dim=514,\n",
    "    num_classes=3,\n",
    "    pretrained=True\n",
    ")\n",
    "\n",
    "# 3. print model \n",
    "print('PyTorch Model is defined as:', model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"Section3\"></div> \n",
    "\n",
    "### 3) Cell Graph explanation: Apply GraphGradCAM to CG-GNN\n",
    "\n",
    "As presented in the first hands-on session, GradCAM is a popular (post-hoc) feature attribution method that allows to highlight regions of the input that are activated by the neural network, *i.e.,* elements of the input that *explain* the prediction. As the input is now a set of *interpretable* biologically-defined nuclei, the explanation is also biologically *interpretable*. \n",
    "\n",
    "We use a modified version of GradCAM that can work with GNNs: GraphGradCAM. Specifically, GraphGradCAM follows 2 steps:\n",
    "\n",
    "- Computation of channel-wise importance score:\n",
    "<figure class=\"image\">\n",
    "    <img src=\"Figures/eq1.png\" width=\"180\">\n",
    "</figure>\n",
    "\n",
    "where, $w_k^{(l)}$ is the importance score of channel $k$ in layer $l$. $|V|$ is the number of nodes in the graph, $H^{(l)}_{n, k}$ are the node embeddings in channel $k$ at layer $l$ and, $y_{\\max}$ is the logit value of the predicted class. \n",
    "\n",
    "- Node-wise importance score computation:\n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/eq2.png\" width=\"250\">\n",
    "</figure>\n",
    "\n",
    "where, $L(l, v)$ denotes the importance of node $v \\in V$ in layer $l$, and $d(l)$ denotes the number of node attributes at layer $l$.\n",
    "    \n",
    "**Note:** GraphGradCAM is one of the feature attribution methods to determine input-level importance scores. There exists a rich literature proposing other approaches. For instance, the GNNExplainer, GraphGradCAM++, GraphLRP etc.\n",
    "\n",
    "**References:**\n",
    "\n",
    "- [Grad-CAM : Visual Explanations from Deep Networks.](https://arxiv.org/pdf/1610.02391.pdf) Selvaraju et al., ICCV, 2017. \n",
    "- [Explainability methods  for graph  convolutional  neural  networks.](https://openaccess.thecvf.com/content_CVPR_2019/papers/Pope_Explainability_Methods_for_Graph_Convolutional_Neural_Networks_CVPR_2019_paper.pdf) Pope et al., CVPR, 2019. \n",
    "- [Quantifying Explainers of Graph Neural Networks in Computational Pathology.](https://arxiv.org/pdf/2011.12646.pdf) Jaume et al., CVPR, 2021."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from glob import glob \n",
    "import tqdm \n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from dgl.data.utils import load_graphs\n",
    "\n",
    "from histocartography.interpretability import GraphGradCAMExplainer\n",
    "from histocartography.utils.graph import set_graph_on_cuda\n",
    "\n",
    "is_cuda = torch.cuda.is_available()\n",
    "\n",
    "INDEX_TO_TUMOR_TYPE = {\n",
    "  0: 'Benign',\n",
    "  1: 'Atypical',\n",
    "  2: 'Malignant'\n",
    "}\n",
    "\n",
    "# 1. Define a GraphGradCAM explainer\n",
    "explainer = GraphGradCAMExplainer(model=model)\n",
    "\n",
    "# 2. Load preprocessed cell graphs, concepts & images \n",
    "cg_fnames = glob(os.path.join('cell_graphs', '*.bin'))\n",
    "image_fnames = glob(os.path.join('images', '*.png'))\n",
    "concept_fnames = glob(os.path.join('nuclei_concepts', '*.npy'))\n",
    "\n",
    "cg_fnames.sort()\n",
    "image_fnames.sort()\n",
    "concept_fnames.sort()\n",
    "\n",
    "# 3. Explain all our samples \n",
    "output = []\n",
    "for cg_name, image_name, concept_name in zip(cg_fnames, image_fnames, concept_fnames):\n",
    "    print('Processing...', image_name)\n",
    "    \n",
    "    image = np.array(Image.open(image_name))\n",
    "    concepts = np.load(concept_name)\n",
    "    graph, _ = load_graphs(cg_name)\n",
    "    graph = graph[0]\n",
    "    graph = set_graph_on_cuda(graph) if is_cuda else graph\n",
    "    \n",
    "    importance_scores, logits = explainer.process(\n",
    "        graph,\n",
    "        output_name=cg_name.replace('.bin', '')\n",
    "    )\n",
    "    print('logits: ', logits)\n",
    "    print('prediction: ', INDEX_TO_TUMOR_TYPE[np.argmax(logits)], '\\n')\n",
    "    \n",
    "    output.append({\n",
    "        'image_name': os.path.basename(image_name).split('.')[0],\n",
    "        'image': image,\n",
    "        'graph': graph,\n",
    "        'importance_scores': importance_scores,\n",
    "        'logits': logits,\n",
    "        'concepts': concepts\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from histocartography.visualization import OverlayGraphVisualization, InstanceImageVisualization\n",
    "\n",
    "INDEX_TO_TUMOR_TYPE = {\n",
    "  0: 'Benign',\n",
    "  1: 'Atypical',\n",
    "  2: 'Malignant'\n",
    "}\n",
    "\n",
    "# Visualize the cell graph along with its relative node importance \n",
    "visualizer = OverlayGraphVisualization(\n",
    "    instance_visualizer=InstanceImageVisualization(),\n",
    "    colormap='plasma'\n",
    ")\n",
    "\n",
    "for i, instance in enumerate(output):\n",
    "    print(instance['image_name'], instance['logits'])\n",
    "    node_attributes = {}\n",
    "    node_attributes[\"color\"] = instance['importance_scores']\n",
    "    node_attributes[\"thickness\"] = 15\n",
    "    node_attributes[\"radius\"] = 10\n",
    "    \n",
    "    viz_cg = visualizer.process(\n",
    "        canvas=instance['image'],\n",
    "        graph=instance['graph'],\n",
    "        node_attributes=node_attributes,\n",
    "    )\n",
    "    \n",
    "    show_inline(viz_cg, title='Sample: {}'.format(INDEX_TO_TUMOR_TYPE[np.argmax(instance['logits'])]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"Section4\"></div> \n",
    "\n",
    "### 4) Nuclei concept analysis: These nodes are important, but why? \n",
    "\n",
    "We were able to identify what are the important nuclei, *i.e.,* the discriminative nodes, using GraphGradCAM. We would like to push our analysis one step further to understand if the attributes (shape, size, etc.) of the important nuclei match prior pathological knowledge. For instance, it is known that cancerous nuclei are larger than benign ones or that atypical nuclei are expected to have irregular shapes.\n",
    "\n",
    "To this end, we extract a set of nuclei-level attributes on the most important nuclei.\n",
    "\n",
    "**Note**: A *quantitative* analysis can be performed by studying nuclei-concept distributions and how they align with prior pathological knowledge. However, this analysis is beyond the scope of this tutorial. The reader can refer to [this work](https://arxiv.org/pdf/2011.12646.pdf) for more details. \n",
    "\n",
    "**References:**\n",
    "\n",
    "- [Quantifying Explainers of Graph Neural Networks in Computational Pathology.](https://arxiv.org/pdf/2011.12646.pdf) Jaume et al., CVPR, 2021."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, out in enumerate(output):\n",
    "    if 'benign' in out['image_name']:\n",
    "        benign_data = out\n",
    "    elif 'atypical' in out['image_name']:\n",
    "        atypical_data = out\n",
    "    elif 'malignant' in out['image_name']:\n",
    "        malignant_data = out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Nuclei visualization\n",
    "\n",
    "- Visualizing the 20 most important nuclei \n",
    "- Visualizing 20 random nuclei for comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Top k nuclei\n",
    "from utils import get_patches, plot_patches\n",
    "\n",
    "k = 20\n",
    "\n",
    "nuclei = get_patches(out=benign_data, k=k)\n",
    "plot_patches(nuclei, ncol=10)\n",
    "\n",
    "nuclei = get_patches(out=atypical_data, k=k)\n",
    "plot_patches(nuclei, ncol=10)\n",
    "\n",
    "nuclei = get_patches(out=malignant_data, k=k)\n",
    "plot_patches(nuclei, ncol=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Top k nuclei\n",
    "from utils import get_patches, plot_patches\n",
    "k = 20\n",
    "\n",
    "nuclei = get_patches(out=benign_data, k=k, random=True)\n",
    "plot_patches(nuclei, ncol=10)\n",
    "\n",
    "nuclei = get_patches(out=atypical_data, k=k, random=True)\n",
    "plot_patches(nuclei, ncol=10)\n",
    "\n",
    "nuclei = get_patches(out=malignant_data, k=k, random=True)\n",
    "plot_patches(nuclei, ncol=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#area,eccentricity,roundness,roughness,shape_factor,mean_crowdedness,glcm_entropy,glcm_contrast\n",
    "FEATURE_TO_INDEX = {\n",
    "  'area': 0,\n",
    "  'eccentricity': 1,\n",
    "  'roundness': 2,\n",
    "  'roughness': 3,\n",
    "  'shape_factor': 4,\n",
    "  'mean_crowdedness': 5,\n",
    "  'glcm_entropy': 6,\n",
    "  'glcm_contrast': 7,\n",
    "}\n",
    "    \n",
    "def compute_concept_ratio(data1, data2, feature, k):\n",
    "    index = FEATURE_TO_INDEX[feature]\n",
    "    \n",
    "    important_indices = (-data1['importance_scores']).argsort()[:k]\n",
    "    important_data1 = data1['concepts'][important_indices, index]\n",
    "\n",
    "    important_indices = (-data2['importance_scores']).argsort()[:k]\n",
    "    important_data2 = data2['concepts'][important_indices, index]\n",
    "\n",
    "    return sum(important_data1) / sum(important_data2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Pathological fact: \"Cancerous nuclei are expected to be larger than benign ones\": area(Malignant) > area(Benign)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 20\n",
    "ratio = compute_concept_ratio(malignant_data, benign_data, 'area', k)\n",
    "print('Ratio between the area of important malignant and benign nuclei: ', round(ratio, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Pathological fact: \"Atypical nuclei are hyperchromatic (solid) and Malignant are vesicular (porous)\": contrast(Malignant) > contrast(Atypical)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 20\n",
    "ratio = compute_concept_ratio(malignant_data, atypical_data, 'glcm_contrast', k)\n",
    "print('Ratio between the contrast of important malignant and atypical nuclei: ', round(ratio, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Pathological fact: \"Benign nuclei are crowded than Atypical\": crowdedness(Atypical) > crowdedness(Benign)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 20\n",
    "ratio = compute_concept_ratio(atypical_data, benign_data, 'mean_crowdedness', k)\n",
    "print('Ratio between the crowdeness of important atypical and benign nuclei: ', round(ratio, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion:\n",
    "\n",
    "Considering the adoption of Graph Neural Networks in various domains, such as pathology, radiology, computation biology, satellite and natural images, graph interpretability and explainability is imperative. The presented algorithms and tools aim to motivate and instruct in the aforementioned direction. Though the presented technologies are demonstrated for digital pathology, they can be seamlessly transferred to other domains by building domain specific relevant graph representations. Potentially, entity-graph modeling and analysis can identify relevant  cues for explainable stratification.\n",
    "\n",
    "\n",
    "<figure class=\"image\">\n",
    "  <img src=\"Figures/conclusion.png\" width=\"850\">\n",
    "</figure>\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}