Diff of /ndv/model.ipynb [000000] .. [64faee]

Switch to unified view

a b/ndv/model.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {
6
    "colab_type": "text",
7
    "id": "b6KgSfpahUZ1"
8
   },
9
   "source": [
10
    "# Setup"
11
   ]
12
  },
13
  {
14
   "cell_type": "code",
15
   "execution_count": 1,
16
   "metadata": {
17
    "colab": {},
18
    "colab_type": "code",
19
    "id": "G75A7ey4hUZ3"
20
   },
21
   "outputs": [],
22
   "source": [
23
    "import torch\n",
24
    "from fastai.callbacks import *\n",
25
    "from fastai.vision import *"
26
   ]
27
  },
28
  {
29
   "cell_type": "markdown",
30
   "metadata": {
31
    "colab_type": "text",
32
    "id": "C77ffh0whUZ7"
33
   },
34
   "source": [
35
    "## GPU "
36
   ]
37
  },
38
  {
39
   "cell_type": "code",
40
   "execution_count": 0,
41
   "metadata": {
42
    "colab": {
43
     "base_uri": "https://localhost:8080/",
44
     "height": 34
45
    },
46
    "colab_type": "code",
47
    "id": "syfiMzschUZ8",
48
    "outputId": "a555f3d9-91cb-43e1-8302-a037fbb5efe9"
49
   },
50
   "outputs": [
51
    {
52
     "data": {
53
      "text/plain": [
54
       "True"
55
      ]
56
     },
57
     "execution_count": 2,
58
     "metadata": {
59
      "tags": []
60
     },
61
     "output_type": "execute_result"
62
    }
63
   ],
64
   "source": [
65
    "# Check GPU availablity\n",
66
    "torch.cuda.is_available()"
67
   ]
68
  },
69
  {
70
   "cell_type": "code",
71
   "execution_count": 0,
72
   "metadata": {
73
    "colab": {
74
     "base_uri": "https://localhost:8080/",
75
     "height": 34
76
    },
77
    "colab_type": "code",
78
    "id": "TOznfYKmhUaB",
79
    "outputId": "6b060a7f-72ae-43c1-c0b0-e521b770ffad"
80
   },
81
   "outputs": [
82
    {
83
     "data": {
84
      "text/plain": [
85
       "1"
86
      ]
87
     },
88
     "execution_count": 4,
89
     "metadata": {
90
      "tags": []
91
     },
92
     "output_type": "execute_result"
93
    }
94
   ],
95
   "source": [
96
    "# Check mounted GPU devices\n",
97
    "torch.cuda.device_count()"
98
   ]
99
  },
100
  {
101
   "cell_type": "code",
102
   "execution_count": 0,
103
   "metadata": {
104
    "colab": {
105
     "base_uri": "https://localhost:8080/",
106
     "height": 34
107
    },
108
    "colab_type": "code",
109
    "id": "BfqqvsjvhUaF",
110
    "outputId": "5afac64a-4654-41df-ed6d-6bc968cc8ef7"
111
   },
112
   "outputs": [
113
    {
114
     "data": {
115
      "text/plain": [
116
       "0"
117
      ]
118
     },
119
     "execution_count": 16,
120
     "metadata": {
121
      "tags": []
122
     },
123
     "output_type": "execute_result"
124
    }
125
   ],
126
   "source": [
127
    "# Current device you're using\n",
128
    "# * 0-indexed *\n",
129
    "torch.cuda.current_device()"
130
   ]
131
  },
132
  {
133
   "cell_type": "code",
134
   "execution_count": 2,
135
   "metadata": {
136
    "code_folding": [],
137
    "colab": {
138
     "base_uri": "https://localhost:8080/",
139
     "height": 289
140
    },
141
    "colab_type": "code",
142
    "id": "hIcrVEJWhUaK",
143
    "outputId": "52eb889d-49ef-433b-9528-0ac9b514dc76"
144
   },
145
   "outputs": [
146
    {
147
     "name": "stdout",
148
     "output_type": "stream",
149
     "text": [
150
      "Wed Jun 10 22:31:28 2020       \n",
151
      "+-----------------------------------------------------------------------------+\n",
152
      "| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |\n",
153
      "|-------------------------------+----------------------+----------------------+\n",
154
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
155
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
156
      "|===============================+======================+======================|\n",
157
      "|   0  Tesla V100-PCIE...  On   | 00000000:00:06.0 Off |                    0 |\n",
158
      "| N/A   36C    P0    41W / 250W |   4990MiB / 32480MiB |      0%      Default |\n",
159
      "+-------------------------------+----------------------+----------------------+\n",
160
      "|   1  Tesla V100-PCIE...  On   | 00000000:00:07.0 Off |                    0 |\n",
161
      "| N/A   37C    P0    43W / 250W |    936MiB / 32480MiB |      0%      Default |\n",
162
      "+-------------------------------+----------------------+----------------------+\n",
163
      "|   2  Tesla V100-PCIE...  On   | 00000000:00:08.0 Off |                    0 |\n",
164
      "| N/A   34C    P0    39W / 250W |    936MiB / 32480MiB |      0%      Default |\n",
165
      "+-------------------------------+----------------------+----------------------+\n",
166
      "                                                                               \n",
167
      "+-----------------------------------------------------------------------------+\n",
168
      "| Processes:                                                       GPU Memory |\n",
169
      "|  GPU       PID   Type   Process name                             Usage      |\n",
170
      "|=============================================================================|\n",
171
      "|    0      5744      C   ...u/anaconda3/envs/pytorch_p36/bin/python   941MiB |\n",
172
      "|    0      8881      C   ...u/anaconda3/envs/pytorch_p36/bin/python   941MiB |\n",
173
      "+-----------------------------------------------------------------------------+\n"
174
     ]
175
    }
176
   ],
177
   "source": [
178
    "# Check workloads of your GPU(s)\n",
179
    "!nvidia-smi"
180
   ]
181
  },
182
  {
183
   "cell_type": "code",
184
   "execution_count": 4,
185
   "metadata": {
186
    "colab": {},
187
    "colab_type": "code",
188
    "id": "2BmajeAXhUaO"
189
   },
190
   "outputs": [],
191
   "source": [
192
    "# Reset your current device (if necessary)\n",
193
    "torch.cuda.set_device(1)"
194
   ]
195
  },
196
  {
197
   "cell_type": "code",
198
   "execution_count": 3,
199
   "metadata": {
200
    "colab": {},
201
    "colab_type": "code",
202
    "id": "tOVolw1UhUaS",
203
    "outputId": "eab8d203-31ee-41b2-822c-3a0abb058fe5"
204
   },
205
   "outputs": [
206
    {
207
     "data": {
208
      "text/plain": [
209
       "1"
210
      ]
211
     },
212
     "execution_count": 3,
213
     "metadata": {},
214
     "output_type": "execute_result"
215
    }
216
   ],
217
   "source": [
218
    "# Check change's been made\n",
219
    "torch.cuda.current_device()\n"
220
   ]
221
  },
222
  {
223
   "cell_type": "code",
224
   "execution_count": 5,
225
   "metadata": {
226
    "colab": {
227
     "base_uri": "https://localhost:8080/",
228
     "height": 34
229
    },
230
    "colab_type": "code",
231
    "id": "wXALsotzhUaY",
232
    "outputId": "c2118aaa-bf6c-4d23-db4b-bfc0667ed5a2"
233
   },
234
   "outputs": [
235
    {
236
     "data": {
237
      "text/plain": [
238
       "'Tesla V100-PCIE-32GB'"
239
      ]
240
     },
241
     "execution_count": 5,
242
     "metadata": {},
243
     "output_type": "execute_result"
244
    }
245
   ],
246
   "source": [
247
    "# Check name of your device\n",
248
    "torch.cuda.get_device_name()"
249
   ]
250
  },
251
  {
252
   "cell_type": "markdown",
253
   "metadata": {
254
    "colab_type": "text",
255
    "id": "CSmixVUShUac"
256
   },
257
   "source": [
258
    "# Model Prototyping"
259
   ]
260
  },
261
  {
262
   "cell_type": "markdown",
263
   "metadata": {
264
    "colab_type": "text",
265
    "heading_collapsed": true,
266
    "id": "btBiurUPxx7t"
267
   },
268
   "source": [
269
    "## Helpers "
270
   ]
271
  },
272
  {
273
   "cell_type": "code",
274
   "execution_count": 5,
275
   "metadata": {
276
    "code_folding": [
277
     0
278
    ],
279
    "colab": {},
280
    "colab_type": "code",
281
    "hidden": true,
282
    "id": "jIljYlKohUad"
283
   },
284
   "outputs": [],
285
   "source": [
286
    "def conv_block(c_in, c_out, ks, num_groups=None, **conv_kwargs):\n",
287
    "    \"A sequence of modules composed of Group Norm, ReLU and Conv3d in order\"\n",
288
    "    if not num_groups : num_groups = int(c_in/2) if c_in%2 == 0 else None\n",
289
    "    return nn.Sequential(nn.GroupNorm(num_groups, c_in),\n",
290
    "                         nn.ReLU(),\n",
291
    "                         nn.Conv3d(c_in, c_out, ks, **conv_kwargs))"
292
   ]
293
  },
294
  {
295
   "cell_type": "code",
296
   "execution_count": 6,
297
   "metadata": {
298
    "code_folding": [
299
     0
300
    ],
301
    "colab": {},
302
    "colab_type": "code",
303
    "hidden": true,
304
    "id": "fLbUaZT7hUag"
305
   },
306
   "outputs": [],
307
   "source": [
308
    "def reslike_block(nf, num_groups=None, bottle_neck:bool=False, **conv_kwargs):\n",
309
    "    \"A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality\"\n",
310
    "    nf_inner = nf / 2 if bottle_neck else nf\n",
311
    "    return SequentialEx(conv_block(num_groups=num_groups, c_in=nf, c_out=nf_inner, ks=3, stride=1, padding=1, **conv_kwargs),\n",
312
    "                        conv_block(num_groups=num_groups, c_in=nf_inner, c_out=nf, ks=3, stride=1, padding=1, **conv_kwargs),\n",
313
    "                        MergeLayer())"
314
   ]
315
  },
316
  {
317
   "cell_type": "code",
318
   "execution_count": 7,
319
   "metadata": {
320
    "code_folding": [
321
     0
322
    ],
323
    "colab": {},
324
    "colab_type": "code",
325
    "hidden": true,
326
    "id": "BKNJoGdsgbk2"
327
   },
328
   "outputs": [],
329
   "source": [
330
    "def upsize(c_in, c_out, ks=1, scale=2):\n",
331
    "    \"Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling\"\n",
332
    "    return nn.Sequential(nn.Conv3d(c_in, c_out, ks),\n",
333
    "                       nn.Upsample(scale_factor=scale, mode='trilinear'))"
334
   ]
335
  },
336
  {
337
   "cell_type": "code",
338
   "execution_count": 8,
339
   "metadata": {
340
    "code_folding": [
341
     0
342
    ],
343
    "colab": {},
344
    "colab_type": "code",
345
    "hidden": true,
346
    "id": "6CjiBTnT8LFF"
347
   },
348
   "outputs": [],
349
   "source": [
350
    "def hook_debug(module, input, output):\n",
351
    "    \"\"\"\n",
352
    "    Print out what's been hooked usually for debugging purpose\n",
353
    "    ----------------------------------------------------------\n",
354
    "       Example:\n",
355
    "       Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
356
    "    \n",
357
    "    \"\"\"\n",
358
    "    print('Hooking ' + module.__class__.__name__)\n",
359
    "    print('output size:', output.data.size())\n",
360
    "    return output"
361
   ]
362
  },
363
  {
364
   "cell_type": "markdown",
365
   "metadata": {
366
    "colab_type": "text",
367
    "id": "MY131WWbx3nN"
368
   },
369
   "source": [
370
    "## Encoder Part"
371
   ]
372
  },
373
  {
374
   "cell_type": "code",
375
   "execution_count": 9,
376
   "metadata": {
377
    "code_folding": [
378
     0
379
    ],
380
    "colab": {},
381
    "colab_type": "code",
382
    "id": "f_8ynHTavdvL"
383
   },
384
   "outputs": [],
385
   "source": [
386
    "class Encoder(nn.Module):\n",
387
    "    \"Encoder part\"\n",
388
    "    def __init__(self):\n",
389
    "        super().__init__()\n",
390
    "        self.conv1 = nn.Conv3d(4, 32, 3, stride=1, padding=1)         \n",
391
    "        self.res_block1 = reslike_block(32, num_groups=8)\n",
392
    "        self.conv_block1 = conv_block(32, 64, 3, num_groups=8, stride=2, padding=1)\n",
393
    "        self.res_block2 = reslike_block(64, num_groups=8)\n",
394
    "        self.conv_block2 = conv_block(64, 64, 3, num_groups=8, stride=1, padding=1)\n",
395
    "        self.res_block3 = reslike_block(64, num_groups=8)\n",
396
    "        self.conv_block3 = conv_block(64, 128, 3, num_groups=8, stride=2, padding=1)\n",
397
    "        self.res_block4 = reslike_block(128, num_groups=8)\n",
398
    "        self.conv_block4 = conv_block(128, 128, 3, num_groups=8, stride=1, padding=1)\n",
399
    "        self.res_block5 = reslike_block(128, num_groups=8)\n",
400
    "        self.conv_block5 = conv_block(128, 256, 3, num_groups=8, stride=2, padding=1)\n",
401
    "        self.res_block6 = reslike_block(256, num_groups=8)\n",
402
    "        self.conv_block6 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
403
    "        self.res_block7 = reslike_block(256, num_groups=8)\n",
404
    "        self.conv_block7 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
405
    "        self.res_block8 = reslike_block(256, num_groups=8)\n",
406
    "        self.conv_block8 = conv_block(256, 256, 3, num_groups=8, stride=1, padding=1)\n",
407
    "        self.res_block9 = reslike_block(256, num_groups=8)\n",
408
    "    \n",
409
    "    def forward(self, x):\n",
410
    "        x = self.conv1(x)                                           # Output size: (1, 32, 160, 192, 128)\n",
411
    "        x = self.res_block1(x)                                      # Output size: (1, 32, 160, 192, 128)\n",
412
    "        x = self.conv_block1(x)                                     # Output size: (1, 64, 80, 96, 64)\n",
413
    "        x = self.res_block2(x)                                      # Output size: (1, 64, 80, 96, 64)\n",
414
    "        x = self.conv_block2(x)                                     # Output size: (1, 64, 80, 96, 64)\n",
415
    "        x = self.res_block3(x)                                      # Output size: (1, 64, 80, 96, 64)\n",
416
    "        x = self.conv_block3(x)                                     # Output size: (1, 128, 40, 48, 32)\n",
417
    "        x = self.res_block4(x)                                      # Output size: (1, 128, 40, 48, 32)\n",
418
    "        x = self.conv_block4(x)                                     # Output size: (1, 128, 40, 48, 32)\n",
419
    "        x = self.res_block5(x)                                      # Output size: (1, 128, 40, 48, 32)\n",
420
    "        x = self.conv_block5(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
421
    "        x = self.res_block6(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
422
    "        x = self.conv_block6(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
423
    "        x = self.res_block7(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
424
    "        x = self.conv_block7(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
425
    "        x = self.res_block8(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
426
    "        x = self.conv_block8(x)                                     # Output size: (1, 256, 20, 24, 16)\n",
427
    "        x = self.res_block9(x)                                      # Output size: (1, 256, 20, 24, 16)\n",
428
    "        return x"
429
   ]
430
  },
431
  {
432
   "cell_type": "code",
433
   "execution_count": 23,
434
   "metadata": {
435
    "code_folding": [
436
     0
437
    ],
438
    "colab": {},
439
    "colab_type": "code",
440
    "id": "bgPmjWCq6n1F"
441
   },
442
   "outputs": [],
443
   "source": [
444
    "########## Sanity-check ############\n",
445
    "# input = torch.randn(1, 4, 160, 192, 128)\n",
446
    "# input = input.cuda()\n",
447
    "# encoder = Encoder()\n",
448
    "# encoder.cuda()\n",
449
    "# ms = [encoder.res_block1, encoder.res_block3, encoder.res_block5]\n",
450
    "# hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
451
    "# output = encoder(input)"
452
   ]
453
  },
454
  {
455
   "cell_type": "markdown",
456
   "metadata": {
457
    "colab_type": "text",
458
    "id": "BKmlf1qY74Fx"
459
   },
460
   "source": [
461
    "## Decoder Part"
462
   ]
463
  },
464
  {
465
   "cell_type": "code",
466
   "execution_count": 10,
467
   "metadata": {
468
    "code_folding": [
469
     0
470
    ],
471
    "colab": {},
472
    "colab_type": "code",
473
    "id": "p9jCdQAeBTch"
474
   },
475
   "outputs": [],
476
   "source": [
477
    "class Decoder(nn.Module):\n",
478
    "    \"Decoder Part\"\n",
479
    "    def __init__(self):\n",
480
    "        super().__init__()\n",
481
    "        self.upsize1 = upsize(256, 128)\n",
482
    "        self.reslike1 = reslike_block(128, num_groups=8)\n",
483
    "        self.upsize2 = upsize(128, 64)\n",
484
    "        self.reslike2 = reslike_block(64, num_groups=8)\n",
485
    "        self.upsize3 = upsize(64, 32)\n",
486
    "        self.reslike3 = reslike_block(32, num_groups=8)\n",
487
    "        self.conv1 = nn.Conv3d(32, 3, 1) \n",
488
    "        self.sigmoid1 = torch.nn.Sigmoid()\n",
489
    "\n",
490
    "    def forward(self, x):\n",
491
    "        x = self.upsize1(x)                                         # Output size: (1, 128, 40, 48, 32)\n",
492
    "        x = x + hooks.stored[2]                                     # Output size: (1, 128, 40, 48, 32)\n",
493
    "        x = self.reslike1(x)                                        # Output size: (1, 128, 40, 48, 32)\n",
494
    "        x = self.upsize2(x)                                         # Output size: (1, 64, 80, 96, 64)\n",
495
    "        x = x + hooks.stored[1]                                     # Output size: (1, 64, 80, 96, 64)\n",
496
    "        x = self.reslike2(x)                                        # Output size: (1, 64, 80, 96, 64)\n",
497
    "        x = self.upsize3(x)                                         # Output size: (1, 32, 160, 192, 128)\n",
498
    "        x = x + hooks.stored[0]                                     # Output size: (1, 32, 160, 192, 128)\n",
499
    "        x = self.reslike3(x)                                        # Output size: (1, 32, 160, 192, 128)\n",
500
    "        x = self.conv1(x)                                           # Output size: (1, 3, 160, 192, 128)\n",
501
    "        x = self.sigmoid1(x)                                        # Output size: (1, 3, 160, 192, 128)\n",
502
    "        return x"
503
   ]
504
  },
505
  {
506
   "cell_type": "code",
507
   "execution_count": 0,
508
   "metadata": {
509
    "code_folding": [
510
     0
511
    ],
512
    "colab": {},
513
    "colab_type": "code",
514
    "id": "54LhlCx7hOt6"
515
   },
516
   "outputs": [],
517
   "source": [
518
    "############ Sanity-check ############\n",
519
    "# input = torch.randn(1, 256, 20, 24, 16)\n",
520
    "# input = input.cuda()\n",
521
    "# decoder = Decoder()\n",
522
    "# decoder.cuda()\n",
523
    "# output = decoder(input)\n",
524
    "# output.shape"
525
   ]
526
  },
527
  {
528
   "cell_type": "markdown",
529
   "metadata": {
530
    "colab_type": "text",
531
    "id": "Sq9kLEFbx8sF"
532
   },
533
   "source": [
534
    "## VAE Part"
535
   ]
536
  },
537
  {
538
   "cell_type": "code",
539
   "execution_count": 11,
540
   "metadata": {
541
    "code_folding": [],
542
    "colab": {},
543
    "colab_type": "code",
544
    "id": "KEpqknq3hUaq"
545
   },
546
   "outputs": [],
547
   "source": [
548
    "class VAEEncoder(nn.Module):\n",
549
    "    \"Variational auto-encoder encoder part\"\n",
550
    "    def __init__(self, latent_dim:int=128):\n",
551
    "        super().__init__()\n",
552
    "        self.latent_dim = latent_dim\n",
553
    "        self.conv_block = conv_block(256, 16, 3, num_groups=8, stride=2, padding=1)\n",
554
    "        self.linear1 = nn.Linear(60, 1)\n",
555
    "        \n",
556
    "        # Assumed latent variable's probability density function parameters\n",
557
    "        self.z_mean = nn.Linear(256, latent_dim)\n",
558
    "        self.z_log_var = nn.Linear(256, latent_dim)\n",
559
    "        #TODO: It should work with or without GPU\n",
560
    "        self.epsilon = torch.randn(1, latent_dim, device='cuda')\n",
561
    "        \n",
562
    "    def forward(self, x):\n",
563
    "        x = self.conv_block(x)                                   # Output size: (1, 16, 10, 12, 8)                                  \n",
564
    "        x = x.view(256, -1)                                      # Output size: (256, 60)                                       \n",
565
    "        x = self.linear1(x)                                      # Output size: (256, 1)\n",
566
    "        x = x.view(1, 256)                                       # Output size: (1, 256)   \n",
567
    "        z_mean = self.z_mean(x)                                  # Output size: (1, 128)\n",
568
    "        z_var = self.z_log_var(x).exp()                          # Output size: (1, 128)              \n",
569
    "        \n",
570
    "        return z_mean + z_var * self.epsilon                     # Output size: (1, 128)                              "
571
   ]
572
  },
573
  {
574
   "cell_type": "code",
575
   "execution_count": 11,
576
   "metadata": {
577
    "code_folding": [
578
     0
579
    ],
580
    "colab": {
581
     "base_uri": "https://localhost:8080/",
582
     "height": 34
583
    },
584
    "colab_type": "code",
585
    "id": "ll26pBm9tj7-",
586
    "outputId": "f1e9300e-8e79-4c66-8d0e-6897ce6b7f80"
587
   },
588
   "outputs": [],
589
   "source": [
590
    "############ Sanity-check ############\n",
591
    "# input = torch.randn(1, 256, 20, 24, 16)\n",
592
    "# input = input.cuda()\n",
593
    "# vae_encoder = VAEEncoder(latent_dim=128)\n",
594
    "# vae_encoder.cuda()\n",
595
    "# output = vae_encoder(output)\n",
596
    "# output.shape"
597
   ]
598
  },
599
  {
600
   "cell_type": "code",
601
   "execution_count": 12,
602
   "metadata": {
603
    "code_folding": [
604
     0
605
    ],
606
    "colab": {},
607
    "colab_type": "code",
608
    "id": "tl4tYTaXe1qw"
609
   },
610
   "outputs": [],
611
   "source": [
612
    "class VAEDecoder(nn.Module):\n",
613
    "    \"Variational auto-encoder decoder part\"\n",
614
    "    def __init__(self):\n",
615
    "        super().__init__()\n",
616
    "        self.linear1 = nn.Linear(128, 256*60)\n",
617
    "        self.relu1 = nn.ReLU()\n",
618
    "        self.upsize1 = upsize(16, 256)\n",
619
    "        self.upsize2 = upsize(256, 128)\n",
620
    "        self.reslike1 = reslike_block(128, num_groups=8)\n",
621
    "        self.upsize3 = upsize(128, 64)\n",
622
    "        self.reslike2 = reslike_block(64, num_groups=8)\n",
623
    "        self.upsize4 = upsize(64, 32)\n",
624
    "        self.reslike3 = reslike_block(32, num_groups=8)\n",
625
    "        self.conv1 = nn.Conv3d(32, 4, 1)\n",
626
    "    \n",
627
    "    def forward(self, x):\n",
628
    "        x = self.linear1(x)                                          # Output size: (1, 256*60)      \n",
629
    "        x = self.relu1(x)                                            # Output size: (1, 256*60)\n",
630
    "        x = x.view(1, 16, 10, 12, 8)                                 # Output size: (1, 16, 10, 12, 8)\n",
631
    "        x = self.upsize1(x)                                          # Output size: (1, 256, 20, 24, 16)\n",
632
    "        x = self.upsize2(x)                                          # Output size: (1, 128, 40, 48, 32)\n",
633
    "        x = self.reslike1(x)                                         # Output size: (1, 128, 40, 48, 32)\n",
634
    "        x = self.upsize3(x)                                          # Output size: (1, 64, 80, 96, 64)\n",
635
    "        x = self.reslike2(x)                                         # Output size: (1, 64, 80, 96, 64)\n",
636
    "        x = self.upsize4(x)                                          # Output size: (1, 32, 160, 192, 128)\n",
637
    "        x = self.reslike3(x)                                         # Output size: (1, 32, 160, 192, 128)\n",
638
    "        x = self.conv1(x)                                            # Output size: (1, 4, 160, 192, 128) \n",
639
    "        return x"
640
   ]
641
  },
642
  {
643
   "cell_type": "code",
644
   "execution_count": 0,
645
   "metadata": {
646
    "code_folding": [
647
     0
648
    ],
649
    "colab": {},
650
    "colab_type": "code",
651
    "id": "RrusoNDpzPOk"
652
   },
653
   "outputs": [],
654
   "source": [
655
    "############ Sanity-check ############\n",
656
    "# input = torch.randn(1, 128)\n",
657
    "# input = input.cuda()\n",
658
    "# vae_decoder = VAEDecoder()\n",
659
    "# vae_decoder.cuda()\n",
660
    "# vae_decoder(output).shape"
661
   ]
662
  },
663
  {
664
   "cell_type": "markdown",
665
   "metadata": {
666
    "colab_type": "text",
667
    "id": "dtLzCKAOEn6c"
668
   },
669
   "source": [
670
    "## AutoUNet"
671
   ]
672
  },
673
  {
674
   "cell_type": "code",
675
   "execution_count": 13,
676
   "metadata": {
677
    "code_folding": [],
678
    "colab": {},
679
    "colab_type": "code",
680
    "id": "9lhVuR2QExrp"
681
   },
682
   "outputs": [],
683
   "source": [
684
    "class AutoUNet(nn.Module):\n",
685
    "  \"3D U-Net using autoencoder regularization\"\n",
686
    "  def __init__(self):\n",
687
    "    super().__init__()\n",
688
    "    self.encoder = Encoder()\n",
689
    "    self.decoder = Decoder()\n",
690
    "    self.vencoder = VAEEncoder(latent_dim=128)\n",
691
    "    self.vdecoder = VAEDecoder()\n",
692
    "\n",
693
    "  def forward(self, input):\n",
694
    "    interm_res = self.encoder(input)\n",
695
    "    top_res = self.decoder(interm_res)                               # Output size: (1, 3, 160, 192, 128)\n",
696
    "    bottom_res = self.vdecoder(self.vencoder(interm_res))            # Output size: (1, 4, 160, 192, 128)\n",
697
    "    return top_res, bottom_res"
698
   ]
699
  },
700
  {
701
   "cell_type": "code",
702
   "execution_count": null,
703
   "metadata": {
704
    "code_folding": [],
705
    "scrolled": true
706
   },
707
   "outputs": [],
708
   "source": [
709
    "############ Sanity-check ############\n",
710
    "input = torch.randn(1, 4, 160, 192, 128)\n",
711
    "input = input.cuda()\n",
712
    "model = AutoUNet()\n",
713
    "model.cuda()\n",
714
    "\n",
715
    "ms = [model.encoder.res_block1, \n",
716
    "      model.encoder.res_block3, \n",
717
    "      model.encoder.res_block5, \n",
718
    "      model.vencoder.z_mean, \n",
719
    "      model.vencoder.z_log_var]\n",
720
    "\n",
721
    "hooks = hook_outputs(ms, detach=False, grad=False) #check: overwrite for each iteration?\n",
722
    "#hooks = Hooks(ms, hook_debug, is_forward=True, detach=False)\n",
723
    "\n",
724
    "output = model(input)"
725
   ]
726
  },
727
  {
728
   "cell_type": "markdown",
729
   "metadata": {
730
    "colab_type": "text",
731
    "id": "ZSPf7atqhOuG"
732
   },
733
   "source": [
734
    "## Custom Loss "
735
   ]
736
  },
737
  {
738
   "cell_type": "code",
739
   "execution_count": null,
740
   "metadata": {
741
    "code_folding": [],
742
    "colab": {
743
     "base_uri": "https://localhost:8080/",
744
     "height": 85
745
    },
746
    "colab_type": "code",
747
    "id": "OQ4vfaR-L9Wz",
748
    "outputId": "cd5cb780-4027-4e12-e0de-07485713db38",
749
    "scrolled": false
750
   },
751
   "outputs": [],
752
   "source": [
753
    "# Set the global variables\n",
754
    "_, C, H, W, D = [input.shape[i] for i in range(len(input.shape))]\n",
755
    "c = output[0].shape[1]\n",
756
    "\n",
757
    "print(\"Channels:\", C)\n",
758
    "print(\"Height:\", H)\n",
759
    "print(\"Width:\", W)\n",
760
    "print(\"Depth:\", D)\n",
761
    "print(\"The Number Of Labels:\", c)"
762
   ]
763
  },
764
  {
765
   "cell_type": "code",
766
   "execution_count": 0,
767
   "metadata": {
768
    "code_folding": [],
769
    "colab": {},
770
    "colab_type": "code",
771
    "id": "j7cmXkIvhOuI"
772
   },
773
   "outputs": [],
774
   "source": [
775
    "class SoftDiceLoss(Module): \n",
776
    "    \"Soft dice loss based on a measure of overlap between prediction and ground truth\"\n",
777
    "    def __init__(self, epsilon=1e-6, c=c):\n",
778
    "        super().__init__()\n",
779
    "        self.epsilon = epsilon\n",
780
    "        self.c = c\n",
781
    "    \n",
782
    "    def forward(self, x:Tensor, y:Tensor):\n",
783
    "        intersection = 2 * ( (x*y).sum() )\n",
784
    "        union = (x**2).sum() + (y**2).sum() \n",
785
    "        return 1 - ( ( intersection / (union + self.epsilon) ) / self.c )"
786
   ]
787
  },
788
  {
789
   "cell_type": "code",
790
   "execution_count": null,
791
   "metadata": {
792
    "code_folding": [
793
     0
794
    ]
795
   },
796
   "outputs": [],
797
   "source": [
798
    "####### Sanity-check ############\n",
799
    "loss = "
800
   ]
801
  },
802
  {
803
   "cell_type": "code",
804
   "execution_count": 16,
805
   "metadata": {
806
    "code_folding": [],
807
    "colab": {},
808
    "colab_type": "code",
809
    "id": "kOjrJ44uhOuK"
810
   },
811
   "outputs": [],
812
   "source": [
813
    "class KLDivergence(Module): \n",
814
    "    \"KL divergence between the estimated normal distribution and a prior distribution\"\n",
815
    "    N = H * W * D  #hyperparameter check\n",
816
    "\n",
817
    "    def __init__(self):\n",
818
    "        super().__init__()\n",
819
    "    \n",
820
    "    def forward(self, z_mean:Tensor, z_log_var:Tensor):\n",
821
    "        z_var = z_log_var.exp()\n",
822
    "        return (1/self.N) * ( (z_mean**2 + z_var**2 - z_log_var**2 - 1).sum() )"
823
   ]
824
  },
825
  {
826
   "cell_type": "code",
827
   "execution_count": null,
828
   "metadata": {
829
    "code_folding": []
830
   },
831
   "outputs": [],
832
   "source": [
833
    "####### Sanity-check ############\n",
834
    "loss2 = KLDivergence()(z_mean=hooks.stored[3], z_log_var=hooks.stored[4])\n",
835
    "print(loss2)\n",
836
    "loss2.backward()"
837
   ]
838
  },
839
  {
840
   "cell_type": "code",
841
   "execution_count": 18,
842
   "metadata": {
843
    "code_folding": [
844
     0
845
    ],
846
    "colab": {},
847
    "colab_type": "code",
848
    "id": "HycYhLrohOuM"
849
   },
850
   "outputs": [],
851
   "source": [
852
    "class L2Loss(Module): \n",
853
    "    \"Measuring the `Euclidian distance` between prediction and ground truh using `L2 Norm`\"\n",
854
    "    def __init__(self):\n",
855
    "        super().__init__()\n",
856
    "        \n",
857
    "    def forward(self, x:Tensor, y:Tensor):\n",
858
    "        return  ( (x - y)**2 ).sum()       "
859
   ]
860
  },
861
  {
862
   "cell_type": "code",
863
   "execution_count": null,
864
   "metadata": {
865
    "code_folding": [
866
     0
867
    ]
868
   },
869
   "outputs": [],
870
   "source": [
871
    "####### Sanity-check ############\n",
872
    "loss3 = L2Loss()(bottom_res=output[1], orig=input)\n",
873
    "print(loss3)\n",
874
    "loss3.backward()"
875
   ]
876
  },
877
  {
878
   "cell_type": "markdown",
879
   "metadata": {
880
    "colab_type": "text",
881
    "id": "MsP_HOw2_6Jd"
882
   },
883
   "source": [
884
    "## Optimizer"
885
   ]
886
  },
887
  {
888
   "cell_type": "code",
889
   "execution_count": 0,
890
   "metadata": {
891
    "colab": {},
892
    "colab_type": "code",
893
    "id": "XYaFQ6nQ_8O4"
894
   },
895
   "outputs": [],
896
   "source": [
897
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)"
898
   ]
899
  },
900
  {
901
   "cell_type": "markdown",
902
   "metadata": {
903
    "colab_type": "text",
904
    "id": "GPK9Qfc0_tGL"
905
   },
906
   "source": [
907
    "## Training"
908
   ]
909
  },
910
  {
911
   "cell_type": "code",
912
   "execution_count": 0,
913
   "metadata": {
914
    "code_folding": [],
915
    "colab": {},
916
    "colab_type": "code",
917
    "id": "3YkqxURk_w8K"
918
   },
919
   "outputs": [],
920
   "source": [
921
    "for epoch in range(epochs):\n",
922
    "  \n",
923
    "  model.train()\n",
924
    "  for xb,yb in train_dl:\n",
925
    "    top_res, bottom_res = model(xb)\n",
926
    "    top_y, bottom_y = train_seg, input\n",
927
    "    z_mean, z_log_var = hooks.stored[4], hooks.stored[5] \n",
928
    "    loss = SoftDiceLoss()(top_res, top_y) + \\\n",
929
    "           (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n",
930
    "           (0.1 * L2Loss()(bottom_res, bottom_y))\n",
931
    "    loss.backward()\n",
932
    "    optimizer.step()\n",
933
    "    optimizer.zero_grad()\n",
934
    "\n",
935
    "  model.eval()\n",
936
    "  with torch.no_grad():\n",
937
    "    tot_loss, tot_acc = 0., 0.\n",
938
    "    for xb, yb in valid_dl:  \n",
939
    "    top_res, bottom_res = model(xb)\n",
940
    "    top_y, bottom_y = valid_seg, input\n",
941
    "    z_mean, z_log_var = hooks.stored[4], hooks.stored[5]\n",
942
    "    loss = SoftDiceLoss()(top_res, top_y) + \\\n",
943
    "           (0.1 * KLDivergence()(z_mean, z_log_var)) + \\\n",
944
    "           (0.1 * L2Loss()(bottom_res, bottom_y))    \n",
945
    "    tot_loss += loss\n",
946
    "    tot_acc += dice_coeff\n",
947
    "\n",
948
    "  nv = len(valid_dl)\n",
949
    "  return tot_loss/nv, tot_acc/nv"
950
   ]
951
  },
952
  {
953
   "cell_type": "markdown",
954
   "metadata": {
955
    "colab_type": "text",
956
    "heading_collapsed": true,
957
    "id": "GXaVq0m5hUbO"
958
   },
959
   "source": [
960
    "## Memory-check"
961
   ]
962
  },
963
  {
964
   "cell_type": "code",
965
   "execution_count": 21,
966
   "metadata": {
967
    "colab": {},
968
    "colab_type": "code",
969
    "hidden": true,
970
    "id": "Xuy-W1NFhUbR",
971
    "outputId": "f9a8cc29-2291-488f-ea8f-44b63cb8bd29"
972
   },
973
   "outputs": [
974
    {
975
     "data": {
976
      "text/plain": [
977
       "9884946432"
978
      ]
979
     },
980
     "execution_count": 21,
981
     "metadata": {},
982
     "output_type": "execute_result"
983
    }
984
   ],
985
   "source": [
986
    "# Memory ocuupied by Pytorch `Tensors`\n",
987
    "torch.cuda.memory_allocated(device=None)"
988
   ]
989
  },
990
  {
991
   "cell_type": "code",
992
   "execution_count": 22,
993
   "metadata": {
994
    "colab": {
995
     "base_uri": "https://localhost:8080/",
996
     "height": 697
997
    },
998
    "colab_type": "code",
999
    "hidden": true,
1000
    "id": "-F1mbF44hUbO",
1001
    "outputId": "6b4ff5a9-766a-48d0-fc2c-bd0675e303e8",
1002
    "scrolled": true
1003
   },
1004
   "outputs": [
1005
    {
1006
     "name": "stdout",
1007
     "output_type": "stream",
1008
     "text": [
1009
      "|===========================================================================|\n",
1010
      "|                  PyTorch CUDA memory summary, device ID 1                 |\n",
1011
      "|---------------------------------------------------------------------------|\n",
1012
      "|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |\n",
1013
      "|===========================================================================|\n",
1014
      "|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |\n",
1015
      "|---------------------------------------------------------------------------|\n",
1016
      "| Allocated memory      |    9427 MB |    9855 MB |   10859 MB |    1432 MB |\n",
1017
      "|       from large pool |    9423 MB |    9851 MB |   10847 MB |    1424 MB |\n",
1018
      "|       from small pool |       3 MB |       3 MB |      11 MB |       8 MB |\n",
1019
      "|---------------------------------------------------------------------------|\n",
1020
      "| Active memory         |    9427 MB |    9855 MB |   10859 MB |    1432 MB |\n",
1021
      "|       from large pool |    9423 MB |    9851 MB |   10847 MB |    1424 MB |\n",
1022
      "|       from small pool |       3 MB |       3 MB |      11 MB |       8 MB |\n",
1023
      "|---------------------------------------------------------------------------|\n",
1024
      "| GPU reserved memory   |    9482 MB |   10012 MB |   10012 MB |  542720 KB |\n",
1025
      "|       from large pool |    9478 MB |   10008 MB |   10008 MB |  542720 KB |\n",
1026
      "|       from small pool |       4 MB |       4 MB |       4 MB |       0 KB |\n",
1027
      "|---------------------------------------------------------------------------|\n",
1028
      "| Non-releasable memory |   56300 KB |  168416 KB |     981 MB |     926 MB |\n",
1029
      "|       from large pool |   55744 KB |  167872 KB |     970 MB |     915 MB |\n",
1030
      "|       from small pool |     556 KB |    2034 KB |      11 MB |      11 MB |\n",
1031
      "|---------------------------------------------------------------------------|\n",
1032
      "| Allocations           |     193    |     265    |     837    |     644    |\n",
1033
      "|       from large pool |      76    |     124    |     145    |      69    |\n",
1034
      "|       from small pool |     117    |     142    |     692    |     575    |\n",
1035
      "|---------------------------------------------------------------------------|\n",
1036
      "| Active allocs         |     193    |     265    |     837    |     644    |\n",
1037
      "|       from large pool |      76    |     124    |     145    |      69    |\n",
1038
      "|       from small pool |     117    |     142    |     692    |     575    |\n",
1039
      "|---------------------------------------------------------------------------|\n",
1040
      "| GPU reserved segments |      66    |      91    |      91    |      25    |\n",
1041
      "|       from large pool |      64    |      89    |      89    |      25    |\n",
1042
      "|       from small pool |       2    |       2    |       2    |       0    |\n",
1043
      "|---------------------------------------------------------------------------|\n",
1044
      "| Non-releasable allocs |      10    |      31    |     100    |      90    |\n",
1045
      "|       from large pool |       8    |      29    |      42    |      34    |\n",
1046
      "|       from small pool |       2    |       5    |      58    |      56    |\n",
1047
      "|===========================================================================|\n",
1048
      "\n"
1049
     ]
1050
    }
1051
   ],
1052
   "source": [
1053
    "# Memory status\n",
1054
    "print(torch.cuda.memory_summary(device=None, abbreviated=False))"
1055
   ]
1056
  }
1057
 ],
1058
 "metadata": {
1059
  "accelerator": "GPU",
1060
  "colab": {
1061
   "collapsed_sections": [
1062
    "b6KgSfpahUZ1",
1063
    "C77ffh0whUZ7",
1064
    "btBiurUPxx7t",
1065
    "MY131WWbx3nN",
1066
    "BKmlf1qY74Fx",
1067
    "Sq9kLEFbx8sF",
1068
    "dtLzCKAOEn6c",
1069
    "ZSPf7atqhOuG",
1070
    "MsP_HOw2_6Jd",
1071
    "GPK9Qfc0_tGL",
1072
    "GXaVq0m5hUbO",
1073
    "v-ODWm3ehUbG"
1074
   ],
1075
   "name": "model_prototype_1.ipynb",
1076
   "provenance": []
1077
  },
1078
  "kernelspec": {
1079
   "display_name": "Python 3",
1080
   "language": "python",
1081
   "name": "python3"
1082
  },
1083
  "language_info": {
1084
   "codemirror_mode": {
1085
    "name": "ipython",
1086
    "version": 3
1087
   },
1088
   "file_extension": ".py",
1089
   "mimetype": "text/x-python",
1090
   "name": "python",
1091
   "nbconvert_exporter": "python",
1092
   "pygments_lexer": "ipython3",
1093
   "version": "3.6.5"
1094
  }
1095
 },
1096
 "nbformat": 4,
1097
 "nbformat_minor": 1
1098
}