315 lines (314 with data), 10.8 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calculate features for nuclei and generate .pt files for each graph\n",
"## 15 features in total"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import re, os\n",
"import cv2\n",
"import math\n",
"import random\n",
"import torch\n",
"import resnet\n",
"import skimage.feature\n",
"import pdb\n",
"from PIL import Image\n",
"from pyflann import *\n",
"from torch_geometric.data import Data\n",
"from collections import OrderedDict\n",
"\n",
"import networkx as nx\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torchvision.transforms.functional as F\n",
"import torch_geometric.data as data\n",
"import torch_geometric.utils as utils\n",
"import pdb\n",
"import torch_geometric"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from model import CPC_model\n",
"device = torch.device('cuda:{}'.format('0'))\n",
"model = CPC_model(1024, 256)\n",
"encoder = model.encoder.to(device)\n",
"ckpt_dir = './pretrained_models/cpc.pt'\n",
"ckpt = torch.load(ckpt_dir)\n",
"encoder.load_state_dict(ckpt['encoder_state_dict'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def from_networkx(G):\n",
" r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n",
" :class:`torch_geometric.data.Data` instance.\n",
" Args:\n",
" G (networkx.Graph or networkx.DiGraph): A networkx graph.\n",
" \"\"\"\n",
"\n",
" G = G.to_directed() if not nx.is_directed(G) else G\n",
" edge_index = torch.tensor(list(G.edges)).t().contiguous()\n",
"\n",
" keys = []\n",
" keys += list(list(G.nodes(data=True))[0][1].keys())\n",
" keys += list(list(G.edges(data=True))[0][2].keys())\n",
" data = {key: [] for key in keys}\n",
"\n",
" for _, feat_dict in G.nodes(data=True):\n",
" for key, value in feat_dict.items():\n",
" data[key].append(value)\n",
"\n",
" for _, _, feat_dict in G.edges(data=True):\n",
" for key, value in feat_dict.items():\n",
" data[key].append(value)\n",
"\n",
" for key, item in data.items():\n",
" data[key] = torch.tensor(item)\n",
"\n",
" data['edge_index'] = edge_index\n",
" data = torch_geometric.data.Data.from_dict(data)\n",
" data.num_nodes = G.number_of_nodes()\n",
"\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from torchvision import transforms\n",
"import itertools\n",
"\n",
"def get_cell_image(img, cx, cy, size=512):\n",
" cx = 32 if cx < 32 else size-32 if cx > size-32 else cx\n",
" cy = 32 if cy < 32 else size-32 if cy > size-32 else cy\n",
" return img[cy-32:cy+32, cx-32:cx+32, :]\n",
"\n",
"def get_cpc_features(cell):\n",
" transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
" cell = transform(cell)\n",
" cell = cell.unsqueeze(0)\n",
" device = torch.device('cuda:{}'.format('0'))\n",
" feats = encoder(cell.to(device)).cpu().detach().numpy()[0]\n",
" return feats\n",
"\n",
"def get_cell_features(img, contour):\n",
" \n",
" # Get contour coordinates from contour\n",
" (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)\n",
" cx, cy = int(cx), int(cy)\n",
" \n",
" # Get a 64 x 64 center crop over each cell \n",
" img_cell = get_cell_image(img, cx, cy)\n",
"\n",
" grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)\n",
" img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') \n",
"\n",
"\n",
" # 1. Generating contour features\n",
" eccentricity = math.sqrt(1-(short_axis/long_axis)**2)\n",
" convex_hull = cv2.convexHull(contour)\n",
" area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)\n",
" solidity = float(area)/hull_area\n",
" arc_length = cv2.arcLength(contour, True)\n",
" roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))\n",
" \n",
" # 2. Generating GLCM features\n",
" out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])\n",
" dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]\n",
" homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]\n",
" energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]\n",
" ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]\n",
" \n",
" # 3. Generating CPC features\n",
" cpc_feats = get_cpc_features(img_cell)\n",
" \n",
"\n",
" # Concatenate + Return all features\n",
" x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],\n",
" [dissimilarity, homogeneity, energy, ASM], \n",
" cpc_feats]\n",
" \n",
" return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy\n",
"\n",
"\n",
"def seg2graph(img, contours):\n",
" G = nx.Graph()\n",
" \n",
" contours = [c for c in contours if c.shape[0] > 5]\n",
"\n",
" for v, contour in enumerate(contours):\n",
"\n",
" features, cx, cy = get_cell_features(img, contour)\n",
" G.add_node(v, centroid = [cx, cy], x = features)\n",
"\n",
" if v < 5: return None\n",
" return G"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"data_dir = \"./example_data/\"\n",
"img_dir = os.path.join(data_dir, 'imgs')\n",
"seg_dir = os.path.join(data_dir,'segs')\n",
"\n",
"roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'\n",
"roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'\n",
"roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'\n",
"\n",
"assert roi1 in os.listdir(seg_dir)\n",
"assert roi2 in os.listdir(seg_dir)\n",
"assert roi3 in os.listdir(seg_dir)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 3/3 [00:13<00:00, 4.41s/it]\n"
]
}
],
"source": [
"save_dir = data_dir\n",
"pt_dir = os.path.join(save_dir, 'pts')\n",
"graph_dir = os.path.join(save_dir, 'graphs')\n",
"fail_list = []\n",
"\n",
"from tqdm import tqdm\n",
"\n",
"for img_fname in tqdm([roi1, roi2, roi3]):\n",
" \n",
" #if int(img_fname.split('_')[2]) > 2: continue\n",
" #print(\"Processing...(%d/%d):\\t%s\" % (idx+1, len(os.listdir(seg_dir)), img_fname))\n",
" \n",
" img = np.array(Image.open(os.path.join(img_dir, img_fname)))\n",
" seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))\n",
" ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) \n",
" contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
" if len(contours) < 1: continue\n",
" \n",
" G = seg2graph(img, contours)\n",
"\n",
" if G is None: \n",
" fail_list.append(img_fname)\n",
" continue\n",
"\n",
"\n",
" centroids = []\n",
" for u, attrib in G.nodes(data=True):\n",
" centroids.append(attrib['centroid'])\n",
" \n",
" cell_centroids = np.array(centroids).astype(np.float64)\n",
" dataset = cell_centroids\n",
" \n",
" start = None\n",
" \n",
" for idx, attrib in list(G.nodes(data=True)):\n",
" start = idx\n",
" flann = FLANN()\n",
" testset = np.array([attrib['centroid']]).astype(np.float64)\n",
" results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)\n",
" results, dists = results[0], dists[0]\n",
" nns_fin = []\n",
" # assert (results.shape[0] < 6)\n",
" for i in range(1, len(results)):\n",
" G.add_edge(idx, results[i], weight = dists[i])\n",
" nns_fin.append(results[i])\n",
" #attrib['nn'] = list(nns_fin)\n",
"\n",
" G = G.subgraph(max(nx.connected_components(G), key=len))\n",
"\n",
" #for idx, attrib in list(G.nodes(data=True)):\n",
" # cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)\n",
" \n",
" cv2.drawContours(img, contours, -1, (0,255,0), 2)\n",
" \n",
" for n, nbrs in G.adjacency():\n",
" for nbr, eattr in nbrs.items():\n",
" cv2.line(img, tuple(G.nodes[n]['centroid']), tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)\n",
"\n",
" Image.fromarray(img).save(os.path.join(graph_dir, img_fname))\n",
" \n",
" G = from_networkx(G)\n",
" \n",
" edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)\n",
" G.edge_attr = edge_attr_long \n",
" \n",
" edge_index_long = G['edge_index'].type(torch.LongTensor)\n",
" G.edge_index = edge_index_long\n",
" \n",
" x_float = G['x'].type(torch.FloatTensor)\n",
" G.x = x_float\n",
" \n",
" G['weight'] = None\n",
" G['nn'] = None\n",
" torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}