Diff of /MRNet_EDA_ns.ipynb [000000] .. [dc3c86]

Switch to unified view

a b/MRNet_EDA_ns.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 1,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import numpy as np\n",
10
    "import pandas as pd\n",
11
    "import matplotlib.pyplot as plt\n",
12
    "from fastai.core import *\n",
13
    "\n",
14
    "%matplotlib notebook"
15
   ]
16
  },
17
  {
18
   "cell_type": "code",
19
   "execution_count": 2,
20
   "metadata": {},
21
   "outputs": [
22
    {
23
     "name": "stdout",
24
     "output_type": "stream",
25
     "text": [
26
      "LICENSE         MRNet_EDA.ipynb README.md\r\n"
27
     ]
28
    }
29
   ],
30
   "source": [
31
    "! ls -R "
32
   ]
33
  },
34
  {
35
   "cell_type": "code",
36
   "execution_count": 3,
37
   "metadata": {},
38
   "outputs": [],
39
   "source": [
40
    "data_path = Path('../data')\n",
41
    "train_path = data_path/'smalltrain'/'train'\n",
42
    "valid_path = data_path/'smallvalid'/'valid'"
43
   ]
44
  },
45
  {
46
   "cell_type": "code",
47
   "execution_count": 5,
48
   "metadata": {},
49
   "outputs": [
50
    {
51
     "name": "stdout",
52
     "output_type": "stream",
53
     "text": [
54
      "          Case\n",
55
      "Abnormal      \n",
56
      "0          217\n",
57
      "1          913\n"
58
     ]
59
    },
60
    {
61
     "data": {
62
      "text/html": [
63
       "<div>\n",
64
       "<style scoped>\n",
65
       "    .dataframe tbody tr th:only-of-type {\n",
66
       "        vertical-align: middle;\n",
67
       "    }\n",
68
       "\n",
69
       "    .dataframe tbody tr th {\n",
70
       "        vertical-align: top;\n",
71
       "    }\n",
72
       "\n",
73
       "    .dataframe thead th {\n",
74
       "        text-align: right;\n",
75
       "    }\n",
76
       "</style>\n",
77
       "<table border=\"1\" class=\"dataframe\">\n",
78
       "  <thead>\n",
79
       "    <tr style=\"text-align: right;\">\n",
80
       "      <th></th>\n",
81
       "      <th>Case</th>\n",
82
       "      <th>Abnormal</th>\n",
83
       "    </tr>\n",
84
       "  </thead>\n",
85
       "  <tbody>\n",
86
       "    <tr>\n",
87
       "      <th>0</th>\n",
88
       "      <td>0000</td>\n",
89
       "      <td>1</td>\n",
90
       "    </tr>\n",
91
       "    <tr>\n",
92
       "      <th>1</th>\n",
93
       "      <td>0001</td>\n",
94
       "      <td>1</td>\n",
95
       "    </tr>\n",
96
       "    <tr>\n",
97
       "      <th>2</th>\n",
98
       "      <td>0002</td>\n",
99
       "      <td>1</td>\n",
100
       "    </tr>\n",
101
       "    <tr>\n",
102
       "      <th>3</th>\n",
103
       "      <td>0003</td>\n",
104
       "      <td>1</td>\n",
105
       "    </tr>\n",
106
       "    <tr>\n",
107
       "      <th>4</th>\n",
108
       "      <td>0004</td>\n",
109
       "      <td>1</td>\n",
110
       "    </tr>\n",
111
       "  </tbody>\n",
112
       "</table>\n",
113
       "</div>"
114
      ],
115
      "text/plain": [
116
       "   Case  Abnormal\n",
117
       "0  0000         1\n",
118
       "1  0001         1\n",
119
       "2  0002         1\n",
120
       "3  0003         1\n",
121
       "4  0004         1"
122
      ]
123
     },
124
     "execution_count": 5,
125
     "metadata": {},
126
     "output_type": "execute_result"
127
    }
128
   ],
129
   "source": [
130
    "train_abnl = pd.read_csv(data_path/'train-abnormal.csv', header=None,\n",
131
    "                       names=['Case', 'Abnormal'], \n",
132
    "                       dtype={'Case': str, 'Abnormal': np.int64})\n",
133
    "print(train_abnl.groupby('Abnormal').count())\n",
134
    "train_abnl.head()"
135
   ]
136
  },
137
  {
138
   "cell_type": "code",
139
   "execution_count": 6,
140
   "metadata": {},
141
   "outputs": [
142
    {
143
     "name": "stdout",
144
     "output_type": "stream",
145
     "text": [
146
      "          Case\n",
147
      "ACL_tear      \n",
148
      "0          922\n",
149
      "1          208\n"
150
     ]
151
    },
152
    {
153
     "data": {
154
      "text/html": [
155
       "<div>\n",
156
       "<style scoped>\n",
157
       "    .dataframe tbody tr th:only-of-type {\n",
158
       "        vertical-align: middle;\n",
159
       "    }\n",
160
       "\n",
161
       "    .dataframe tbody tr th {\n",
162
       "        vertical-align: top;\n",
163
       "    }\n",
164
       "\n",
165
       "    .dataframe thead th {\n",
166
       "        text-align: right;\n",
167
       "    }\n",
168
       "</style>\n",
169
       "<table border=\"1\" class=\"dataframe\">\n",
170
       "  <thead>\n",
171
       "    <tr style=\"text-align: right;\">\n",
172
       "      <th></th>\n",
173
       "      <th>Case</th>\n",
174
       "      <th>ACL_tear</th>\n",
175
       "    </tr>\n",
176
       "  </thead>\n",
177
       "  <tbody>\n",
178
       "    <tr>\n",
179
       "      <th>0</th>\n",
180
       "      <td>0000</td>\n",
181
       "      <td>0</td>\n",
182
       "    </tr>\n",
183
       "    <tr>\n",
184
       "      <th>1</th>\n",
185
       "      <td>0001</td>\n",
186
       "      <td>1</td>\n",
187
       "    </tr>\n",
188
       "    <tr>\n",
189
       "      <th>2</th>\n",
190
       "      <td>0002</td>\n",
191
       "      <td>0</td>\n",
192
       "    </tr>\n",
193
       "    <tr>\n",
194
       "      <th>3</th>\n",
195
       "      <td>0003</td>\n",
196
       "      <td>0</td>\n",
197
       "    </tr>\n",
198
       "    <tr>\n",
199
       "      <th>4</th>\n",
200
       "      <td>0004</td>\n",
201
       "      <td>0</td>\n",
202
       "    </tr>\n",
203
       "  </tbody>\n",
204
       "</table>\n",
205
       "</div>"
206
      ],
207
      "text/plain": [
208
       "   Case  ACL_tear\n",
209
       "0  0000         0\n",
210
       "1  0001         1\n",
211
       "2  0002         0\n",
212
       "3  0003         0\n",
213
       "4  0004         0"
214
      ]
215
     },
216
     "execution_count": 6,
217
     "metadata": {},
218
     "output_type": "execute_result"
219
    }
220
   ],
221
   "source": [
222
    "train_acl = pd.read_csv(data_path/'train-acl.csv', header=None,\n",
223
    "                       names=['Case', 'ACL_tear'], \n",
224
    "                       dtype={'Case': str, 'ACL_tear': np.int64})\n",
225
    "print(train_acl.groupby('ACL_tear').count())\n",
226
    "train_acl.head()"
227
   ]
228
  },
229
  {
230
   "cell_type": "code",
231
   "execution_count": 7,
232
   "metadata": {},
233
   "outputs": [
234
    {
235
     "name": "stdout",
236
     "output_type": "stream",
237
     "text": [
238
      "               Case\n",
239
      "Meniscus_tear      \n",
240
      "0               733\n",
241
      "1               397\n"
242
     ]
243
    },
244
    {
245
     "data": {
246
      "text/html": [
247
       "<div>\n",
248
       "<style scoped>\n",
249
       "    .dataframe tbody tr th:only-of-type {\n",
250
       "        vertical-align: middle;\n",
251
       "    }\n",
252
       "\n",
253
       "    .dataframe tbody tr th {\n",
254
       "        vertical-align: top;\n",
255
       "    }\n",
256
       "\n",
257
       "    .dataframe thead th {\n",
258
       "        text-align: right;\n",
259
       "    }\n",
260
       "</style>\n",
261
       "<table border=\"1\" class=\"dataframe\">\n",
262
       "  <thead>\n",
263
       "    <tr style=\"text-align: right;\">\n",
264
       "      <th></th>\n",
265
       "      <th>Case</th>\n",
266
       "      <th>Meniscus_tear</th>\n",
267
       "    </tr>\n",
268
       "  </thead>\n",
269
       "  <tbody>\n",
270
       "    <tr>\n",
271
       "      <th>0</th>\n",
272
       "      <td>0000</td>\n",
273
       "      <td>0</td>\n",
274
       "    </tr>\n",
275
       "    <tr>\n",
276
       "      <th>1</th>\n",
277
       "      <td>0001</td>\n",
278
       "      <td>1</td>\n",
279
       "    </tr>\n",
280
       "    <tr>\n",
281
       "      <th>2</th>\n",
282
       "      <td>0002</td>\n",
283
       "      <td>0</td>\n",
284
       "    </tr>\n",
285
       "    <tr>\n",
286
       "      <th>3</th>\n",
287
       "      <td>0003</td>\n",
288
       "      <td>1</td>\n",
289
       "    </tr>\n",
290
       "    <tr>\n",
291
       "      <th>4</th>\n",
292
       "      <td>0004</td>\n",
293
       "      <td>0</td>\n",
294
       "    </tr>\n",
295
       "  </tbody>\n",
296
       "</table>\n",
297
       "</div>"
298
      ],
299
      "text/plain": [
300
       "   Case  Meniscus_tear\n",
301
       "0  0000              0\n",
302
       "1  0001              1\n",
303
       "2  0002              0\n",
304
       "3  0003              1\n",
305
       "4  0004              0"
306
      ]
307
     },
308
     "execution_count": 7,
309
     "metadata": {},
310
     "output_type": "execute_result"
311
    }
312
   ],
313
   "source": [
314
    "train_meniscus = pd.read_csv(data_path/'train-meniscus.csv', header=None,\n",
315
    "                       names=['Case', 'Meniscus_tear'], \n",
316
    "                       dtype={'Case': str, 'Meniscus_tear': np.int64})\n",
317
    "print(train_meniscus.groupby('Meniscus_tear').count())\n",
318
    "train_meniscus.head()"
319
   ]
320
  },
321
  {
322
   "cell_type": "markdown",
323
   "metadata": {},
324
   "source": [
325
    "### Co-occurrence of ACL and Meniscus tears"
326
   ]
327
  },
328
  {
329
   "cell_type": "code",
330
   "execution_count": null,
331
   "metadata": {},
332
   "outputs": [],
333
   "source": [
334
    "train = pd.merge(train_abnl, train_acl, on='Case')"
335
   ]
336
  },
337
  {
338
   "cell_type": "code",
339
   "execution_count": 10,
340
   "metadata": {},
341
   "outputs": [],
342
   "source": [
343
    "train = pd.merge(train, train_meniscus, on='Case')"
344
   ]
345
  },
346
  {
347
   "cell_type": "code",
348
   "execution_count": 16,
349
   "metadata": {},
350
   "outputs": [
351
    {
352
     "data": {
353
      "text/html": [
354
       "<div>\n",
355
       "<style scoped>\n",
356
       "    .dataframe tbody tr th:only-of-type {\n",
357
       "        vertical-align: middle;\n",
358
       "    }\n",
359
       "\n",
360
       "    .dataframe tbody tr th {\n",
361
       "        vertical-align: top;\n",
362
       "    }\n",
363
       "\n",
364
       "    .dataframe thead th {\n",
365
       "        text-align: right;\n",
366
       "    }\n",
367
       "</style>\n",
368
       "<table border=\"1\" class=\"dataframe\">\n",
369
       "  <thead>\n",
370
       "    <tr style=\"text-align: right;\">\n",
371
       "      <th></th>\n",
372
       "      <th>Case</th>\n",
373
       "      <th>Abnormal</th>\n",
374
       "      <th>ACL_tear</th>\n",
375
       "      <th>Meniscus_tear</th>\n",
376
       "    </tr>\n",
377
       "  </thead>\n",
378
       "  <tbody>\n",
379
       "    <tr>\n",
380
       "      <th>0</th>\n",
381
       "      <td>0000</td>\n",
382
       "      <td>1</td>\n",
383
       "      <td>0</td>\n",
384
       "      <td>0</td>\n",
385
       "    </tr>\n",
386
       "    <tr>\n",
387
       "      <th>1</th>\n",
388
       "      <td>0001</td>\n",
389
       "      <td>1</td>\n",
390
       "      <td>1</td>\n",
391
       "      <td>1</td>\n",
392
       "    </tr>\n",
393
       "    <tr>\n",
394
       "      <th>2</th>\n",
395
       "      <td>0002</td>\n",
396
       "      <td>1</td>\n",
397
       "      <td>0</td>\n",
398
       "      <td>0</td>\n",
399
       "    </tr>\n",
400
       "    <tr>\n",
401
       "      <th>3</th>\n",
402
       "      <td>0003</td>\n",
403
       "      <td>1</td>\n",
404
       "      <td>0</td>\n",
405
       "      <td>1</td>\n",
406
       "    </tr>\n",
407
       "    <tr>\n",
408
       "      <th>4</th>\n",
409
       "      <td>0004</td>\n",
410
       "      <td>1</td>\n",
411
       "      <td>0</td>\n",
412
       "      <td>0</td>\n",
413
       "    </tr>\n",
414
       "  </tbody>\n",
415
       "</table>\n",
416
       "</div>"
417
      ],
418
      "text/plain": [
419
       "   Case  Abnormal  ACL_tear  Meniscus_tear\n",
420
       "0  0000         1         0              0\n",
421
       "1  0001         1         1              1\n",
422
       "2  0002         1         0              0\n",
423
       "3  0003         1         0              1\n",
424
       "4  0004         1         0              0"
425
      ]
426
     },
427
     "metadata": {},
428
     "output_type": "display_data"
429
    },
430
    {
431
     "data": {
432
      "text/html": [
433
       "<div>\n",
434
       "<style scoped>\n",
435
       "    .dataframe tbody tr th:only-of-type {\n",
436
       "        vertical-align: middle;\n",
437
       "    }\n",
438
       "\n",
439
       "    .dataframe tbody tr th {\n",
440
       "        vertical-align: top;\n",
441
       "    }\n",
442
       "\n",
443
       "    .dataframe thead th {\n",
444
       "        text-align: right;\n",
445
       "    }\n",
446
       "</style>\n",
447
       "<table border=\"1\" class=\"dataframe\">\n",
448
       "  <thead>\n",
449
       "    <tr style=\"text-align: right;\">\n",
450
       "      <th></th>\n",
451
       "      <th></th>\n",
452
       "      <th></th>\n",
453
       "      <th>Case</th>\n",
454
       "    </tr>\n",
455
       "    <tr>\n",
456
       "      <th>Abnormal</th>\n",
457
       "      <th>ACL_tear</th>\n",
458
       "      <th>Meniscus_tear</th>\n",
459
       "      <th></th>\n",
460
       "    </tr>\n",
461
       "  </thead>\n",
462
       "  <tbody>\n",
463
       "    <tr>\n",
464
       "      <th>0</th>\n",
465
       "      <th>0</th>\n",
466
       "      <th>0</th>\n",
467
       "      <td>217</td>\n",
468
       "    </tr>\n",
469
       "    <tr>\n",
470
       "      <th rowspan=\"4\" valign=\"top\">1</th>\n",
471
       "      <th rowspan=\"2\" valign=\"top\">0</th>\n",
472
       "      <th>0</th>\n",
473
       "      <td>433</td>\n",
474
       "    </tr>\n",
475
       "    <tr>\n",
476
       "      <th>1</th>\n",
477
       "      <td>272</td>\n",
478
       "    </tr>\n",
479
       "    <tr>\n",
480
       "      <th rowspan=\"2\" valign=\"top\">1</th>\n",
481
       "      <th>0</th>\n",
482
       "      <td>83</td>\n",
483
       "    </tr>\n",
484
       "    <tr>\n",
485
       "      <th>1</th>\n",
486
       "      <td>125</td>\n",
487
       "    </tr>\n",
488
       "  </tbody>\n",
489
       "</table>\n",
490
       "</div>"
491
      ],
492
      "text/plain": [
493
       "                                 Case\n",
494
       "Abnormal ACL_tear Meniscus_tear      \n",
495
       "0        0        0               217\n",
496
       "1        0        0               433\n",
497
       "                  1               272\n",
498
       "         1        0                83\n",
499
       "                  1               125"
500
      ]
501
     },
502
     "metadata": {},
503
     "output_type": "display_data"
504
    }
505
   ],
506
   "source": [
507
    "display(train.head())\n",
508
    "display(train.groupby(['Abnormal','ACL_tear','Meniscus_tear']).count())"
509
   ]
510
  },
511
  {
512
   "cell_type": "markdown",
513
   "metadata": {},
514
   "source": [
515
    "Note that cases considered Abnormal but without either ACL or Meniscus tear are the most common category, and ACL tears without Meniscus tear is the least common case in the training sample."
516
   ]
517
  },
518
  {
519
   "cell_type": "markdown",
520
   "metadata": {},
521
   "source": [
522
    "## Load stacks/sequences of images from each plane\n",
523
    "Files are saved as NumPy arrays. Scans were taken from each of three planes, axial, coronal, and sagittal. For each plane, the scan results in a set of images.  \n",
524
    "\n",
525
    "First, let's check for variation in the number of images per sequence, and in the image dimensions."
526
   ]
527
  },
528
  {
529
   "cell_type": "code",
530
   "execution_count": 67,
531
   "metadata": {},
532
   "outputs": [],
533
   "source": [
534
    "def collect_stack_dims(case_df, data_path=train_path):\n",
535
    "    cases = list(case_df.Case)\n",
536
    "    data = []\n",
537
    "    for case in cases:\n",
538
    "        row = [case]\n",
539
    "        for plane in ['axial', 'coronal', 'sagittal']:\n",
540
    "            fpath = data_path/plane/'{}.npy'.format(case)\n",
541
    "            try: \n",
542
    "                s,w,h = np.load(fpath).shape \n",
543
    "                row.extend([s,w,h])\n",
544
    "            except FileNotFoundError:\n",
545
    "                continue\n",
546
    "#        print('{}: {}'.format(case,row))\n",
547
    "        if len(row)==10: data.append(row)\n",
548
    "    columns=['Case',\n",
549
    "             'axial_s','axial_w','axial_h',\n",
550
    "             'coronal_s','coronal_w','coronal_h',\n",
551
    "             'sagittal_s','sagittal_w','sagittal_h',\n",
552
    "            ]\n",
553
    "    data_dict = {}\n",
554
    "    for i,k in enumerate(columns): data_dict[k] = [row[i] for row in data]\n",
555
    "    return pd.DataFrame(data_dict)"
556
   ]
557
  },
558
  {
559
   "cell_type": "code",
560
   "execution_count": 68,
561
   "metadata": {},
562
   "outputs": [],
563
   "source": [
564
    "dimdf = collect_stack_dims(train)"
565
   ]
566
  },
567
  {
568
   "cell_type": "code",
569
   "execution_count": 70,
570
   "metadata": {},
571
   "outputs": [
572
    {
573
     "data": {
574
      "text/html": [
575
       "<div>\n",
576
       "<style scoped>\n",
577
       "    .dataframe tbody tr th:only-of-type {\n",
578
       "        vertical-align: middle;\n",
579
       "    }\n",
580
       "\n",
581
       "    .dataframe tbody tr th {\n",
582
       "        vertical-align: top;\n",
583
       "    }\n",
584
       "\n",
585
       "    .dataframe thead th {\n",
586
       "        text-align: right;\n",
587
       "    }\n",
588
       "</style>\n",
589
       "<table border=\"1\" class=\"dataframe\">\n",
590
       "  <thead>\n",
591
       "    <tr style=\"text-align: right;\">\n",
592
       "      <th></th>\n",
593
       "      <th>axial_s</th>\n",
594
       "      <th>axial_w</th>\n",
595
       "      <th>axial_h</th>\n",
596
       "      <th>coronal_s</th>\n",
597
       "      <th>coronal_w</th>\n",
598
       "      <th>coronal_h</th>\n",
599
       "      <th>sagittal_s</th>\n",
600
       "      <th>sagittal_w</th>\n",
601
       "      <th>sagittal_h</th>\n",
602
       "    </tr>\n",
603
       "  </thead>\n",
604
       "  <tbody>\n",
605
       "    <tr>\n",
606
       "      <th>count</th>\n",
607
       "      <td>50.000000</td>\n",
608
       "      <td>50.0</td>\n",
609
       "      <td>50.0</td>\n",
610
       "      <td>50.000000</td>\n",
611
       "      <td>50.0</td>\n",
612
       "      <td>50.0</td>\n",
613
       "      <td>50.00000</td>\n",
614
       "      <td>50.0</td>\n",
615
       "      <td>50.0</td>\n",
616
       "    </tr>\n",
617
       "    <tr>\n",
618
       "      <th>mean</th>\n",
619
       "      <td>35.860000</td>\n",
620
       "      <td>256.0</td>\n",
621
       "      <td>256.0</td>\n",
622
       "      <td>31.360000</td>\n",
623
       "      <td>256.0</td>\n",
624
       "      <td>256.0</td>\n",
625
       "      <td>31.72000</td>\n",
626
       "      <td>256.0</td>\n",
627
       "      <td>256.0</td>\n",
628
       "    </tr>\n",
629
       "    <tr>\n",
630
       "      <th>std</th>\n",
631
       "      <td>7.050865</td>\n",
632
       "      <td>0.0</td>\n",
633
       "      <td>0.0</td>\n",
634
       "      <td>7.899264</td>\n",
635
       "      <td>0.0</td>\n",
636
       "      <td>0.0</td>\n",
637
       "      <td>6.35687</td>\n",
638
       "      <td>0.0</td>\n",
639
       "      <td>0.0</td>\n",
640
       "    </tr>\n",
641
       "    <tr>\n",
642
       "      <th>min</th>\n",
643
       "      <td>22.000000</td>\n",
644
       "      <td>256.0</td>\n",
645
       "      <td>256.0</td>\n",
646
       "      <td>18.000000</td>\n",
647
       "      <td>256.0</td>\n",
648
       "      <td>256.0</td>\n",
649
       "      <td>19.00000</td>\n",
650
       "      <td>256.0</td>\n",
651
       "      <td>256.0</td>\n",
652
       "    </tr>\n",
653
       "    <tr>\n",
654
       "      <th>25%</th>\n",
655
       "      <td>32.000000</td>\n",
656
       "      <td>256.0</td>\n",
657
       "      <td>256.0</td>\n",
658
       "      <td>24.000000</td>\n",
659
       "      <td>256.0</td>\n",
660
       "      <td>256.0</td>\n",
661
       "      <td>26.25000</td>\n",
662
       "      <td>256.0</td>\n",
663
       "      <td>256.0</td>\n",
664
       "    </tr>\n",
665
       "    <tr>\n",
666
       "      <th>50%</th>\n",
667
       "      <td>37.500000</td>\n",
668
       "      <td>256.0</td>\n",
669
       "      <td>256.0</td>\n",
670
       "      <td>32.000000</td>\n",
671
       "      <td>256.0</td>\n",
672
       "      <td>256.0</td>\n",
673
       "      <td>32.00000</td>\n",
674
       "      <td>256.0</td>\n",
675
       "      <td>256.0</td>\n",
676
       "    </tr>\n",
677
       "    <tr>\n",
678
       "      <th>75%</th>\n",
679
       "      <td>40.000000</td>\n",
680
       "      <td>256.0</td>\n",
681
       "      <td>256.0</td>\n",
682
       "      <td>37.750000</td>\n",
683
       "      <td>256.0</td>\n",
684
       "      <td>256.0</td>\n",
685
       "      <td>36.00000</td>\n",
686
       "      <td>256.0</td>\n",
687
       "      <td>256.0</td>\n",
688
       "    </tr>\n",
689
       "    <tr>\n",
690
       "      <th>max</th>\n",
691
       "      <td>51.000000</td>\n",
692
       "      <td>256.0</td>\n",
693
       "      <td>256.0</td>\n",
694
       "      <td>46.000000</td>\n",
695
       "      <td>256.0</td>\n",
696
       "      <td>256.0</td>\n",
697
       "      <td>46.00000</td>\n",
698
       "      <td>256.0</td>\n",
699
       "      <td>256.0</td>\n",
700
       "    </tr>\n",
701
       "  </tbody>\n",
702
       "</table>\n",
703
       "</div>"
704
      ],
705
      "text/plain": [
706
       "         axial_s  axial_w  axial_h  coronal_s  coronal_w  coronal_h  \\\n",
707
       "count  50.000000     50.0     50.0  50.000000       50.0       50.0   \n",
708
       "mean   35.860000    256.0    256.0  31.360000      256.0      256.0   \n",
709
       "std     7.050865      0.0      0.0   7.899264        0.0        0.0   \n",
710
       "min    22.000000    256.0    256.0  18.000000      256.0      256.0   \n",
711
       "25%    32.000000    256.0    256.0  24.000000      256.0      256.0   \n",
712
       "50%    37.500000    256.0    256.0  32.000000      256.0      256.0   \n",
713
       "75%    40.000000    256.0    256.0  37.750000      256.0      256.0   \n",
714
       "max    51.000000    256.0    256.0  46.000000      256.0      256.0   \n",
715
       "\n",
716
       "       sagittal_s  sagittal_w  sagittal_h  \n",
717
       "count    50.00000        50.0        50.0  \n",
718
       "mean     31.72000       256.0       256.0  \n",
719
       "std       6.35687         0.0         0.0  \n",
720
       "min      19.00000       256.0       256.0  \n",
721
       "25%      26.25000       256.0       256.0  \n",
722
       "50%      32.00000       256.0       256.0  \n",
723
       "75%      36.00000       256.0       256.0  \n",
724
       "max      46.00000       256.0       256.0  "
725
      ]
726
     },
727
     "execution_count": 70,
728
     "metadata": {},
729
     "output_type": "execute_result"
730
    }
731
   ],
732
   "source": [
733
    "dimdf.describe()"
734
   ]
735
  },
736
  {
737
   "cell_type": "markdown",
738
   "metadata": {},
739
   "source": [
740
    "The number of images in a set varies from case (patient) to case, and the dimensions of each image is the same, 256x256. In the sample of data collected here, axial sequences range in length from 22 to 51; coronal, from 18 to 46; sagittal, from 19 to 46."
741
   ]
742
  },
743
  {
744
   "cell_type": "code",
745
   "execution_count": 20,
746
   "metadata": {},
747
   "outputs": [],
748
   "source": [
749
    "def load_one_stack(case, data_path=train_path, plane='coronal'):\n",
750
    "    fpath = data_path/plane/'{}.npy'.format(case)\n",
751
    "    return np.load(fpath)\n",
752
    "\n",
753
    "def load_stacks(case):\n",
754
    "    x = {}\n",
755
    "    planes = ['axial', 'coronal', 'sagittal']\n",
756
    "    for i, plane in enumerate(planes):\n",
757
    "        x[plane] = load_one_stack(case, plane=plane)\n",
758
    "    return x"
759
   ]
760
  },
761
  {
762
   "cell_type": "code",
763
   "execution_count": 29,
764
   "metadata": {},
765
   "outputs": [
766
    {
767
     "name": "stdout",
768
     "output_type": "stream",
769
     "text": [
770
      "(36, 256, 256)\n",
771
      "255\n"
772
     ]
773
    }
774
   ],
775
   "source": [
776
    "case = train_abnl.Case[0]\n",
777
    "x = load_one_stack(case, plane='coronal')\n",
778
    "print(x.shape)\n",
779
    "print(x.max())"
780
   ]
781
  },
782
  {
783
   "cell_type": "code",
784
   "execution_count": 9,
785
   "metadata": {},
786
   "outputs": [
787
    {
788
     "data": {
789
      "text/plain": [
790
       "{'axial': array([[[ 0,  0,  0,  0, ...,  4,  5,  4,  3],\n",
791
       "         [ 0,  0,  0,  0, ...,  8,  8,  6,  8],\n",
792
       "         [ 0,  0,  0,  0, ..., 14, 14, 11, 11],\n",
793
       "         [ 0,  0,  0,  0, ..., 16, 16, 14, 15],\n",
794
       "         ...,\n",
795
       "         [ 0,  0,  0,  0, ..., 14, 15, 18, 16],\n",
796
       "         [ 0,  0,  0,  0, ..., 15, 16, 15, 12],\n",
797
       "         [ 0,  0,  0,  0, ..., 11, 12, 13, 12],\n",
798
       "         [ 0,  0,  0,  0, ...,  8, 11,  7,  9]],\n",
799
       " \n",
800
       "        [[ 0,  0,  0,  0, ...,  4,  3,  2,  2],\n",
801
       "         [ 0,  0,  0,  0, ...,  5,  9,  7,  7],\n",
802
       "         [ 0,  0,  0,  0, ..., 10, 13, 10, 10],\n",
803
       "         [ 0,  0,  0,  0, ..., 14, 14, 19, 17],\n",
804
       "         ...,\n",
805
       "         [ 0,  0,  0,  0, ..., 18, 16, 16, 17],\n",
806
       "         [ 0,  0,  0,  0, ..., 13, 12, 15, 13],\n",
807
       "         [ 0,  0,  0,  0, ..., 16, 14, 12, 12],\n",
808
       "         [ 0,  0,  0,  0, ...,  8,  6,  5,  7]],\n",
809
       " \n",
810
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
811
       "         [ 0,  0,  0,  0, ...,  7,  8,  6,  6],\n",
812
       "         [ 0,  0,  0,  0, ..., 12, 11, 13, 10],\n",
813
       "         [ 0,  0,  0,  0, ..., 12, 18, 18, 16],\n",
814
       "         ...,\n",
815
       "         [ 0,  0,  0,  0, ..., 16, 18, 16, 17],\n",
816
       "         [ 0,  0,  0,  0, ..., 15, 13, 13, 16],\n",
817
       "         [ 0,  0,  0,  0, ..., 10, 10, 10, 12],\n",
818
       "         [ 0,  0,  0,  0, ...,  6,  6,  6,  5]],\n",
819
       " \n",
820
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
821
       "         [ 0,  0,  0,  0, ...,  5,  7,  5,  4],\n",
822
       "         [ 0,  0,  0,  0, ..., 11, 10, 11, 12],\n",
823
       "         [ 0,  0,  0,  0, ..., 16, 16, 15, 14],\n",
824
       "         ...,\n",
825
       "         [ 0,  0,  0,  0, ..., 17, 21, 20, 18],\n",
826
       "         [ 0,  0,  0,  0, ..., 14, 15, 18, 14],\n",
827
       "         [ 0,  0,  0,  0, ..., 11,  9,  8, 10],\n",
828
       "         [ 0,  0,  0,  0, ...,  5,  5,  5,  5]],\n",
829
       " \n",
830
       "        ...,\n",
831
       " \n",
832
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
833
       "         [ 0,  0,  0,  0, ...,  4,  4,  4,  4],\n",
834
       "         [ 0,  0,  0,  0, ..., 10,  9,  9, 10],\n",
835
       "         [ 0,  0,  0,  0, ..., 12, 13, 16, 14],\n",
836
       "         ...,\n",
837
       "         [ 0,  0,  0,  0, ..., 16, 12, 14, 15],\n",
838
       "         [ 0,  0,  0,  0, ..., 11, 10, 13, 10],\n",
839
       "         [ 0,  0,  0,  0, ...,  9, 12,  9,  9],\n",
840
       "         [ 0,  0,  0,  0, ...,  6,  6,  4,  5]],\n",
841
       " \n",
842
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
843
       "         [ 0,  0,  0,  0, ...,  5,  5,  4,  4],\n",
844
       "         [ 0,  0,  0,  0, ..., 12,  8,  9, 11],\n",
845
       "         [ 0,  0,  0,  0, ..., 16, 17, 12, 13],\n",
846
       "         ...,\n",
847
       "         [ 0,  0,  0,  0, ..., 13, 14, 14, 17],\n",
848
       "         [ 0,  0,  0,  0, ..., 12, 14, 16, 13],\n",
849
       "         [ 0,  0,  0,  0, ...,  9, 12, 12,  8],\n",
850
       "         [ 0,  0,  0,  0, ...,  6,  5,  4,  5]],\n",
851
       " \n",
852
       "        [[ 0,  0,  0,  0, ...,  3,  2,  3,  2],\n",
853
       "         [ 0,  0,  0,  0, ...,  5,  4,  5,  6],\n",
854
       "         [ 0,  0,  0,  0, ..., 13, 11, 11,  9],\n",
855
       "         [ 0,  0,  0,  0, ..., 13, 13, 16, 16],\n",
856
       "         ...,\n",
857
       "         [ 0,  0,  0,  0, ..., 17, 17, 14, 16],\n",
858
       "         [ 0,  0,  0,  0, ..., 16, 14, 16, 15],\n",
859
       "         [ 0,  0,  0,  0, ...,  9,  8, 10,  8],\n",
860
       "         [ 0,  0,  0,  0, ...,  6,  6,  6,  6]],\n",
861
       " \n",
862
       "        [[ 0,  0,  0,  0, ...,  5,  4,  4,  4],\n",
863
       "         [ 0,  0,  0,  0, ...,  9,  8,  7,  6],\n",
864
       "         [ 0,  0,  0,  0, ..., 11,  9, 10, 11],\n",
865
       "         [ 0,  0,  0,  0, ..., 17, 15, 12, 13],\n",
866
       "         ...,\n",
867
       "         [ 0,  0,  0,  0, ..., 12, 16, 16, 15],\n",
868
       "         [ 0,  0,  0,  0, ..., 16, 12, 15, 16],\n",
869
       "         [ 0,  0,  0,  0, ..., 10, 14, 12, 12],\n",
870
       "         [ 0,  0,  0,  0, ...,  6,  8,  7,  7]]], dtype=uint8),\n",
871
       " 'coronal': array([[[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
872
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
873
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  2],\n",
874
       "         [ 0,  0,  0,  0, ...,  3,  3,  1,  2],\n",
875
       "         ...,\n",
876
       "         [ 0,  0,  0,  0, ...,  2,  3,  2,  2],\n",
877
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
878
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
879
       "         [ 0,  0,  0,  0, ...,  0,  0,  1,  0]],\n",
880
       " \n",
881
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  0],\n",
882
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  1],\n",
883
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
884
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
885
       "         ...,\n",
886
       "         [ 0,  0,  0,  0, ...,  3,  3,  2,  2],\n",
887
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
888
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  2],\n",
889
       "         [ 0,  0,  0,  0, ...,  0,  1,  1,  1]],\n",
890
       " \n",
891
       "        [[ 0,  0,  0,  0, ...,  1,  1,  0,  1],\n",
892
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
893
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  2],\n",
894
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
895
       "         ...,\n",
896
       "         [ 0,  0,  0,  0, ...,  3,  2,  1,  2],\n",
897
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
898
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  2],\n",
899
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1]],\n",
900
       " \n",
901
       "        [[ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
902
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  1],\n",
903
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
904
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
905
       "         ...,\n",
906
       "         [ 0,  0,  0,  0, ...,  2,  3,  2,  2],\n",
907
       "         [ 0,  0,  0,  0, ...,  3,  2,  1,  2],\n",
908
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  1],\n",
909
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1]],\n",
910
       " \n",
911
       "        ...,\n",
912
       " \n",
913
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
914
       "         [ 0,  0,  0,  0, ...,  1,  1,  2,  2],\n",
915
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  3],\n",
916
       "         [ 0,  0,  0,  0, ...,  5,  4,  4,  5],\n",
917
       "         ...,\n",
918
       "         [ 0,  0,  0,  0, ...,  5,  5,  4,  5],\n",
919
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  3],\n",
920
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  1],\n",
921
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
922
       " \n",
923
       "        [[ 0,  0,  0,  0, ...,  1,  0,  0,  1],\n",
924
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  2],\n",
925
       "         [ 0,  0,  0,  0, ...,  4,  5,  4,  4],\n",
926
       "         [ 0,  0,  0,  0, ...,  7,  8,  8,  7],\n",
927
       "         ...,\n",
928
       "         [ 0,  0,  0,  0, ...,  7,  7,  6,  6],\n",
929
       "         [ 0,  0,  0,  0, ...,  5,  3,  3,  5],\n",
930
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
931
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
932
       " \n",
933
       "        [[ 0,  0,  0,  0, ...,  0,  0,  1,  1],\n",
934
       "         [ 0,  0,  0,  0, ...,  2,  3,  3,  3],\n",
935
       "         [ 0,  0,  0,  0, ...,  4,  4,  5,  4],\n",
936
       "         [ 0,  0,  0,  0, ...,  9,  6,  8, 10],\n",
937
       "         ...,\n",
938
       "         [ 0,  0,  0,  0, ...,  8,  9,  8, 11],\n",
939
       "         [ 0,  0,  0,  0, ...,  5,  6,  3,  4],\n",
940
       "         [ 0,  0,  0,  0, ...,  0,  0,  1,  0],\n",
941
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
942
       " \n",
943
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
944
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  2],\n",
945
       "         [ 0,  0,  0,  0, ...,  3,  3,  5,  5],\n",
946
       "         [ 0,  0,  0,  0, ...,  9, 10,  7,  6],\n",
947
       "         ...,\n",
948
       "         [ 0,  0,  0,  0, ...,  9,  9,  9,  8],\n",
949
       "         [ 0,  0,  0,  0, ...,  5,  5,  4,  4],\n",
950
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
951
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]]], dtype=uint8),\n",
952
       " 'sagittal': array([[[ 0,  0,  0,  0, ...,  2,  1,  1,  2],\n",
953
       "         [ 0,  0,  0,  0, ...,  8,  8,  8,  7],\n",
954
       "         [ 0,  0,  0,  0, ..., 11, 15, 16, 14],\n",
955
       "         [ 7,  5,  5,  7, ..., 15, 14, 19, 16],\n",
956
       "         ...,\n",
957
       "         [ 7,  5,  7,  5, ..., 15, 17, 15, 14],\n",
958
       "         [ 0,  1,  1,  1, ..., 10, 15, 12, 11],\n",
959
       "         [ 0,  0,  0,  0, ...,  7,  7,  6,  5],\n",
960
       "         [ 0,  0,  0,  0, ...,  3,  2,  2,  1]],\n",
961
       " \n",
962
       "        [[ 0,  0,  0,  0, ...,  4,  3,  1,  1],\n",
963
       "         [ 0,  0,  0,  0, ..., 10,  9,  7,  8],\n",
964
       "         [ 0,  0,  0,  0, ..., 17, 17, 15, 17],\n",
965
       "         [ 7,  5,  5,  6, ..., 20, 22, 21, 17],\n",
966
       "         ...,\n",
967
       "         [ 6,  7,  6,  4, ..., 18, 14, 19, 19],\n",
968
       "         [ 0,  1,  0,  2, ..., 16, 13, 15, 16],\n",
969
       "         [ 0,  0,  0,  0, ..., 10,  8,  5,  6],\n",
970
       "         [ 0,  0,  0,  0, ...,  2,  1,  1,  4]],\n",
971
       " \n",
972
       "        [[ 0,  0,  0,  0, ...,  3,  4,  3,  4],\n",
973
       "         [ 0,  0,  0,  0, ...,  9, 12,  9, 10],\n",
974
       "         [ 0,  0,  0,  0, ..., 19, 20, 12, 11],\n",
975
       "         [ 7,  6,  7,  6, ..., 19, 17, 13, 16],\n",
976
       "         ...,\n",
977
       "         [ 3,  6,  6,  4, ..., 27, 27, 14, 19],\n",
978
       "         [ 2,  3,  1,  1, ..., 23, 19, 14, 15],\n",
979
       "         [ 0,  0,  0,  0, ...,  9,  4, 11,  9],\n",
980
       "         [ 0,  0,  0,  0, ...,  2,  1,  2,  3]],\n",
981
       " \n",
982
       "        [[ 0,  0,  0,  0, ...,  3,  6,  4,  3],\n",
983
       "         [ 0,  0,  0,  0, ...,  8, 10,  8, 10],\n",
984
       "         [ 0,  0,  0,  0, ..., 21, 18, 14, 18],\n",
985
       "         [ 5,  9,  6,  6, ..., 26, 19, 18, 22],\n",
986
       "         ...,\n",
987
       "         [ 6,  4,  6,  9, ..., 22, 23, 24, 25],\n",
988
       "         [ 1,  1,  2,  1, ..., 14, 17, 19, 17],\n",
989
       "         [ 0,  0,  0,  0, ..., 10,  8,  7,  9],\n",
990
       "         [ 0,  0,  0,  0, ...,  2,  2,  2,  3]],\n",
991
       " \n",
992
       "        ...,\n",
993
       " \n",
994
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
995
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
996
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  2],\n",
997
       "         [ 0,  0,  0,  0, ..., 10,  8, 10, 10],\n",
998
       "         ...,\n",
999
       "         [ 0,  0,  0,  3, ...,  7, 10,  8,  7],\n",
1000
       "         [ 0,  0,  0,  0, ...,  4,  4,  5,  4],\n",
1001
       "         [ 0,  0,  0,  0, ...,  2,  2,  1,  1],\n",
1002
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
1003
       " \n",
1004
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
1005
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
1006
       "         [ 0,  0,  0,  0, ...,  3,  3,  3,  2],\n",
1007
       "         [ 0,  0,  0,  0, ..., 10, 11,  9, 11],\n",
1008
       "         ...,\n",
1009
       "         [ 0,  0,  0,  4, ..., 10,  9, 10,  9],\n",
1010
       "         [ 0,  0,  0,  1, ...,  5,  6,  5,  6],\n",
1011
       "         [ 0,  0,  0,  0, ...,  3,  2,  1,  2],\n",
1012
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
1013
       " \n",
1014
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
1015
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
1016
       "         [ 0,  0,  0,  0, ...,  3,  2,  2,  2],\n",
1017
       "         [ 0,  0,  0,  1, ..., 11, 12, 11,  9],\n",
1018
       "         ...,\n",
1019
       "         [ 0,  0,  0,  3, ...,  9, 11, 10,  9],\n",
1020
       "         [ 0,  0,  0,  1, ...,  5,  6,  6,  5],\n",
1021
       "         [ 0,  0,  0,  0, ...,  2,  3,  2,  1],\n",
1022
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]],\n",
1023
       " \n",
1024
       "        [[ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
1025
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0],\n",
1026
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
1027
       "         [ 0,  0,  0,  0, ...,  8, 10,  8,  9],\n",
1028
       "         ...,\n",
1029
       "         [ 0,  0,  0,  2, ...,  7,  9,  8,  8],\n",
1030
       "         [ 0,  0,  0,  0, ...,  4,  4,  4,  5],\n",
1031
       "         [ 0,  0,  0,  0, ...,  1,  1,  1,  1],\n",
1032
       "         [ 0,  0,  0,  0, ...,  0,  0,  0,  0]]], dtype=uint8)}"
1033
      ]
1034
     },
1035
     "execution_count": 9,
1036
     "metadata": {},
1037
     "output_type": "execute_result"
1038
    }
1039
   ],
1040
   "source": [
1041
    "x_multi = load_stacks(case)\n",
1042
    "x_multi"
1043
   ]
1044
  },
1045
  {
1046
   "cell_type": "code",
1047
   "execution_count": 30,
1048
   "metadata": {},
1049
   "outputs": [],
1050
   "source": [
1051
    "from ipywidgets import interactive\n",
1052
    "from IPython.display import display\n",
1053
    "\n",
1054
    "plt.style.use('grayscale')\n",
1055
    "\n",
1056
    "class KneePlot():\n",
1057
    "    def __init__(self, x, figsize=(10, 10)):\n",
1058
    "        self.x = x\n",
1059
    "        self.slice_range = (0, self.x.shape[0] - 1)\n",
1060
    "        self.resize(figsize)\n",
1061
    "    \n",
1062
    "    def _plot_slice(self, im_slice):\n",
1063
    "        fig, ax = plt.subplots(1, 1, figsize=self.figsize)\n",
1064
    "        ax.imshow(self.x[im_slice, :, :])\n",
1065
    "        plt.show()\n",
1066
    "\n",
1067
    "    def resize(self, figsize):\n",
1068
    "        self.figsize = figsize\n",
1069
    "        self.interactive_plot = interactive(self._plot_slice, im_slice=self.slice_range)\n",
1070
    "        self.output = self.interactive_plot.children[-1]\n",
1071
    "        self.output.layout.height = '{}px'.format(60 * self.figsize[1])\n",
1072
    "\n",
1073
    "    def show(self):\n",
1074
    "        display(self.interactive_plot)\n"
1075
   ]
1076
  },
1077
  {
1078
   "cell_type": "code",
1079
   "execution_count": 31,
1080
   "metadata": {},
1081
   "outputs": [
1082
    {
1083
     "data": {
1084
      "application/vnd.jupyter.widget-view+json": {
1085
       "model_id": "bb94b9e5a31b44c5abd287a8cdb12fe9",
1086
       "version_major": 2,
1087
       "version_minor": 0
1088
      },
1089
      "text/plain": [
1090
       "interactive(children=(IntSlider(value=17, description='im_slice', max=35), Output(layout=Layout(height='600px'…"
1091
      ]
1092
     },
1093
     "metadata": {},
1094
     "output_type": "display_data"
1095
    }
1096
   ],
1097
   "source": [
1098
    "plot = KneePlot(x)\n",
1099
    "plot.show()\n"
1100
   ]
1101
  },
1102
  {
1103
   "cell_type": "code",
1104
   "execution_count": 12,
1105
   "metadata": {},
1106
   "outputs": [
1107
    {
1108
     "data": {
1109
      "application/vnd.jupyter.widget-view+json": {
1110
       "model_id": "4d55829c45294dbd90f82665f799a8ca",
1111
       "version_major": 2,
1112
       "version_minor": 0
1113
      },
1114
      "text/plain": [
1115
       "interactive(children=(IntSlider(value=17, description='im_slice', max=35), Output(layout=Layout(height='720px'…"
1116
      ]
1117
     },
1118
     "metadata": {},
1119
     "output_type": "display_data"
1120
    }
1121
   ],
1122
   "source": [
1123
    "plot.resize(figsize=(12, 12))\n",
1124
    "plot.show()\n"
1125
   ]
1126
  },
1127
  {
1128
   "cell_type": "code",
1129
   "execution_count": 15,
1130
   "metadata": {},
1131
   "outputs": [],
1132
   "source": [
1133
    "from ipywidgets import interact, Dropdown, IntSlider\n",
1134
    "\n",
1135
    "class MultiKneePlot():\n",
1136
    "    def __init__(self, x_multi, figsize=(10, 10)):\n",
1137
    "        self.x = x_multi\n",
1138
    "        self.planes = ['coronal', 'sagittal', 'axial']\n",
1139
    "        self.slice_nums = {plane: self.x[plane].shape[0] for plane in self.planes}\n",
1140
    "        self.figsize = figsize\n",
1141
    "    \n",
1142
    "    def _plot_slices(self, plane, im_slice): \n",
1143
    "        fig, ax = plt.subplots(1, 1, figsize=self.figsize)\n",
1144
    "        ax.imshow(self.x[plane][im_slice, :, :])\n",
1145
    "        plt.show()\n",
1146
    "    \n",
1147
    "    def draw(self):\n",
1148
    "        planes_widget = Dropdown(options=self.planes)\n",
1149
    "        plane_init = self.planes[0]\n",
1150
    "        slice_init = self.slice_nums[plane_init] - 1\n",
1151
    "        slices_widget = IntSlider(min=0, max=slice_init, value=slice_init//2)\n",
1152
    "        def update_slices_widget(*args):\n",
1153
    "            slices_widget.max = self.slice_nums[planes_widget.value] - 1\n",
1154
    "            slices_widget.value = slices_widget.max // 2\n",
1155
    "        planes_widget.observe(update_slices_widget, 'value')\n",
1156
    "        interact(self._plot_slices, plane=planes_widget, im_slice=slices_widget)\n",
1157
    "    \n",
1158
    "    def resize(self, figsize): self.figsize = figsize\n"
1159
   ]
1160
  },
1161
  {
1162
   "cell_type": "code",
1163
   "execution_count": 16,
1164
   "metadata": {},
1165
   "outputs": [
1166
    {
1167
     "data": {
1168
      "application/vnd.jupyter.widget-view+json": {
1169
       "model_id": "0d7a1c151c3440c59bfa0a2b55e6180e",
1170
       "version_major": 2,
1171
       "version_minor": 0
1172
      },
1173
      "text/plain": [
1174
       "interactive(children=(Dropdown(description='plane', options=('coronal', 'sagittal', 'axial'), value='coronal')…"
1175
      ]
1176
     },
1177
     "metadata": {},
1178
     "output_type": "display_data"
1179
    }
1180
   ],
1181
   "source": [
1182
    "plot_multi = MultiKneePlot(x_multi)\n",
1183
    "plot_multi.draw()"
1184
   ]
1185
  },
1186
  {
1187
   "cell_type": "code",
1188
   "execution_count": null,
1189
   "metadata": {},
1190
   "outputs": [],
1191
   "source": []
1192
  }
1193
 ],
1194
 "metadata": {
1195
  "kernelspec": {
1196
   "display_name": "Python 3",
1197
   "language": "python",
1198
   "name": "python3"
1199
  },
1200
  "language_info": {
1201
   "codemirror_mode": {
1202
    "name": "ipython",
1203
    "version": 3
1204
   },
1205
   "file_extension": ".py",
1206
   "mimetype": "text/x-python",
1207
   "name": "python",
1208
   "nbconvert_exporter": "python",
1209
   "pygments_lexer": "ipython3",
1210
   "version": "3.7.2"
1211
  }
1212
 },
1213
 "nbformat": 4,
1214
 "nbformat_minor": 2
1215
}