<a href="https://colab.research.google.com/github/maragraziani/interpretAI_DigiPath/blob/main/hands-on-session-2/hands-on-session-2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <center> Hands-on Session 2</center>
## <center> Explainable Graph Representations in Digital Pathology</center>

**Presented by:**
- Guillaume Jaume
    - Pre-doc researcher with EPFL & IBM Research 
    - gja@zurich.ibm.com  
<br/>
- Pushpak Pati 
    - Pre-doc researcher with ETH & IBM Research
    - pus@zurich.ibm.com
    
#### Content

* [Introduction & Motivation](#Intro)
* [Installation & Data](#Section0)
* [(1) Cell Graph construction](#Section1)
* [(2) Cell Graph classification](#Section2)
* [(3) Cell Graph explanation](#Section3)
* [(4) Nuclei concept analysis](#Section4)

#### Take-away

* Motivation of entity-graph modeling for model explainability
* Getting familiar with the histocartography library and BRACS dataset
* Tools to construct and analyze cell-graphs 
* Understand and use post-hoc graph explainability techniques

## Introduction & Motivation:

The first part of this tutorial will guide you to build **interpretable entity-based representations** of tissue regions.

The motivation for shifting from pixel- to entity-based analysis is as follows:

- Cancer diagnosis and prognosis from tissue specimens highly depend on the phenotype and topological distribution of constituting histological entities, *e.g.,* cells, nuclei, tissue regions. To adequately characterize the tissue composition and utilize the tissue structure-to-function relationship, an entity-paradigm is imperative.
### <center> "*Tissue composition matters for analyzing tissue functionality.*" </center> 
<figure class="image">
  <img src="Figures/fig1_1.png" width="750">
  <img src="Figures/fig1_2.png" width="750">
</figure>

- The entity-based processing enables to delineate the diagnostically relevant and irrelevant histopathological entities. The set of entities and corresponding inter- and intra-entity interactions can be customized by using task-specific prior pathological knowledge.
### <center> "*Entity-paradigm enables to incorporate pathological prior during diagnosis.*" </center> 
<figure class="image">
  <img src="Figures/fig2.png" width="750">
</figure>

- Unlike most of the deep learning techniques operating at pixel-level, the entity-based analysis preserves the notion of histopathological entities, which the pathologists can relate to and reason with. Thus, explainability of the entity-graph based methodologies can be interpreted by pathologists, which can potentially lead to build trust and adoption of AI in clinical practice. Notably, the produced explanations in the entity-space are better localized, and therefore better discernible.
### <center> "*Pathologically comprehensible and localized explanations in the entity-space.*" </center> 
<figure class="image">
  <img src="Figures/fig3.png" width="750">
</figure>

- Further, the light-weight and flexible graph representation allows to scale to large and arbitrary tissue regions by including arbitrary number of nodes and edges.
### <center> "*Context vs Resolution trade-off.*" </center> 
<figure class="image">
  <img src="Figures/context.png" width="550">
</figure>

In this tutorial, we will focus on nuclei as entities to build **Cell-graphs**. A similar approach can naturally be extended to other histopathological entities, such as tissue regions, glands.

**References:**

- [Hierarchical Graph Representations in Digital Pathology.](https://arxiv.org/pdf/2102.11057.pdf) Pati et al., 	arXiv:2102.11057, 2021.
- [CGC-Net: Cell Graph Convolutional Network for Grading of Colorectal Cancer Histology Images.](https://arxiv.org/pdf/1909.01068.pdf) Zhou et al., IEEE CVPR Workshops, 2019.

<div id="Section0"></div> 

## Installation and Data 

- Running on **Colab**: this tutorial requires a GPU. Colab allows you to use a K80 GPU for 12h. Please do the following steps:
    - Open the tab *Runtime*
    - Click on *Change Runtime Type*
    - Set the hardware to *GPU* and *Save*


- Installation of the **histocartography** library, a Python-based library to facilitate entity-graph analysis and explainability in Computational Pathology. Documentation and examples can be checked [here](https://github.com/histocartography/histocartography).

<figure class="image">
  <img src="Figures/hcg_logo.png" width="450">
</figure>

- Downloading samples from the **BRACS** dataset, a large cohort of H&E stained breast carcinoma tissue regions. More information and download link to the dataset can be found [here](https://www.bracs.icar.cnr.it/). 
<figure class="image">
  <img src="Figures/bracs_logo.png" width="450">
</figure>

In [None]:
# installing missing packages 
!pip install histocartography
!pip install mpld3

In [None]:
# Required only if you run this code on Colab:
# Get dependent files
!wget https://raw.githubusercontent.com/maragraziani/interpretAI_DigiPath/main/hands-on-session-2/cg_bracs_cggnn_3_classes_gin.yml
!wget https://raw.githubusercontent.com/maragraziani/interpretAI_DigiPath/main/hands-on-session-2/utils.py
    
# Get images
import os
!mkdir images
os.chdir('images')
!wget --content-disposition https://ibm.box.com/shared/static/6320wnhxsjte9tjlqb02zn0jaxlca5vb.png
!wget --content-disposition https://ibm.box.com/shared/static/d8rdupnzbo9ufcnc4qaluh0s2w7jt8mh.png
!wget --content-disposition https://ibm.box.com/shared/static/yj6kho8j5ovypafnheoju7y18bvtk32h.png
os.chdir('..')

In [None]:
import os 
from glob import glob 
from PIL import Image 

# 1. set up inline show
import matplotlib.pyplot as plt
%matplotlib inline
import mpld3
mpld3.enable_notebook() 

# 2. visualize the images: We will work with these 3 samples throughout the tutorial
images = [(Image.open(path), os.path.basename(path).split('.')[0]) 
          for path in glob(os.path.join('images', '*.png'))]
for image, image_name in images:
    print('Image:', image_name)
    display(image)

<div id="Section1"></div> 

## 1) Image-to-Graph: Cell-Graph construction

This code enables to build a cell-graph for an input H&E image. The step-by-step procedure to define a cell-graph is as follows,

- **Nodes**: Detecting nuclei using HoverNet
- **Node features**: Extracting features to characterize the nuclei
- **Edges**: Constructing k-NN graph to denote the intter-nuclei interactions

**References:**

- [Hierarchical Graph Representations in Digital Pathology.](https://arxiv.org/pdf/2102.11057.pdf) Pati et al., 	arXiv:2102.11057, 2021.
- [Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images.](https://arxiv.org/pdf/1812.06499.pdf) Graham et al., Medical Image Analysis, 2019.
- [PanNuke Dataset Extension, Insights and Baselines.](https://arxiv.org/abs/2003.10778) Gamper et al., arXiv:2003.10778, 2020.

In [None]:
import os 
from glob import glob
from PIL import Image
import numpy as np
import torch 
from tqdm import tqdm  
from dgl.data.utils import save_graphs

from histocartography.preprocessing import NucleiExtractor, DeepFeatureExtractor, KNNGraphBuilder, NucleiConceptExtractor

import warnings
warnings.filterwarnings("ignore")

# Define nuclei extractor: HoverNet pre-trained on the PanNuke dataset. 
nuclei_detector = NucleiExtractor()

# Define a deep feature extractor with ResNet34 and patches 72 resized to 224 to match ResNet input
feature_extractor = DeepFeatureExtractor(architecture='resnet34', patch_size=72, resize_size=224)

# Define a graph builder to build a DGLGraph object 
graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True)

# Define nuclei concept extractor: extract nuclei-level attributes - will be useful later for understanding the model
nuclei_concept_extractor = NucleiConceptExtractor(
  concept_names='area,eccentricity,roundness,roughness,shape_factor,mean_crowdedness,glcm_entropy,glcm_contrast'
)

# Load image fnames to process 
image_fnames = glob(os.path.join('images', '*.png'))

# Create output directories 
os.makedirs('cell_graphs', exist_ok=True)
os.makedirs('nuclei_concepts', exist_ok=True)

for image_name in tqdm(image_fnames):
    print('Processing...', image_name)
    
    # 1. load image
    image = np.array(Image.open(image_name))

    # 2. nuclei detection 
    nuclei_map, nuclei_centroids = nuclei_detector.process(image)

    # 3. nuclei feature extraction 
    features = feature_extractor.process(image, nuclei_map)

    # 4. build the cell graph
    cell_graph = graph_builder.process(
        instance_map=nuclei_map,
        features=features
    )
    
    # 5. extract the nuclei-level concept, i.e., properties: shape, size, etc.
    concepts = nuclei_concept_extractor.process(image, nuclei_map)

    # 6. print graph properties
    print('Number of nodes:', cell_graph.number_of_nodes())
    print('Number of edges:', cell_graph.number_of_edges())
    print('Number of features per node:', cell_graph.ndata['feat'].shape[1])

    # 7. save graph with DGL library and concepts 
    image_id = os.path.basename(image_name).split('.')[0]
    save_graphs(os.path.join('cell_graphs', image_id + '.bin'), [cell_graph])
    with open(os.path.join('nuclei_concepts', image_id + '.npy'), 'wb') as f:
        np.save(f, concepts)

In [None]:
from histocartography.visualization import OverlayGraphVisualization, InstanceImageVisualization
from utils import *

# Visualize the nuclei detection 
visualizer = InstanceImageVisualization()
viz_nuclei = visualizer.process(image, instance_map=nuclei_map)
show_inline(viz_nuclei)

In [None]:
# Visualize the resulting cell graph 
visualizer = OverlayGraphVisualization(
    instance_visualizer=InstanceImageVisualization(
        instance_style="filled+outline"
    )
)
viz_cg = visualizer.process(
  canvas=image,
  graph=cell_graph,
  instance_map=nuclei_map
)

show_inline(viz_cg)

<div id="Section2"></div> 

## 2) Cell-graph classification

Given the set of cell graphs generated for the 4000 H&E images in the BRACS dataset, a Graph Neural Network (GNN) is trained to classify each sample as either *Benign*, *Atypical* or *Malignant*. 

A GNN is an artifical neural network designed to operate on graph-structured data. They work in an analogous way as Convolutional Neural Networks (CNNs). For each node, a GNN layer is aggregating and updating information from its neighbors to contextualize the node feature representation. More information about GNNs can be found [here](https://github.com/guillaumejaume/graph-neural-networks-roadmap).


<figure class="image">
  <img src="Figures/gnn.png" width="650">
</figure>

**References:**

- [Hierarchical Graph Representations in Digital Pathology.](https://arxiv.org/pdf/2102.11057.pdf) Pati et al., 	arXiv:2102.11057, 2021.
- [Benchmarking Graph Neural Networks.](https://arxiv.org/pdf/2003.00982.pdf) Dwivedi et al., NeurIPS, 2020. 

In [None]:
import os 
import yaml 

from histocartography.ml import CellGraphModel 

# 1. load CG-GNN config 
config_fname = 'cg_bracs_cggnn_3_classes_gin.yml'
with open(config_fname, 'r') as file:
    config = yaml.load(file)

# 2. declare cell graph model: A pytorch model for predicting the tumor type given an input cell-graph
model = CellGraphModel(
    gnn_params=config['gnn_params'],
    classification_params=config['classification_params'],
    node_dim=514,
    num_classes=3,
    pretrained=True
)

# 3. print model 
print('PyTorch Model is defined as:', model)

<div id="Section3"></div> 

### 3) Cell Graph explanation: Apply GraphGradCAM to CG-GNN

As presented in the first hands-on session, GradCAM is a popular (post-hoc) feature attribution method that allows to highlight regions of the input that are activated by the neural network, *i.e.,* elements of the input that *explain* the prediction. As the input is now a set of *interpretable* biologically-defined nuclei, the explanation is also biologically *interpretable*. 

We use a modified version of GradCAM that can work with GNNs: GraphGradCAM. Specifically, GraphGradCAM follows 2 steps:

- Computation of channel-wise importance score:
<figure class="image">
    <img src="Figures/eq1.png" width="180">
</figure>

where, $w_k^{(l)}$ is the importance score of channel $k$ in layer $l$. $|V|$ is the number of nodes in the graph, $H^{(l)}_{n, k}$ are the node embeddings in channel $k$ at layer $l$ and, $y_{\max}$ is the logit value of the predicted class. 

- Node-wise importance score computation:
<figure class="image">
  <img src="Figures/eq2.png" width="250">
</figure>

where, $L(l, v)$ denotes the importance of node $v \in V$ in layer $l$, and $d(l)$ denotes the number of node attributes at layer $l$.
    
**Note:** GraphGradCAM is one of the feature attribution methods to determine input-level importance scores. There exists a rich literature proposing other approaches. For instance, the GNNExplainer, GraphGradCAM++, GraphLRP etc.

**References:**

- [Grad-CAM : Visual Explanations from Deep Networks.](https://arxiv.org/pdf/1610.02391.pdf) Selvaraju et al., ICCV, 2017. 
- [Explainability methods  for graph  convolutional  neural  networks.](https://openaccess.thecvf.com/content_CVPR_2019/papers/Pope_Explainability_Methods_for_Graph_Convolutional_Neural_Networks_CVPR_2019_paper.pdf) Pope et al., CVPR, 2019. 
- [Quantifying Explainers of Graph Neural Networks in Computational Pathology.](https://arxiv.org/pdf/2011.12646.pdf) Jaume et al., CVPR, 2021.

In [None]:
import torch 
from glob import glob 
import tqdm 
import numpy as np
from PIL import Image
from dgl.data.utils import load_graphs

from histocartography.interpretability import GraphGradCAMExplainer
from histocartography.utils.graph import set_graph_on_cuda

is_cuda = torch.cuda.is_available()

INDEX_TO_TUMOR_TYPE = {
  0: 'Benign',
  1: 'Atypical',
  2: 'Malignant'
}

# 1. Define a GraphGradCAM explainer
explainer = GraphGradCAMExplainer(model=model)

# 2. Load preprocessed cell graphs, concepts & images 
cg_fnames = glob(os.path.join('cell_graphs', '*.bin'))
image_fnames = glob(os.path.join('images', '*.png'))
concept_fnames = glob(os.path.join('nuclei_concepts', '*.npy'))

cg_fnames.sort()
image_fnames.sort()
concept_fnames.sort()

# 3. Explain all our samples 
output = []
for cg_name, image_name, concept_name in zip(cg_fnames, image_fnames, concept_fnames):
    print('Processing...', image_name)
    
    image = np.array(Image.open(image_name))
    concepts = np.load(concept_name)
    graph, _ = load_graphs(cg_name)
    graph = graph[0]
    graph = set_graph_on_cuda(graph) if is_cuda else graph
    
    importance_scores, logits = explainer.process(
        graph,
        output_name=cg_name.replace('.bin', '')
    )
    print('logits: ', logits)
    print('prediction: ', INDEX_TO_TUMOR_TYPE[np.argmax(logits)], '\n')
    
    output.append({
        'image_name': os.path.basename(image_name).split('.')[0],
        'image': image,
        'graph': graph,
        'importance_scores': importance_scores,
        'logits': logits,
        'concepts': concepts
    })

In [None]:
from histocartography.visualization import OverlayGraphVisualization, InstanceImageVisualization

INDEX_TO_TUMOR_TYPE = {
  0: 'Benign',
  1: 'Atypical',
  2: 'Malignant'
}

# Visualize the cell graph along with its relative node importance 
visualizer = OverlayGraphVisualization(
    instance_visualizer=InstanceImageVisualization(),
    colormap='plasma'
)

for i, instance in enumerate(output):
    print(instance['image_name'], instance['logits'])
    node_attributes = {}
    node_attributes["color"] = instance['importance_scores']
    node_attributes["thickness"] = 15
    node_attributes["radius"] = 10
    
    viz_cg = visualizer.process(
        canvas=instance['image'],
        graph=instance['graph'],
        node_attributes=node_attributes,
    )
    
    show_inline(viz_cg, title='Sample: {}'.format(INDEX_TO_TUMOR_TYPE[np.argmax(instance['logits'])]))

<div id="Section4"></div> 

### 4) Nuclei concept analysis: These nodes are important, but why? 

We were able to identify what are the important nuclei, *i.e.,* the discriminative nodes, using GraphGradCAM. We would like to push our analysis one step further to understand if the attributes (shape, size, etc.) of the important nuclei match prior pathological knowledge. For instance, it is known that cancerous nuclei are larger than benign ones or that atypical nuclei are expected to have irregular shapes.

To this end, we extract a set of nuclei-level attributes on the most important nuclei.

**Note**: A *quantitative* analysis can be performed by studying nuclei-concept distributions and how they align with prior pathological knowledge. However, this analysis is beyond the scope of this tutorial. The reader can refer to [this work](https://arxiv.org/pdf/2011.12646.pdf) for more details. 

**References:**

- [Quantifying Explainers of Graph Neural Networks in Computational Pathology.](https://arxiv.org/pdf/2011.12646.pdf) Jaume et al., CVPR, 2021.

In [None]:
for i, out in enumerate(output):
    if 'benign' in out['image_name']:
        benign_data = out
    elif 'atypical' in out['image_name']:
        atypical_data = out
    elif 'malignant' in out['image_name']:
        malignant_data = out

#### Nuclei visualization

- Visualizing the 20 most important nuclei 
- Visualizing 20 random nuclei for comparison

In [None]:
# Top k nuclei
from utils import get_patches, plot_patches

k = 20

nuclei = get_patches(out=benign_data, k=k)
plot_patches(nuclei, ncol=10)

nuclei = get_patches(out=atypical_data, k=k)
plot_patches(nuclei, ncol=10)

nuclei = get_patches(out=malignant_data, k=k)
plot_patches(nuclei, ncol=10)

In [None]:
# Top k nuclei
from utils import get_patches, plot_patches
k = 20

nuclei = get_patches(out=benign_data, k=k, random=True)
plot_patches(nuclei, ncol=10)

nuclei = get_patches(out=atypical_data, k=k, random=True)
plot_patches(nuclei, ncol=10)

nuclei = get_patches(out=malignant_data, k=k, random=True)
plot_patches(nuclei, ncol=10)

In [None]:
#area,eccentricity,roundness,roughness,shape_factor,mean_crowdedness,glcm_entropy,glcm_contrast
FEATURE_TO_INDEX = {
  'area': 0,
  'eccentricity': 1,
  'roundness': 2,
  'roughness': 3,
  'shape_factor': 4,
  'mean_crowdedness': 5,
  'glcm_entropy': 6,
  'glcm_contrast': 7,
}
    
def compute_concept_ratio(data1, data2, feature, k):
    index = FEATURE_TO_INDEX[feature]
    
    important_indices = (-data1['importance_scores']).argsort()[:k]
    important_data1 = data1['concepts'][important_indices, index]

    important_indices = (-data2['importance_scores']).argsort()[:k]
    important_data2 = data2['concepts'][important_indices, index]

    return sum(important_data1) / sum(important_data2)

#### Pathological fact: "Cancerous nuclei are expected to be larger than benign ones": area(Malignant) > area(Benign)

In [None]:
k = 20
ratio = compute_concept_ratio(malignant_data, benign_data, 'area', k)
print('Ratio between the area of important malignant and benign nuclei: ', round(ratio, 4))

#### Pathological fact: "Atypical nuclei are hyperchromatic (solid) and Malignant are vesicular (porous)": contrast(Malignant) > contrast(Atypical)

In [None]:
k = 20
ratio = compute_concept_ratio(malignant_data, atypical_data, 'glcm_contrast', k)
print('Ratio between the contrast of important malignant and atypical nuclei: ', round(ratio, 4))

#### Pathological fact: "Benign nuclei are crowded than Atypical": crowdedness(Atypical) > crowdedness(Benign)

In [None]:
k = 20
ratio = compute_concept_ratio(atypical_data, benign_data, 'mean_crowdedness', k)
print('Ratio between the crowdeness of important atypical and benign nuclei: ', round(ratio, 4))

## Conclusion:

Considering the adoption of Graph Neural Networks in various domains, such as pathology, radiology, computation biology, satellite and natural images, graph interpretability and explainability is imperative. The presented algorithms and tools aim to motivate and instruct in the aforementioned direction. Though the presented technologies are demonstrated for digital pathology, they can be seamlessly transferred to other domains by building domain specific relevant graph representations. Potentially, entity-graph modeling and analysis can identify relevant  cues for explainable stratification.


<figure class="image">
  <img src="Figures/conclusion.png" width="850">
</figure>
