a b/ipynb/classifier.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 1,
6
   "metadata": {},
7
   "outputs": [
8
    {
9
     "data": {
10
      "text/plain": [
11
       "True"
12
      ]
13
     },
14
     "execution_count": 1,
15
     "metadata": {},
16
     "output_type": "execute_result"
17
    }
18
   ],
19
   "source": [
20
    "import numpy as np\n",
21
    "import pandas as pd\n",
22
    "from sklearn.model_selection import train_test_split\n",
23
    "from fastai import *\n",
24
    "from fastai.basic_data import *\n",
25
    "from fastai.basic_train import *\n",
26
    "from fastai.tabular import *\n",
27
    "from torch import nn\n",
28
    "import torch.nn.functional as F\n",
29
    "from torch.utils.data import Dataset\n",
30
    "\n",
31
    "torch.cuda.is_available()"
32
   ]
33
  },
34
  {
35
   "cell_type": "markdown",
36
   "metadata": {},
37
   "source": [
38
    "### Load ksent vectors of random genome samples of 16kb"
39
   ]
40
  },
41
  {
42
   "cell_type": "code",
43
   "execution_count": null,
44
   "metadata": {},
45
   "outputs": [],
46
   "source": [
47
    "df = pd.read_pickle(\"../data/random-ksent.pkl\")"
48
   ]
49
  },
50
  {
51
   "cell_type": "markdown",
52
   "metadata": {},
53
   "source": [
54
    "#### Look at the data"
55
   ]
56
  },
57
  {
58
   "cell_type": "code",
59
   "execution_count": 2,
60
   "metadata": {
61
    "scrolled": false
62
   },
63
   "outputs": [
64
    {
65
     "data": {
66
      "text/html": [
67
       "<div>\n",
68
       "<style scoped>\n",
69
       "    .dataframe tbody tr th:only-of-type {\n",
70
       "        vertical-align: middle;\n",
71
       "    }\n",
72
       "\n",
73
       "    .dataframe tbody tr th {\n",
74
       "        vertical-align: top;\n",
75
       "    }\n",
76
       "\n",
77
       "    .dataframe thead th {\n",
78
       "        text-align: right;\n",
79
       "    }\n",
80
       "</style>\n",
81
       "<table border=\"1\" class=\"dataframe\">\n",
82
       "  <thead>\n",
83
       "    <tr style=\"text-align: right;\">\n",
84
       "      <th></th>\n",
85
       "      <th>spicies</th>\n",
86
       "      <th>ksent</th>\n",
87
       "    </tr>\n",
88
       "  </thead>\n",
89
       "  <tbody>\n",
90
       "    <tr>\n",
91
       "      <th>0</th>\n",
92
       "      <td>subtilis</td>\n",
93
       "      <td>[-0.14380797743797302, 0.042256396263837814, 0...</td>\n",
94
       "    </tr>\n",
95
       "    <tr>\n",
96
       "      <th>1</th>\n",
97
       "      <td>subtilis</td>\n",
98
       "      <td>[-0.05331391096115112, 0.06979767978191376, 0....</td>\n",
99
       "    </tr>\n",
100
       "    <tr>\n",
101
       "      <th>2</th>\n",
102
       "      <td>velezensis</td>\n",
103
       "      <td>[-0.0810142382979393, 0.0927421897649765, 0.15...</td>\n",
104
       "    </tr>\n",
105
       "    <tr>\n",
106
       "      <th>3</th>\n",
107
       "      <td>velezensis</td>\n",
108
       "      <td>[-0.11013756692409515, 0.08658633381128311, 0....</td>\n",
109
       "    </tr>\n",
110
       "    <tr>\n",
111
       "      <th>4</th>\n",
112
       "      <td>anthracis</td>\n",
113
       "      <td>[-0.01958618313074112, 0.11950772255659103, 0....</td>\n",
114
       "    </tr>\n",
115
       "  </tbody>\n",
116
       "</table>\n",
117
       "</div>"
118
      ],
119
      "text/plain": [
120
       "      spicies                                              ksent\n",
121
       "0    subtilis  [-0.14380797743797302, 0.042256396263837814, 0...\n",
122
       "1    subtilis  [-0.05331391096115112, 0.06979767978191376, 0....\n",
123
       "2  velezensis  [-0.0810142382979393, 0.0927421897649765, 0.15...\n",
124
       "3  velezensis  [-0.11013756692409515, 0.08658633381128311, 0....\n",
125
       "4   anthracis  [-0.01958618313074112, 0.11950772255659103, 0...."
126
      ]
127
     },
128
     "execution_count": 2,
129
     "metadata": {},
130
     "output_type": "execute_result"
131
    }
132
   ],
133
   "source": [
134
    "df.head()"
135
   ]
136
  },
137
  {
138
   "cell_type": "code",
139
   "execution_count": 19,
140
   "metadata": {},
141
   "outputs": [
142
    {
143
     "data": {
144
      "text/plain": [
145
       "(690000, 2)"
146
      ]
147
     },
148
     "execution_count": 19,
149
     "metadata": {},
150
     "output_type": "execute_result"
151
    }
152
   ],
153
   "source": [
154
    "df.shape"
155
   ]
156
  },
157
  {
158
   "cell_type": "code",
159
   "execution_count": 18,
160
   "metadata": {},
161
   "outputs": [
162
    {
163
     "data": {
164
      "text/plain": [
165
       "subtilis      321000\n",
166
       "velezensis    216000\n",
167
       "anthracis     153000\n",
168
       "Name: spicies, dtype: int64"
169
      ]
170
     },
171
     "execution_count": 18,
172
     "metadata": {},
173
     "output_type": "execute_result"
174
    }
175
   ],
176
   "source": [
177
    "df.spicies.value_counts()"
178
   ]
179
  },
180
  {
181
   "cell_type": "markdown",
182
   "metadata": {},
183
   "source": [
184
    "###  Databunch"
185
   ]
186
  },
187
  {
188
   "cell_type": "code",
189
   "execution_count": 3,
190
   "metadata": {},
191
   "outputs": [],
192
   "source": [
193
    "valid_idx = random.sample(range(df.shape[0]), int(np.floor(df.shape[0]* 0.2)))\n",
194
    "\n",
195
    "db = (ItemList.from_df(df,cols=\"ksent\").\n",
196
    "      split_by_idx(valid_idx).\n",
197
    "      label_from_df(cols=\"spicies\").\n",
198
    "      databunch())"
199
   ]
200
  },
201
  {
202
   "cell_type": "markdown",
203
   "metadata": {},
204
   "source": [
205
    "###  Model"
206
   ]
207
  },
208
  {
209
   "cell_type": "code",
210
   "execution_count": 4,
211
   "metadata": {},
212
   "outputs": [],
213
   "source": [
214
    "def submodel(dims, bias=False):\n",
215
    "    layer_dims = list(zip(dims[:-1],dims[1:]))\n",
216
    "    fcl = [nn.Linear(*x, bias=bias) for x in layer_dims]\n",
217
    "    [nn.init.xavier_uniform_(m.weight) for m in fcl]\n",
218
    "    if bias: \n",
219
    "        for l in fcl: l.bias.data.normal_(0, 1)\n",
220
    "    relu = [nn.ReLU() for _ in range(len(fcl))]\n",
221
    "    layers = np.asarray(list(zip(fcl, relu))).ravel()[:-1]\n",
222
    "    return nn.Sequential(*layers)\n"
223
   ]
224
  },
225
  {
226
   "cell_type": "code",
227
   "execution_count": 5,
228
   "metadata": {},
229
   "outputs": [],
230
   "source": [
231
    "class Classifier (nn.Module):\n",
232
    "\n",
233
    "    def __init__(self, encoder_dims, classifier_dims):\n",
234
    "        super().__init__()\n",
235
    "        self.encoder = submodel(encoder_dims,bias=True)\n",
236
    "        self.classifier = submodel(classifier_dims,bias=True)\n",
237
    "\n",
238
    "\n",
239
    "    def forward(self, x):\n",
240
    "        x = self.encoder(x)\n",
241
    "        return F.softmax(self.classifier(x), dim=1)\n",
242
    "\n",
243
    "    def save_encoder(self,file:PathOrStr):\n",
244
    "        torch.save(self.encoder.state_dict(), path)\n"
245
   ]
246
  },
247
  {
248
   "cell_type": "code",
249
   "execution_count": 8,
250
   "metadata": {
251
    "scrolled": true
252
   },
253
   "outputs": [],
254
   "source": [
255
    "model = Classifier([100,50,3], [3,20,3]).double()"
256
   ]
257
  },
258
  {
259
   "cell_type": "code",
260
   "execution_count": 9,
261
   "metadata": {},
262
   "outputs": [
263
    {
264
     "data": {
265
      "text/plain": [
266
       "Classifier(\n",
267
       "  (encoder): Sequential(\n",
268
       "    (0): Linear(in_features=100, out_features=50, bias=True)\n",
269
       "    (1): ReLU()\n",
270
       "    (2): Linear(in_features=50, out_features=3, bias=True)\n",
271
       "  )\n",
272
       "  (classifier): Sequential(\n",
273
       "    (0): Linear(in_features=3, out_features=20, bias=True)\n",
274
       "    (1): ReLU()\n",
275
       "    (2): Linear(in_features=20, out_features=3, bias=True)\n",
276
       "  )\n",
277
       ")"
278
      ]
279
     },
280
     "execution_count": 9,
281
     "metadata": {},
282
     "output_type": "execute_result"
283
    }
284
   ],
285
   "source": [
286
    "model"
287
   ]
288
  },
289
  {
290
   "cell_type": "markdown",
291
   "metadata": {},
292
   "source": [
293
    "### Learner"
294
   ]
295
  },
296
  {
297
   "cell_type": "code",
298
   "execution_count": 12,
299
   "metadata": {},
300
   "outputs": [
301
    {
302
     "data": {
303
      "text/plain": [
304
       "FlattenedLoss of CrossEntropyLoss()"
305
      ]
306
     },
307
     "execution_count": 12,
308
     "metadata": {},
309
     "output_type": "execute_result"
310
    }
311
   ],
312
   "source": [
313
    "learn = Learner(db, model,metrics=[accuracy])\n",
314
    "learn.loss_func"
315
   ]
316
  },
317
  {
318
   "cell_type": "code",
319
   "execution_count": 13,
320
   "metadata": {
321
    "scrolled": true
322
   },
323
   "outputs": [
324
    {
325
     "data": {
326
      "text/html": [],
327
      "text/plain": [
328
       "<IPython.core.display.HTML object>"
329
      ]
330
     },
331
     "metadata": {},
332
     "output_type": "display_data"
333
    },
334
    {
335
     "name": "stdout",
336
     "output_type": "stream",
337
     "text": [
338
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
339
     ]
340
    },
341
    {
342
     "data": {
343
      "image/png": "\n",
344
      "text/plain": [
345
       "<Figure size 432x288 with 1 Axes>"
346
      ]
347
     },
348
     "metadata": {
349
      "needs_background": "light"
350
     },
351
     "output_type": "display_data"
352
    }
353
   ],
354
   "source": [
355
    "learn.lr_find()\n",
356
    "learn.recorder.plot()"
357
   ]
358
  },
359
  {
360
   "cell_type": "code",
361
   "execution_count": 14,
362
   "metadata": {
363
    "scrolled": true
364
   },
365
   "outputs": [
366
    {
367
     "data": {
368
      "text/html": [
369
       "<table border=\"1\" class=\"dataframe\">\n",
370
       "  <thead>\n",
371
       "    <tr style=\"text-align: left;\">\n",
372
       "      <th>epoch</th>\n",
373
       "      <th>train_loss</th>\n",
374
       "      <th>valid_loss</th>\n",
375
       "      <th>accuracy</th>\n",
376
       "      <th>time</th>\n",
377
       "    </tr>\n",
378
       "  </thead>\n",
379
       "  <tbody>\n",
380
       "    <tr>\n",
381
       "      <td>0</td>\n",
382
       "      <td>0.588498</td>\n",
383
       "      <td>0.583153</td>\n",
384
       "      <td>0.968529</td>\n",
385
       "      <td>00:35</td>\n",
386
       "    </tr>\n",
387
       "    <tr>\n",
388
       "      <td>1</td>\n",
389
       "      <td>0.593730</td>\n",
390
       "      <td>0.606674</td>\n",
391
       "      <td>0.943275</td>\n",
392
       "      <td>00:35</td>\n",
393
       "    </tr>\n",
394
       "    <tr>\n",
395
       "      <td>2</td>\n",
396
       "      <td>0.594737</td>\n",
397
       "      <td>0.589081</td>\n",
398
       "      <td>0.962116</td>\n",
399
       "      <td>00:35</td>\n",
400
       "    </tr>\n",
401
       "    <tr>\n",
402
       "      <td>3</td>\n",
403
       "      <td>0.600526</td>\n",
404
       "      <td>0.583841</td>\n",
405
       "      <td>0.967457</td>\n",
406
       "      <td>00:36</td>\n",
407
       "    </tr>\n",
408
       "    <tr>\n",
409
       "      <td>4</td>\n",
410
       "      <td>0.591394</td>\n",
411
       "      <td>0.605085</td>\n",
412
       "      <td>0.945500</td>\n",
413
       "      <td>00:35</td>\n",
414
       "    </tr>\n",
415
       "    <tr>\n",
416
       "      <td>5</td>\n",
417
       "      <td>0.586290</td>\n",
418
       "      <td>0.582825</td>\n",
419
       "      <td>0.968377</td>\n",
420
       "      <td>00:36</td>\n",
421
       "    </tr>\n",
422
       "    <tr>\n",
423
       "      <td>6</td>\n",
424
       "      <td>0.584365</td>\n",
425
       "      <td>0.580657</td>\n",
426
       "      <td>0.970710</td>\n",
427
       "      <td>00:36</td>\n",
428
       "    </tr>\n",
429
       "    <tr>\n",
430
       "      <td>7</td>\n",
431
       "      <td>0.586121</td>\n",
432
       "      <td>0.583098</td>\n",
433
       "      <td>0.968109</td>\n",
434
       "      <td>00:35</td>\n",
435
       "    </tr>\n",
436
       "    <tr>\n",
437
       "      <td>8</td>\n",
438
       "      <td>0.580480</td>\n",
439
       "      <td>0.580342</td>\n",
440
       "      <td>0.970928</td>\n",
441
       "      <td>00:36</td>\n",
442
       "    </tr>\n",
443
       "    <tr>\n",
444
       "      <td>9</td>\n",
445
       "      <td>0.578384</td>\n",
446
       "      <td>0.579312</td>\n",
447
       "      <td>0.972043</td>\n",
448
       "      <td>00:37</td>\n",
449
       "    </tr>\n",
450
       "  </tbody>\n",
451
       "</table>"
452
      ],
453
      "text/plain": [
454
       "<IPython.core.display.HTML object>"
455
      ]
456
     },
457
     "metadata": {},
458
     "output_type": "display_data"
459
    }
460
   ],
461
   "source": [
462
    "learn.fit_one_cycle(10,1e-2)"
463
   ]
464
  },
465
  {
466
   "cell_type": "code",
467
   "execution_count": 16,
468
   "metadata": {
469
    "scrolled": true
470
   },
471
   "outputs": [
472
    {
473
     "data": {
474
      "image/png": "\n",
475
      "text/plain": [
476
       "<Figure size 432x288 with 1 Axes>"
477
      ]
478
     },
479
     "metadata": {
480
      "needs_background": "light"
481
     },
482
     "output_type": "display_data"
483
    }
484
   ],
485
   "source": [
486
    "learn.recorder.plot_metrics()"
487
   ]
488
  },
489
  {
490
   "cell_type": "code",
491
   "execution_count": 17,
492
   "metadata": {},
493
   "outputs": [],
494
   "source": [
495
    "interpretation = learn.interpret()"
496
   ]
497
  },
498
  {
499
   "cell_type": "code",
500
   "execution_count": 30,
501
   "metadata": {},
502
   "outputs": [
503
    {
504
     "data": {
505
      "image/png": "\n",
506
      "text/plain": [
507
       "<Figure size 432x288 with 1 Axes>"
508
      ]
509
     },
510
     "metadata": {
511
      "needs_background": "light"
512
     },
513
     "output_type": "display_data"
514
    }
515
   ],
516
   "source": [
517
    "interpretation.plot_confusion_matrix()"
518
   ]
519
  },
520
  {
521
   "cell_type": "code",
522
   "execution_count": null,
523
   "metadata": {},
524
   "outputs": [],
525
   "source": []
526
  }
527
 ],
528
 "metadata": {
529
  "kernelspec": {
530
   "display_name": "Python 3",
531
   "language": "python",
532
   "name": "python3"
533
  },
534
  "language_info": {
535
   "codemirror_mode": {
536
    "name": "ipython",
537
    "version": 3
538
   },
539
   "file_extension": ".py",
540
   "mimetype": "text/x-python",
541
   "name": "python",
542
   "nbconvert_exporter": "python",
543
   "pygments_lexer": "ipython3",
544
   "version": "3.6.9"
545
  },
546
  "latex_envs": {
547
   "LaTeX_envs_menu_present": true,
548
   "autoclose": false,
549
   "autocomplete": true,
550
   "bibliofile": "biblio.bib",
551
   "cite_by": "apalike",
552
   "current_citInitial": 1,
553
   "eqLabelWithNumbers": true,
554
   "eqNumInitial": 1,
555
   "hotkeys": {
556
    "equation": "Ctrl-E",
557
    "itemize": "Ctrl-I"
558
   },
559
   "labels_anchors": false,
560
   "latex_user_defs": false,
561
   "report_style_numbering": false,
562
   "user_envs_cfg": false
563
  },
564
  "notify_time": "30",
565
  "toc": {
566
   "base_numbering": 1,
567
   "nav_menu": {},
568
   "number_sections": true,
569
   "sideBar": true,
570
   "skip_h1_title": false,
571
   "title_cell": "Table of Contents",
572
   "title_sidebar": "Contents",
573
   "toc_cell": false,
574
   "toc_position": {},
575
   "toc_section_display": true,
576
   "toc_window_display": false
577
  }
578
 },
579
 "nbformat": 4,
580
 "nbformat_minor": 2
581
}