|
a |
|
b/Training.ipynb |
|
|
1 |
{ |
|
|
2 |
"nbformat": 4, |
|
|
3 |
"nbformat_minor": 0, |
|
|
4 |
"metadata": { |
|
|
5 |
"colab": { |
|
|
6 |
"provenance": [] |
|
|
7 |
}, |
|
|
8 |
"kernelspec": { |
|
|
9 |
"name": "python3", |
|
|
10 |
"display_name": "Python 3" |
|
|
11 |
}, |
|
|
12 |
"language_info": { |
|
|
13 |
"name": "python" |
|
|
14 |
} |
|
|
15 |
}, |
|
|
16 |
"cells": [ |
|
|
17 |
{ |
|
|
18 |
"cell_type": "code", |
|
|
19 |
"execution_count": null, |
|
|
20 |
"metadata": { |
|
|
21 |
"id": "fqPJUT863Ps2" |
|
|
22 |
}, |
|
|
23 |
"outputs": [], |
|
|
24 |
"source": [ |
|
|
25 |
"from google.colab import drive\n", |
|
|
26 |
"drive.mount('/content/drive')\n", |
|
|
27 |
"#!unzip '/content/drive/MyDrive/Batoul_Code/input/CT/images.zip' -d '/content/drive/MyDrive/Batoul_Code/input/CT'" |
|
|
28 |
] |
|
|
29 |
}, |
|
|
30 |
{ |
|
|
31 |
"cell_type": "code", |
|
|
32 |
"source": [ |
|
|
33 |
"import torch\n", |
|
|
34 |
"import torchvision\n", |
|
|
35 |
"import os\n", |
|
|
36 |
"import glob\n", |
|
|
37 |
"import time\n", |
|
|
38 |
"import pickle\n", |
|
|
39 |
"import sys\n", |
|
|
40 |
"sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n", |
|
|
41 |
"sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n", |
|
|
42 |
"import pandas as pd\n", |
|
|
43 |
"import numpy as np\n", |
|
|
44 |
"import matplotlib.pyplot as plt\n", |
|
|
45 |
"\n", |
|
|
46 |
"from pathlib import Path\n", |
|
|
47 |
"from PIL import Image\n", |
|
|
48 |
"from sklearn.model_selection import train_test_split\n", |
|
|
49 |
"\n", |
|
|
50 |
"\n", |
|
|
51 |
"from data import LungDataset, blend, Pad, Crop, Resize\n", |
|
|
52 |
"from OurModel import CxlNet\n", |
|
|
53 |
"from metrics import jaccard, dice" |
|
|
54 |
], |
|
|
55 |
"metadata": { |
|
|
56 |
"id": "N-IZfgid3Xwc" |
|
|
57 |
}, |
|
|
58 |
"execution_count": null, |
|
|
59 |
"outputs": [] |
|
|
60 |
}, |
|
|
61 |
{ |
|
|
62 |
"cell_type": "code", |
|
|
63 |
"source": [ |
|
|
64 |
"in_channels=1\n", |
|
|
65 |
"out_channels=2\n", |
|
|
66 |
"batch_norm=True\n", |
|
|
67 |
"upscale_mode=\"bilinear\"\n", |
|
|
68 |
"image_size=512\n", |
|
|
69 |
"def selectModel():\n", |
|
|
70 |
" return CxlNet(\n", |
|
|
71 |
" in_channels=in_channels,\n", |
|
|
72 |
" out_channels=out_channels,\n", |
|
|
73 |
" batch_norm=batch_norm,\n", |
|
|
74 |
" upscale_mode=upscale_mode,\n", |
|
|
75 |
" image_size=image_size)" |
|
|
76 |
], |
|
|
77 |
"metadata": { |
|
|
78 |
"id": "J349vBjr31Ir" |
|
|
79 |
}, |
|
|
80 |
"execution_count": null, |
|
|
81 |
"outputs": [] |
|
|
82 |
}, |
|
|
83 |
{ |
|
|
84 |
"cell_type": "code", |
|
|
85 |
"source": [ |
|
|
86 |
"dataset_name=\"dataset\"\n", |
|
|
87 |
"dataset_type=\"png\"\n", |
|
|
88 |
"split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n", |
|
|
89 |
"version=\"CxlNet\"\n", |
|
|
90 |
"approach=\"contour\"\n", |
|
|
91 |
"model = selectModel()\n", |
|
|
92 |
"\n", |
|
|
93 |
"base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n", |
|
|
94 |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
|
|
95 |
"device\n", |
|
|
96 |
"\n", |
|
|
97 |
"data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n", |
|
|
98 |
"origins_folder = data_folder / \"images\"\n", |
|
|
99 |
"masks_folder = data_folder / \"masks\"\n", |
|
|
100 |
"masks_contour_folder = data_folder / \"masks_contour\"\n", |
|
|
101 |
"models_folder = Path(base_path+\"models\")\n", |
|
|
102 |
"images_folder = Path(base_path+\"images\")\n", |
|
|
103 |
"\n", |
|
|
104 |
"\n" |
|
|
105 |
], |
|
|
106 |
"metadata": { |
|
|
107 |
"id": "MhOSosnw4Q96" |
|
|
108 |
}, |
|
|
109 |
"execution_count": null, |
|
|
110 |
"outputs": [] |
|
|
111 |
}, |
|
|
112 |
{ |
|
|
113 |
"cell_type": "code", |
|
|
114 |
"source": [ |
|
|
115 |
"#@title\n", |
|
|
116 |
"batch_size = 4\n", |
|
|
117 |
"torch.cuda.empty_cache()\n", |
|
|
118 |
"origins_list = [f.stem for f in origins_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
119 |
"#masks_list = [f.stem for f in masks_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
120 |
"masks_list = [f.stem for f in masks_contour_folder.glob(f\"*.{dataset_type}\")]\n", |
|
|
121 |
"\n", |
|
|
122 |
"\n", |
|
|
123 |
"\n", |
|
|
124 |
"\n", |
|
|
125 |
"origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n", |
|
|
126 |
"\n", |
|
|
127 |
"\n", |
|
|
128 |
"if os.path.isfile(split_file):\n", |
|
|
129 |
" with open(split_file, \"rb\") as f:\n", |
|
|
130 |
" splits = pickle.load(f)\n", |
|
|
131 |
"else:\n", |
|
|
132 |
" splits = {}\n", |
|
|
133 |
" splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n", |
|
|
134 |
" splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n", |
|
|
135 |
" with open(split_file, \"wb\") as f:\n", |
|
|
136 |
" pickle.dump(splits, f)\n", |
|
|
137 |
"\n", |
|
|
138 |
"val_test_transforms = torchvision.transforms.Compose([\n", |
|
|
139 |
" Resize((image_size, image_size)),\n", |
|
|
140 |
"])\n", |
|
|
141 |
"\n", |
|
|
142 |
"train_transforms = torchvision.transforms.Compose([\n", |
|
|
143 |
" Pad(200),\n", |
|
|
144 |
" Crop(300),\n", |
|
|
145 |
" val_test_transforms,\n", |
|
|
146 |
"])\n", |
|
|
147 |
"\n", |
|
|
148 |
"datasets = {x: LungDataset(\n", |
|
|
149 |
" splits[x],\n", |
|
|
150 |
" origins_folder,\n", |
|
|
151 |
" #masks_folder,\n", |
|
|
152 |
" masks_contour_folder,\n", |
|
|
153 |
" train_transforms if x == \"train\" else val_test_transforms,\n", |
|
|
154 |
" dataset_type=dataset_type\n", |
|
|
155 |
") for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
156 |
"\n", |
|
|
157 |
"dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size)\n", |
|
|
158 |
" for x in [\"train\", \"test\", \"val\"]}\n", |
|
|
159 |
"\n", |
|
|
160 |
"print(len(dataloaders['train']))\n", |
|
|
161 |
"\n", |
|
|
162 |
"idx = 0\n", |
|
|
163 |
"phase = \"train\"\n", |
|
|
164 |
"\n", |
|
|
165 |
"plt.figure(figsize=(20, 20))\n", |
|
|
166 |
"origin, mask = datasets[phase][idx]\n", |
|
|
167 |
"\n", |
|
|
168 |
"pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", |
|
|
169 |
"print(origin.size())\n", |
|
|
170 |
"print(mask.size())\n", |
|
|
171 |
"pil_origin.save(\"1.png\")\n", |
|
|
172 |
"\n", |
|
|
173 |
"\n", |
|
|
174 |
"print(mask.size())\n", |
|
|
175 |
"pil_mask = torchvision.transforms.functional.to_pil_image(mask.float())\n", |
|
|
176 |
"pil_mask.save(\"2.png\")\n", |
|
|
177 |
"plt.subplot(1, 3, 1)\n", |
|
|
178 |
"plt.title(\"origin image\")\n", |
|
|
179 |
"plt.imshow(np.array(pil_origin))\n", |
|
|
180 |
"\n", |
|
|
181 |
"plt.subplot(1, 3, 2)\n", |
|
|
182 |
"plt.title(\"manually labeled mask\")\n", |
|
|
183 |
"plt.imshow(np.array(pil_mask))\n", |
|
|
184 |
"\n", |
|
|
185 |
"plt.subplot(1, 3, 3)\n", |
|
|
186 |
"plt.title(\"blended origin + mask\")\n", |
|
|
187 |
"plt.imshow(np.array(blend(origin, mask)));\n", |
|
|
188 |
"\n", |
|
|
189 |
"plt.savefig(images_folder / \"data-example.png\", bbox_inches='tight')\n", |
|
|
190 |
"plt.show()\n", |
|
|
191 |
"train=True\n", |
|
|
192 |
"model_name = \"ournet_\"+version+\".pt\"\n", |
|
|
193 |
"if train==True:\n", |
|
|
194 |
"\n", |
|
|
195 |
" if os.path.isfile(models_folder / model_name):\n", |
|
|
196 |
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", |
|
|
197 |
" print(\"load_state_dict\")\n", |
|
|
198 |
"\n", |
|
|
199 |
" model = model.to(device)\n", |
|
|
200 |
" # optimizer = torch.optim.SGD(unet.parameters(), lr=0.0005, momentum=0.9)\n", |
|
|
201 |
" optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n", |
|
|
202 |
"\n", |
|
|
203 |
" train_log_filename = base_path + \"train-log-\"+version+\".txt\"\n", |
|
|
204 |
" epochs = 50\n", |
|
|
205 |
" best_val_loss = np.inf\n", |
|
|
206 |
"\n", |
|
|
207 |
"\n", |
|
|
208 |
" hist = []\n", |
|
|
209 |
"\n", |
|
|
210 |
" for e in range(epochs):\n", |
|
|
211 |
" start_t = time.time()\n", |
|
|
212 |
"\n", |
|
|
213 |
" print(\"Epoch \"+str(e))\n", |
|
|
214 |
" model.train()\n", |
|
|
215 |
"\n", |
|
|
216 |
" train_loss = 0.0\n", |
|
|
217 |
"\n", |
|
|
218 |
" for origins, masks in dataloaders[\"train\"]:\n", |
|
|
219 |
" num = origins.size(0)\n", |
|
|
220 |
"\n", |
|
|
221 |
" origins = origins.to(device)\n", |
|
|
222 |
" #print(masks.size())\n", |
|
|
223 |
" #if dataset_name!=\"dataset\":\n", |
|
|
224 |
" #masks = masks.permute((0,3,1, 2))\n", |
|
|
225 |
" #masks=masks[:,0,:,:]\n", |
|
|
226 |
" #print(masks.size())\n", |
|
|
227 |
"\n", |
|
|
228 |
" masks = masks.to(device)\n", |
|
|
229 |
" optimizer.zero_grad()\n", |
|
|
230 |
" outs = model(origins)\n", |
|
|
231 |
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", |
|
|
232 |
" loss = torch.nn.functional.nll_loss(softmax, masks)\n", |
|
|
233 |
" loss.backward()\n", |
|
|
234 |
" optimizer.step()\n", |
|
|
235 |
"\n", |
|
|
236 |
" train_loss += loss.item() * num\n", |
|
|
237 |
" print(\".\", end=\"\")\n", |
|
|
238 |
"\n", |
|
|
239 |
" train_loss = train_loss / len(datasets['train'])\n", |
|
|
240 |
" print()\n", |
|
|
241 |
"\n", |
|
|
242 |
" print(\"validation phase\")\n", |
|
|
243 |
" model.eval()\n", |
|
|
244 |
" val_loss = 0.0\n", |
|
|
245 |
" val_jaccard = 0.0\n", |
|
|
246 |
" val_dice = 0.0\n", |
|
|
247 |
"\n", |
|
|
248 |
" for origins, masks in dataloaders[\"val\"]:\n", |
|
|
249 |
" num = origins.size(0)\n", |
|
|
250 |
" origins = origins.to(device)\n", |
|
|
251 |
" masks = masks.to(device)\n", |
|
|
252 |
"\n", |
|
|
253 |
" with torch.no_grad():\n", |
|
|
254 |
" outs = model(origins)\n", |
|
|
255 |
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", |
|
|
256 |
" val_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n", |
|
|
257 |
"\n", |
|
|
258 |
" outs = torch.argmax(softmax, dim=1)\n", |
|
|
259 |
" outs = outs.float()\n", |
|
|
260 |
" masks = masks.float()\n", |
|
|
261 |
" val_jaccard += jaccard(masks, outs.float()).item() * num\n", |
|
|
262 |
" val_dice += dice(masks, outs).item() * num\n", |
|
|
263 |
"\n", |
|
|
264 |
" print(\".\", end=\"\")\n", |
|
|
265 |
" val_loss = val_loss / len(datasets[\"val\"])\n", |
|
|
266 |
" val_jaccard = val_jaccard / len(datasets[\"val\"])\n", |
|
|
267 |
" val_dice = val_dice / len(datasets[\"val\"])\n", |
|
|
268 |
" print()\n", |
|
|
269 |
"\n", |
|
|
270 |
" end_t = time.time()\n", |
|
|
271 |
" spended_t = end_t - start_t\n", |
|
|
272 |
"\n", |
|
|
273 |
" with open(train_log_filename, \"a\") as train_log_file:\n", |
|
|
274 |
" report = f\"epoch: {e + 1}/{epochs}, time: {spended_t}, train loss: {train_loss}, \\n\" \\\n", |
|
|
275 |
" + f\"val loss: {val_loss}, val jaccard: {val_jaccard}, val dice: {val_dice}\"\n", |
|
|
276 |
"\n", |
|
|
277 |
" hist.append({\n", |
|
|
278 |
" \"time\": spended_t,\n", |
|
|
279 |
" \"train_loss\": train_loss,\n", |
|
|
280 |
" \"val_loss\": val_loss,\n", |
|
|
281 |
" \"val_jaccard\": val_jaccard,\n", |
|
|
282 |
" \"val_dice\": val_dice,\n", |
|
|
283 |
" })\n", |
|
|
284 |
"\n", |
|
|
285 |
" print(report)\n", |
|
|
286 |
" train_log_file.write(report + \"\\n\")\n", |
|
|
287 |
"\n", |
|
|
288 |
" if val_loss < best_val_loss:\n", |
|
|
289 |
" best_val_loss = val_loss\n", |
|
|
290 |
" torch.save(model.state_dict(), models_folder / model_name)\n", |
|
|
291 |
" print(\"model saved\")\n", |
|
|
292 |
" train_log_file.write(\"model saved\\n\")\n", |
|
|
293 |
" print()\n", |
|
|
294 |
"\n", |
|
|
295 |
" #if val_jaccard >=0.9179:\n", |
|
|
296 |
" #break\n", |
|
|
297 |
" plt.figure(figsize=(15, 7))\n", |
|
|
298 |
" train_loss_hist = [h[\"train_loss\"] for h in hist]\n", |
|
|
299 |
" plt.plot(range(len(hist)), train_loss_hist, \"b\", label=\"train loss\")\n", |
|
|
300 |
"\n", |
|
|
301 |
" val_loss_hist = [h[\"val_loss\"] for h in hist]\n", |
|
|
302 |
" plt.plot(range(len(hist)), val_loss_hist, \"r\", label=\"val loss\")\n", |
|
|
303 |
"\n", |
|
|
304 |
" val_dice_hist = [h[\"val_dice\"] for h in hist]\n", |
|
|
305 |
" plt.plot(range(len(hist)), val_dice_hist, \"g\", label=\"val dice\")\n", |
|
|
306 |
"\n", |
|
|
307 |
" val_jaccard_hist = [h[\"val_jaccard\"] for h in hist]\n", |
|
|
308 |
" plt.plot(range(len(hist)), val_jaccard_hist, \"y\", label=\"val jaccard\")\n", |
|
|
309 |
"\n", |
|
|
310 |
" plt.legend()\n", |
|
|
311 |
" plt.xlabel(\"epoch\")\n", |
|
|
312 |
" plt.savefig(images_folder / model_name.replace(\".pt\", \"-train-hist.png\"))\n", |
|
|
313 |
"\n", |
|
|
314 |
" time_hist = [h[\"time\"] for h in hist]\n", |
|
|
315 |
" overall_time = sum(time_hist) // 60\n", |
|
|
316 |
" mean_epoch_time = sum(time_hist) / len(hist)\n", |
|
|
317 |
" print(f\"epochs: {len(hist)}, overall time: {overall_time}m, mean epoch time: {mean_epoch_time}s\")\n", |
|
|
318 |
"\n", |
|
|
319 |
" torch.cuda.empty_cache()\n", |
|
|
320 |
"else:\n", |
|
|
321 |
"\n", |
|
|
322 |
" model_name = \"ournet_\"+version+\".pt\"\n", |
|
|
323 |
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", |
|
|
324 |
" model.to(device)\n", |
|
|
325 |
" model.eval()\n", |
|
|
326 |
"\n", |
|
|
327 |
" test_loss = 0.0\n", |
|
|
328 |
" test_jaccard = 0.0\n", |
|
|
329 |
" test_dice = 0.0\n", |
|
|
330 |
"\n", |
|
|
331 |
" for origins, masks in dataloaders[\"test\"]:\n", |
|
|
332 |
" num = origins.size(0)\n", |
|
|
333 |
" origins = origins.to(device)\n", |
|
|
334 |
" masks = masks.to(device)\n", |
|
|
335 |
"\n", |
|
|
336 |
" with torch.no_grad():\n", |
|
|
337 |
" outs = model(origins)\n", |
|
|
338 |
" softmax = torch.nn.functional.log_softmax(outs, dim=1)\n", |
|
|
339 |
" test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n", |
|
|
340 |
"\n", |
|
|
341 |
" outs = torch.argmax(softmax, dim=1)\n", |
|
|
342 |
" outs = outs.float()\n", |
|
|
343 |
" masks = masks.float()\n", |
|
|
344 |
" test_jaccard += jaccard(masks, outs).item() * num\n", |
|
|
345 |
" test_dice += dice(masks, outs).item() * num\n", |
|
|
346 |
" print(\".\", end=\"\")\n", |
|
|
347 |
"\n", |
|
|
348 |
" test_loss = test_loss / len(datasets[\"test\"])\n", |
|
|
349 |
" test_jaccard = test_jaccard / len(datasets[\"test\"])\n", |
|
|
350 |
" test_dice = test_dice / len(datasets[\"test\"])\n", |
|
|
351 |
"\n", |
|
|
352 |
" print()\n", |
|
|
353 |
" print(f\"avg test loss: {test_loss}\")\n", |
|
|
354 |
" print(f\"avg test jaccard: {test_jaccard}\")\n", |
|
|
355 |
" print(f\"avg test dice: {test_dice}\")\n", |
|
|
356 |
"\n", |
|
|
357 |
" num_samples = 9\n", |
|
|
358 |
" phase = \"test\"\n", |
|
|
359 |
"\n", |
|
|
360 |
" subset = torch.utils.data.Subset(\n", |
|
|
361 |
" datasets[phase],\n", |
|
|
362 |
" np.random.randint(0, len(datasets[phase]), num_samples)\n", |
|
|
363 |
" )\n", |
|
|
364 |
" random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=1)\n", |
|
|
365 |
" plt.figure(figsize=(20, 25))\n", |
|
|
366 |
"\n", |
|
|
367 |
" for idx, (origin, mask) in enumerate(random_samples_loader):\n", |
|
|
368 |
" plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n", |
|
|
369 |
"\n", |
|
|
370 |
" origin = origin.to(device)\n", |
|
|
371 |
" mask = mask.to(device)\n", |
|
|
372 |
"\n", |
|
|
373 |
" with torch.no_grad():\n", |
|
|
374 |
" out = model(origin)\n", |
|
|
375 |
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n", |
|
|
376 |
" out = torch.argmax(softmax, dim=1)\n", |
|
|
377 |
"\n", |
|
|
378 |
" jaccard_score = jaccard(mask.float(), out.float()).item()\n", |
|
|
379 |
" dice_score = dice(mask.float(), out.float()).item()\n", |
|
|
380 |
"\n", |
|
|
381 |
" origin = origin[0].to(\"cpu\")\n", |
|
|
382 |
" out = out[0].to(\"cpu\")\n", |
|
|
383 |
" mask = mask[0].to(\"cpu\")\n", |
|
|
384 |
"\n", |
|
|
385 |
" plt.imshow(np.array(blend(origin, mask, out)))\n", |
|
|
386 |
" plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n", |
|
|
387 |
" print(\".\", end=\"\")\n", |
|
|
388 |
" plt.show()\n", |
|
|
389 |
" plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n", |
|
|
390 |
" print()\n", |
|
|
391 |
" print(\"red area - predict\")\n", |
|
|
392 |
" print(\"green area - ground truth\")\n", |
|
|
393 |
" print(\"yellow area - intersection\")\n", |
|
|
394 |
"\n", |
|
|
395 |
"\n", |
|
|
396 |
"\n", |
|
|
397 |
" model_name = \"ournet_\"+version+\".pt\"\n", |
|
|
398 |
" model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n", |
|
|
399 |
" model.to(device)\n", |
|
|
400 |
" model.eval()\n", |
|
|
401 |
"\n", |
|
|
402 |
" device\n", |
|
|
403 |
"\n", |
|
|
404 |
" # %%\n", |
|
|
405 |
"\n", |
|
|
406 |
" origin_filename = \"input/dataset/images/CHNCXR_0042_0.png\"\n", |
|
|
407 |
"\n", |
|
|
408 |
" origin = Image.open(origin_filename).convert(\"P\")\n", |
|
|
409 |
" origin = torchvision.transforms.functional.resize(origin, (200, 200))\n", |
|
|
410 |
" origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n", |
|
|
411 |
"\n", |
|
|
412 |
" with torch.no_grad():\n", |
|
|
413 |
" origin = torch.stack([origin])\n", |
|
|
414 |
" origin = origin.to(device)\n", |
|
|
415 |
" out = model(origin)\n", |
|
|
416 |
" softmax = torch.nn.functional.log_softmax(out, dim=1)\n", |
|
|
417 |
" out = torch.argmax(softmax, dim=1)\n", |
|
|
418 |
"\n", |
|
|
419 |
" origin = origin[0].to(\"cpu\")\n", |
|
|
420 |
" out = out[0].to(\"cpu\")\n", |
|
|
421 |
"\n", |
|
|
422 |
" plt.figure(figsize=(20, 10))\n", |
|
|
423 |
"\n", |
|
|
424 |
" pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n", |
|
|
425 |
"\n", |
|
|
426 |
" plt.subplot(1, 2, 1)\n", |
|
|
427 |
" plt.title(\"origin image\")\n", |
|
|
428 |
" plt.imshow(np.array(pil_origin))\n", |
|
|
429 |
"\n", |
|
|
430 |
" plt.subplot(1, 2, 2)\n", |
|
|
431 |
" plt.title(\"blended origin + predict\")\n", |
|
|
432 |
" plt.imshow(np.array(blend(origin, out)))\n", |
|
|
433 |
" plt.show()\n" |
|
|
434 |
], |
|
|
435 |
"metadata": { |
|
|
436 |
"id": "sGk2UEtw4JLr" |
|
|
437 |
}, |
|
|
438 |
"execution_count": null, |
|
|
439 |
"outputs": [] |
|
|
440 |
} |
|
|
441 |
] |
|
|
442 |
} |