|
a |
|
b/Evalution.ipynb |
|
|
1 |
{ |
|
|
2 |
"nbformat": 4, |
|
|
3 |
"nbformat_minor": 0, |
|
|
4 |
"metadata": { |
|
|
5 |
"colab": { |
|
|
6 |
"provenance": [] |
|
|
7 |
}, |
|
|
8 |
"kernelspec": { |
|
|
9 |
"name": "python3", |
|
|
10 |
"display_name": "Python 3" |
|
|
11 |
}, |
|
|
12 |
"language_info": { |
|
|
13 |
"name": "python" |
|
|
14 |
} |
|
|
15 |
}, |
|
|
16 |
"cells": [ |
|
|
17 |
{ |
|
|
18 |
"cell_type": "code", |
|
|
19 |
"execution_count": null, |
|
|
20 |
"metadata": { |
|
|
21 |
"id": "iXo6Nj7M5GK8" |
|
|
22 |
}, |
|
|
23 |
"outputs": [], |
|
|
24 |
"source": [ |
|
|
25 |
"import torch\n", |
|
|
26 |
"import torchvision\n", |
|
|
27 |
"import os\n", |
|
|
28 |
"import glob\n", |
|
|
29 |
"import time\n", |
|
|
30 |
"import pickle\n", |
|
|
31 |
"import sys\n", |
|
|
32 |
"sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n", |
|
|
33 |
"sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n", |
|
|
34 |
"\n", |
|
|
35 |
"import pandas as pd\n", |
|
|
36 |
"import numpy as np\n", |
|
|
37 |
"import matplotlib.pyplot as plt\n", |
|
|
38 |
"from pathlib import Patha\n", |
|
|
39 |
"from PIL import Image\n", |
|
|
40 |
"from sklearn.model_selection import train_test_split\n", |
|
|
41 |
"\n", |
|
|
42 |
"from data import LungDataset, blend, Pad, Crop, Resize\n", |
|
|
43 |
"from data2 import LungDataset2, blend, Pad, Crop, Resize\n", |
|
|
44 |
"\n", |
|
|
45 |
"from OurModel import CxlNet\n", |
|
|
46 |
"\n", |
|
|
47 |
"from metrics import jaccard, dice,get_accuracy, get_sensitivity, get_specificity" |
|
|
48 |
] |
|
|
49 |
}, |
|
|
50 |
{ |
|
|
51 |
"cell_type": "code", |
|
|
52 |
"source": [ |
|
|
53 |
"in_channels=1\n", |
|
|
54 |
"out_channels=2\n", |
|
|
55 |
"batch_norm=True\n", |
|
|
56 |
"upscale_mode=\"bilinear\"\n", |
|
|
57 |
"image_size=512\n", |
|
|
58 |
"def selectModel():\n", |
|
|
59 |
" return CxlNet(\n", |
|
|
60 |
" in_channels=in_channels,\n", |
|
|
61 |
" out_channels=out_channels,\n", |
|
|
62 |
" batch_norm=batch_norm,\n", |
|
|
63 |
" upscale_mode=upscale_mode,\n", |
|
|
64 |
" image_size=image_size)" |
|
|
65 |
], |
|
|
66 |
"metadata": { |
|
|
67 |
"id": "BGn0Pjmb5gRA" |
|
|
68 |
}, |
|
|
69 |
"execution_count": null, |
|
|
70 |
"outputs": [] |
|
|
71 |
}, |
|
|
72 |
{ |
|
|
73 |
"cell_type": "code", |
|
|
74 |
"source": [ |
|
|
75 |
"dataset_name=\"dataset\"\n", |
|
|
76 |
"dataset_types={\"dataset\":\"png\",\"CT\":\"jpg\"}\n", |
|
|
77 |
"dataset_type=dataset_types[dataset_name]\n", |
|
|
78 |
"print(dataset_type)\n", |
|
|
79 |
"image_size=512\n", |
|
|
80 |
"split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n", |
|
|
81 |
"list_data_file = \"/content/drive/MyDrive/Batoul_Code/list_data.pk\"\n", |
|
|
82 |
"version=\"UNet\"\n", |
|
|
83 |
"approach=\"contour\"\n", |
|
|
84 |
"model = selectModel()\n", |
|
|
85 |
"\n", |
|
|
86 |
"base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n", |
|
|
87 |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
|
|
88 |
"device\n", |
|
|
89 |
"\n", |
|
|
90 |
"\n", |
|
|
91 |
"\n", |
|
|
92 |
"\n", |
|
|
93 |
"data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n", |
|
|
94 |
"origins_folder = data_folder / \"images\"\n", |
|
|
95 |
"masks_folder = data_folder / \"masks\"\n", |
|
|
96 |
"masks_contour_folder = data_folder / \"masks_contour\"\n", |
|
|
97 |
"masks_folder =masks_contour_folder\n", |
|
|
98 |
"models_folder = Path(base_path+\"models\")\n", |
|
|
99 |
"images_folder = Path(base_path+\"images\")\n" |
|
|
100 |
], |
|
|
101 |
"metadata": { |
|
|
102 |
"id": "5Sp6Ga1y50He" |
|
|
103 |
}, |
|
|
104 |
"execution_count": null, |
|
|
105 |
"outputs": [] |
|
|
106 |
}, |
|
|
107 |
{ |
|
|
108 |
"cell_type": "code", |
|
|
109 |
"source": [ |
|
|
110 |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
|
|
111 |
"models_folder = Path(base_path+\"models\")\n", |
|
|
112 |
"model_name = \"unet-6v.pt\"\n", |
|
|
113 |
"model_name=\"ournet_\"+version+\".pt\"\n", |
|
|
114 |
"print(model_name)\n", |
|
|
115 |
"model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", |
|
|
116 |
"model.to(device)\n", |
|
|
117 |
"model.eval()\n", |
|
|
118 |
"\n", |
|
|
119 |
"\n", |
|
|
120 |
"test_loss = 0.0\n", |
|
|
121 |
"test_jaccard = 0.0\n", |
|
|
122 |
"test_dice = 0.0\n", |
|
|
123 |
"test_accuracy=0.0\n", |
|
|
124 |
"test_sensitivity=0.0\n", |
|
|
125 |
"test_specificity=0.0\n", |
|
|
126 |
"batch_size = 4\n", |
|
|
127 |
"\n", |
|
|
128 |
"if os.path.isfile(list_data_file):\n", |
|
|
129 |
" with open(list_data_file, \"rb\") as f:\n", |
|
|
130 |
" list_data = pickle.load(f)\n", |
|
|
131 |
" origins_list=list_data[0]\n", |
|
|
132 |
" masks_list=list_data[1]\n", |
|
|
133 |
"else:\n", |
|
|
134 |
" origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
135 |
" masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
136 |
" with open(list_data_file, \"wb\") as f:\n", |
|
|
137 |
" pickle.dump([origins_list,masks_list], f)\n", |
|
|
138 |
"\n", |
|
|
139 |
"\n", |
|
|
140 |
"#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n", |
|
|
141 |
"#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n", |
|
|
142 |
"\n", |
|
|
143 |
"\n", |
|
|
144 |
"origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n", |
|
|
145 |
"\n", |
|
|
146 |
"\n", |
|
|
147 |
"\n", |
|
|
148 |
"if os.path.isfile(split_file):\n", |
|
|
149 |
" with open(split_file, \"rb\") as f:\n", |
|
|
150 |
" splits = pickle.load(f)\n", |
|
|
151 |
"else:\n", |
|
|
152 |
" splits = {}\n", |
|
|
153 |
" splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n", |
|
|
154 |
" splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n", |
|
|
155 |
" with open(split_file, \"wb\") as f:\n", |
|
|
156 |
" pickle.dump(splits, f)\n", |
|
|
157 |
"\n", |
|
|
158 |
"val_test_transforms = torchvision.transforms.Compose([\n", |
|
|
159 |
" Resize((image_size, image_size)),\n", |
|
|
160 |
"])\n", |
|
|
161 |
"\n", |
|
|
162 |
"if dataset_name!=\"dataset\":\n", |
|
|
163 |
" train_transforms = torchvision.transforms.Compose([\n", |
|
|
164 |
" Pad(200),\n", |
|
|
165 |
" Crop(300),\n", |
|
|
166 |
" val_test_transforms,\n", |
|
|
167 |
" ])\n", |
|
|
168 |
" datasets = {x: LungDataset2(\n", |
|
|
169 |
" splits[x],\n", |
|
|
170 |
" origins_folder,\n", |
|
|
171 |
" masks_folder,\n", |
|
|
172 |
" train_transforms if x == \"train\" else val_test_transforms,\n", |
|
|
173 |
" dataset_type=dataset_type\n", |
|
|
174 |
" ) for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
175 |
"else:\n", |
|
|
176 |
" train_transforms = torchvision.transforms.Compose([\n", |
|
|
177 |
" Pad(200),\n", |
|
|
178 |
" Crop(300),\n", |
|
|
179 |
" val_test_transforms,])\n", |
|
|
180 |
"\n", |
|
|
181 |
" datasets = {x: LungDataset(\n", |
|
|
182 |
" splits[x],\n", |
|
|
183 |
" origins_folder,\n", |
|
|
184 |
" masks_folder,\n", |
|
|
185 |
" train_transforms if x == \"train\" else val_test_transforms,\n", |
|
|
186 |
" dataset_type=dataset_type\n", |
|
|
187 |
" ) for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
188 |
"\n", |
|
|
189 |
"num_samples = 9\n", |
|
|
190 |
"phase = \"test\"\n", |
|
|
191 |
"print(len(datasets[phase]))\n", |
|
|
192 |
"\n", |
|
|
193 |
"dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size) for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
194 |
"\n", |
|
|
195 |
"for origins, masks in dataloaders[\"test\"]:\n", |
|
|
196 |
" num = origins.size(0)\n", |
|
|
197 |
"\n", |
|
|
198 |
" origins = origins.to(device)\n", |
|
|
199 |
" masks = masks.to(device)\n", |
|
|
200 |
"\n", |
|
|
201 |
" with torch.no_grad():\n", |
|
|
202 |
" outs = model(origins)\n", |
|
|
203 |
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", |
|
|
204 |
" test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n", |
|
|
205 |
" outs = torch.argmax(softmax, dim=1)\n", |
|
|
206 |
" outs = outs.float()\n", |
|
|
207 |
" masks = masks.float()\n", |
|
|
208 |
" test_jaccard += jaccard(masks, outs).item() * num\n", |
|
|
209 |
" test_dice += dice(masks, outs).item() * num\n", |
|
|
210 |
" test_accuracy += get_accuracy(masks, outs) * num\n", |
|
|
211 |
" test_sensitivity += get_sensitivity(masks, outs) * num\n", |
|
|
212 |
" test_specificity += get_specificity(masks, outs) * num\n", |
|
|
213 |
" print(\".\", end=\"\")\n", |
|
|
214 |
"\n", |
|
|
215 |
"test_loss = test_loss / len(datasets[\"test\"])\n", |
|
|
216 |
"test_jaccard = test_jaccard / len(datasets[\"test\"])\n", |
|
|
217 |
"test_dice = test_dice / len(datasets[\"test\"])\n", |
|
|
218 |
"test_accuracy = test_accuracy / len(datasets[\"test\"])\n", |
|
|
219 |
"print()\n", |
|
|
220 |
"print(f\"avg test loss: {test_loss}\")\n", |
|
|
221 |
"print(f\"avg test jaccard: {test_jaccard}\")\n", |
|
|
222 |
"print(f\"avg test dice: {test_dice}\")\n", |
|
|
223 |
"print(f\"avg test accuracy: {test_accuracy}\")\n", |
|
|
224 |
"print(f\"avg test sensitivity: {test_sensitivity}\")\n", |
|
|
225 |
"print(f\"avg test specificity: {test_specificity}\")\n", |
|
|
226 |
"\n", |
|
|
227 |
"\n", |
|
|
228 |
"\n", |
|
|
229 |
"subset = torch.utils.data.Subset(\n", |
|
|
230 |
" datasets[phase],\n", |
|
|
231 |
" np.random.randint(0, len(datasets[phase]), num_samples)\n", |
|
|
232 |
")\n", |
|
|
233 |
"random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=2)\n", |
|
|
234 |
"plt.figure(figsize=(20, 25))\n", |
|
|
235 |
"\n", |
|
|
236 |
"for idx, (origin, mask) in enumerate(random_samples_loader):\n", |
|
|
237 |
" plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n", |
|
|
238 |
"\n", |
|
|
239 |
" origin = origin.to(device)\n", |
|
|
240 |
" mask = mask.to(device)\n", |
|
|
241 |
"\n", |
|
|
242 |
" with torch.no_grad():\n", |
|
|
243 |
" out = model(origin)\n", |
|
|
244 |
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n", |
|
|
245 |
" out = torch.argmax(softmax, dim=1)\n", |
|
|
246 |
"\n", |
|
|
247 |
" jaccard_score = jaccard(mask.float(), out.float()).item()\n", |
|
|
248 |
" dice_score = dice(mask.float(), out.float()).item()\n", |
|
|
249 |
"\n", |
|
|
250 |
" origin = origin[0].to(\"cpu\")\n", |
|
|
251 |
" out = out[0].to(\"cpu\")\n", |
|
|
252 |
" mask = mask[0].to(\"cpu\")\n", |
|
|
253 |
" #plt.imshow(np.array(blend(origin, mask, out)))\n", |
|
|
254 |
" plt.imshow(np.array(blend(origin, out, out)))\n", |
|
|
255 |
" plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n", |
|
|
256 |
" print(\".\", end=\"\")\n", |
|
|
257 |
"\n", |
|
|
258 |
"plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n", |
|
|
259 |
"plt.show()\n", |
|
|
260 |
"print()\n", |
|
|
261 |
"print(\"red area - predict\")\n", |
|
|
262 |
"print(\"green area - ground truth\")\n", |
|
|
263 |
"print(\"yellow area - intersection\")\n", |
|
|
264 |
"\n", |
|
|
265 |
"\n", |
|
|
266 |
"model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", |
|
|
267 |
"model.to(device)\n", |
|
|
268 |
"model.eval()\n", |
|
|
269 |
"\n", |
|
|
270 |
"device\n", |
|
|
271 |
"\n", |
|
|
272 |
"#%%\n", |
|
|
273 |
"\n", |
|
|
274 |
"origin_filename = base_path+ f\"input/{dataset_name}/images/ID00015637202177877247924_110.jpg\"\n", |
|
|
275 |
"#origin_filename=base_path + \"external_samples/1.jpg\"\n", |
|
|
276 |
"\n", |
|
|
277 |
"origin = Image.open(origin_filename).convert(\"P\")\n", |
|
|
278 |
"origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n", |
|
|
279 |
"origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n", |
|
|
280 |
"\n", |
|
|
281 |
"with torch.no_grad():\n", |
|
|
282 |
" origin = torch.stack([origin])\n", |
|
|
283 |
" origin = origin.to(device)\n", |
|
|
284 |
" out = model(origin)\n", |
|
|
285 |
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n", |
|
|
286 |
" out = torch.argmax(softmax, dim=1)\n", |
|
|
287 |
"\n", |
|
|
288 |
" origin = origin[0].to(\"cpu\")\n", |
|
|
289 |
" out = out[0].to(\"cpu\")\n", |
|
|
290 |
"\n", |
|
|
291 |
"\n", |
|
|
292 |
"plt.figure(figsize=(20, 10))\n", |
|
|
293 |
"\n", |
|
|
294 |
"pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", |
|
|
295 |
"\n", |
|
|
296 |
"plt.subplot(1, 2, 1)\n", |
|
|
297 |
"plt.title(\"origin image\")\n", |
|
|
298 |
"plt.imshow(np.array(pil_origin))\n", |
|
|
299 |
"plt.show()\n", |
|
|
300 |
"plt.subplot(1, 2, 2)\n", |
|
|
301 |
"plt.title(\"blended origin + predict\")\n", |
|
|
302 |
"plt.imshow(np.array(blend(origin, out)))\n", |
|
|
303 |
"plt.show()\n" |
|
|
304 |
], |
|
|
305 |
"metadata": { |
|
|
306 |
"id": "jFpE8AwG5-Nm" |
|
|
307 |
}, |
|
|
308 |
"execution_count": null, |
|
|
309 |
"outputs": [] |
|
|
310 |
}, |
|
|
311 |
{ |
|
|
312 |
"cell_type": "code", |
|
|
313 |
"source": [ |
|
|
314 |
"\n", |
|
|
315 |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
|
|
316 |
"models_folder = Path(base_path+\"models\")\n", |
|
|
317 |
"\n", |
|
|
318 |
"\n", |
|
|
319 |
"\n", |
|
|
320 |
"test_loss = 0.0\n", |
|
|
321 |
"test_jaccard = 0.0\n", |
|
|
322 |
"test_dice = 0.0\n", |
|
|
323 |
"\n", |
|
|
324 |
"batch_size = 4\n", |
|
|
325 |
"\n", |
|
|
326 |
"if os.path.isfile(list_data_file):\n", |
|
|
327 |
" with open(list_data_file, \"rb\") as f:\n", |
|
|
328 |
" list_data = pickle.load(f)\n", |
|
|
329 |
" origins_list=list_data[0]\n", |
|
|
330 |
" masks_list=list_data[1]\n", |
|
|
331 |
"else:\n", |
|
|
332 |
" origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
333 |
" masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
334 |
" with open(list_data_file, \"wb\") as f:\n", |
|
|
335 |
" pickle.dump([origins_list,masks_list], f)\n", |
|
|
336 |
"\n", |
|
|
337 |
"\n", |
|
|
338 |
"#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n", |
|
|
339 |
"#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n", |
|
|
340 |
"\n", |
|
|
341 |
"\n", |
|
|
342 |
"origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n", |
|
|
343 |
"\n", |
|
|
344 |
"\n", |
|
|
345 |
"\n", |
|
|
346 |
"if os.path.isfile(split_file):\n", |
|
|
347 |
" with open(split_file, \"rb\") as f:\n", |
|
|
348 |
" splits = pickle.load(f)\n", |
|
|
349 |
"else:\n", |
|
|
350 |
" splits = {}\n", |
|
|
351 |
" splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n", |
|
|
352 |
" splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n", |
|
|
353 |
" with open(split_file, \"wb\") as f:\n", |
|
|
354 |
" pickle.dump(splits, f)\n", |
|
|
355 |
"\n", |
|
|
356 |
"val_test_transforms = torchvision.transforms.Compose([\n", |
|
|
357 |
" Resize((image_size, image_size)),\n", |
|
|
358 |
"])\n", |
|
|
359 |
"\n", |
|
|
360 |
"if dataset_name!=\"dataset\":\n", |
|
|
361 |
" train_transforms = torchvision.transforms.Compose([\n", |
|
|
362 |
" #Pad(200),\n", |
|
|
363 |
" #Crop(300),\n", |
|
|
364 |
" #val_test_transforms,\n", |
|
|
365 |
" ])\n", |
|
|
366 |
" datasets = {x: LungDataset2(\n", |
|
|
367 |
" splits[x],\n", |
|
|
368 |
" origins_folder,\n", |
|
|
369 |
" masks_folder,\n", |
|
|
370 |
" train_transforms if x == \"train\" else val_test_transforms,\n", |
|
|
371 |
" dataset_type=dataset_type\n", |
|
|
372 |
" ) for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
373 |
"else:\n", |
|
|
374 |
" train_transforms = torchvision.transforms.Compose([\n", |
|
|
375 |
" Pad(200),\n", |
|
|
376 |
" Crop(300),\n", |
|
|
377 |
" val_test_transforms,])\n", |
|
|
378 |
"\n", |
|
|
379 |
" datasets = {x: LungDataset(\n", |
|
|
380 |
" splits[x],\n", |
|
|
381 |
" origins_folder,\n", |
|
|
382 |
" masks_folder,\n", |
|
|
383 |
" train_transforms if x == \"train\" else val_test_transforms,\n", |
|
|
384 |
" dataset_type=dataset_type\n", |
|
|
385 |
" ) for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
386 |
"\n", |
|
|
387 |
"\n", |
|
|
388 |
"def mask_to_class_rgb1(mask):\n", |
|
|
389 |
" #print('----mask->rgb----')\n", |
|
|
390 |
" mask = torch.from_numpy(np.array(mask))\n", |
|
|
391 |
" mask = torch.squeeze(mask) # remove 1\n", |
|
|
392 |
"\n", |
|
|
393 |
" class_mask = mask\n", |
|
|
394 |
"\n", |
|
|
395 |
" class_mask = class_mask.permute(2, 0, 1).contiguous()\n", |
|
|
396 |
" h, w = class_mask.shape[1], class_mask.shape[2]\n", |
|
|
397 |
" mask_out = torch.zeros((h, w))\n", |
|
|
398 |
"\n", |
|
|
399 |
" threshold=200\n", |
|
|
400 |
" for i in range(0,3):\n", |
|
|
401 |
" class_mask[i][class_mask[i] < threshold] = 0\n", |
|
|
402 |
"\n", |
|
|
403 |
" for i in range(2, 3):\n", |
|
|
404 |
" mask_out[class_mask[i] >= threshold]=1\n", |
|
|
405 |
" return mask_out\n", |
|
|
406 |
"\n", |
|
|
407 |
"\n", |
|
|
408 |
"def mask_to_class_rgb(mask):\n", |
|
|
409 |
" #print('----mask->rgb----')\n", |
|
|
410 |
" mask = torch.from_numpy(np.array(mask))\n", |
|
|
411 |
" mask = torch.squeeze(mask) # remove 1\n", |
|
|
412 |
"\n", |
|
|
413 |
" class_mask = mask\n", |
|
|
414 |
"\n", |
|
|
415 |
" class_mask = class_mask.permute(2, 0, 1).contiguous()\n", |
|
|
416 |
" h, w = class_mask.shape[1], class_mask.shape[2]\n", |
|
|
417 |
" mask_out = torch.zeros((h, w))\n", |
|
|
418 |
"\n", |
|
|
419 |
" threshold=200\n", |
|
|
420 |
" for i in range(0,3):\n", |
|
|
421 |
" class_mask[i][class_mask[i] < threshold] = 0\n", |
|
|
422 |
"\n", |
|
|
423 |
" for i in range(2, 3):\n", |
|
|
424 |
" mask_out[class_mask[i] >= threshold]=1\n", |
|
|
425 |
" return mask_out\n", |
|
|
426 |
"\n", |
|
|
427 |
"def getitem2(path):\n", |
|
|
428 |
" mask = Image.open(path)\n", |
|
|
429 |
" mask = mask_to_class_rgb(mask)\n", |
|
|
430 |
" mask=mask.long()\n", |
|
|
431 |
" #mask = (torch.tensor(mask) > 128).long()\n", |
|
|
432 |
" return mask\n", |
|
|
433 |
"\n", |
|
|
434 |
"def getitem1(path):\n", |
|
|
435 |
" mask = Image.open(path)\n", |
|
|
436 |
" mask = mask.resize((image_size,image_size))\n", |
|
|
437 |
" mask = np.array(mask)\n", |
|
|
438 |
" mask = (torch.tensor(mask) > 128).long()\n", |
|
|
439 |
" return mask\n", |
|
|
440 |
"\n", |
|
|
441 |
"\n", |
|
|
442 |
"idx=1\n", |
|
|
443 |
"phase = \"test\"\n", |
|
|
444 |
"fig = plt.figure(figsize=(20, 10))\n", |
|
|
445 |
"input=0\n", |
|
|
446 |
"if dataset_name!=\"dataset\":\n", |
|
|
447 |
" samples=[\"ID00015637202177877247924_110.jpg\",\n", |
|
|
448 |
" \"ID00009637202177434476278_173.jpg\",\n", |
|
|
449 |
" \"ID00009637202177434476278_316.jpg\",\n", |
|
|
450 |
" \"ID00009637202177434476278_204.jpg\",]\n", |
|
|
451 |
" masks = [mask_name.replace(\"_\", \"_mask_\").replace(\"images\", \"masks\") for mask_name in samples]\n", |
|
|
452 |
"\n", |
|
|
453 |
"else:\n", |
|
|
454 |
" samples=[\"CHNCXR_0060_0.png\",\n", |
|
|
455 |
" \"CHNCXR_0074_0.png\",\n", |
|
|
456 |
" \"CHNCXR_0129_0.png\",\n", |
|
|
457 |
" \"CHNCXR_0167_0.png\",]\n", |
|
|
458 |
" masks = [mask_name.replace(\"_0.png\", \"_0_mask.png\").replace(\"images\", \"masks\") for mask_name in samples]\n", |
|
|
459 |
"\n", |
|
|
460 |
"\n", |
|
|
461 |
"samples=[base_path + f\"input/{dataset_name}/images/\"+ sample_name for sample_name in samples]\n", |
|
|
462 |
"masks=[base_path + f\"input/{dataset_name}/masks/\"+ mask_name for mask_name in masks]\n", |
|
|
463 |
"models=[\"ResNetDUCHDC\",\"OueNetNew3\",\"NestedUNet\",\"ResNetDUC\",\"FCN_GCN\",\"SegNet2\",\"UNet\"]\n", |
|
|
464 |
"\n", |
|
|
465 |
"for input in range(0,len(samples)) :\n", |
|
|
466 |
" for m in range(0,len(models)):\n", |
|
|
467 |
" origin_filename = samples[input]\n", |
|
|
468 |
" origin = Image.open(origin_filename).convert(\"P\")\n", |
|
|
469 |
" origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n", |
|
|
470 |
" origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n", |
|
|
471 |
" if dataset_name!=\"dataset\":\n", |
|
|
472 |
" mask= getitem2(masks[input])\n", |
|
|
473 |
" else:\n", |
|
|
474 |
" mask= getitem1(masks[input])\n", |
|
|
475 |
" version=models[m]\n", |
|
|
476 |
" if dataset_name!=\"dataset\":\n", |
|
|
477 |
" version=version+\"_\"+dataset_name\n", |
|
|
478 |
" model = selectModel(models[m])\n", |
|
|
479 |
" model_name=\"ournet_\"+version+\".pt\"\n", |
|
|
480 |
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", |
|
|
481 |
" model.to(device)\n", |
|
|
482 |
" with torch.no_grad():\n", |
|
|
483 |
" origin = torch.stack([origin])\n", |
|
|
484 |
" origin = origin.to(device)\n", |
|
|
485 |
" out = model(origin)\n", |
|
|
486 |
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n", |
|
|
487 |
" out = torch.argmax(softmax, dim=1)\n", |
|
|
488 |
"\n", |
|
|
489 |
" origin = origin[0].to(\"cpu\")\n", |
|
|
490 |
" out = out[0].to(\"cpu\")\n", |
|
|
491 |
"\n", |
|
|
492 |
" pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", |
|
|
493 |
" plt.subplots_adjust(hspace=0)\n", |
|
|
494 |
" if m==0:\n", |
|
|
495 |
" ax=fig.add_subplot(len(samples), len(models)+2,idx)\n", |
|
|
496 |
" ax.set_axis_off()\n", |
|
|
497 |
" #plt.title(\"origin image\")\n", |
|
|
498 |
" plt.imshow(np.array(pil_origin))\n", |
|
|
499 |
" idx=idx+1\n", |
|
|
500 |
" ax=fig.add_subplot(len(samples), len(models)+2,idx)\n", |
|
|
501 |
" ax.set_axis_off()\n", |
|
|
502 |
" plt.imshow(np.array(blend(origin, mask,amount=0.4)))\n", |
|
|
503 |
" idx=idx+1\n", |
|
|
504 |
" ax=fig.add_subplot(len(samples), len(models)+2,idx)\n", |
|
|
505 |
" ax.set_axis_off()\n", |
|
|
506 |
" #plt.title(\"blended origin + predict\")\n", |
|
|
507 |
" plt.imshow(np.array(blend(origin, out,amount=0.5)))\n", |
|
|
508 |
" #plt.savefig(images_folder / f\"results/{version} {input}\", bbox_inches='tight')\n", |
|
|
509 |
" idx=idx+1\n", |
|
|
510 |
"\n", |
|
|
511 |
"plt.show()\n" |
|
|
512 |
], |
|
|
513 |
"metadata": { |
|
|
514 |
"id": "ksO_-uVR6V4F" |
|
|
515 |
}, |
|
|
516 |
"execution_count": null, |
|
|
517 |
"outputs": [] |
|
|
518 |
} |
|
|
519 |
] |
|
|
520 |
} |