Diff of /Training.ipynb [000000] .. [2507a0]

Switch to unified view

a b/Training.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "language_info": {
13
      "name": "python"
14
    }
15
  },
16
  "cells": [
17
    {
18
      "cell_type": "code",
19
      "execution_count": null,
20
      "metadata": {
21
        "id": "fqPJUT863Ps2"
22
      },
23
      "outputs": [],
24
      "source": [
25
        "from google.colab import drive\n",
26
        "drive.mount('/content/drive')\n",
27
        "#!unzip '/content/drive/MyDrive/Batoul_Code/input/CT/images.zip' -d '/content/drive/MyDrive/Batoul_Code/input/CT'"
28
      ]
29
    },
30
    {
31
      "cell_type": "code",
32
      "source": [
33
        "import torch\n",
34
        "import torchvision\n",
35
        "import os\n",
36
        "import glob\n",
37
        "import time\n",
38
        "import pickle\n",
39
        "import sys\n",
40
        "sys.path.append('/content/drive/MyDrive/Batoul_Code/')\n",
41
        "sys.path.append('/content/drive/MyDrive/Batoul_Code/src')\n",
42
        "import pandas as pd\n",
43
        "import numpy as np\n",
44
        "import matplotlib.pyplot as plt\n",
45
        "\n",
46
        "from pathlib import Path\n",
47
        "from PIL import Image\n",
48
        "from sklearn.model_selection import train_test_split\n",
49
        "\n",
50
        "\n",
51
        "from data import LungDataset, blend, Pad, Crop, Resize\n",
52
        "from  OurModel import  CxlNet\n",
53
        "from  metrics import jaccard, dice"
54
      ],
55
      "metadata": {
56
        "id": "N-IZfgid3Xwc"
57
      },
58
      "execution_count": null,
59
      "outputs": []
60
    },
61
    {
62
      "cell_type": "code",
63
      "source": [
64
        "in_channels=1\n",
65
        "out_channels=2\n",
66
        "batch_norm=True\n",
67
        "upscale_mode=\"bilinear\"\n",
68
        "image_size=512\n",
69
        "def selectModel():\n",
70
        "    return CxlNet(\n",
71
        "            in_channels=in_channels,\n",
72
        "            out_channels=out_channels,\n",
73
        "            batch_norm=batch_norm,\n",
74
        "            upscale_mode=upscale_mode,\n",
75
        "            image_size=image_size)"
76
      ],
77
      "metadata": {
78
        "id": "J349vBjr31Ir"
79
      },
80
      "execution_count": null,
81
      "outputs": []
82
    },
83
    {
84
      "cell_type": "code",
85
      "source": [
86
        "dataset_name=\"dataset\"\n",
87
        "dataset_type=\"png\"\n",
88
        "split_file = \"/content/drive/MyDrive/Batoul_Code/splits.pk\"\n",
89
        "version=\"CxlNet\"\n",
90
        "approach=\"contour\"\n",
91
        "model = selectModel()\n",
92
        "\n",
93
        "base_path=\"/content/drive/MyDrive/Batoul_Code/\"\n",
94
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
95
        "device\n",
96
        "\n",
97
        "data_folder = Path(base_path+\"input\", base_path+\"input/\"+dataset_name)\n",
98
        "origins_folder = data_folder / \"images\"\n",
99
        "masks_folder = data_folder / \"masks\"\n",
100
        "masks_contour_folder = data_folder / \"masks_contour\"\n",
101
        "models_folder = Path(base_path+\"models\")\n",
102
        "images_folder = Path(base_path+\"images\")\n",
103
        "\n",
104
        "\n"
105
      ],
106
      "metadata": {
107
        "id": "MhOSosnw4Q96"
108
      },
109
      "execution_count": null,
110
      "outputs": []
111
    },
112
    {
113
      "cell_type": "code",
114
      "source": [
115
        "#@title\n",
116
        "batch_size = 4\n",
117
        "torch.cuda.empty_cache()\n",
118
        "origins_list = [f.stem  for f in origins_folder.glob(f\"*.{dataset_type}\")]\n",
119
        "#masks_list = [f.stem  for f in masks_folder.glob(f\"*.{dataset_type}\")]\n",
120
        "masks_list = [f.stem  for f in masks_contour_folder.glob(f\"*.{dataset_type}\")]\n",
121
        "\n",
122
        "\n",
123
        "\n",
124
        "\n",
125
        "origin_mask_list = [(mask_name.replace(\"_mask\", \"\"), mask_name) for mask_name in masks_list]\n",
126
        "\n",
127
        "\n",
128
        "if os.path.isfile(split_file):\n",
129
        "    with open(split_file, \"rb\") as f:\n",
130
        "        splits = pickle.load(f)\n",
131
        "else:\n",
132
        "    splits = {}\n",
133
        "    splits[\"train\"], splits[\"test\"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)\n",
134
        "    splits[\"train\"], splits[\"val\"] = train_test_split(splits[\"train\"], test_size=0.1, random_state=42)\n",
135
        "    with open(split_file, \"wb\") as f:\n",
136
        "        pickle.dump(splits, f)\n",
137
        "\n",
138
        "val_test_transforms = torchvision.transforms.Compose([\n",
139
        "    Resize((image_size, image_size)),\n",
140
        "])\n",
141
        "\n",
142
        "train_transforms = torchvision.transforms.Compose([\n",
143
        "    Pad(200),\n",
144
        "    Crop(300),\n",
145
        "    val_test_transforms,\n",
146
        "])\n",
147
        "\n",
148
        "datasets = {x: LungDataset(\n",
149
        "    splits[x],\n",
150
        "    origins_folder,\n",
151
        "    #masks_folder,\n",
152
        "    masks_contour_folder,\n",
153
        "    train_transforms if x == \"train\" else val_test_transforms,\n",
154
        "    dataset_type=dataset_type\n",
155
        ") for x in [\"train\", \"test\", \"val\"]}\n",
156
        "\n",
157
        "dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size)\n",
158
        "               for x in [\"train\", \"test\", \"val\"]}\n",
159
        "\n",
160
        "print(len(dataloaders['train']))\n",
161
        "\n",
162
        "idx = 0\n",
163
        "phase = \"train\"\n",
164
        "\n",
165
        "plt.figure(figsize=(20, 20))\n",
166
        "origin, mask = datasets[phase][idx]\n",
167
        "\n",
168
        "pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
169
        "print(origin.size())\n",
170
        "print(mask.size())\n",
171
        "pil_origin.save(\"1.png\")\n",
172
        "\n",
173
        "\n",
174
        "print(mask.size())\n",
175
        "pil_mask = torchvision.transforms.functional.to_pil_image(mask.float())\n",
176
        "pil_mask.save(\"2.png\")\n",
177
        "plt.subplot(1, 3, 1)\n",
178
        "plt.title(\"origin image\")\n",
179
        "plt.imshow(np.array(pil_origin))\n",
180
        "\n",
181
        "plt.subplot(1, 3, 2)\n",
182
        "plt.title(\"manually labeled mask\")\n",
183
        "plt.imshow(np.array(pil_mask))\n",
184
        "\n",
185
        "plt.subplot(1, 3, 3)\n",
186
        "plt.title(\"blended origin + mask\")\n",
187
        "plt.imshow(np.array(blend(origin, mask)));\n",
188
        "\n",
189
        "plt.savefig(images_folder / \"data-example.png\", bbox_inches='tight')\n",
190
        "plt.show()\n",
191
        "train=True\n",
192
        "model_name = \"ournet_\"+version+\".pt\"\n",
193
        "if train==True:\n",
194
        "\n",
195
        "    if os.path.isfile(models_folder / model_name):\n",
196
        "      model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
197
        "      print(\"load_state_dict\")\n",
198
        "\n",
199
        "    model = model.to(device)\n",
200
        "    # optimizer = torch.optim.SGD(unet.parameters(), lr=0.0005, momentum=0.9)\n",
201
        "    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n",
202
        "\n",
203
        "    train_log_filename = base_path + \"train-log-\"+version+\".txt\"\n",
204
        "    epochs = 50\n",
205
        "    best_val_loss = np.inf\n",
206
        "\n",
207
        "\n",
208
        "    hist = []\n",
209
        "\n",
210
        "    for e in range(epochs):\n",
211
        "        start_t = time.time()\n",
212
        "\n",
213
        "        print(\"Epoch \"+str(e))\n",
214
        "        model.train()\n",
215
        "\n",
216
        "        train_loss = 0.0\n",
217
        "\n",
218
        "        for origins, masks in dataloaders[\"train\"]:\n",
219
        "            num = origins.size(0)\n",
220
        "\n",
221
        "            origins = origins.to(device)\n",
222
        "            #print(masks.size())\n",
223
        "            #if dataset_name!=\"dataset\":\n",
224
        "              #masks = masks.permute((0,3,1, 2))\n",
225
        "              #masks=masks[:,0,:,:]\n",
226
        "              #print(masks.size())\n",
227
        "\n",
228
        "            masks = masks.to(device)\n",
229
        "            optimizer.zero_grad()\n",
230
        "            outs = model(origins)\n",
231
        "            softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
232
        "            loss = torch.nn.functional.nll_loss(softmax, masks)\n",
233
        "            loss.backward()\n",
234
        "            optimizer.step()\n",
235
        "\n",
236
        "            train_loss += loss.item() * num\n",
237
        "            print(\".\", end=\"\")\n",
238
        "\n",
239
        "        train_loss = train_loss / len(datasets['train'])\n",
240
        "        print()\n",
241
        "\n",
242
        "        print(\"validation phase\")\n",
243
        "        model.eval()\n",
244
        "        val_loss = 0.0\n",
245
        "        val_jaccard = 0.0\n",
246
        "        val_dice = 0.0\n",
247
        "\n",
248
        "        for origins, masks in dataloaders[\"val\"]:\n",
249
        "            num = origins.size(0)\n",
250
        "            origins = origins.to(device)\n",
251
        "            masks = masks.to(device)\n",
252
        "\n",
253
        "            with torch.no_grad():\n",
254
        "                outs = model(origins)\n",
255
        "                softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
256
        "                val_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n",
257
        "\n",
258
        "                outs = torch.argmax(softmax, dim=1)\n",
259
        "                outs = outs.float()\n",
260
        "                masks = masks.float()\n",
261
        "                val_jaccard += jaccard(masks, outs.float()).item() * num\n",
262
        "                val_dice += dice(masks, outs).item() * num\n",
263
        "\n",
264
        "            print(\".\", end=\"\")\n",
265
        "        val_loss = val_loss / len(datasets[\"val\"])\n",
266
        "        val_jaccard = val_jaccard / len(datasets[\"val\"])\n",
267
        "        val_dice = val_dice / len(datasets[\"val\"])\n",
268
        "        print()\n",
269
        "\n",
270
        "        end_t = time.time()\n",
271
        "        spended_t = end_t - start_t\n",
272
        "\n",
273
        "        with open(train_log_filename, \"a\") as train_log_file:\n",
274
        "            report = f\"epoch: {e + 1}/{epochs}, time: {spended_t}, train loss: {train_loss}, \\n\" \\\n",
275
        "                     + f\"val loss: {val_loss}, val jaccard: {val_jaccard}, val dice: {val_dice}\"\n",
276
        "\n",
277
        "            hist.append({\n",
278
        "                \"time\": spended_t,\n",
279
        "                \"train_loss\": train_loss,\n",
280
        "                \"val_loss\": val_loss,\n",
281
        "                \"val_jaccard\": val_jaccard,\n",
282
        "                \"val_dice\": val_dice,\n",
283
        "            })\n",
284
        "\n",
285
        "            print(report)\n",
286
        "            train_log_file.write(report + \"\\n\")\n",
287
        "\n",
288
        "            if val_loss < best_val_loss:\n",
289
        "                best_val_loss = val_loss\n",
290
        "                torch.save(model.state_dict(), models_folder / model_name)\n",
291
        "                print(\"model saved\")\n",
292
        "                train_log_file.write(\"model saved\\n\")\n",
293
        "            print()\n",
294
        "\n",
295
        "        #if val_jaccard >=0.9179:\n",
296
        "            #break\n",
297
        "    plt.figure(figsize=(15, 7))\n",
298
        "    train_loss_hist = [h[\"train_loss\"] for h in hist]\n",
299
        "    plt.plot(range(len(hist)), train_loss_hist, \"b\", label=\"train loss\")\n",
300
        "\n",
301
        "    val_loss_hist = [h[\"val_loss\"] for h in hist]\n",
302
        "    plt.plot(range(len(hist)), val_loss_hist, \"r\", label=\"val loss\")\n",
303
        "\n",
304
        "    val_dice_hist = [h[\"val_dice\"] for h in hist]\n",
305
        "    plt.plot(range(len(hist)), val_dice_hist, \"g\", label=\"val dice\")\n",
306
        "\n",
307
        "    val_jaccard_hist = [h[\"val_jaccard\"] for h in hist]\n",
308
        "    plt.plot(range(len(hist)), val_jaccard_hist, \"y\", label=\"val jaccard\")\n",
309
        "\n",
310
        "    plt.legend()\n",
311
        "    plt.xlabel(\"epoch\")\n",
312
        "    plt.savefig(images_folder / model_name.replace(\".pt\", \"-train-hist.png\"))\n",
313
        "\n",
314
        "    time_hist = [h[\"time\"] for h in hist]\n",
315
        "    overall_time = sum(time_hist) // 60\n",
316
        "    mean_epoch_time = sum(time_hist) / len(hist)\n",
317
        "    print(f\"epochs: {len(hist)}, overall time: {overall_time}m, mean epoch time: {mean_epoch_time}s\")\n",
318
        "\n",
319
        "    torch.cuda.empty_cache()\n",
320
        "else:\n",
321
        "\n",
322
        "    model_name = \"ournet_\"+version+\".pt\"\n",
323
        "    model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
324
        "    model.to(device)\n",
325
        "    model.eval()\n",
326
        "\n",
327
        "    test_loss = 0.0\n",
328
        "    test_jaccard = 0.0\n",
329
        "    test_dice = 0.0\n",
330
        "\n",
331
        "    for origins, masks in dataloaders[\"test\"]:\n",
332
        "        num = origins.size(0)\n",
333
        "        origins = origins.to(device)\n",
334
        "        masks = masks.to(device)\n",
335
        "\n",
336
        "        with torch.no_grad():\n",
337
        "            outs = model(origins)\n",
338
        "            softmax = torch.nn.functional.log_softmax(outs, dim=1)\n",
339
        "            test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num\n",
340
        "\n",
341
        "            outs = torch.argmax(softmax, dim=1)\n",
342
        "            outs = outs.float()\n",
343
        "            masks = masks.float()\n",
344
        "            test_jaccard += jaccard(masks, outs).item() * num\n",
345
        "            test_dice += dice(masks, outs).item() * num\n",
346
        "        print(\".\", end=\"\")\n",
347
        "\n",
348
        "    test_loss = test_loss / len(datasets[\"test\"])\n",
349
        "    test_jaccard = test_jaccard / len(datasets[\"test\"])\n",
350
        "    test_dice = test_dice / len(datasets[\"test\"])\n",
351
        "\n",
352
        "    print()\n",
353
        "    print(f\"avg test loss: {test_loss}\")\n",
354
        "    print(f\"avg test jaccard: {test_jaccard}\")\n",
355
        "    print(f\"avg test dice: {test_dice}\")\n",
356
        "\n",
357
        "    num_samples = 9\n",
358
        "    phase = \"test\"\n",
359
        "\n",
360
        "    subset = torch.utils.data.Subset(\n",
361
        "        datasets[phase],\n",
362
        "        np.random.randint(0, len(datasets[phase]), num_samples)\n",
363
        "    )\n",
364
        "    random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=1)\n",
365
        "    plt.figure(figsize=(20, 25))\n",
366
        "\n",
367
        "    for idx, (origin, mask) in enumerate(random_samples_loader):\n",
368
        "        plt.subplot((num_samples // 3) + 1, 3, idx + 1)\n",
369
        "\n",
370
        "        origin = origin.to(device)\n",
371
        "        mask = mask.to(device)\n",
372
        "\n",
373
        "        with torch.no_grad():\n",
374
        "            out = model(origin)\n",
375
        "            softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
376
        "            out = torch.argmax(softmax, dim=1)\n",
377
        "\n",
378
        "            jaccard_score = jaccard(mask.float(), out.float()).item()\n",
379
        "            dice_score = dice(mask.float(), out.float()).item()\n",
380
        "\n",
381
        "            origin = origin[0].to(\"cpu\")\n",
382
        "            out = out[0].to(\"cpu\")\n",
383
        "            mask = mask[0].to(\"cpu\")\n",
384
        "\n",
385
        "            plt.imshow(np.array(blend(origin, mask, out)))\n",
386
        "            plt.title(f\"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}\")\n",
387
        "            print(\".\", end=\"\")\n",
388
        "            plt.show()\n",
389
        "    plt.savefig(images_folder / \"obtained-results.png\", bbox_inches='tight')\n",
390
        "    print()\n",
391
        "    print(\"red area - predict\")\n",
392
        "    print(\"green area - ground truth\")\n",
393
        "    print(\"yellow area - intersection\")\n",
394
        "\n",
395
        "\n",
396
        "\n",
397
        "    model_name = \"ournet_\"+version+\".pt\"\n",
398
        "    model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device(\"cpu\")))\n",
399
        "    model.to(device)\n",
400
        "    model.eval()\n",
401
        "\n",
402
        "    device\n",
403
        "\n",
404
        "    # %%\n",
405
        "\n",
406
        "    origin_filename = \"input/dataset/images/CHNCXR_0042_0.png\"\n",
407
        "\n",
408
        "    origin = Image.open(origin_filename).convert(\"P\")\n",
409
        "    origin = torchvision.transforms.functional.resize(origin, (200, 200))\n",
410
        "    origin = torchvision.transforms.functional.to_tensor(origin) - 0.5\n",
411
        "\n",
412
        "    with torch.no_grad():\n",
413
        "        origin = torch.stack([origin])\n",
414
        "        origin = origin.to(device)\n",
415
        "        out = model(origin)\n",
416
        "        softmax = torch.nn.functional.log_softmax(out, dim=1)\n",
417
        "        out = torch.argmax(softmax, dim=1)\n",
418
        "\n",
419
        "        origin = origin[0].to(\"cpu\")\n",
420
        "        out = out[0].to(\"cpu\")\n",
421
        "\n",
422
        "    plt.figure(figsize=(20, 10))\n",
423
        "\n",
424
        "    pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert(\"RGB\")\n",
425
        "\n",
426
        "    plt.subplot(1, 2, 1)\n",
427
        "    plt.title(\"origin image\")\n",
428
        "    plt.imshow(np.array(pil_origin))\n",
429
        "\n",
430
        "    plt.subplot(1, 2, 2)\n",
431
        "    plt.title(\"blended origin + predict\")\n",
432
        "    plt.imshow(np.array(blend(origin, out)))\n",
433
        "    plt.show()\n"
434
      ],
435
      "metadata": {
436
        "id": "sGk2UEtw4JLr"
437
      },
438
      "execution_count": null,
439
      "outputs": []
440
    }
441
  ]
442
}