|
a |
|
b/CellGraph/Graph Construction.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "markdown", |
|
|
5 |
"metadata": {}, |
|
|
6 |
"source": [ |
|
|
7 |
"## Calculate features for nuclei and generate .pt files for each graph\n", |
|
|
8 |
"## 15 features in total" |
|
|
9 |
] |
|
|
10 |
}, |
|
|
11 |
{ |
|
|
12 |
"cell_type": "code", |
|
|
13 |
"execution_count": 1, |
|
|
14 |
"metadata": {}, |
|
|
15 |
"outputs": [], |
|
|
16 |
"source": [ |
|
|
17 |
"import re, os\n", |
|
|
18 |
"import cv2\n", |
|
|
19 |
"import math\n", |
|
|
20 |
"import random\n", |
|
|
21 |
"import torch\n", |
|
|
22 |
"import resnet\n", |
|
|
23 |
"import skimage.feature\n", |
|
|
24 |
"import pdb\n", |
|
|
25 |
"from PIL import Image\n", |
|
|
26 |
"from pyflann import *\n", |
|
|
27 |
"from torch_geometric.data import Data\n", |
|
|
28 |
"from collections import OrderedDict\n", |
|
|
29 |
"\n", |
|
|
30 |
"import networkx as nx\n", |
|
|
31 |
"import numpy as np\n", |
|
|
32 |
"import pandas as pd\n", |
|
|
33 |
"import torchvision.transforms.functional as F\n", |
|
|
34 |
"import torch_geometric.data as data\n", |
|
|
35 |
"import torch_geometric.utils as utils\n", |
|
|
36 |
"import pdb\n", |
|
|
37 |
"import torch_geometric" |
|
|
38 |
] |
|
|
39 |
}, |
|
|
40 |
{ |
|
|
41 |
"cell_type": "code", |
|
|
42 |
"execution_count": 2, |
|
|
43 |
"metadata": {}, |
|
|
44 |
"outputs": [ |
|
|
45 |
{ |
|
|
46 |
"data": { |
|
|
47 |
"text/plain": [ |
|
|
48 |
"<All keys matched successfully>" |
|
|
49 |
] |
|
|
50 |
}, |
|
|
51 |
"execution_count": 2, |
|
|
52 |
"metadata": {}, |
|
|
53 |
"output_type": "execute_result" |
|
|
54 |
} |
|
|
55 |
], |
|
|
56 |
"source": [ |
|
|
57 |
"from model import CPC_model\n", |
|
|
58 |
"device = torch.device('cuda:{}'.format('0'))\n", |
|
|
59 |
"model = CPC_model(1024, 256)\n", |
|
|
60 |
"encoder = model.encoder.to(device)\n", |
|
|
61 |
"ckpt_dir = './pretrained_models/cpc.pt'\n", |
|
|
62 |
"ckpt = torch.load(ckpt_dir)\n", |
|
|
63 |
"encoder.load_state_dict(ckpt['encoder_state_dict'])" |
|
|
64 |
] |
|
|
65 |
}, |
|
|
66 |
{ |
|
|
67 |
"cell_type": "code", |
|
|
68 |
"execution_count": 4, |
|
|
69 |
"metadata": {}, |
|
|
70 |
"outputs": [], |
|
|
71 |
"source": [ |
|
|
72 |
"def from_networkx(G):\n", |
|
|
73 |
" r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n", |
|
|
74 |
" :class:`torch_geometric.data.Data` instance.\n", |
|
|
75 |
" Args:\n", |
|
|
76 |
" G (networkx.Graph or networkx.DiGraph): A networkx graph.\n", |
|
|
77 |
" \"\"\"\n", |
|
|
78 |
"\n", |
|
|
79 |
" G = G.to_directed() if not nx.is_directed(G) else G\n", |
|
|
80 |
" edge_index = torch.tensor(list(G.edges)).t().contiguous()\n", |
|
|
81 |
"\n", |
|
|
82 |
" keys = []\n", |
|
|
83 |
" keys += list(list(G.nodes(data=True))[0][1].keys())\n", |
|
|
84 |
" keys += list(list(G.edges(data=True))[0][2].keys())\n", |
|
|
85 |
" data = {key: [] for key in keys}\n", |
|
|
86 |
"\n", |
|
|
87 |
" for _, feat_dict in G.nodes(data=True):\n", |
|
|
88 |
" for key, value in feat_dict.items():\n", |
|
|
89 |
" data[key].append(value)\n", |
|
|
90 |
"\n", |
|
|
91 |
" for _, _, feat_dict in G.edges(data=True):\n", |
|
|
92 |
" for key, value in feat_dict.items():\n", |
|
|
93 |
" data[key].append(value)\n", |
|
|
94 |
"\n", |
|
|
95 |
" for key, item in data.items():\n", |
|
|
96 |
" data[key] = torch.tensor(item)\n", |
|
|
97 |
"\n", |
|
|
98 |
" data['edge_index'] = edge_index\n", |
|
|
99 |
" data = torch_geometric.data.Data.from_dict(data)\n", |
|
|
100 |
" data.num_nodes = G.number_of_nodes()\n", |
|
|
101 |
"\n", |
|
|
102 |
" return data" |
|
|
103 |
] |
|
|
104 |
}, |
|
|
105 |
{ |
|
|
106 |
"cell_type": "code", |
|
|
107 |
"execution_count": 6, |
|
|
108 |
"metadata": {}, |
|
|
109 |
"outputs": [], |
|
|
110 |
"source": [ |
|
|
111 |
"from torchvision import transforms\n", |
|
|
112 |
"import itertools\n", |
|
|
113 |
"\n", |
|
|
114 |
"def get_cell_image(img, cx, cy, size=512):\n", |
|
|
115 |
" cx = 32 if cx < 32 else size-32 if cx > size-32 else cx\n", |
|
|
116 |
" cy = 32 if cy < 32 else size-32 if cy > size-32 else cy\n", |
|
|
117 |
" return img[cy-32:cy+32, cx-32:cx+32, :]\n", |
|
|
118 |
"\n", |
|
|
119 |
"def get_cpc_features(cell):\n", |
|
|
120 |
" transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", |
|
|
121 |
" cell = transform(cell)\n", |
|
|
122 |
" cell = cell.unsqueeze(0)\n", |
|
|
123 |
" device = torch.device('cuda:{}'.format('0'))\n", |
|
|
124 |
" feats = encoder(cell.to(device)).cpu().detach().numpy()[0]\n", |
|
|
125 |
" return feats\n", |
|
|
126 |
"\n", |
|
|
127 |
"def get_cell_features(img, contour):\n", |
|
|
128 |
" \n", |
|
|
129 |
" # Get contour coordinates from contour\n", |
|
|
130 |
" (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)\n", |
|
|
131 |
" cx, cy = int(cx), int(cy)\n", |
|
|
132 |
" \n", |
|
|
133 |
" # Get a 64 x 64 center crop over each cell \n", |
|
|
134 |
" img_cell = get_cell_image(img, cx, cy)\n", |
|
|
135 |
"\n", |
|
|
136 |
" grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)\n", |
|
|
137 |
" img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') \n", |
|
|
138 |
"\n", |
|
|
139 |
"\n", |
|
|
140 |
" # 1. Generating contour features\n", |
|
|
141 |
" eccentricity = math.sqrt(1-(short_axis/long_axis)**2)\n", |
|
|
142 |
" convex_hull = cv2.convexHull(contour)\n", |
|
|
143 |
" area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)\n", |
|
|
144 |
" solidity = float(area)/hull_area\n", |
|
|
145 |
" arc_length = cv2.arcLength(contour, True)\n", |
|
|
146 |
" roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))\n", |
|
|
147 |
" \n", |
|
|
148 |
" # 2. Generating GLCM features\n", |
|
|
149 |
" out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])\n", |
|
|
150 |
" dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]\n", |
|
|
151 |
" homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]\n", |
|
|
152 |
" energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]\n", |
|
|
153 |
" ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]\n", |
|
|
154 |
" \n", |
|
|
155 |
" # 3. Generating CPC features\n", |
|
|
156 |
" cpc_feats = get_cpc_features(img_cell)\n", |
|
|
157 |
" \n", |
|
|
158 |
"\n", |
|
|
159 |
" # Concatenate + Return all features\n", |
|
|
160 |
" x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],\n", |
|
|
161 |
" [dissimilarity, homogeneity, energy, ASM], \n", |
|
|
162 |
" cpc_feats]\n", |
|
|
163 |
" \n", |
|
|
164 |
" return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy\n", |
|
|
165 |
"\n", |
|
|
166 |
"\n", |
|
|
167 |
"def seg2graph(img, contours):\n", |
|
|
168 |
" G = nx.Graph()\n", |
|
|
169 |
" \n", |
|
|
170 |
" contours = [c for c in contours if c.shape[0] > 5]\n", |
|
|
171 |
"\n", |
|
|
172 |
" for v, contour in enumerate(contours):\n", |
|
|
173 |
"\n", |
|
|
174 |
" features, cx, cy = get_cell_features(img, contour)\n", |
|
|
175 |
" G.add_node(v, centroid = [cx, cy], x = features)\n", |
|
|
176 |
"\n", |
|
|
177 |
" if v < 5: return None\n", |
|
|
178 |
" return G" |
|
|
179 |
] |
|
|
180 |
}, |
|
|
181 |
{ |
|
|
182 |
"cell_type": "code", |
|
|
183 |
"execution_count": 1, |
|
|
184 |
"metadata": {}, |
|
|
185 |
"outputs": [], |
|
|
186 |
"source": [ |
|
|
187 |
"data_dir = \"./example_data/\"\n", |
|
|
188 |
"img_dir = os.path.join(data_dir, 'imgs')\n", |
|
|
189 |
"seg_dir = os.path.join(data_dir,'segs')\n", |
|
|
190 |
"\n", |
|
|
191 |
"roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'\n", |
|
|
192 |
"roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'\n", |
|
|
193 |
"roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'\n", |
|
|
194 |
"\n", |
|
|
195 |
"assert roi1 in os.listdir(seg_dir)\n", |
|
|
196 |
"assert roi2 in os.listdir(seg_dir)\n", |
|
|
197 |
"assert roi3 in os.listdir(seg_dir)" |
|
|
198 |
] |
|
|
199 |
}, |
|
|
200 |
{ |
|
|
201 |
"cell_type": "code", |
|
|
202 |
"execution_count": 26, |
|
|
203 |
"metadata": { |
|
|
204 |
"scrolled": true |
|
|
205 |
}, |
|
|
206 |
"outputs": [ |
|
|
207 |
{ |
|
|
208 |
"name": "stderr", |
|
|
209 |
"output_type": "stream", |
|
|
210 |
"text": [ |
|
|
211 |
"100%|██████████| 3/3 [00:13<00:00, 4.41s/it]\n" |
|
|
212 |
] |
|
|
213 |
} |
|
|
214 |
], |
|
|
215 |
"source": [ |
|
|
216 |
"save_dir = data_dir\n", |
|
|
217 |
"pt_dir = os.path.join(save_dir, 'pts')\n", |
|
|
218 |
"graph_dir = os.path.join(save_dir, 'graphs')\n", |
|
|
219 |
"fail_list = []\n", |
|
|
220 |
"\n", |
|
|
221 |
"from tqdm import tqdm\n", |
|
|
222 |
"\n", |
|
|
223 |
"for img_fname in tqdm([roi1, roi2, roi3]):\n", |
|
|
224 |
" \n", |
|
|
225 |
" #if int(img_fname.split('_')[2]) > 2: continue\n", |
|
|
226 |
" #print(\"Processing...(%d/%d):\\t%s\" % (idx+1, len(os.listdir(seg_dir)), img_fname))\n", |
|
|
227 |
" \n", |
|
|
228 |
" img = np.array(Image.open(os.path.join(img_dir, img_fname)))\n", |
|
|
229 |
" seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))\n", |
|
|
230 |
" ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) \n", |
|
|
231 |
" contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n", |
|
|
232 |
" if len(contours) < 1: continue\n", |
|
|
233 |
" \n", |
|
|
234 |
" G = seg2graph(img, contours)\n", |
|
|
235 |
"\n", |
|
|
236 |
" if G is None: \n", |
|
|
237 |
" fail_list.append(img_fname)\n", |
|
|
238 |
" continue\n", |
|
|
239 |
"\n", |
|
|
240 |
"\n", |
|
|
241 |
" centroids = []\n", |
|
|
242 |
" for u, attrib in G.nodes(data=True):\n", |
|
|
243 |
" centroids.append(attrib['centroid'])\n", |
|
|
244 |
" \n", |
|
|
245 |
" cell_centroids = np.array(centroids).astype(np.float64)\n", |
|
|
246 |
" dataset = cell_centroids\n", |
|
|
247 |
" \n", |
|
|
248 |
" start = None\n", |
|
|
249 |
" \n", |
|
|
250 |
" for idx, attrib in list(G.nodes(data=True)):\n", |
|
|
251 |
" start = idx\n", |
|
|
252 |
" flann = FLANN()\n", |
|
|
253 |
" testset = np.array([attrib['centroid']]).astype(np.float64)\n", |
|
|
254 |
" results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)\n", |
|
|
255 |
" results, dists = results[0], dists[0]\n", |
|
|
256 |
" nns_fin = []\n", |
|
|
257 |
" # assert (results.shape[0] < 6)\n", |
|
|
258 |
" for i in range(1, len(results)):\n", |
|
|
259 |
" G.add_edge(idx, results[i], weight = dists[i])\n", |
|
|
260 |
" nns_fin.append(results[i])\n", |
|
|
261 |
" #attrib['nn'] = list(nns_fin)\n", |
|
|
262 |
"\n", |
|
|
263 |
" G = G.subgraph(max(nx.connected_components(G), key=len))\n", |
|
|
264 |
"\n", |
|
|
265 |
" #for idx, attrib in list(G.nodes(data=True)):\n", |
|
|
266 |
" # cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)\n", |
|
|
267 |
" \n", |
|
|
268 |
" cv2.drawContours(img, contours, -1, (0,255,0), 2)\n", |
|
|
269 |
" \n", |
|
|
270 |
" for n, nbrs in G.adjacency():\n", |
|
|
271 |
" for nbr, eattr in nbrs.items():\n", |
|
|
272 |
" cv2.line(img, tuple(G.nodes[n]['centroid']), tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)\n", |
|
|
273 |
"\n", |
|
|
274 |
" Image.fromarray(img).save(os.path.join(graph_dir, img_fname))\n", |
|
|
275 |
" \n", |
|
|
276 |
" G = from_networkx(G)\n", |
|
|
277 |
" \n", |
|
|
278 |
" edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)\n", |
|
|
279 |
" G.edge_attr = edge_attr_long \n", |
|
|
280 |
" \n", |
|
|
281 |
" edge_index_long = G['edge_index'].type(torch.LongTensor)\n", |
|
|
282 |
" G.edge_index = edge_index_long\n", |
|
|
283 |
" \n", |
|
|
284 |
" x_float = G['x'].type(torch.FloatTensor)\n", |
|
|
285 |
" G.x = x_float\n", |
|
|
286 |
" \n", |
|
|
287 |
" G['weight'] = None\n", |
|
|
288 |
" G['nn'] = None\n", |
|
|
289 |
" torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))" |
|
|
290 |
] |
|
|
291 |
} |
|
|
292 |
], |
|
|
293 |
"metadata": { |
|
|
294 |
"kernelspec": { |
|
|
295 |
"display_name": "Python 3", |
|
|
296 |
"language": "python", |
|
|
297 |
"name": "python3" |
|
|
298 |
}, |
|
|
299 |
"language_info": { |
|
|
300 |
"codemirror_mode": { |
|
|
301 |
"name": "ipython", |
|
|
302 |
"version": 3 |
|
|
303 |
}, |
|
|
304 |
"file_extension": ".py", |
|
|
305 |
"mimetype": "text/x-python", |
|
|
306 |
"name": "python", |
|
|
307 |
"nbconvert_exporter": "python", |
|
|
308 |
"pygments_lexer": "ipython3", |
|
|
309 |
"version": "3.8.3" |
|
|
310 |
} |
|
|
311 |
}, |
|
|
312 |
"nbformat": 4, |
|
|
313 |
"nbformat_minor": 2 |
|
|
314 |
} |