--- a +++ b/CellGraph/Graph Construction.ipynb @@ -0,0 +1,314 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calculate features for nuclei and generate .pt files for each graph\n", + "## 15 features in total" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import re, os\n", + "import cv2\n", + "import math\n", + "import random\n", + "import torch\n", + "import resnet\n", + "import skimage.feature\n", + "import pdb\n", + "from PIL import Image\n", + "from pyflann import *\n", + "from torch_geometric.data import Data\n", + "from collections import OrderedDict\n", + "\n", + "import networkx as nx\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torchvision.transforms.functional as F\n", + "import torch_geometric.data as data\n", + "import torch_geometric.utils as utils\n", + "import pdb\n", + "import torch_geometric" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<All keys matched successfully>" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from model import CPC_model\n", + "device = torch.device('cuda:{}'.format('0'))\n", + "model = CPC_model(1024, 256)\n", + "encoder = model.encoder.to(device)\n", + "ckpt_dir = './pretrained_models/cpc.pt'\n", + "ckpt = torch.load(ckpt_dir)\n", + "encoder.load_state_dict(ckpt['encoder_state_dict'])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def from_networkx(G):\n", + " r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n", + " :class:`torch_geometric.data.Data` instance.\n", + " Args:\n", + " G (networkx.Graph or networkx.DiGraph): A networkx graph.\n", + " \"\"\"\n", + "\n", + " G = G.to_directed() if not nx.is_directed(G) else G\n", + " edge_index = torch.tensor(list(G.edges)).t().contiguous()\n", + "\n", + " keys = []\n", + " keys += list(list(G.nodes(data=True))[0][1].keys())\n", + " keys += list(list(G.edges(data=True))[0][2].keys())\n", + " data = {key: [] for key in keys}\n", + "\n", + " for _, feat_dict in G.nodes(data=True):\n", + " for key, value in feat_dict.items():\n", + " data[key].append(value)\n", + "\n", + " for _, _, feat_dict in G.edges(data=True):\n", + " for key, value in feat_dict.items():\n", + " data[key].append(value)\n", + "\n", + " for key, item in data.items():\n", + " data[key] = torch.tensor(item)\n", + "\n", + " data['edge_index'] = edge_index\n", + " data = torch_geometric.data.Data.from_dict(data)\n", + " data.num_nodes = G.number_of_nodes()\n", + "\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from torchvision import transforms\n", + "import itertools\n", + "\n", + "def get_cell_image(img, cx, cy, size=512):\n", + " cx = 32 if cx < 32 else size-32 if cx > size-32 else cx\n", + " cy = 32 if cy < 32 else size-32 if cy > size-32 else cy\n", + " return img[cy-32:cy+32, cx-32:cx+32, :]\n", + "\n", + "def get_cpc_features(cell):\n", + " transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", + " cell = transform(cell)\n", + " cell = cell.unsqueeze(0)\n", + " device = torch.device('cuda:{}'.format('0'))\n", + " feats = encoder(cell.to(device)).cpu().detach().numpy()[0]\n", + " return feats\n", + "\n", + "def get_cell_features(img, contour):\n", + " \n", + " # Get contour coordinates from contour\n", + " (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)\n", + " cx, cy = int(cx), int(cy)\n", + " \n", + " # Get a 64 x 64 center crop over each cell \n", + " img_cell = get_cell_image(img, cx, cy)\n", + "\n", + " grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)\n", + " img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') \n", + "\n", + "\n", + " # 1. Generating contour features\n", + " eccentricity = math.sqrt(1-(short_axis/long_axis)**2)\n", + " convex_hull = cv2.convexHull(contour)\n", + " area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)\n", + " solidity = float(area)/hull_area\n", + " arc_length = cv2.arcLength(contour, True)\n", + " roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))\n", + " \n", + " # 2. Generating GLCM features\n", + " out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])\n", + " dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]\n", + " homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]\n", + " energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]\n", + " ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]\n", + " \n", + " # 3. Generating CPC features\n", + " cpc_feats = get_cpc_features(img_cell)\n", + " \n", + "\n", + " # Concatenate + Return all features\n", + " x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],\n", + " [dissimilarity, homogeneity, energy, ASM], \n", + " cpc_feats]\n", + " \n", + " return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy\n", + "\n", + "\n", + "def seg2graph(img, contours):\n", + " G = nx.Graph()\n", + " \n", + " contours = [c for c in contours if c.shape[0] > 5]\n", + "\n", + " for v, contour in enumerate(contours):\n", + "\n", + " features, cx, cy = get_cell_features(img, contour)\n", + " G.add_node(v, centroid = [cx, cy], x = features)\n", + "\n", + " if v < 5: return None\n", + " return G" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"./example_data/\"\n", + "img_dir = os.path.join(data_dir, 'imgs')\n", + "seg_dir = os.path.join(data_dir,'segs')\n", + "\n", + "roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'\n", + "roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'\n", + "roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'\n", + "\n", + "assert roi1 in os.listdir(seg_dir)\n", + "assert roi2 in os.listdir(seg_dir)\n", + "assert roi3 in os.listdir(seg_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:13<00:00, 4.41s/it]\n" + ] + } + ], + "source": [ + "save_dir = data_dir\n", + "pt_dir = os.path.join(save_dir, 'pts')\n", + "graph_dir = os.path.join(save_dir, 'graphs')\n", + "fail_list = []\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "for img_fname in tqdm([roi1, roi2, roi3]):\n", + " \n", + " #if int(img_fname.split('_')[2]) > 2: continue\n", + " #print(\"Processing...(%d/%d):\\t%s\" % (idx+1, len(os.listdir(seg_dir)), img_fname))\n", + " \n", + " img = np.array(Image.open(os.path.join(img_dir, img_fname)))\n", + " seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))\n", + " ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) \n", + " contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n", + " if len(contours) < 1: continue\n", + " \n", + " G = seg2graph(img, contours)\n", + "\n", + " if G is None: \n", + " fail_list.append(img_fname)\n", + " continue\n", + "\n", + "\n", + " centroids = []\n", + " for u, attrib in G.nodes(data=True):\n", + " centroids.append(attrib['centroid'])\n", + " \n", + " cell_centroids = np.array(centroids).astype(np.float64)\n", + " dataset = cell_centroids\n", + " \n", + " start = None\n", + " \n", + " for idx, attrib in list(G.nodes(data=True)):\n", + " start = idx\n", + " flann = FLANN()\n", + " testset = np.array([attrib['centroid']]).astype(np.float64)\n", + " results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)\n", + " results, dists = results[0], dists[0]\n", + " nns_fin = []\n", + " # assert (results.shape[0] < 6)\n", + " for i in range(1, len(results)):\n", + " G.add_edge(idx, results[i], weight = dists[i])\n", + " nns_fin.append(results[i])\n", + " #attrib['nn'] = list(nns_fin)\n", + "\n", + " G = G.subgraph(max(nx.connected_components(G), key=len))\n", + "\n", + " #for idx, attrib in list(G.nodes(data=True)):\n", + " # cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)\n", + " \n", + " cv2.drawContours(img, contours, -1, (0,255,0), 2)\n", + " \n", + " for n, nbrs in G.adjacency():\n", + " for nbr, eattr in nbrs.items():\n", + " cv2.line(img, tuple(G.nodes[n]['centroid']), tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)\n", + "\n", + " Image.fromarray(img).save(os.path.join(graph_dir, img_fname))\n", + " \n", + " G = from_networkx(G)\n", + " \n", + " edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)\n", + " G.edge_attr = edge_attr_long \n", + " \n", + " edge_index_long = G['edge_index'].type(torch.LongTensor)\n", + " G.edge_index = edge_index_long\n", + " \n", + " x_float = G['x'].type(torch.FloatTensor)\n", + " G.x = x_float\n", + " \n", + " G['weight'] = None\n", + " G['nn'] = None\n", + " torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}