Diff of /Evalution.ipynb [000000] .. [2507a0]

Switch to unified view

a b/Evalution.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "language_info": {
13
      "name": "python"
14
    }
15
  },
16
  "cells": [
17
    {
18
      "cell_type": "code",
19
      "execution_count": null,
20
      "metadata": {
21
        "id": "iXo6Nj7M5GK8"
22
      },
23
      "outputs": [],
24
      "source": [
25
        "import torch\n",
26
        "import torchvision\n",
27
        "import os\n",
28
        "import glob\n",
29
        "import time\n",
30
        "import pickle\n",
31
        "import sys\n",
32
        "sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n",
33
        "sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n",
34
        "\n",
35
        "import pandas as pd\n",
36
        "import numpy as np\n",
37
        "import matplotlib.pyplot as plt\n",
38
        "from pathlib import Patha\n",
39
        "from PIL import Image\n",
40
        "from sklearn.model_selection import train_test_split\n",
41
        "\n",
42
        "from data import LungDataset, blend, Pad, Crop, Resize\n",
43
        "from data2 import LungDataset2, blend, Pad, Crop, Resize\n",
44
        "\n",
45
        "from  OurModel import CxlNet\n",
46
        "\n",
47
        "from  metrics import jaccard, dice,get_accuracy, get_sensitivity, get_specificity"
48
      ]
49
    },
50
    {
51
      "cell_type": "code",
52
      "source": [
53
        "in_channels=1\n",
54
        "out_channels=2\n",
55
        "batch_norm=True\n",
56
        "upscale_mode=\"bilinear\"\n",
57
        "image_size=512\n",
58
        "def selectModel():\n",
59
        "    return CxlNet(\n",
60
        "            in_channels=in_channels,\n",
61
        "            out_channels=out_channels,\n",
62
        "            batch_norm=batch_norm,\n",
63
        "            upscale_mode=upscale_mode,\n",
64
        "            image_size=image_size)"
65
      ],
66
      "metadata": {
67
        "id": "BGn0Pjmb5gRA"
68
      },
69
      "execution_count": null,
70
      "outputs": []
71
    },
72
    {
73
      "cell_type": "code",
74
      "source": [
75
        "dataset_name=\"dataset\"\n",
76
        "dataset_types={\"dataset\":\"png\",\"CT\":\"jpg\"}\n",
77
        "dataset_type=dataset_types[dataset_name]\n",
78
        "print(dataset_type)\n",
79
        "image_size=512\n",
80
        "split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n",
81
        "list_data_file = \"/content/drive/MyDrive/Batoul_Code/list_data.pk\"\n",
82
        "version=\"UNet\"\n",
83
        "approach=\"contour\"\n",
84
        "model = selectModel()\n",
85
        "\n",
86
        "base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n",
87
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
88
        "device\n",
89
        "\n",
90
        "\n",
91
        "\n",
92
        "\n",
93
        "data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n",
94
        "origins_folder = data_folder / \"images\"\n",
95
        "masks_folder = data_folder / \"masks\"\n",
96
        "masks_contour_folder = data_folder / \"masks_contour\"\n",
97
        "masks_folder =masks_contour_folder\n",
98
        "models_folder = Path(base_path+\"models\")\n",
99
        "images_folder = Path(base_path+\"images\")\n"
100
      ],
101
      "metadata": {
102
        "id": "5Sp6Ga1y50He"
103
      },
104
      "execution_count": null,
105
      "outputs": []
106
    },
107
    {
108
      "cell_type": "code",
109
      "source": [
110
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
111
        "models_folder = Path(base_path+\"models\")\n",
112
        "model_name = \"unet-6v.pt\"\n",
113
        "model_name=\"ournet_\"+version+\".pt\"\n",
114
        "print(model_name)\n",
115
        "model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
116
        "model.to(device)\n",
117
        "model.eval()\n",
118
        "\n",
119
        "\n",
120
        "test_loss = 0.0\n",
121
        "test_jaccard = 0.0\n",
122
        "test_dice = 0.0\n",
123
        "test_accuracy=0.0\n",
124
        "test_sensitivity=0.0\n",
125
        "test_specificity=0.0\n",
126
        "batch_size = 4\n",
127
        "\n",
128
        "if os.path.isfile(list_data_file):\n",
129
        "  with open(list_data_file, \"rb\") as f:\n",
130
        "    list_data = pickle.load(f)\n",
131
        "    origins_list=list_data[0]\n",
132
        "    masks_list=list_data[1]\n",
133
        "else:\n",
134
        "  origins_list = [f.stem  for f in origins_folder.glob(f\"*.{dataset_type}\")]\n",
135
        "  masks_list = [f.stem  for f in masks_folder.glob(f\"*.{dataset_type}\")]\n",
136
        "  with open(list_data_file, \"wb\") as f:\n",
137
        "    pickle.dump([origins_list,masks_list], f)\n",
138
        "\n",
139
        "\n",
140
        "#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n",
141
        "#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n",
142
        "\n",
143
        "\n",
144
        "origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n",
145
        "\n",
146
        "\n",
147
        "\n",
148
        "if os.path.isfile(split_file):\n",
149
        "    with open(split_file, \"rb\") as f:\n",
150
        "        splits = pickle.load(f)\n",
151
        "else:\n",
152
        "    splits = {}\n",
153
        "    splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n",
154
        "    splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n",
155
        "    with open(split_file, \"wb\") as f:\n",
156
        "        pickle.dump(splits, f)\n",
157
        "\n",
158
        "val_test_transforms = torchvision.transforms.Compose([\n",
159
        "    Resize((image_size, image_size)),\n",
160
        "])\n",
161
        "\n",
162
        "if dataset_name!=\"dataset\":\n",
163
        "  train_transforms = torchvision.transforms.Compose([\n",
164
        "  Pad(200),\n",
165
        "  Crop(300),\n",
166
        "  val_test_transforms,\n",
167
        "  ])\n",
168
        "  datasets = {x: LungDataset2(\n",
169
        "  splits[x],\n",
170
        "  origins_folder,\n",
171
        "  masks_folder,\n",
172
        "  train_transforms if x == \"train\" else val_test_transforms,\n",
173
        "  dataset_type=dataset_type\n",
174
        "  ) for x in [\"train\", \"test\", \"val\"]}\n",
175
        "else:\n",
176
        "  train_transforms = torchvision.transforms.Compose([\n",
177
        "  Pad(200),\n",
178
        "  Crop(300),\n",
179
        "  val_test_transforms,])\n",
180
        "\n",
181
        "  datasets = {x: LungDataset(\n",
182
        "  splits[x],\n",
183
        "  origins_folder,\n",
184
        "  masks_folder,\n",
185
        "  train_transforms if x == \"train\" else val_test_transforms,\n",
186
        "  dataset_type=dataset_type\n",
187
        "  ) for x in [\"train\", \"test\", \"val\"]}\n",
188
        "\n",
189
        "num_samples = 9\n",
190
        "phase = \"test\"\n",
191
        "print(len(datasets[phase]))\n",
192
        "\n",
193
        "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size) for x in [\"train\", \"test\", \"val\"]}\n",
194
        "\n",
195
        "for origins, masks in dataloaders[\"test\"]:\n",
196
        "    num = origins.size(0)\n",
197
        "\n",
198
        "    origins = origins.to(device)\n",
199
        "    masks = masks.to(device)\n",
200
        "\n",
201
        "    with torch.no_grad():\n",
202
        "        outs = model(origins)\n",
203
        "        softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
204
        "        test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n",
205
        "        outs = torch.argmax(softmax, dim=1)\n",
206
        "        outs = outs.float()\n",
207
        "        masks = masks.float()\n",
208
        "        test_jaccard += jaccard(masks, outs).item() * num\n",
209
        "        test_dice += dice(masks, outs).item() * num\n",
210
        "        test_accuracy += get_accuracy(masks, outs) * num\n",
211
        "        test_sensitivity += get_sensitivity(masks, outs) * num\n",
212
        "        test_specificity += get_specificity(masks, outs) * num\n",
213
        "    print(\".\", end=\"\")\n",
214
        "\n",
215
        "test_loss = test_loss / len(datasets[\"test\"])\n",
216
        "test_jaccard = test_jaccard / len(datasets[\"test\"])\n",
217
        "test_dice = test_dice / len(datasets[\"test\"])\n",
218
        "test_accuracy = test_accuracy / len(datasets[\"test\"])\n",
219
        "print()\n",
220
        "print(f\"avg test loss: {test_loss}\")\n",
221
        "print(f\"avg test jaccard: {test_jaccard}\")\n",
222
        "print(f\"avg test dice: {test_dice}\")\n",
223
        "print(f\"avg test accuracy: {test_accuracy}\")\n",
224
        "print(f\"avg test sensitivity: {test_sensitivity}\")\n",
225
        "print(f\"avg test specificity: {test_specificity}\")\n",
226
        "\n",
227
        "\n",
228
        "\n",
229
        "subset = torch.utils.data.Subset(\n",
230
        "    datasets[phase],\n",
231
        "    np.random.randint(0, len(datasets[phase]), num_samples)\n",
232
        ")\n",
233
        "random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=2)\n",
234
        "plt.figure(figsize=(20, 25))\n",
235
        "\n",
236
        "for idx, (origin, mask) in enumerate(random_samples_loader):\n",
237
        "    plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n",
238
        "\n",
239
        "    origin = origin.to(device)\n",
240
        "    mask = mask.to(device)\n",
241
        "\n",
242
        "    with torch.no_grad():\n",
243
        "        out = model(origin)\n",
244
        "        softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
245
        "        out = torch.argmax(softmax, dim=1)\n",
246
        "\n",
247
        "        jaccard_score = jaccard(mask.float(), out.float()).item()\n",
248
        "        dice_score = dice(mask.float(), out.float()).item()\n",
249
        "\n",
250
        "        origin = origin[0].to(\"cpu\")\n",
251
        "        out = out[0].to(\"cpu\")\n",
252
        "        mask = mask[0].to(\"cpu\")\n",
253
        "        #plt.imshow(np.array(blend(origin, mask, out)))\n",
254
        "        plt.imshow(np.array(blend(origin, out, out)))\n",
255
        "        plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n",
256
        "        print(\".\", end=\"\")\n",
257
        "\n",
258
        "plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n",
259
        "plt.show()\n",
260
        "print()\n",
261
        "print(\"red area - predict\")\n",
262
        "print(\"green area - ground truth\")\n",
263
        "print(\"yellow area - intersection\")\n",
264
        "\n",
265
        "\n",
266
        "model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
267
        "model.to(device)\n",
268
        "model.eval()\n",
269
        "\n",
270
        "device\n",
271
        "\n",
272
        "#%%\n",
273
        "\n",
274
        "origin_filename = base_path+ f\"input/{dataset_name}/images/ID00015637202177877247924_110.jpg\"\n",
275
        "#origin_filename=base_path + \"external_samples/1.jpg\"\n",
276
        "\n",
277
        "origin = Image.open(origin_filename).convert(\"P\")\n",
278
        "origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n",
279
        "origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n",
280
        "\n",
281
        "with torch.no_grad():\n",
282
        "    origin = torch.stack([origin])\n",
283
        "    origin = origin.to(device)\n",
284
        "    out = model(origin)\n",
285
        "    softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
286
        "    out = torch.argmax(softmax, dim=1)\n",
287
        "\n",
288
        "    origin = origin[0].to(\"cpu\")\n",
289
        "    out = out[0].to(\"cpu\")\n",
290
        "\n",
291
        "\n",
292
        "plt.figure(figsize=(20, 10))\n",
293
        "\n",
294
        "pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
295
        "\n",
296
        "plt.subplot(1, 2, 1)\n",
297
        "plt.title(\"origin image\")\n",
298
        "plt.imshow(np.array(pil_origin))\n",
299
        "plt.show()\n",
300
        "plt.subplot(1, 2, 2)\n",
301
        "plt.title(\"blended origin + predict\")\n",
302
        "plt.imshow(np.array(blend(origin, out)))\n",
303
        "plt.show()\n"
304
      ],
305
      "metadata": {
306
        "id": "jFpE8AwG5-Nm"
307
      },
308
      "execution_count": null,
309
      "outputs": []
310
    },
311
    {
312
      "cell_type": "code",
313
      "source": [
314
        "\n",
315
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
316
        "models_folder = Path(base_path+\"models\")\n",
317
        "\n",
318
        "\n",
319
        "\n",
320
        "test_loss = 0.0\n",
321
        "test_jaccard = 0.0\n",
322
        "test_dice = 0.0\n",
323
        "\n",
324
        "batch_size = 4\n",
325
        "\n",
326
        "if os.path.isfile(list_data_file):\n",
327
        "  with open(list_data_file, \"rb\") as f:\n",
328
        "    list_data = pickle.load(f)\n",
329
        "    origins_list=list_data[0]\n",
330
        "    masks_list=list_data[1]\n",
331
        "else:\n",
332
        "  origins_list = [f.stem  for f in origins_folder.glob(f\"*.{dataset_type}\")]\n",
333
        "  masks_list = [f.stem  for f in masks_folder.glob(f\"*.{dataset_type}\")]\n",
334
        "  with open(list_data_file, \"wb\") as f:\n",
335
        "    pickle.dump([origins_list,masks_list], f)\n",
336
        "\n",
337
        "\n",
338
        "#origins_list = [f.stem for f in origins_folder.glob(\"*.png\")]\n",
339
        "#masks_list = [f.stem for f in masks_folder.glob(\"*.png\")]\n",
340
        "\n",
341
        "\n",
342
        "origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n",
343
        "\n",
344
        "\n",
345
        "\n",
346
        "if os.path.isfile(split_file):\n",
347
        "    with open(split_file, \"rb\") as f:\n",
348
        "        splits = pickle.load(f)\n",
349
        "else:\n",
350
        "    splits = {}\n",
351
        "    splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n",
352
        "    splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n",
353
        "    with open(split_file, \"wb\") as f:\n",
354
        "        pickle.dump(splits, f)\n",
355
        "\n",
356
        "val_test_transforms = torchvision.transforms.Compose([\n",
357
        "    Resize((image_size, image_size)),\n",
358
        "])\n",
359
        "\n",
360
        "if dataset_name!=\"dataset\":\n",
361
        "  train_transforms = torchvision.transforms.Compose([\n",
362
        "  #Pad(200),\n",
363
        "  #Crop(300),\n",
364
        "  #val_test_transforms,\n",
365
        "  ])\n",
366
        "  datasets = {x: LungDataset2(\n",
367
        "  splits[x],\n",
368
        "  origins_folder,\n",
369
        "  masks_folder,\n",
370
        "  train_transforms if x == \"train\" else val_test_transforms,\n",
371
        "  dataset_type=dataset_type\n",
372
        "  ) for x in [\"train\", \"test\", \"val\"]}\n",
373
        "else:\n",
374
        "  train_transforms = torchvision.transforms.Compose([\n",
375
        "  Pad(200),\n",
376
        "  Crop(300),\n",
377
        "  val_test_transforms,])\n",
378
        "\n",
379
        "  datasets = {x: LungDataset(\n",
380
        "  splits[x],\n",
381
        "  origins_folder,\n",
382
        "  masks_folder,\n",
383
        "  train_transforms if x == \"train\" else val_test_transforms,\n",
384
        "  dataset_type=dataset_type\n",
385
        "  ) for x in [\"train\", \"test\", \"val\"]}\n",
386
        "\n",
387
        "\n",
388
        "def mask_to_class_rgb1(mask):\n",
389
        "        #print('----mask->rgb----')\n",
390
        "  mask = torch.from_numpy(np.array(mask))\n",
391
        "  mask = torch.squeeze(mask)  # remove 1\n",
392
        "\n",
393
        "  class_mask = mask\n",
394
        "\n",
395
        "  class_mask = class_mask.permute(2, 0, 1).contiguous()\n",
396
        "  h, w = class_mask.shape[1], class_mask.shape[2]\n",
397
        "  mask_out = torch.zeros((h, w))\n",
398
        "\n",
399
        "  threshold=200\n",
400
        "  for i in range(0,3):\n",
401
        "    class_mask[i][class_mask[i] < threshold] = 0\n",
402
        "\n",
403
        "  for i in range(2, 3):\n",
404
        "    mask_out[class_mask[i] >= threshold]=1\n",
405
        "  return mask_out\n",
406
        "\n",
407
        "\n",
408
        "def mask_to_class_rgb(mask):\n",
409
        "        #print('----mask->rgb----')\n",
410
        "  mask = torch.from_numpy(np.array(mask))\n",
411
        "  mask = torch.squeeze(mask)  # remove 1\n",
412
        "\n",
413
        "  class_mask = mask\n",
414
        "\n",
415
        "  class_mask = class_mask.permute(2, 0, 1).contiguous()\n",
416
        "  h, w = class_mask.shape[1], class_mask.shape[2]\n",
417
        "  mask_out = torch.zeros((h, w))\n",
418
        "\n",
419
        "  threshold=200\n",
420
        "  for i in range(0,3):\n",
421
        "    class_mask[i][class_mask[i] < threshold] = 0\n",
422
        "\n",
423
        "  for i in range(2, 3):\n",
424
        "    mask_out[class_mask[i] >= threshold]=1\n",
425
        "  return mask_out\n",
426
        "\n",
427
        "def getitem2(path):\n",
428
        "  mask = Image.open(path)\n",
429
        "  mask = mask_to_class_rgb(mask)\n",
430
        "  mask=mask.long()\n",
431
        "  #mask = (torch.tensor(mask) > 128).long()\n",
432
        "  return mask\n",
433
        "\n",
434
        "def getitem1(path):\n",
435
        "  mask = Image.open(path)\n",
436
        "  mask = mask.resize((image_size,image_size))\n",
437
        "  mask = np.array(mask)\n",
438
        "  mask = (torch.tensor(mask) > 128).long()\n",
439
        "  return mask\n",
440
        "\n",
441
        "\n",
442
        "idx=1\n",
443
        "phase = \"test\"\n",
444
        "fig = plt.figure(figsize=(20, 10))\n",
445
        "input=0\n",
446
        "if dataset_name!=\"dataset\":\n",
447
        "  samples=[\"ID00015637202177877247924_110.jpg\",\n",
448
        "           \"ID00009637202177434476278_173.jpg\",\n",
449
        "           \"ID00009637202177434476278_316.jpg\",\n",
450
        "           \"ID00009637202177434476278_204.jpg\",]\n",
451
        "  masks = [mask_name.replace(\"_\", \"_mask_\").replace(\"images\", \"masks\") for mask_name in samples]\n",
452
        "\n",
453
        "else:\n",
454
        "  samples=[\"CHNCXR_0060_0.png\",\n",
455
        "           \"CHNCXR_0074_0.png\",\n",
456
        "           \"CHNCXR_0129_0.png\",\n",
457
        "           \"CHNCXR_0167_0.png\",]\n",
458
        "  masks = [mask_name.replace(\"_0.png\", \"_0_mask.png\").replace(\"images\", \"masks\") for mask_name in samples]\n",
459
        "\n",
460
        "\n",
461
        "samples=[base_path + f\"input/{dataset_name}/images/\"+ sample_name for sample_name in samples]\n",
462
        "masks=[base_path + f\"input/{dataset_name}/masks/\"+ mask_name for mask_name in masks]\n",
463
        "models=[\"ResNetDUCHDC\",\"OueNetNew3\",\"NestedUNet\",\"ResNetDUC\",\"FCN_GCN\",\"SegNet2\",\"UNet\"]\n",
464
        "\n",
465
        "for input in range(0,len(samples)) :\n",
466
        "  for m in range(0,len(models)):\n",
467
        "    origin_filename = samples[input]\n",
468
        "    origin = Image.open(origin_filename).convert(\"P\")\n",
469
        "    origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))\n",
470
        "    origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n",
471
        "    if dataset_name!=\"dataset\":\n",
472
        "      mask= getitem2(masks[input])\n",
473
        "    else:\n",
474
        "      mask= getitem1(masks[input])\n",
475
        "    version=models[m]\n",
476
        "    if dataset_name!=\"dataset\":\n",
477
        "      version=version+\"_\"+dataset_name\n",
478
        "    model = selectModel(models[m])\n",
479
        "    model_name=\"ournet_\"+version+\".pt\"\n",
480
        "    model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
481
        "    model.to(device)\n",
482
        "    with torch.no_grad():\n",
483
        "        origin = torch.stack([origin])\n",
484
        "        origin = origin.to(device)\n",
485
        "        out = model(origin)\n",
486
        "        softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
487
        "        out = torch.argmax(softmax, dim=1)\n",
488
        "\n",
489
        "        origin = origin[0].to(\"cpu\")\n",
490
        "        out = out[0].to(\"cpu\")\n",
491
        "\n",
492
        "    pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
493
        "    plt.subplots_adjust(hspace=0)\n",
494
        "    if m==0:\n",
495
        "      ax=fig.add_subplot(len(samples), len(models)+2,idx)\n",
496
        "      ax.set_axis_off()\n",
497
        "      #plt.title(\"origin image\")\n",
498
        "      plt.imshow(np.array(pil_origin))\n",
499
        "      idx=idx+1\n",
500
        "      ax=fig.add_subplot(len(samples), len(models)+2,idx)\n",
501
        "      ax.set_axis_off()\n",
502
        "      plt.imshow(np.array(blend(origin, mask,amount=0.4)))\n",
503
        "      idx=idx+1\n",
504
        "    ax=fig.add_subplot(len(samples), len(models)+2,idx)\n",
505
        "    ax.set_axis_off()\n",
506
        "    #plt.title(\"blended origin + predict\")\n",
507
        "    plt.imshow(np.array(blend(origin, out,amount=0.5)))\n",
508
        "    #plt.savefig(images_folder / f\"results/{version} {input}\", bbox_inches='tight')\n",
509
        "    idx=idx+1\n",
510
        "\n",
511
        "plt.show()\n"
512
      ],
513
      "metadata": {
514
        "id": "ksO_-uVR6V4F"
515
      },
516
      "execution_count": null,
517
      "outputs": []
518
    }
519
  ]
520
}