a b/CellGraph/Graph Construction.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "## Calculate features for nuclei and generate .pt files for each graph\n",
8
    "## 15 features in total"
9
   ]
10
  },
11
  {
12
   "cell_type": "code",
13
   "execution_count": 1,
14
   "metadata": {},
15
   "outputs": [],
16
   "source": [
17
    "import re, os\n",
18
    "import cv2\n",
19
    "import math\n",
20
    "import random\n",
21
    "import torch\n",
22
    "import resnet\n",
23
    "import skimage.feature\n",
24
    "import pdb\n",
25
    "from PIL import Image\n",
26
    "from pyflann import *\n",
27
    "from torch_geometric.data import Data\n",
28
    "from collections import OrderedDict\n",
29
    "\n",
30
    "import networkx as nx\n",
31
    "import numpy as np\n",
32
    "import pandas as pd\n",
33
    "import torchvision.transforms.functional as F\n",
34
    "import torch_geometric.data as data\n",
35
    "import torch_geometric.utils as utils\n",
36
    "import pdb\n",
37
    "import torch_geometric"
38
   ]
39
  },
40
  {
41
   "cell_type": "code",
42
   "execution_count": 2,
43
   "metadata": {},
44
   "outputs": [
45
    {
46
     "data": {
47
      "text/plain": [
48
       "<All keys matched successfully>"
49
      ]
50
     },
51
     "execution_count": 2,
52
     "metadata": {},
53
     "output_type": "execute_result"
54
    }
55
   ],
56
   "source": [
57
    "from model import CPC_model\n",
58
    "device = torch.device('cuda:{}'.format('0'))\n",
59
    "model = CPC_model(1024, 256)\n",
60
    "encoder = model.encoder.to(device)\n",
61
    "ckpt_dir = './pretrained_models/cpc.pt'\n",
62
    "ckpt = torch.load(ckpt_dir)\n",
63
    "encoder.load_state_dict(ckpt['encoder_state_dict'])"
64
   ]
65
  },
66
  {
67
   "cell_type": "code",
68
   "execution_count": 4,
69
   "metadata": {},
70
   "outputs": [],
71
   "source": [
72
    "def from_networkx(G):\n",
73
    "    r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n",
74
    "    :class:`torch_geometric.data.Data` instance.\n",
75
    "    Args:\n",
76
    "        G (networkx.Graph or networkx.DiGraph): A networkx graph.\n",
77
    "    \"\"\"\n",
78
    "\n",
79
    "    G = G.to_directed() if not nx.is_directed(G) else G\n",
80
    "    edge_index = torch.tensor(list(G.edges)).t().contiguous()\n",
81
    "\n",
82
    "    keys = []\n",
83
    "    keys += list(list(G.nodes(data=True))[0][1].keys())\n",
84
    "    keys += list(list(G.edges(data=True))[0][2].keys())\n",
85
    "    data = {key: [] for key in keys}\n",
86
    "\n",
87
    "    for _, feat_dict in G.nodes(data=True):\n",
88
    "        for key, value in feat_dict.items():\n",
89
    "            data[key].append(value)\n",
90
    "\n",
91
    "    for _, _, feat_dict in G.edges(data=True):\n",
92
    "        for key, value in feat_dict.items():\n",
93
    "            data[key].append(value)\n",
94
    "\n",
95
    "    for key, item in data.items():\n",
96
    "        data[key] = torch.tensor(item)\n",
97
    "\n",
98
    "    data['edge_index'] = edge_index\n",
99
    "    data = torch_geometric.data.Data.from_dict(data)\n",
100
    "    data.num_nodes = G.number_of_nodes()\n",
101
    "\n",
102
    "    return data"
103
   ]
104
  },
105
  {
106
   "cell_type": "code",
107
   "execution_count": 6,
108
   "metadata": {},
109
   "outputs": [],
110
   "source": [
111
    "from torchvision import transforms\n",
112
    "import itertools\n",
113
    "\n",
114
    "def get_cell_image(img, cx, cy, size=512):\n",
115
    "    cx = 32 if cx < 32 else size-32 if cx > size-32 else cx\n",
116
    "    cy = 32 if cy < 32 else size-32 if cy > size-32 else cy\n",
117
    "    return img[cy-32:cy+32, cx-32:cx+32, :]\n",
118
    "\n",
119
    "def get_cpc_features(cell):\n",
120
    "    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
121
    "    cell = transform(cell)\n",
122
    "    cell = cell.unsqueeze(0)\n",
123
    "    device = torch.device('cuda:{}'.format('0'))\n",
124
    "    feats = encoder(cell.to(device)).cpu().detach().numpy()[0]\n",
125
    "    return feats\n",
126
    "\n",
127
    "def get_cell_features(img, contour):\n",
128
    "    \n",
129
    "    # Get contour coordinates from contour\n",
130
    "    (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)\n",
131
    "    cx, cy = int(cx), int(cy)\n",
132
    "    \n",
133
    "    # Get a 64 x 64 center crop over each cell    \n",
134
    "    img_cell = get_cell_image(img, cx, cy)\n",
135
    "\n",
136
    "    grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)\n",
137
    "    img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') \n",
138
    "\n",
139
    "\n",
140
    "    # 1. Generating contour features\n",
141
    "    eccentricity = math.sqrt(1-(short_axis/long_axis)**2)\n",
142
    "    convex_hull = cv2.convexHull(contour)\n",
143
    "    area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)\n",
144
    "    solidity = float(area)/hull_area\n",
145
    "    arc_length = cv2.arcLength(contour, True)\n",
146
    "    roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))\n",
147
    "    \n",
148
    "    # 2. Generating GLCM features\n",
149
    "    out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])\n",
150
    "    dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]\n",
151
    "    homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]\n",
152
    "    energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]\n",
153
    "    ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]\n",
154
    "    \n",
155
    "    # 3. Generating CPC features\n",
156
    "    cpc_feats = get_cpc_features(img_cell)\n",
157
    "    \n",
158
    "\n",
159
    "    # Concatenate + Return all features\n",
160
    "    x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],\n",
161
    "         [dissimilarity, homogeneity, energy, ASM], \n",
162
    "         cpc_feats]\n",
163
    "    \n",
164
    "    return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy\n",
165
    "\n",
166
    "\n",
167
    "def seg2graph(img, contours):\n",
168
    "    G = nx.Graph()\n",
169
    "    \n",
170
    "    contours = [c for c in contours if c.shape[0] > 5]\n",
171
    "\n",
172
    "    for v, contour in enumerate(contours):\n",
173
    "\n",
174
    "        features, cx, cy = get_cell_features(img, contour)\n",
175
    "        G.add_node(v, centroid = [cx, cy], x = features)\n",
176
    "\n",
177
    "    if v < 5: return None\n",
178
    "    return G"
179
   ]
180
  },
181
  {
182
   "cell_type": "code",
183
   "execution_count": 1,
184
   "metadata": {},
185
   "outputs": [],
186
   "source": [
187
    "data_dir = \"./example_data/\"\n",
188
    "img_dir = os.path.join(data_dir, 'imgs')\n",
189
    "seg_dir =  os.path.join(data_dir,'segs')\n",
190
    "\n",
191
    "roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'\n",
192
    "roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'\n",
193
    "roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'\n",
194
    "\n",
195
    "assert roi1 in os.listdir(seg_dir)\n",
196
    "assert roi2 in os.listdir(seg_dir)\n",
197
    "assert roi3 in os.listdir(seg_dir)"
198
   ]
199
  },
200
  {
201
   "cell_type": "code",
202
   "execution_count": 26,
203
   "metadata": {
204
    "scrolled": true
205
   },
206
   "outputs": [
207
    {
208
     "name": "stderr",
209
     "output_type": "stream",
210
     "text": [
211
      "100%|██████████| 3/3 [00:13<00:00,  4.41s/it]\n"
212
     ]
213
    }
214
   ],
215
   "source": [
216
    "save_dir = data_dir\n",
217
    "pt_dir = os.path.join(save_dir, 'pts')\n",
218
    "graph_dir = os.path.join(save_dir, 'graphs')\n",
219
    "fail_list = []\n",
220
    "\n",
221
    "from tqdm import tqdm\n",
222
    "\n",
223
    "for img_fname in tqdm([roi1, roi2, roi3]):\n",
224
    "    \n",
225
    "    #if int(img_fname.split('_')[2]) > 2: continue\n",
226
    "    #print(\"Processing...(%d/%d):\\t%s\" % (idx+1, len(os.listdir(seg_dir)), img_fname))\n",
227
    "    \n",
228
    "    img = np.array(Image.open(os.path.join(img_dir, img_fname)))\n",
229
    "    seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))\n",
230
    "    ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) \n",
231
    "    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
232
    "    if len(contours) < 1: continue\n",
233
    "    \n",
234
    "    G = seg2graph(img, contours)\n",
235
    "\n",
236
    "    if G is None: \n",
237
    "        fail_list.append(img_fname)\n",
238
    "        continue\n",
239
    "\n",
240
    "\n",
241
    "    centroids = []\n",
242
    "    for u, attrib in G.nodes(data=True):\n",
243
    "        centroids.append(attrib['centroid'])\n",
244
    "    \n",
245
    "    cell_centroids = np.array(centroids).astype(np.float64)\n",
246
    "    dataset = cell_centroids\n",
247
    "    \n",
248
    "    start = None\n",
249
    "            \n",
250
    "    for idx, attrib in list(G.nodes(data=True)):\n",
251
    "        start = idx\n",
252
    "        flann = FLANN()\n",
253
    "        testset = np.array([attrib['centroid']]).astype(np.float64)\n",
254
    "        results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)\n",
255
    "        results, dists = results[0], dists[0]\n",
256
    "        nns_fin = []\n",
257
    "       # assert (results.shape[0] < 6)\n",
258
    "        for i in range(1, len(results)):\n",
259
    "            G.add_edge(idx, results[i], weight = dists[i])\n",
260
    "            nns_fin.append(results[i])\n",
261
    "        #attrib['nn'] = list(nns_fin)\n",
262
    "\n",
263
    "    G = G.subgraph(max(nx.connected_components(G), key=len))\n",
264
    "\n",
265
    "    #for idx, attrib in list(G.nodes(data=True)):\n",
266
    "    #    cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)\n",
267
    "    \n",
268
    "    cv2.drawContours(img, contours, -1, (0,255,0), 2)\n",
269
    "    \n",
270
    "    for n, nbrs in G.adjacency():\n",
271
    "        for nbr, eattr in nbrs.items():\n",
272
    "            cv2.line(img, tuple(G.nodes[n]['centroid']),  tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)\n",
273
    "\n",
274
    "    Image.fromarray(img).save(os.path.join(graph_dir, img_fname))\n",
275
    "    \n",
276
    "    G = from_networkx(G)\n",
277
    "    \n",
278
    "    edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)\n",
279
    "    G.edge_attr = edge_attr_long \n",
280
    "    \n",
281
    "    edge_index_long = G['edge_index'].type(torch.LongTensor)\n",
282
    "    G.edge_index = edge_index_long\n",
283
    "    \n",
284
    "    x_float = G['x'].type(torch.FloatTensor)\n",
285
    "    G.x = x_float\n",
286
    "    \n",
287
    "    G['weight'] = None\n",
288
    "    G['nn'] = None\n",
289
    "    torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))"
290
   ]
291
  }
292
 ],
293
 "metadata": {
294
  "kernelspec": {
295
   "display_name": "Python 3",
296
   "language": "python",
297
   "name": "python3"
298
  },
299
  "language_info": {
300
   "codemirror_mode": {
301
    "name": "ipython",
302
    "version": 3
303
   },
304
   "file_extension": ".py",
305
   "mimetype": "text/x-python",
306
   "name": "python",
307
   "nbconvert_exporter": "python",
308
   "pygments_lexer": "ipython3",
309
   "version": "3.8.3"
310
  }
311
 },
312
 "nbformat": 4,
313
 "nbformat_minor": 2
314
}