[f2cb69]: / CellGraph / Graph Construction.ipynb

Download this file

315 lines (314 with data), 10.8 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate features for nuclei and generate .pt files for each graph\n",
    "## 15 features in total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re, os\n",
    "import cv2\n",
    "import math\n",
    "import random\n",
    "import torch\n",
    "import resnet\n",
    "import skimage.feature\n",
    "import pdb\n",
    "from PIL import Image\n",
    "from pyflann import *\n",
    "from torch_geometric.data import Data\n",
    "from collections import OrderedDict\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torchvision.transforms.functional as F\n",
    "import torch_geometric.data as data\n",
    "import torch_geometric.utils as utils\n",
    "import pdb\n",
    "import torch_geometric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from model import CPC_model\n",
    "device = torch.device('cuda:{}'.format('0'))\n",
    "model = CPC_model(1024, 256)\n",
    "encoder = model.encoder.to(device)\n",
    "ckpt_dir = './pretrained_models/cpc.pt'\n",
    "ckpt = torch.load(ckpt_dir)\n",
    "encoder.load_state_dict(ckpt['encoder_state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def from_networkx(G):\n",
    "    r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n",
    "    :class:`torch_geometric.data.Data` instance.\n",
    "    Args:\n",
    "        G (networkx.Graph or networkx.DiGraph): A networkx graph.\n",
    "    \"\"\"\n",
    "\n",
    "    G = G.to_directed() if not nx.is_directed(G) else G\n",
    "    edge_index = torch.tensor(list(G.edges)).t().contiguous()\n",
    "\n",
    "    keys = []\n",
    "    keys += list(list(G.nodes(data=True))[0][1].keys())\n",
    "    keys += list(list(G.edges(data=True))[0][2].keys())\n",
    "    data = {key: [] for key in keys}\n",
    "\n",
    "    for _, feat_dict in G.nodes(data=True):\n",
    "        for key, value in feat_dict.items():\n",
    "            data[key].append(value)\n",
    "\n",
    "    for _, _, feat_dict in G.edges(data=True):\n",
    "        for key, value in feat_dict.items():\n",
    "            data[key].append(value)\n",
    "\n",
    "    for key, item in data.items():\n",
    "        data[key] = torch.tensor(item)\n",
    "\n",
    "    data['edge_index'] = edge_index\n",
    "    data = torch_geometric.data.Data.from_dict(data)\n",
    "    data.num_nodes = G.number_of_nodes()\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "import itertools\n",
    "\n",
    "def get_cell_image(img, cx, cy, size=512):\n",
    "    cx = 32 if cx < 32 else size-32 if cx > size-32 else cx\n",
    "    cy = 32 if cy < 32 else size-32 if cy > size-32 else cy\n",
    "    return img[cy-32:cy+32, cx-32:cx+32, :]\n",
    "\n",
    "def get_cpc_features(cell):\n",
    "    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "    cell = transform(cell)\n",
    "    cell = cell.unsqueeze(0)\n",
    "    device = torch.device('cuda:{}'.format('0'))\n",
    "    feats = encoder(cell.to(device)).cpu().detach().numpy()[0]\n",
    "    return feats\n",
    "\n",
    "def get_cell_features(img, contour):\n",
    "    \n",
    "    # Get contour coordinates from contour\n",
    "    (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)\n",
    "    cx, cy = int(cx), int(cy)\n",
    "    \n",
    "    # Get a 64 x 64 center crop over each cell    \n",
    "    img_cell = get_cell_image(img, cx, cy)\n",
    "\n",
    "    grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)\n",
    "    img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') \n",
    "\n",
    "\n",
    "    # 1. Generating contour features\n",
    "    eccentricity = math.sqrt(1-(short_axis/long_axis)**2)\n",
    "    convex_hull = cv2.convexHull(contour)\n",
    "    area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)\n",
    "    solidity = float(area)/hull_area\n",
    "    arc_length = cv2.arcLength(contour, True)\n",
    "    roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))\n",
    "    \n",
    "    # 2. Generating GLCM features\n",
    "    out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])\n",
    "    dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]\n",
    "    homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]\n",
    "    energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]\n",
    "    ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]\n",
    "    \n",
    "    # 3. Generating CPC features\n",
    "    cpc_feats = get_cpc_features(img_cell)\n",
    "    \n",
    "\n",
    "    # Concatenate + Return all features\n",
    "    x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],\n",
    "         [dissimilarity, homogeneity, energy, ASM], \n",
    "         cpc_feats]\n",
    "    \n",
    "    return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy\n",
    "\n",
    "\n",
    "def seg2graph(img, contours):\n",
    "    G = nx.Graph()\n",
    "    \n",
    "    contours = [c for c in contours if c.shape[0] > 5]\n",
    "\n",
    "    for v, contour in enumerate(contours):\n",
    "\n",
    "        features, cx, cy = get_cell_features(img, contour)\n",
    "        G.add_node(v, centroid = [cx, cy], x = features)\n",
    "\n",
    "    if v < 5: return None\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"./example_data/\"\n",
    "img_dir = os.path.join(data_dir, 'imgs')\n",
    "seg_dir =  os.path.join(data_dir,'segs')\n",
    "\n",
    "roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'\n",
    "roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'\n",
    "roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'\n",
    "\n",
    "assert roi1 in os.listdir(seg_dir)\n",
    "assert roi2 in os.listdir(seg_dir)\n",
    "assert roi3 in os.listdir(seg_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:13<00:00,  4.41s/it]\n"
     ]
    }
   ],
   "source": [
    "save_dir = data_dir\n",
    "pt_dir = os.path.join(save_dir, 'pts')\n",
    "graph_dir = os.path.join(save_dir, 'graphs')\n",
    "fail_list = []\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "for img_fname in tqdm([roi1, roi2, roi3]):\n",
    "    \n",
    "    #if int(img_fname.split('_')[2]) > 2: continue\n",
    "    #print(\"Processing...(%d/%d):\\t%s\" % (idx+1, len(os.listdir(seg_dir)), img_fname))\n",
    "    \n",
    "    img = np.array(Image.open(os.path.join(img_dir, img_fname)))\n",
    "    seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))\n",
    "    ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) \n",
    "    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
    "    if len(contours) < 1: continue\n",
    "    \n",
    "    G = seg2graph(img, contours)\n",
    "\n",
    "    if G is None: \n",
    "        fail_list.append(img_fname)\n",
    "        continue\n",
    "\n",
    "\n",
    "    centroids = []\n",
    "    for u, attrib in G.nodes(data=True):\n",
    "        centroids.append(attrib['centroid'])\n",
    "    \n",
    "    cell_centroids = np.array(centroids).astype(np.float64)\n",
    "    dataset = cell_centroids\n",
    "    \n",
    "    start = None\n",
    "            \n",
    "    for idx, attrib in list(G.nodes(data=True)):\n",
    "        start = idx\n",
    "        flann = FLANN()\n",
    "        testset = np.array([attrib['centroid']]).astype(np.float64)\n",
    "        results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)\n",
    "        results, dists = results[0], dists[0]\n",
    "        nns_fin = []\n",
    "       # assert (results.shape[0] < 6)\n",
    "        for i in range(1, len(results)):\n",
    "            G.add_edge(idx, results[i], weight = dists[i])\n",
    "            nns_fin.append(results[i])\n",
    "        #attrib['nn'] = list(nns_fin)\n",
    "\n",
    "    G = G.subgraph(max(nx.connected_components(G), key=len))\n",
    "\n",
    "    #for idx, attrib in list(G.nodes(data=True)):\n",
    "    #    cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)\n",
    "    \n",
    "    cv2.drawContours(img, contours, -1, (0,255,0), 2)\n",
    "    \n",
    "    for n, nbrs in G.adjacency():\n",
    "        for nbr, eattr in nbrs.items():\n",
    "            cv2.line(img, tuple(G.nodes[n]['centroid']),  tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)\n",
    "\n",
    "    Image.fromarray(img).save(os.path.join(graph_dir, img_fname))\n",
    "    \n",
    "    G = from_networkx(G)\n",
    "    \n",
    "    edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)\n",
    "    G.edge_attr = edge_attr_long \n",
    "    \n",
    "    edge_index_long = G['edge_index'].type(torch.LongTensor)\n",
    "    G.edge_index = edge_index_long\n",
    "    \n",
    "    x_float = G['x'].type(torch.FloatTensor)\n",
    "    G.x = x_float\n",
    "    \n",
    "    G['weight'] = None\n",
    "    G['nn'] = None\n",
    "    torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}