Diff of /notebooks/Code.ipynb [000000] .. [53d15f]

Switch to unified view

a b/notebooks/Code.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 28,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "#conda activate torch-xla-nightly\n",
10
    "#export XRT_TPU_CONFIG=\"tpu_worker;0;$10.0.101.2:8470\"\n",
11
    "#git init\n",
12
    "#git remote add origin https://github.com/nosound2/RSNA-Hemorrhage\n",
13
    "#git config remote.origin.push HEAD\n",
14
    "#git config credential.helper store\n",
15
    "#git pull origin master\n",
16
    "#git config --global branch.master.remote origin\n",
17
    "#git config --global branch.master.merge refs/heads/master\n",
18
    "#gcloud config set compute/zone europe-west4-a\n",
19
    "#gcloud config set compute/zone us-central1-c\n",
20
    "#gcloud auth login\n",
21
    "#gcloud config set project endless-empire-239015\n",
22
    "#pip install kaggle --user\n",
23
    "#PATH=$PATH:~/.local/bin\n",
24
    "#mkdir .kaggle\n",
25
    "#gsutil cp gs://recursion-double-strand/kaggle-keys/kaggle.json ~/.kaggle\n",
26
    "#chmod 600 /home/zahar_chikishev/.kaggle/kaggle.json\n",
27
    "#kaggle competitions download rsna-intracranial-hemorrhage-detection -f stage_2_train.csv\n",
28
    "#sudo apt install unzip\n",
29
    "#unzip stage_1_train.csv.zip\n",
30
    "#kaggle kernels output xhlulu/rsna-generate-metadata-csvs -p .\n",
31
    "#kaggle kernels output zaharch/dicom-test-metadata-to-csv -p .\n",
32
    "#kaggle kernels output zaharch/dicom-metadata-to-csv -p .\n",
33
    "#gsutil -m cp gs://rsna-hemorrhage/yuvals/* .\n",
34
    "\n",
35
    "#export XRT_TPU_CONFIG=\"tpu_worker;0;10.0.101.2:8470\"; conda activate torch-xla-nightly; jupyter notebook\n",
36
    "\n",
37
    "# 35.188.114.109\n",
38
    "\n",
39
    "# lsblk\n",
40
    "# sudo mkfs.ext4 -m 0 -F -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb\n",
41
    "# sudo mkdir /mnt/edisk\n",
42
    "# sudo mount -o discard,defaults /dev/sdb /mnt/edisk\n",
43
    "# sudo chmod a+w /mnt/edisk\n",
44
    "# df -H\n",
45
    "# echo UUID=`sudo blkid -s UUID -o value /dev/sdb` /mnt/edisk ext4 discard,defaults,nofail 0 2 | sudo tee -a /etc/fstab\n",
46
    "# cat /etc/fstab\n",
47
    "\n",
48
    "#jupyter notebook --generate-config\n",
49
    "#vi ~/.jupyter/jupyter_notebook_config.py\n",
50
    "\n",
51
    "#c = get_config()\n",
52
    "#c.NotebookApp.ip = '*'\n",
53
    "#c.NotebookApp.open_browser = False\n",
54
    "#c.NotebookApp.port = 5000"
55
   ]
56
  },
57
  {
58
   "cell_type": "code",
59
   "execution_count": null,
60
   "metadata": {},
61
   "outputs": [],
62
   "source": []
63
  },
64
  {
65
   "cell_type": "code",
66
   "execution_count": 2,
67
   "metadata": {},
68
   "outputs": [],
69
   "source": [
70
    "import sys\n",
71
    "\n",
72
    "from pathlib import Path\n",
73
    "from PIL import ImageDraw, ImageFont, Image\n",
74
    "from matplotlib import patches, patheffects\n",
75
    "import time\n",
76
    "from random import randint\n",
77
    "import numpy as np\n",
78
    "import pandas as pd\n",
79
    "import pickle\n",
80
    "\n",
81
    "from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold\n",
82
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer\n",
83
    "from sklearn.preprocessing import StandardScaler\n",
84
    "from sklearn.metrics import mean_squared_error,log_loss,roc_auc_score\n",
85
    "from scipy.stats import ks_2samp\n",
86
    "\n",
87
    "import pdb\n",
88
    "\n",
89
    "import scipy as sp\n",
90
    "from tqdm import tqdm, tqdm_notebook\n",
91
    "\n",
92
    "import os\n",
93
    "import glob\n",
94
    "\n",
95
    "import torch\n",
96
    "\n",
97
    "#CLOUD = not torch.cuda.is_available()\n",
98
    "CLOUD = (torch.cuda.get_device_name(0) in ['Tesla P4', 'Tesla K80', 'Tesla P100-PCIE-16GB'])\n",
99
    "CLOUD_SINGLE = True\n",
100
    "TPU = False\n",
101
    "\n",
102
    "if not CLOUD:\n",
103
    "    torch.cuda.current_device()\n",
104
    "\n",
105
    "import torch.nn as nn\n",
106
    "import torch.utils.data as D\n",
107
    "import torch.nn.functional as F\n",
108
    "import torch.utils as U\n",
109
    "\n",
110
    "import torchvision\n",
111
    "from torchvision import transforms as T\n",
112
    "from torchvision import models as M\n",
113
    "\n",
114
    "import matplotlib.pyplot as plt\n",
115
    "\n",
116
    "if CLOUD:\n",
117
    "    PATH = Path('/home/zahar_chikishev/Hemorrhage')\n",
118
    "    PATH_WORK = Path('/mnt/edisk/running')\n",
119
    "    PATH_DISK = Path('/mnt/edisk/running')\n",
120
    "else:\n",
121
    "    PATH = Path('C:/StudioProjects/Hemorrhage')\n",
122
    "    PATH_WORK = Path('C:/StudioProjects/Hemorrhage/running')\n",
123
    "    PATH_DISK = PATH_WORK\n",
124
    "\n",
125
    "from collections import defaultdict, Counter\n",
126
    "import random\n",
127
    "import seaborn as sn\n",
128
    "\n",
129
    "pd.set_option(\"display.max_columns\", 100)\n",
130
    "\n",
131
    "all_ich = ['any','epidural','intraparenchymal','intraventricular','subarachnoid','subdural']\n",
132
    "class_weights = 6.0*np.array([2,1,1,1,1,1])/7.0\n",
133
    "\n",
134
    "if CLOUD and TPU:\n",
135
    "    import torch_xla\n",
136
    "    import torch_xla.distributed.data_parallel as dp\n",
137
    "    import torch_xla.utils as xu\n",
138
    "    import torch_xla.core.xla_model as xm\n",
139
    "\n",
140
    "from typing import Collection"
141
   ]
142
  },
143
  {
144
   "cell_type": "code",
145
   "execution_count": 3,
146
   "metadata": {},
147
   "outputs": [],
148
   "source": [
149
    "all_black = '006d4432e'\n",
150
    "all_black = '00bd6c59c'\n",
151
    "\n",
152
    "if CLOUD:\n",
153
    "    if TPU:\n",
154
    "        device = xm.xla_device()\n",
155
    "    else:\n",
156
    "        device = 'cuda'\n",
157
    "    #device = 'cpu'\n",
158
    "    MAX_DEVICES = 1 if CLOUD_SINGLE else 8\n",
159
    "    NUM_WORKERS = 16\n",
160
    "    bs = 32\n",
161
    "else:\n",
162
    "    device = 'cuda'\n",
163
    "    #device = 'cpu'\n",
164
    "    MAX_DEVICES = 1\n",
165
    "    NUM_WORKERS = 0\n",
166
    "    bs = 16\n",
167
    "\n",
168
    "#if CLOUD and (not CLOUD_SINGLE):\n",
169
    "#    devices = xm.get_xla_supported_devices(max_devices=MAX_DEVICES)"
170
   ]
171
  },
172
  {
173
   "cell_type": "code",
174
   "execution_count": 4,
175
   "metadata": {},
176
   "outputs": [],
177
   "source": [
178
    "SEED = 2351\n",
179
    "\n",
180
    "def setSeeds(seed):\n",
181
    "    np.random.seed(seed)\n",
182
    "    torch.manual_seed(seed)\n",
183
    "    torch.cuda.manual_seed(seed)\n",
184
    "\n",
185
    "setSeeds(SEED)\n",
186
    "torch.backends.cudnn.deterministic = True"
187
   ]
188
  },
189
  {
190
   "cell_type": "code",
191
   "execution_count": 5,
192
   "metadata": {},
193
   "outputs": [],
194
   "source": [
195
    "def getDSName(dataset):\n",
196
    "    if dataset == 6:\n",
197
    "        return 'Densenet201_F3'\n",
198
    "    elif dataset == 7:\n",
199
    "        return 'Densenet161_F3'\n",
200
    "    elif dataset == 8:\n",
201
    "        return 'Densenet169_F3'\n",
202
    "    elif dataset == 9:\n",
203
    "        return 'se_resnext101_32x4d_F3'\n",
204
    "    elif dataset == 10:\n",
205
    "        return 'se_resnet101_F3'\n",
206
    "    elif dataset == 11:\n",
207
    "        return 'se_resnext101_32x4d_F5'\n",
208
    "    elif dataset == 12:\n",
209
    "        return 'se_resnet101_F5'\n",
210
    "    elif dataset == 13:\n",
211
    "        return 'se_resnet101_focal_F5'\n",
212
    "    elif dataset == 14:\n",
213
    "        return 'se_resnet101_F5'\n",
214
    "    else: assert False"
215
   ]
216
  },
217
  {
218
   "cell_type": "code",
219
   "execution_count": 6,
220
   "metadata": {},
221
   "outputs": [],
222
   "source": [
223
    "def getDSParams(dataset):\n",
224
    "    if dataset == 6:\n",
225
    "        return 'Densenet201','_3','_v5',240,'','','classifier',''\n",
226
    "    elif dataset == 7:\n",
227
    "        return 'Densenet161','_3','_v4',552,'2','','classifier',''\n",
228
    "    elif dataset == 8:\n",
229
    "        return 'Densenet169','_3','_v3',208,'2','','classifier',''\n",
230
    "    elif dataset == 9:\n",
231
    "        return 'se_resnext101_32x4d','','',256,'','_tta','classifier',''\n",
232
    "    elif dataset == 10:\n",
233
    "        return 'se_resnet101','','',256,'','','classifier',''\n",
234
    "    elif dataset == 11:\n",
235
    "        return 'se_resnext101_32x4d_5','','',256,'','','new',''\n",
236
    "    elif dataset == 12:\n",
237
    "        return 'se_resnet101_5','','',256,'','','new',''\n",
238
    "    elif dataset == 13:\n",
239
    "        return 'se_resnet101_5f','','',256,'','','new','focal_'\n",
240
    "    elif dataset == 14:\n",
241
    "        return 'se_resnet101_5n','','',256,'','','new','stage2_'\n",
242
    "    else: assert False\n",
243
    "\n",
244
    "def getNFolds(dataset):\n",
245
    "    if dataset <= 10:\n",
246
    "        return 3\n",
247
    "    elif dataset <= 14:\n",
248
    "        return 5\n",
249
    "    else: assert False\n",
250
    "\n",
251
    "my_datasets3 = [7,9]\n",
252
    "my_datasets5 = [11,12,13]\n",
253
    "my_len = len(my_datasets3) + len(my_datasets5)"
254
   ]
255
  },
256
  {
257
   "cell_type": "code",
258
   "execution_count": 7,
259
   "metadata": {},
260
   "outputs": [],
261
   "source": [
262
    "cols_le,cols_float,cols_bool = pickle.load(open(PATH_WORK/'covs','rb'))\n",
263
    "meta_cols = cols_bool + cols_float\n",
264
    "\n",
265
    "#meta_cols = [c for c in meta_cols if c not in ['SeriesPP']]\n",
266
    "#cols_le = cols_le[:-1]"
267
   ]
268
  },
269
  {
270
   "cell_type": "markdown",
271
   "metadata": {},
272
   "source": [
273
    "# Pre-processing"
274
   ]
275
  },
276
  {
277
   "cell_type": "code",
278
   "execution_count": 34,
279
   "metadata": {},
280
   "outputs": [],
281
   "source": [
282
    "def loadMetadata(rerun=False):\n",
283
    "    if rerun:\n",
284
    "        train_dedup = pd.read_csv(PATH_WORK/'yuval'/'train_dedup.csv')\n",
285
    "        pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n",
286
    "\n",
287
    "        assert len(pids) == 17079\n",
288
    "        assert len(np.unique(pids)) == 17079\n",
289
    "\n",
290
    "        for fol in folding:\n",
291
    "            assert len(fol[0]) + len(fol[1]) == 17079\n",
292
    "\n",
293
    "        assert len(folding[0][1]) + len(folding[1][1]) + len(folding[2][1]) == 17079\n",
294
    "\n",
295
    "        assert len(train_dedup.PID.unique()) == 17079\n",
296
    "\n",
297
    "        train_dedup['fold'] = np.nan\n",
298
    "\n",
299
    "        for fold in range(3):\n",
300
    "            train_dedup.loc[train_dedup.PID.isin(pids[folding[fold][1]]),'fold'] = fold\n",
301
    "\n",
302
    "        assert train_dedup.fold.isnull().sum() == 0\n",
303
    "        \n",
304
    "        oof5 = pickle.load(open(PATH_DISK/'yuval'/'OOF_validation_image_ids_5.pkl','rb'))\n",
305
    "        for fold in range(5):\n",
306
    "            train_dedup.loc[train_dedup.PatientID.isin(oof5[fold]),'fold5'] = fold\n",
307
    "        assert train_dedup.fold5.isnull().sum() == 0\n",
308
    "        \n",
309
    "        train_md = pd.read_csv(PATH_WORK/'train_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
310
    "        train_md['img_id'] = train_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
311
    "\n",
312
    "        ids_df = train_dedup[['fold','PatientID','fold5']]\n",
313
    "        ids_df.columns = ['fold','img_id','fold5']\n",
314
    "\n",
315
    "        train_md = ids_df.join(train_md.set_index('img_id'), on = 'img_id')\n",
316
    "\n",
317
    "        pickle.dump(train_md, open(PATH_WORK/'train.ids.df','wb'))\n",
318
    "\n",
319
    "        #test_md = pickle.load(open(PATH_WORK/'test.post.processed.1','rb'))\n",
320
    "        test_md = pd.read_csv(PATH_WORK/'test_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
321
    "        test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
322
    "\n",
323
    "        filename = PATH_WORK/'yuval'/'test_indexes.pkl'\n",
324
    "        test_ids = pickle.load(open(filename,'rb'))\n",
325
    "\n",
326
    "        test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n",
327
    "        test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n",
328
    "        \n",
329
    "        assert len(test_md.SeriesInstanceUID.unique()) == 2214\n",
330
    "        \n",
331
    "        #train_md = pd.concat([train_md, test_md], axis = 0, sort=False).reset_index(drop=True)\n",
332
    "        \n",
333
    "        #pickle.dump(train_md, open(PATH_WORK/'train.ids.df','wb'))\n",
334
    "        pickle.dump(test_md, open(PATH_WORK/'test.ids.df','wb'))\n",
335
    "    else:\n",
336
    "        train_md = pickle.load(open(PATH_WORK/'train.ids.df','rb'))\n",
337
    "        test_md = pickle.load(open(PATH_WORK/'test.ids.df','rb'))\n",
338
    "    \n",
339
    "    return train_md, test_md"
340
   ]
341
  },
342
  {
343
   "cell_type": "code",
344
   "execution_count": 17,
345
   "metadata": {},
346
   "outputs": [],
347
   "source": [
348
    "def loadMetadata2(rerun=False):\n",
349
    "    if rerun:\n",
350
    "        train_dedup2 = pd.read_csv(PATH_WORK/'yuval'/'train_stage2.csv')\n",
351
    "\n",
352
    "        assert train_dedup2.PatientID.nunique() == 752797\n",
353
    "        assert train_dedup2.PID.nunique() == 18938\n",
354
    "\n",
355
    "        pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n",
356
    "        pids3, folding3 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_3_stage_2.pkl','rb'))\n",
357
    "        pids5, folding5 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_5_stage_2.pkl','rb'))\n",
358
    "\n",
359
    "        assert np.array([len(folding3[i][1]) for i in range(3)]).sum() == 18938\n",
360
    "        assert np.array([len(folding3[i][0]) for i in range(3)]).sum() == 2*18938\n",
361
    "        \n",
362
    "        for i in range(3):\n",
363
    "            assert set(pids[folding[i][1]]).issubset(set(pids3[folding3[i][1]]))\n",
364
    "\n",
365
    "        for i in range(3):\n",
366
    "            train_dedup2.loc[train_dedup2.PID.isin(pids3[folding3[i][1]]),'fold'] = i\n",
367
    "        assert train_dedup2.fold.isnull().sum() == 0\n",
368
    "\n",
369
    "        for i in range(5):\n",
370
    "            train_dedup2.loc[train_dedup2.PID.isin(pids5[folding5[i][1]]),'fold5'] = i\n",
371
    "        assert train_dedup2.fold5.isnull().sum() == 0\n",
372
    "        \n",
373
    "        oof5 = pickle.load(open(PATH_DISK/'yuval'/'OOF_validation_image_ids_5.pkl','rb'))\n",
374
    "        for i in range(5):\n",
375
    "            assert set(oof5[i]).issubset(set(train_dedup2.loc[train_dedup2.fold5==i,'PatientID']))\n",
376
    "        \n",
377
    "        train_csv = pd.read_csv(PATH/'stage_2_train.csv')\n",
378
    "        train_csv = train_csv.loc[~train_csv.ID.duplicated()].sort_values('ID').reset_index(drop=True)        \n",
379
    "        all_sop_ids = train_csv.ID.str.split('_').apply(lambda x: x[0]+'_'+x[1]).unique()\n",
380
    "        train_df = pd.DataFrame(train_csv.Label.values.reshape((-1,6)), columns = all_ich)\n",
381
    "        train_df['sop_id'] = all_sop_ids\n",
382
    "        \n",
383
    "        old_test = pd.read_csv(PATH_WORK/'test_md.csv')\n",
384
    "        old_test = old_test.drop(all_ich,axis=1)\n",
385
    "        old_test = old_test.join(train_df.set_index('sop_id'), on = 'SOPInstanceUID')\n",
386
    "        \n",
387
    "        train_md2 = pd.concat([pd.read_csv(PATH_WORK/'train_md.csv'), old_test],\\\n",
388
    "                              axis=0,sort=False)\n",
389
    "        train_md2['img_id'] = train_md2.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
390
    "        \n",
391
    "        ids_df = train_dedup2[['fold','PatientID','fold5']]\n",
392
    "        ids_df.columns = ['fold','img_id','fold5']\n",
393
    "\n",
394
    "        train_md2 = ids_df.join(train_md2.set_index('img_id'), on = 'img_id')\n",
395
    "        assert train_md2.test.sum() == 78545\n",
396
    "        \n",
397
    "        pickle.dump(train_md2, open(PATH_WORK/'train2.ids.df','wb'))\n",
398
    "        \n",
399
    "        test_md = pd.read_csv(PATH_WORK/'test2_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
400
    "        test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
401
    "\n",
402
    "        filename = PATH_WORK/'yuval/test_indexes_stage2.pkl'\n",
403
    "        test_ids = pickle.load(open(filename,'rb'))\n",
404
    "\n",
405
    "        test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n",
406
    "        test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n",
407
    "\n",
408
    "        assert len(test_md.SeriesInstanceUID.unique()) == 3518\n",
409
    "\n",
410
    "        pickle.dump(test_md, open(PATH_WORK/'test2.ids.df','wb'))\n",
411
    "\n",
412
    "    else:\n",
413
    "        train_md2 = pickle.load(open(PATH_WORK/'train2.ids.df','rb'))\n",
414
    "        test_md = pickle.load(open(PATH_WORK/'test2.ids.df','rb'))\n",
415
    "    \n",
416
    "    return train_md2, test_md"
417
   ]
418
  },
419
  {
420
   "cell_type": "code",
421
   "execution_count": 34,
422
   "metadata": {},
423
   "outputs": [],
424
   "source": [
425
    "def loadMetadata3(rerun=False):\n",
426
    "    if rerun:\n",
427
    "        train_md = pickle.load(open(PATH_WORK/'train.ids.df','rb'))\n",
428
    "        test_md = pickle.load(open(PATH_WORK/'test.ids.df','rb'))\n",
429
    "        train_md = pd.concat([train_md, test_md], axis=0,sort=False).reset_index(drop=True)\n",
430
    "        \n",
431
    "        train_md.fold = np.nan\n",
432
    "        train_md['PID'] = train_md.PatientID.str.split('_').apply(lambda x: x[1])\n",
433
    "        \n",
434
    "        pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n",
435
    "        pids3, folding3 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_3_stage_2.pkl','rb'))\n",
436
    "        pids5, folding5 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_5_stage_2.pkl','rb'))\n",
437
    "\n",
438
    "        for i in range(3):\n",
439
    "            train_md.loc[train_md.PID.isin(pids3[folding3[i][1]]),'fold'] = i\n",
440
    "        assert train_md.fold.isnull().sum() == 0\n",
441
    "\n",
442
    "        for i in range(5):\n",
443
    "            train_md.loc[train_md.PID.isin(pids5[folding5[i][1]]),'fold5'] = i\n",
444
    "        assert train_md.fold5.isnull().sum() == 0\n",
445
    "        \n",
446
    "        train_md.weights = train_md.weights.fillna(1)\n",
447
    "        \n",
448
    "        pickle.dump(train_md, open(PATH_WORK/'train3.ids.df','wb'))\n",
449
    "        \n",
450
    "        test_md = pd.read_csv(PATH_WORK/'test2_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
451
    "        test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
452
    "\n",
453
    "        filename = PATH_WORK/'yuval/test_indexes_stage2.pkl'\n",
454
    "        test_ids = pickle.load(open(filename,'rb'))\n",
455
    "\n",
456
    "        test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n",
457
    "        test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n",
458
    "\n",
459
    "        assert len(test_md.SeriesInstanceUID.unique()) == 3518\n",
460
    "\n",
461
    "        pickle.dump(test_md, open(PATH_WORK/'test3.ids.df','wb'))\n",
462
    "    else:\n",
463
    "        train_md = pickle.load(open(PATH_WORK/'train3.ids.df','rb'))\n",
464
    "        test_md = pickle.load(open(PATH_WORK/'test3.ids.df','rb'))\n",
465
    "    \n",
466
    "    return train_md, test_md"
467
   ]
468
  },
469
  {
470
   "cell_type": "code",
471
   "execution_count": 66,
472
   "metadata": {},
473
   "outputs": [],
474
   "source": [
475
    "#train_md, test_md = loadMetadata(True)"
476
   ]
477
  },
478
  {
479
   "cell_type": "code",
480
   "execution_count": 18,
481
   "metadata": {},
482
   "outputs": [],
483
   "source": [
484
    "#train_md = loadMetadata2(True)"
485
   ]
486
  },
487
  {
488
   "cell_type": "code",
489
   "execution_count": null,
490
   "metadata": {},
491
   "outputs": [],
492
   "source": []
493
  },
494
  {
495
   "cell_type": "code",
496
   "execution_count": 7,
497
   "metadata": {},
498
   "outputs": [],
499
   "source": [
500
    "def preprocessedData(dataset, folds = range(3), fold_col=None, do_test=True, do_test2=True, do_train=True):\n",
501
    "    assert dataset >= 6\n",
502
    "    \n",
503
    "    dataset_name, filename_add, filename_add2, feat_sz, ds_num, test_fix, dsft, focal = getDSParams(dataset)\n",
504
    "    \n",
505
    "    #if 'train_md' not in globals() or 'test_md' not in globals():\n",
506
    "    #    train_md, test_md = loadMetadata(False)\n",
507
    "    \n",
508
    "    PATH_DS = PATH_DISK/'features/{}{}'.format(dataset_name,filename_add2)\n",
509
    "    if not PATH_DS.is_dir():\n",
510
    "        PATH_DS.mkdir()\n",
511
    "    if not (PATH_DS/'train').is_dir():\n",
512
    "        (PATH_DS/'train').mkdir()\n",
513
    "    if not (PATH_DS/'test').is_dir():\n",
514
    "        (PATH_DS/'test').mkdir()\n",
515
    "    if not (PATH_DS/'test2').is_dir():\n",
516
    "        (PATH_DS/'test2').mkdir()\n",
517
    "    \n",
518
    "    if fold_col is None:\n",
519
    "        fold_col = 'fold'\n",
520
    "        if len(folds) == 5:\n",
521
    "            fold_col = 'fold5'\n",
522
    "    \n",
523
    "    for fold in folds:\n",
524
    "        \n",
525
    "        if do_train:\n",
526
    "            filename = PATH_DISK/'yuval'/\\\n",
527
    "                'model_{}{}_version_{}_splits_{}type_features_train_tta{}_split_{}.pkl'\\\n",
528
    "                .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''), \\\n",
529
    "                        filename_add, dsft, focal, ds_num, fold)\n",
530
    "            feats = pickle.load(open(filename,'rb'))\n",
531
    "\n",
532
    "            print('dataset',dataset,'fold',fold,'feats size', feats.shape)\n",
533
    "            if dataset <= 13:\n",
534
    "                train_md_loc = train_md.loc[~train_md.test]\n",
535
    "            else:\n",
536
    "                train_md_loc = train_md\n",
537
    "            \n",
538
    "            assert len(feats) == 4*len(train_md_loc)\n",
539
    "            assert feats.shape[1] == feat_sz\n",
540
    "            means = feats.mean(0,keepdim=True)\n",
541
    "            stds = feats.std(0,keepdim=True)\n",
542
    "\n",
543
    "            feats = feats - means\n",
544
    "            feats = torch.where(stds > 0, feats/stds, feats)\n",
545
    "\n",
546
    "            for i in range(4):\n",
547
    "                feats_sub1 = feats[torch.BoolTensor(np.arange(len(feats))%4 == i)]\n",
548
    "                feats_sub2 = feats_sub1[torch.BoolTensor(train_md_loc[fold_col] != fold)]\n",
549
    "                pickle.dump(feats_sub2, open(PATH_DS/'train/train.f{}.a{}'.format(fold,i),'wb'))\n",
550
    "\n",
551
    "                feats_sub2 = feats_sub1[torch.BoolTensor(train_md_loc[fold_col] == fold)]\n",
552
    "                pickle.dump(feats_sub2, open(PATH_DS/'train/valid.f{}.a{}'.format(fold,i),'wb'))\n",
553
    "\n",
554
    "                if i==0:\n",
555
    "                    black_feats = feats_sub1[torch.BoolTensor(train_md_loc.img_id == all_black)].squeeze()\n",
556
    "                    pickle.dump(black_feats, open(PATH_DS/'train/black.f{}'.format(fold),'wb'))\n",
557
    "        \n",
558
    "        if do_test:\n",
559
    "            filename = PATH_DISK/'yuval'/\\\n",
560
    "                'model_{}{}_version_{}_splits_{}type_features_test{}{}_split_{}.pkl'\\\n",
561
    "                .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''),\n",
562
    "                        filename_add,dsft,focal,ds_num,test_fix,fold)\n",
563
    "            feats = pickle.load(open(filename,'rb'))\n",
564
    "\n",
565
    "            assert len(feats) == 8*len(test_md)\n",
566
    "            assert feats.shape[1] == feat_sz\n",
567
    "\n",
568
    "            feats = feats - means\n",
569
    "            feats = torch.where(stds > 0, feats/stds, feats)\n",
570
    "\n",
571
    "            for i in range(8):\n",
572
    "                feats_sub = feats[torch.BoolTensor(np.arange(len(feats))%8 == i)]\n",
573
    "                pickle.dump(feats_sub, open(PATH_DS/'test/test.f{}.a{}'.format(fold,i),'wb'))\n",
574
    "                assert len(feats_sub) == len(test_md)\n",
575
    "        \n",
576
    "        if do_test2:\n",
577
    "            filename = PATH_DISK/'yuval'/\\\n",
578
    "                'model_{}{}_version_{}_splits_{}type_features_test_stage2_split_{}.pkl'\\\n",
579
    "                .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''),\n",
580
    "                        filename_add,dsft,focal,fold)\n",
581
    "            feats = pickle.load(open(filename,'rb'))\n",
582
    "\n",
583
    "            assert len(feats) == 8*len(test_md)\n",
584
    "            assert feats.shape[1] == feat_sz\n",
585
    "\n",
586
    "            feats = feats - means\n",
587
    "            feats = torch.where(stds > 0, feats/stds, feats)\n",
588
    "\n",
589
    "            for i in range(8):\n",
590
    "                feats_sub = feats[torch.BoolTensor(np.arange(len(feats))%8 == i)]\n",
591
    "                pickle.dump(feats_sub, open(PATH_DS/'test2/test.f{}.a{}'.format(fold,i),'wb'))\n",
592
    "                assert len(feats_sub) == len(test_md)"
593
   ]
594
  },
595
  {
596
   "cell_type": "code",
597
   "execution_count": null,
598
   "metadata": {},
599
   "outputs": [],
600
   "source": []
601
  },
602
  {
603
   "cell_type": "code",
604
   "execution_count": null,
605
   "metadata": {},
606
   "outputs": [],
607
   "source": []
608
  },
609
  {
610
   "cell_type": "markdown",
611
   "metadata": {},
612
   "source": [
613
    "# Dataset"
614
   ]
615
  },
616
  {
617
   "cell_type": "code",
618
   "execution_count": 86,
619
   "metadata": {},
620
   "outputs": [],
621
   "source": [
622
    "if False:\n",
623
    "    path = PATH_WORK/'features/densenet161_v3/train/ID_992b567eb6'\n",
624
    "    black_feats = pickle.load(open(path,'rb'))[41]"
625
   ]
626
  },
627
  {
628
   "cell_type": "code",
629
   "execution_count": 87,
630
   "metadata": {},
631
   "outputs": [],
632
   "source": [
633
    "class RSNA_DataSet(D.Dataset):\n",
634
    "    def __init__(self, metadata, dataset, mode='train', bs=None, fold=0):\n",
635
    "        \n",
636
    "        super(RSNA_DataSet, self).__init__()\n",
637
    "        \n",
638
    "        dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
639
    "        num_folds = getNFolds(dataset)\n",
640
    "        \n",
641
    "        self.dataset_name = dataset_name\n",
642
    "        self.filename_add2 = filename_add2\n",
643
    "        \n",
644
    "        folds_col = 'fold'\n",
645
    "        if num_folds == 5:\n",
646
    "            folds_col = 'fold5'\n",
647
    "        \n",
648
    "        if (dataset <= 13) and (mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n",
649
    "            if mode == 'train':\n",
650
    "                self.test_mask = torch.BoolTensor((metadata.loc[metadata.test, folds_col] != fold).values)\n",
651
    "            else:\n",
652
    "                self.test_mask = torch.BoolTensor((metadata.loc[metadata.test, folds_col] == fold).values)\n",
653
    "        \n",
654
    "        if mode == 'train':\n",
655
    "            md = metadata.loc[metadata[folds_col] != fold].copy().reset_index(drop=True)\n",
656
    "        elif mode == 'valid':\n",
657
    "            md = metadata.loc[metadata[folds_col] == fold].copy().reset_index(drop=True)\n",
658
    "        else:\n",
659
    "            md = metadata.copy().reset_index(drop=True)\n",
660
    "        \n",
661
    "        series = np.sort(md.SeriesInstanceUID.unique())\n",
662
    "        md = md.set_index('SeriesInstanceUID', drop=True)\n",
663
    "        \n",
664
    "        samples_add = 0\n",
665
    "        if (mode != 'train') and not DATA_SMALL:\n",
666
    "            batch_num = -((-len(series))//(bs*MAX_DEVICES))\n",
667
    "            samples_add = batch_num*bs*MAX_DEVICES - len(series)\n",
668
    "            print('adding dummy serieses', samples_add)\n",
669
    "        \n",
670
    "        #self.records = df.to_records(index=False)\n",
671
    "        self.mode = mode\n",
672
    "        self.real = np.concatenate([np.repeat(True,len(series)),np.repeat(False,samples_add)])\n",
673
    "        self.series = np.concatenate([series, random.sample(list(series),samples_add)])\n",
674
    "        self.metadata = md\n",
675
    "        self.dataset = dataset\n",
676
    "        self.fold = fold\n",
677
    "        \n",
678
    "        print('DataSet', dataset, mode, 'size', len(self.series), 'fold', fold)\n",
679
    "        \n",
680
    "        path = PATH_DISK/'features/{}{}/train/black.f{}'.format(self.dataset_name, self.filename_add2, fold)\n",
681
    "        self.black_feats = pickle.load(open(path,'rb')).squeeze()\n",
682
    "        \n",
683
    "        if WEIGHTED and (self.mode == 'train'):\n",
684
    "            tt = self.metadata['weights'].groupby(self.metadata.index).mean()\n",
685
    "            self.weights = tt.loc[self.series].values\n",
686
    "            print(pd.value_counts(self.weights))\n",
687
    "    \n",
688
    "    def setFeats(self, anum, epoch=0):\n",
689
    "        def getAPath(an,mode):\n",
690
    "            folder = 'test2' if mode == 'test' else 'train'\n",
691
    "            return PATH_DISK/'features/{}{}/{}/{}.f{}.a{}'\\\n",
692
    "                .format(self.dataset_name,self.filename_add2,folder,mode,self.fold,an)\n",
693
    "        \n",
694
    "        def getAPathFeats(mode, sz, mask=None):\n",
695
    "            max_a = 8 if mode == 'test' else 4\n",
696
    "            feats2 = torch.stack([pickle.load(open(getAPath(an,mode),'rb')) for an in range(max_a)])\n",
697
    "            if mask is not None:\n",
698
    "                feats2 = feats2[:,mask]\n",
699
    "            assert feats2.shape[1] == sz\n",
700
    "            feats = feats2[torch.randint(max_a,(sz,)), torch.arange(sz, dtype=torch.long)].squeeze()\n",
701
    "            return feats\n",
702
    "        \n",
703
    "        if self.dataset == 1: return\n",
704
    "        print('setFeats, augmentation', anum)\n",
705
    "        self.anum = anum\n",
706
    "        sz = len(self.metadata)\n",
707
    "        \n",
708
    "        assert anum <= 0\n",
709
    "        \n",
710
    "        if anum == -1:\n",
711
    "            if (self.dataset <= 13) and (self.mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n",
712
    "                feats = torch.cat([getAPathFeats(self.mode, sz - self.metadata.test.sum()), \n",
713
    "                                   getAPathFeats('test', self.metadata.test.sum(), self.test_mask)] ,axis=0)\n",
714
    "            else:\n",
715
    "                feats = getAPathFeats(self.mode, sz)\n",
716
    "        else:\n",
717
    "            if (self.dataset <= 13) and (self.mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n",
718
    "                feats = torch.cat([pickle.load(open(getAPath(anum,self.mode),'rb')),\n",
719
    "                                   pickle.load(open(getAPath(anum,'test'),'rb'))[self.test_mask]], axis=0)\n",
720
    "            else:\n",
721
    "                feats = pickle.load(open(getAPath(anum,self.mode),'rb'))\n",
722
    "        \n",
723
    "        self.feats = feats\n",
724
    "        self.epoch = epoch\n",
725
    "        assert len(feats) == sz\n",
726
    "    \n",
727
    "    def __getitem__(self, index):\n",
728
    "        \n",
729
    "        series_id = self.series[index]\n",
730
    "        #df = self.metadata.loc[self.metadata.SeriesInstanceUID == series_id].reset_index(drop=True)\n",
731
    "        df = self.metadata.loc[series_id].reset_index(drop=True)\n",
732
    "        \n",
733
    "        if self.dataset >= 6:\n",
734
    "            feats = self.feats[torch.BoolTensor(self.metadata.index.values == series_id)]\n",
735
    "        else: assert False\n",
736
    "        \n",
737
    "        order = np.argsort(df.pos_idx1.values)\n",
738
    "        df = df.sort_values(['pos_idx1'])\n",
739
    "        feats = feats[torch.LongTensor(order)]\n",
740
    "        \n",
741
    "        #if WEIGHTED:\n",
742
    "        #    non_black = torch.Tensor(df['weights'])\n",
743
    "        #else:\n",
744
    "        non_black = torch.ones(len(feats))\n",
745
    "        \n",
746
    "        feats = torch.cat([feats, torch.Tensor(df[meta_cols].values)], dim=1)\n",
747
    "        feats_le = torch.LongTensor(df[cols_le].values)\n",
748
    "        \n",
749
    "        target = torch.Tensor(df[all_ich].values)\n",
750
    "        \n",
751
    "        PAD = 12\n",
752
    "        \n",
753
    "        np.random.seed(1234 + index + 10000*self.epoch)\n",
754
    "        offset = np.random.randint(0, 61 - feats.shape[0])\n",
755
    "        \n",
756
    "        #offset = 0\n",
757
    "        bk_add = 0\n",
758
    "        top_pad = PAD + offset\n",
759
    "        if top_pad > 0:\n",
760
    "            dummy_row = torch.cat([self.black_feats, torch.Tensor(df.head(1)[meta_cols].values).squeeze()])\n",
761
    "            feats = torch.cat([dummy_row.repeat(top_pad,1), feats], dim=0)\n",
762
    "            feats_le = torch.cat([torch.LongTensor(df.head(1)[cols_le].values).squeeze().repeat(top_pad,1), feats_le])\n",
763
    "            if offset > 0:\n",
764
    "                non_black = torch.cat([bk_add + torch.zeros(offset), non_black])\n",
765
    "                target = torch.cat([torch.zeros((offset, len(all_ich))), target], dim=0)\n",
766
    "        bot_pad = 60 - len(df) - offset + PAD\n",
767
    "        if bot_pad > 0:\n",
768
    "            dummy_row = torch.cat([self.black_feats, torch.Tensor(df.tail(1)[meta_cols].values).squeeze()])\n",
769
    "            feats = torch.cat([feats, dummy_row.repeat(bot_pad,1)], dim=0)\n",
770
    "            feats_le = torch.cat([feats_le, torch.LongTensor(df.tail(1)[cols_le].values).squeeze().repeat(bot_pad,1)])\n",
771
    "            if (60 - len(df) - offset) > 0:\n",
772
    "                non_black = torch.cat([non_black, bk_add + torch.zeros(60 - len(df) - offset)])\n",
773
    "                target = torch.cat([target, torch.zeros((60 - len(df) - offset, len(all_ich)))], dim=0)\n",
774
    "        \n",
775
    "        assert feats_le.shape[0] == (60 + 2*PAD)\n",
776
    "        assert feats.shape[0] == (60 + 2*PAD)\n",
777
    "        assert target.shape[0] == 60\n",
778
    "        \n",
779
    "        idx = index\n",
780
    "        if not self.real[index]: idx = -1\n",
781
    "        \n",
782
    "        if self.mode == 'train':\n",
783
    "            return feats, feats_le, target, non_black\n",
784
    "        else:\n",
785
    "            return feats, feats_le, target, idx, offset\n",
786
    "    \n",
787
    "    def __len__(self):\n",
788
    "        return len(self.series) if not DATA_SMALL else int(0.01*len(self.series))"
789
   ]
790
  },
791
  {
792
   "cell_type": "code",
793
   "execution_count": 88,
794
   "metadata": {},
795
   "outputs": [],
796
   "source": [
797
    "def getCurrentBatch(dataset, fold=0, ver=VERSION):\n",
798
    "    sel_batch = None\n",
799
    "    for filename in os.listdir(PATH_DISK/'models'):\n",
800
    "        splits = filename.split('.')\n",
801
    "        if int(splits[2][1]) != fold: continue\n",
802
    "        if int(splits[3][1:]) != dataset: continue\n",
803
    "        if int(splits[4][1:]) != ver: continue\n",
804
    "        if sel_batch is None:\n",
805
    "            sel_batch = int(splits[1][1:])\n",
806
    "        else:\n",
807
    "            sel_batch = max(sel_batch, int(splits[1][1:]))\n",
808
    "    return sel_batch\n",
809
    "\n",
810
    "def modelFileName(dataset, fold=0, batch = 1, return_last = False, return_next = False, ver=VERSION):\n",
811
    "    sel_batch = batch\n",
812
    "    if return_last or return_next:\n",
813
    "        sel_batch = getCurrentBatch(fold=fold, dataset=dataset, ver=ver)\n",
814
    "        if return_last and sel_batch is None:\n",
815
    "            return None\n",
816
    "        if return_next:\n",
817
    "            if sel_batch is None: sel_batch = 1\n",
818
    "            else: sel_batch += 1\n",
819
    "    \n",
820
    "    return 'model.b{}.f{}.d{}.v{}'.format(sel_batch, fold, dataset, ver)"
821
   ]
822
  },
823
  {
824
   "cell_type": "markdown",
825
   "metadata": {},
826
   "source": [
827
    "# Model"
828
   ]
829
  },
830
  {
831
   "cell_type": "code",
832
   "execution_count": 6,
833
   "metadata": {},
834
   "outputs": [],
835
   "source": [
836
    "class BCEWithLogitsLoss(nn.Module):\n",
837
    "    def __init__(self, weight=None):\n",
838
    "        super().__init__()\n",
839
    "        self.weight = weight\n",
840
    "    \n",
841
    "    def forward(self, input, target, batch_weights = None, focal=0):\n",
842
    "        if focal == 0:\n",
843
    "            loss = (torch.log(1+torch.exp(input)) - target*input)*self.weight\n",
844
    "        else:\n",
845
    "            loss = torch.pow(1+torch.exp(input), -focal) * \\\n",
846
    "                (((1-target)*torch.exp(focal*input) + target) * torch.log(1+torch.exp(input)) - target*input)\n",
847
    "            loss = loss*self.weight\n",
848
    "        if batch_weights is not None:\n",
849
    "            loss = batch_weights*loss\n",
850
    "        return loss.mean()"
851
   ]
852
  },
853
  {
854
   "cell_type": "code",
855
   "execution_count": 90,
856
   "metadata": {},
857
   "outputs": [],
858
   "source": [
859
    "def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None):\n",
860
    "    \"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`.\"\n",
861
    "    layers = [nn.BatchNorm1d(n_in)] if bn else []\n",
862
    "    if p != 0: layers.append(nn.Dropout(p))\n",
863
    "    layers.append(nn.Linear(n_in, n_out))\n",
864
    "    if actn is not None: layers.append(actn)\n",
865
    "    return layers"
866
   ]
867
  },
868
  {
869
   "cell_type": "code",
870
   "execution_count": 91,
871
   "metadata": {},
872
   "outputs": [],
873
   "source": [
874
    "def noop(x): return x\n",
875
    "act_fun = nn.ReLU(inplace=True)\n",
876
    "\n",
877
    "def conv_layer(ni, nf, ks=3, act=True):\n",
878
    "    bn = nn.BatchNorm1d(nf)\n",
879
    "    layers = [nn.Conv1d(ni, nf, ks), bn]\n",
880
    "    if act: layers.append(act_fun)\n",
881
    "    return nn.Sequential(*layers)\n",
882
    "\n",
883
    "class ResBlock(nn.Module):\n",
884
    "    def __init__(self, ni, nh):\n",
885
    "        super().__init__()\n",
886
    "        layers  = [conv_layer(ni, nh, 1),\n",
887
    "                   conv_layer(nh, nh, 5, act=False)]\n",
888
    "        self.convs = nn.Sequential(*layers)\n",
889
    "        self.idconv = noop if (ni == nh) else conv_layer(ni, nh, 1, act=False)\n",
890
    "    \n",
891
    "    def forward(self, x): return act_fun(self.convs(x) + self.idconv(x[:,:,2:-2]))"
892
   ]
893
  },
894
  {
895
   "cell_type": "code",
896
   "execution_count": 77,
897
   "metadata": {},
898
   "outputs": [],
899
   "source": [
900
    "class ResNetModel(nn.Module):\n",
901
    "    def __init__(self, n_cont:int, feat_sz=2208):\n",
902
    "        super().__init__()\n",
903
    "        \n",
904
    "        self.le_sz = 10\n",
905
    "        assert self.le_sz == (len(cols_le))\n",
906
    "        le_in_sizes = np.array([5,5,7,4,4,11,4,6,3,3])\n",
907
    "        le_out_sizes = np.array([3,3,4,2,2,6,2,4,2,2])\n",
908
    "        le_out_sz = le_out_sizes.sum()\n",
909
    "        self.embeddings = nn.ModuleList([embedding(le_in_sizes[i], le_out_sizes[i]) for i in range(self.le_sz)])\n",
910
    "        \n",
911
    "        self.feat_sz = feat_sz\n",
912
    "        \n",
913
    "        self.n_cont = n_cont\n",
914
    "        scale = 4\n",
915
    "        \n",
916
    "        self.conv2D = nn.Conv2d(1,scale*16,(feat_sz + n_cont + le_out_sz,1))\n",
917
    "        self.bn1 = nn.BatchNorm1d(scale*16)\n",
918
    "        \n",
919
    "        self.res1 = ResBlock(scale*16,scale*16)\n",
920
    "        self.res2 = ResBlock(scale*16,scale*8)\n",
921
    "        \n",
922
    "        self.res3 = ResBlock(scale*24,scale*16)\n",
923
    "        self.res4 = ResBlock(scale*16,scale*8)\n",
924
    "        \n",
925
    "        self.res5 = ResBlock(scale*32,scale*16)\n",
926
    "        self.res6 = ResBlock(scale*16,scale*8)\n",
927
    "        \n",
928
    "        self.conv1D = nn.Conv1d(scale*40,6,1)\n",
929
    "    \n",
930
    "    def forward(self, x, x_le, x_le_mix = None, lambd = None) -> torch.Tensor:\n",
931
    "        x_le = [e(x_le[:,:,i]) for i,e in enumerate(self.embeddings)]\n",
932
    "        x_le = torch.cat(x_le, 2)\n",
933
    "        \n",
934
    "        x = torch.cat([x, x_le], 2)\n",
935
    "        x = x.transpose(1,2)\n",
936
    "        \n",
937
    "        #x = torch.cat([x[:,:self.feat_sz],self.bn_cont(x[:,self.feat_sz:])], dim=1)\n",
938
    "        x = x.reshape(x.shape[0],1,x.shape[1],x.shape[2])\n",
939
    "        \n",
940
    "        x = self.conv2D(x).squeeze()\n",
941
    "        x = self.bn1(x)\n",
942
    "        x = act_fun(x)\n",
943
    "        \n",
944
    "        x2 = self.res1(x)\n",
945
    "        x2 = self.res2(x2)\n",
946
    "        \n",
947
    "        x3 = torch.cat([x[:,:,4:-4], x2], 1)\n",
948
    "        \n",
949
    "        x3 = self.res3(x3)\n",
950
    "        x3 = self.res4(x3)\n",
951
    "        \n",
952
    "        x4 = torch.cat([x[:,:,8:-8], x2[:,:,4:-4], x3], 1)\n",
953
    "        \n",
954
    "        x4 = self.res5(x4)\n",
955
    "        x4 = self.res6(x4)\n",
956
    "        \n",
957
    "        x5 = torch.cat([x[:,:,12:-12], x2[:,:,8:-8], x3[:,:,4:-4], x4], 1)\n",
958
    "        x5 = self.conv1D(x5)\n",
959
    "        x5 = x5.transpose(1,2)\n",
960
    "        \n",
961
    "        return x5"
962
   ]
963
  },
964
  {
965
   "cell_type": "code",
966
   "execution_count": 93,
967
   "metadata": {},
968
   "outputs": [],
969
   "source": [
970
    "def trunc_normal_(x, mean:float=0., std:float=1.):\n",
971
    "    \"Truncated normal initialization.\"\n",
972
    "    # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12\n",
973
    "    return x.normal_().fmod_(2).mul_(std).add_(mean)"
974
   ]
975
  },
976
  {
977
   "cell_type": "code",
978
   "execution_count": 94,
979
   "metadata": {},
980
   "outputs": [],
981
   "source": [
982
    "def embedding(ni:int,nf:int) -> nn.Module:\n",
983
    "    \"Create an embedding layer.\"\n",
984
    "    emb = nn.Embedding(ni, nf)\n",
985
    "    # See https://arxiv.org/abs/1711.09160\n",
986
    "    with torch.no_grad(): trunc_normal_(emb.weight, std=0.01)\n",
987
    "    return emb"
988
   ]
989
  },
990
  {
991
   "cell_type": "code",
992
   "execution_count": 95,
993
   "metadata": {},
994
   "outputs": [],
995
   "source": [
996
    "class TabularModel(nn.Module):\n",
997
    "    \"Basic model for tabular data.\"\n",
998
    "    def __init__(self, n_cont:int, feat_sz=2208, fc_drop_p=0.3):\n",
999
    "        super().__init__()\n",
1000
    "        self.le_sz = 10\n",
1001
    "        assert len(cols_le) == self.le_sz\n",
1002
    "        le_in_sizes = np.array([5,5,7,4,4,11,4,6,3,3])\n",
1003
    "        le_out_sizes = np.array([3,3,4,2,2,6,2,4,2,2])\n",
1004
    "        le_out_sz = le_out_sizes.sum()\n",
1005
    "        self.embeddings = nn.ModuleList([embedding(le_in_sizes[i], le_out_sizes[i]) for i in range(self.le_sz)])\n",
1006
    "        #self.bn_cont = nn.BatchNorm1d(feat_sz + n_cont)\n",
1007
    "        \n",
1008
    "        self.feat_sz = feat_sz\n",
1009
    "        self.bn_cont = nn.BatchNorm1d(n_cont + le_out_sz)\n",
1010
    "        self.n_cont = n_cont\n",
1011
    "        self.fc_drop = nn.Dropout(p=fc_drop_p)\n",
1012
    "        self.relu = nn.ReLU(inplace=True)\n",
1013
    "        \n",
1014
    "        scale = 4\n",
1015
    "        \n",
1016
    "        self.conv2D_1 = nn.Conv2d(1,16*scale,(feat_sz + n_cont + le_out_sz,1))\n",
1017
    "        self.conv2D_2 = nn.Conv2d(1,16*scale,(feat_sz + n_cont + le_out_sz,5))\n",
1018
    "        self.bn_cont1 = nn.BatchNorm1d(32*scale)\n",
1019
    "        self.conv1D_1 = nn.Conv1d(32*scale,16*scale,3)\n",
1020
    "        self.conv1D_3 = nn.Conv1d(32*scale,16*scale,5,dilation=5)\n",
1021
    "        self.conv1D_2 = nn.Conv1d(32*scale,6,3)\n",
1022
    "        self.bn_cont2 = nn.BatchNorm1d(32*scale)\n",
1023
    "        self.bn_cont3 = nn.BatchNorm1d(6)\n",
1024
    "\n",
1025
    "        self.conv1D_4 = nn.Conv1d(32*scale,32*scale,3)\n",
1026
    "        self.bn_cont4 = nn.BatchNorm1d(32*scale)\n",
1027
    "\n",
1028
    "    def forward(self, x, x_le, x_le_mix = None, lambd = None) -> torch.Tensor:\n",
1029
    "        x_le = [e(x_le[:,:,i]) for i,e in enumerate(self.embeddings)]\n",
1030
    "        x_le = torch.cat(x_le, 2)\n",
1031
    "        \n",
1032
    "        if MIXUP and x_le_mix is not None:\n",
1033
    "            x_le_mix = [e(x_le_mix[:,:,i]) for i,e in enumerate(self.embeddings)]\n",
1034
    "            x_le_mix = torch.cat(x_le_mix, 2)\n",
1035
    "            x_le = lambd * x_le + (1-lambd) * x_le_mix\n",
1036
    "        \n",
1037
    "        #assert torch.isnan(x_le).any().cpu() == False\n",
1038
    "        x = torch.cat([x, x_le], 2)\n",
1039
    "        x = x.transpose(1,2)\n",
1040
    "        \n",
1041
    "        #x = torch.cat([x[:,:self.feat_sz],self.bn_cont(x[:,self.feat_sz:])], dim=1)\n",
1042
    "        \n",
1043
    "        x = x.reshape(x.shape[0],1,x.shape[1],x.shape[2])\n",
1044
    "        x = self.fc_drop(x)\n",
1045
    "        x = torch.cat([self.conv2D_1(x[:,:,:,2:(-2)]).squeeze(), \n",
1046
    "                       self.conv2D_2(x).squeeze()], dim=1)\n",
1047
    "        x = self.relu(x)\n",
1048
    "        x = self.bn_cont1(x)\n",
1049
    "        x = self.fc_drop(x)\n",
1050
    "        \n",
1051
    "        x = torch.cat([self.conv1D_1(x[:,:,9:(-9)]),\n",
1052
    "                       self.conv1D_3(x)], dim=1)\n",
1053
    "        x = self.relu(x)\n",
1054
    "        x = self.bn_cont2(x)\n",
1055
    "        x = self.fc_drop(x)\n",
1056
    "        \n",
1057
    "        x = self.conv1D_4(x)\n",
1058
    "        x = self.relu(x)\n",
1059
    "        x = self.bn_cont4(x)\n",
1060
    "        \n",
1061
    "        x = self.conv1D_2(x)\n",
1062
    "        x = x.transpose(1,2)\n",
1063
    "        \n",
1064
    "        return x"
1065
   ]
1066
  },
1067
  {
1068
   "cell_type": "markdown",
1069
   "metadata": {},
1070
   "source": [
1071
    "# Training"
1072
   ]
1073
  },
1074
  {
1075
   "cell_type": "code",
1076
   "execution_count": 96,
1077
   "metadata": {},
1078
   "outputs": [],
1079
   "source": [
1080
    "def train_loop_fn(model, loader, device, context = None):\n",
1081
    "    \n",
1082
    "    if CLOUD and (not CLOUD_SINGLE):\n",
1083
    "        tlen = len(loader._loader._loader)\n",
1084
    "        OUT_LOSS = 1000\n",
1085
    "        OUT_TIME = 4\n",
1086
    "        generator = loader\n",
1087
    "        device_num = int(str(device)[-1])\n",
1088
    "        dataset = loader._loader._loader.dataset\n",
1089
    "    else:\n",
1090
    "        tlen = len(loader)\n",
1091
    "        OUT_LOSS = 50\n",
1092
    "        OUT_TIME = 50\n",
1093
    "        generator = enumerate(loader)\n",
1094
    "        device_num = 1\n",
1095
    "        dataset = loader.dataset\n",
1096
    "    \n",
1097
    "    #print('Start training {}'.format(device), 'batches', tlen)\n",
1098
    "    \n",
1099
    "    criterion = BCEWithLogitsLoss(weight = torch.Tensor(class_weights).to(device))\n",
1100
    "    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.99))\n",
1101
    "    \n",
1102
    "    model.train()\n",
1103
    "    \n",
1104
    "    if CLOUD and TPU:\n",
1105
    "        tracker = xm.RateTracker()\n",
1106
    "\n",
1107
    "    tloss = 0\n",
1108
    "    tloss_count = 0\n",
1109
    "    \n",
1110
    "    st = time.time()\n",
1111
    "    mixup_collected = False\n",
1112
    "    x_le_mix = None\n",
1113
    "    lambd = None\n",
1114
    "    for i, (x, x_le, y, non_black) in generator:\n",
1115
    "        if (not CLOUD) or CLOUD_SINGLE:\n",
1116
    "            x = x.to(device)\n",
1117
    "            x_le = x_le.to(device)\n",
1118
    "            y = y.to(device)\n",
1119
    "            non_black = non_black.to(device)\n",
1120
    "        \n",
1121
    "        if MIXUP:\n",
1122
    "            if mixup_collected:\n",
1123
    "                lambd = np.random.beta(0.4, 0.4, y.size(0))\n",
1124
    "                lambd = torch.Tensor(lambd).to(device)[:,None,None]\n",
1125
    "                #shuffle = torch.randperm(y.size(0)).to(device)\n",
1126
    "                x = lambd * x + (1-lambd) * x_mix #x[shuffle]\n",
1127
    "                #x_le = lambd * x_le + (1-lambd) * x_le_mix #x[shuffle]\n",
1128
    "                mixup_collected = False\n",
1129
    "            else:\n",
1130
    "                x_mix = x\n",
1131
    "                x_le_mix = x_le\n",
1132
    "                y_mix = y\n",
1133
    "                mixup_collected = True\n",
1134
    "                continue\n",
1135
    "        \n",
1136
    "        optimizer.zero_grad()\n",
1137
    "        output = model(x, x_le, x_le_mix, lambd)\n",
1138
    "        \n",
1139
    "        if MIXUP:\n",
1140
    "            if NO_BLACK_LOSS:\n",
1141
    "                loss = criterion(output, y, lambd*non_black[:,:,None]) \\\n",
1142
    "                     + criterion(output, y_mix, (1-lambd)*non_black[:,:,None])\n",
1143
    "            else:\n",
1144
    "                loss = criterion(output, y, lambd) + criterion(output, y_mix, 1-lambd)\n",
1145
    "            del x_mix, y_mix\n",
1146
    "        else:\n",
1147
    "            if NO_BLACK_LOSS:\n",
1148
    "                loss = criterion(output, y, non_black[:,:,None])\n",
1149
    "            else:\n",
1150
    "                loss = criterion(output, y)\n",
1151
    "        \n",
1152
    "        loss.backward()\n",
1153
    "        \n",
1154
    "        tloss += len(y)*loss.cpu().detach().item()\n",
1155
    "        tloss_count += len(y)\n",
1156
    "        \n",
1157
    "        if (CLOUD or CLOUD_SINGLE) and TPU:\n",
1158
    "            xm.optimizer_step(optimizer)\n",
1159
    "            if CLOUD_SINGLE:\n",
1160
    "                xm.mark_step()\n",
1161
    "        else:\n",
1162
    "            optimizer.step()\n",
1163
    "        \n",
1164
    "        if CLOUD and TPU:\n",
1165
    "            tracker.add(len(y))\n",
1166
    "        \n",
1167
    "        st_passed = time.time() - st\n",
1168
    "        if (i+1)%OUT_TIME == 0 and device_num == 1:\n",
1169
    "            #print(torch_xla._XLAC._xla_metrics_report())\n",
1170
    "            print('Batch {} device: {} time passed: {:.3f} time per batch: {:.3f}'\n",
1171
    "                .format(i+1, device, st_passed, st_passed/(i+1)))\n",
1172
    "        \n",
1173
    "        del loss, output, y, x, x_le\n",
1174
    "    \n",
1175
    "    return tloss, tloss_count"
1176
   ]
1177
  },
1178
  {
1179
   "cell_type": "code",
1180
   "execution_count": 97,
1181
   "metadata": {},
1182
   "outputs": [],
1183
   "source": [
1184
    "@torch.no_grad()\n",
1185
    "def val_loop_fn(model, loader, device, context = None):\n",
1186
    "    \n",
1187
    "    if CLOUD and (not CLOUD_SINGLE):\n",
1188
    "        tlen = len(loader._loader._loader)\n",
1189
    "        OUT_LOSS = 1000\n",
1190
    "        OUT_TIME = 4\n",
1191
    "        generator = loader\n",
1192
    "        device_num = int(str(device)[-1])\n",
1193
    "    else:\n",
1194
    "        tlen = len(loader)\n",
1195
    "        OUT_LOSS = 1000\n",
1196
    "        OUT_TIME = 50\n",
1197
    "        generator = enumerate(loader)\n",
1198
    "        device_num = 1\n",
1199
    "    \n",
1200
    "    #print('Start validating {}'.format(device), 'batches', tlen)\n",
1201
    "    \n",
1202
    "    st = time.time()\n",
1203
    "    model.eval()\n",
1204
    "    \n",
1205
    "    results = []\n",
1206
    "    indices = []\n",
1207
    "    offsets = []\n",
1208
    "    \n",
1209
    "    for i, (x, x_le, y, idx, offset) in generator:\n",
1210
    "        \n",
1211
    "        if (not CLOUD) or CLOUD_SINGLE:\n",
1212
    "            x = x.to(device)\n",
1213
    "            x_le = x_le.to(device)\n",
1214
    "        \n",
1215
    "        output = model(x, x_le)\n",
1216
    "        assert torch.isnan(output).any().cpu() == False\n",
1217
    "        output = torch.sigmoid(output)\n",
1218
    "        assert torch.isnan(output).any().cpu() == False\n",
1219
    "        \n",
1220
    "        mask = (idx >= 0)\n",
1221
    "        results.append(output[mask].cpu().detach().numpy())\n",
1222
    "        indices.append(idx[mask].cpu().detach().numpy())\n",
1223
    "        offsets.append(offset[mask].cpu().detach().numpy())\n",
1224
    "        \n",
1225
    "        st_passed = time.time() - st\n",
1226
    "        if (i+1)%OUT_TIME == 0 and device_num == 1:\n",
1227
    "            print('Batch {} device: {} time passed: {:.3f} time per batch: {:.3f}'\n",
1228
    "                  .format(i+1, device, st_passed, st_passed/(i+1)))\n",
1229
    "        \n",
1230
    "        del output, y, x, x_le, idx, offset\n",
1231
    "    \n",
1232
    "    results = np.concatenate(results)\n",
1233
    "    indices = np.concatenate(indices)\n",
1234
    "    offsets = np.concatenate(offsets)\n",
1235
    "    \n",
1236
    "    return results, indices, offsets"
1237
   ]
1238
  },
1239
  {
1240
   "cell_type": "code",
1241
   "execution_count": 98,
1242
   "metadata": {},
1243
   "outputs": [],
1244
   "source": [
1245
    "@torch.no_grad()\n",
1246
    "def test_loop_fn(model, loader, device, context = None):\n",
1247
    "    \n",
1248
    "    if CLOUD and (not CLOUD_SINGLE):\n",
1249
    "        tlen = len(loader._loader._loader)\n",
1250
    "        OUT_LOSS = 1000\n",
1251
    "        OUT_TIME = 100\n",
1252
    "        generator = loader\n",
1253
    "        device_num = int(str(device)[-1])\n",
1254
    "    else:\n",
1255
    "        tlen = len(loader)\n",
1256
    "        OUT_LOSS = 1000\n",
1257
    "        OUT_TIME = 10\n",
1258
    "        generator = enumerate(loader)\n",
1259
    "        device_num = 1\n",
1260
    "    \n",
1261
    "    #print('Start testing {}'.format(device), 'batches', tlen)\n",
1262
    "    \n",
1263
    "    st = time.time()\n",
1264
    "    model.eval()\n",
1265
    "    \n",
1266
    "    results = []\n",
1267
    "    indices = []\n",
1268
    "    offsets = []\n",
1269
    "    \n",
1270
    "    for i, (x, x_le, y, idx, offset) in generator:\n",
1271
    "        \n",
1272
    "        if (not CLOUD) or CLOUD_SINGLE:\n",
1273
    "            x = x.to(device)\n",
1274
    "            x_le = x_le.to(device)\n",
1275
    "        \n",
1276
    "        output = torch.sigmoid(model(x, x_le))\n",
1277
    "        \n",
1278
    "        mask = (idx >= 0)\n",
1279
    "        results.append(output[mask].cpu().detach().numpy())\n",
1280
    "        indices.append(idx[mask].cpu().detach().numpy())\n",
1281
    "        offsets.append(offset[mask].cpu().detach().numpy())\n",
1282
    "        \n",
1283
    "        st_passed = time.time() - st\n",
1284
    "        if (i+1)%OUT_TIME == 0 and device_num == 1:\n",
1285
    "            print('B{} -> time passed: {:.3f} time per batch: {:.3f}'.format(i+1, st_passed, st_passed/(i+1)))\n",
1286
    "        \n",
1287
    "        del output, x, y, idx, offset\n",
1288
    "    \n",
1289
    "    return np.concatenate(results), np.concatenate(indices), np.concatenate(offsets)"
1290
   ]
1291
  },
1292
  {
1293
   "cell_type": "code",
1294
   "execution_count": 76,
1295
   "metadata": {},
1296
   "outputs": [],
1297
   "source": [
1298
    "def train_one(dataset, weight=None, load_model=True, epochs=1, bs=100, fold=0, init_ver=None):\n",
1299
    "    \n",
1300
    "    st0 = time.time()\n",
1301
    "    dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
1302
    "    \n",
1303
    "    cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n",
1304
    "    if cur_epoch is None: cur_epoch = 0\n",
1305
    "    print('completed epochs:', cur_epoch, 'starting now:', epochs)\n",
1306
    "    \n",
1307
    "    setSeeds(SEED + cur_epoch)\n",
1308
    "    \n",
1309
    "    assert train_md.weights.isnull().sum() == 0\n",
1310
    "    \n",
1311
    "    if dataset >= 6:\n",
1312
    "        trn_ds = RSNA_DataSet(train_md, mode='train', bs=bs, fold=fold, dataset=dataset)\n",
1313
    "        val_ds = RSNA_DataSet(train_md, mode='valid', bs=bs, fold=fold, dataset=dataset)\n",
1314
    "    else: assert False\n",
1315
    "    val_ds.setFeats(0)\n",
1316
    "    \n",
1317
    "    if WEIGHTED:\n",
1318
    "        print('WeightedRandomSampler')\n",
1319
    "        sampler = D.sampler.WeightedRandomSampler(trn_ds.weights, len(trn_ds))\n",
1320
    "        loader = D.DataLoader(trn_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
1321
    "                              shuffle=False, drop_last=True, sampler=sampler)\n",
1322
    "    else:\n",
1323
    "        loader = D.DataLoader(trn_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
1324
    "                              shuffle=True, drop_last=True)\n",
1325
    "    loader_val = D.DataLoader(val_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
1326
    "                              shuffle=True)\n",
1327
    "    print('dataset train:', len(trn_ds), 'valid:', len(val_ds), 'loader train:', len(loader), 'valid:', len(loader_val))\n",
1328
    "    \n",
1329
    "    #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz, fc_drop_p=0)\n",
1330
    "    model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
1331
    "    \n",
1332
    "    model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n",
1333
    "    \n",
1334
    "    if (model_file_name is None) and (init_ver is not None):\n",
1335
    "        model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset, ver=init_ver)\n",
1336
    "    \n",
1337
    "    if model_file_name is not None:\n",
1338
    "        print('loading model', model_file_name)\n",
1339
    "        state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n",
1340
    "        model.load_state_dict(state_dict)\n",
1341
    "    else:\n",
1342
    "        print('starting from scratch')\n",
1343
    "    \n",
1344
    "    if (not CLOUD) or CLOUD_SINGLE:\n",
1345
    "        model = model.to(device)\n",
1346
    "    else:\n",
1347
    "        model_parallel = dp.DataParallel(model, device_ids=devices)\n",
1348
    "    \n",
1349
    "    loc_data = val_ds.metadata.copy()\n",
1350
    "    \n",
1351
    "    if DATA_SMALL:\n",
1352
    "        val_sz = len(val_ds)\n",
1353
    "        val_series = val_ds.series[:val_sz]\n",
1354
    "        loc_data = loc_data.loc[loc_data.index.isin(val_series)]\n",
1355
    "    \n",
1356
    "    series_counts = loc_data.index.value_counts()\n",
1357
    "    \n",
1358
    "    loc_data['orig_idx'] = np.arange(len(loc_data))\n",
1359
    "    loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n",
1360
    "    loc_data['my_order'] = np.arange(len(loc_data))\n",
1361
    "    loc_data = loc_data.sort_values(['orig_idx'])\n",
1362
    "    \n",
1363
    "    ww_val = loc_data.weights\n",
1364
    "    \n",
1365
    "    for i in range(cur_epoch+1, cur_epoch+epochs+1):\n",
1366
    "        st = time.time()\n",
1367
    "        \n",
1368
    "        #trn_ds.setFeats((i-1) % 4)\n",
1369
    "        trn_ds.setFeats(-1, epoch=i)\n",
1370
    "        \n",
1371
    "        if CLOUD and (not CLOUD_SINGLE):\n",
1372
    "            results = model_parallel(train_loop_fn, loader)\n",
1373
    "            tloss, tloss_count = np.stack(results).sum(0)\n",
1374
    "            state_dict = model_parallel._models[0].state_dict()\n",
1375
    "        else:\n",
1376
    "            tloss, tloss_count = train_loop_fn(model, loader, device)\n",
1377
    "            state_dict = model.state_dict()\n",
1378
    "        \n",
1379
    "        state_dict = {k:v.to('cpu') for k,v in state_dict.items()}\n",
1380
    "        tr_ll = tloss / tloss_count\n",
1381
    "        \n",
1382
    "        train_time = time.time()-st\n",
1383
    "        \n",
1384
    "        model_file_name = modelFileName(return_next=True, fold=fold, dataset=dataset)\n",
1385
    "        if not DATA_SMALL:\n",
1386
    "            torch.save(state_dict, PATH_DISK/'models'/model_file_name)\n",
1387
    "        \n",
1388
    "        st = time.time()\n",
1389
    "        if CLOUD and (not CLOUD_SINGLE):\n",
1390
    "            results = model_parallel(val_loop_fn, loader_val)\n",
1391
    "            predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n",
1392
    "            indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n",
1393
    "            offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n",
1394
    "        else:\n",
1395
    "            predictions, indices, offsets = val_loop_fn(model, loader_val, device)\n",
1396
    "        \n",
1397
    "        predictions = predictions[np.argsort(indices)]\n",
1398
    "        offsets = offsets[np.argsort(indices)]\n",
1399
    "        assert len(predictions) == len(loc_data.index.unique())\n",
1400
    "        assert len(predictions) == len(offsets)\n",
1401
    "        assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n",
1402
    "        \n",
1403
    "        #val_results = np.zeros((len(loc_data),6))\n",
1404
    "        val_results = []\n",
1405
    "        for k, series in enumerate(np.sort(loc_data.index.unique())):\n",
1406
    "            cnt = series_counts[series]\n",
1407
    "            #mask = loc_data.SeriesInstanceUID == series\n",
1408
    "            assert (offsets[k] + cnt) <= 60\n",
1409
    "            #val_results[mask] = predictions[k,offsets[k]:(offsets[k] + cnt)]\n",
1410
    "            val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n",
1411
    "        \n",
1412
    "        val_results = np.concatenate(val_results)\n",
1413
    "        assert np.isnan(val_results).sum() == 0\n",
1414
    "        val_results = val_results[loc_data.my_order]\n",
1415
    "        assert np.isnan(val_results).sum() == 0\n",
1416
    "        assert len(val_results) == len(loc_data)\n",
1417
    "        \n",
1418
    "        lls = [log_loss(loc_data.loc[~loc_data.test, all_ich[k]].values, val_results[~loc_data.test,k], \\\n",
1419
    "                        eps=1e-7, labels=[0,1]) for k in range(6)]\n",
1420
    "        ll = (class_weights * np.array(lls)).mean()\n",
1421
    "        cor = np.corrcoef(loc_data.loc[~loc_data.test,all_ich].values.reshape(-1), \\\n",
1422
    "                          val_results[~loc_data.test].reshape(-1))[0,1]\n",
1423
    "        auc = roc_auc_score(loc_data.loc[~loc_data.test, all_ich].values.reshape(-1), \\\n",
1424
    "                            val_results[~loc_data.test].reshape(-1))\n",
1425
    "        \n",
1426
    "        lls_w = [log_loss(loc_data.loc[~loc_data.test, all_ich[k]].values, val_results[~loc_data.test,k], eps=1e-7, \n",
1427
    "                          labels=[0,1], sample_weight=ww_val[~loc_data.test]) for k in range(6)]\n",
1428
    "        ll_w = (class_weights * np.array(lls_w)).mean()\n",
1429
    "        \n",
1430
    "        if TRAIN_ON_STAGE_1:\n",
1431
    "            lls = [log_loss(loc_data.loc[loc_data.test, all_ich[k]].values, val_results[loc_data.test,k], \\\n",
1432
    "                            eps=1e-7, labels=[0,1]) for k in range(6)]\n",
1433
    "            ll2 = (class_weights * np.array(lls)).mean()\n",
1434
    "\n",
1435
    "            lls_w = [log_loss(loc_data.loc[loc_data.test, all_ich[k]].values, val_results[loc_data.test,k], eps=1e-7, \n",
1436
    "                              labels=[0,1], sample_weight=ww_val[loc_data.test]) for k in range(6)]\n",
1437
    "            ll2_w = (class_weights * np.array(lls_w)).mean()\n",
1438
    "        else:\n",
1439
    "            ll2 = 0\n",
1440
    "            ll2_w = 0\n",
1441
    "        \n",
1442
    "        print('v{}, d{}, e{}, f{}, trn ll: {:.4f}, val ll: {:.4f}, ll_w: {:.4f}, cor: {:.4f}, auc: {:.4f}, lr: {}'\\\n",
1443
    "              .format(VERSION, dataset, i, fold, tr_ll, ll, ll_w, cor, auc, learning_rate))\n",
1444
    "        valid_time = time.time()-st\n",
1445
    "        \n",
1446
    "        epoch_stats = pd.DataFrame([[VERSION, dataset, i, fold, tr_ll, ll, ll_w, ll2, ll2_w, cor, \n",
1447
    "                                     lls[0], lls[1], lls[2], lls[3], \n",
1448
    "                                     lls[4], lls[5], len(trn_ds), len(val_ds), bs, train_time, valid_time,\n",
1449
    "                                     learning_rate, weight_decay]],\n",
1450
    "                                   columns = \n",
1451
    "                                    ['ver','dataset','epoch','fold','train_loss','val_loss','val_w_loss',\n",
1452
    "                                     'val_loss2','val_w_loss2','cor',\n",
1453
    "                                     'any','epidural','intraparenchymal','intraventricular','subarachnoid','subdural',\n",
1454
    "                                     'train_sz','val_sz','bs','train_time','valid_time','lr','wd'\n",
1455
    "                                     ])\n",
1456
    "\n",
1457
    "        stats_filename = PATH_WORK/'stats.f{}.v{}'.format(fold,VERSION)\n",
1458
    "        if stats_filename.is_file():\n",
1459
    "            epoch_stats = pd.concat([pd.read_csv(stats_filename), epoch_stats], sort=False)\n",
1460
    "        #if not DATA_SMALL:\n",
1461
    "        epoch_stats.to_csv(stats_filename, index=False)\n",
1462
    "    \n",
1463
    "    print('total running time', time.time() - st0)\n",
1464
    "    \n",
1465
    "    return model, predictions, val_results"
1466
   ]
1467
  },
1468
  {
1469
   "cell_type": "markdown",
1470
   "metadata": {},
1471
   "source": [
1472
    "# OOF"
1473
   ]
1474
  },
1475
  {
1476
   "cell_type": "code",
1477
   "execution_count": 100,
1478
   "metadata": {},
1479
   "outputs": [],
1480
   "source": [
1481
    "def oof_one(dataset, num_iter=1, bs=100, fold=0):\n",
1482
    "    \n",
1483
    "    st0 = time.time()\n",
1484
    "    dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
1485
    "    \n",
1486
    "    cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n",
1487
    "    if cur_epoch is None: cur_epoch = 0\n",
1488
    "    print('completed epochs:', cur_epoch, 'iters starting now:', num_iter)\n",
1489
    "    \n",
1490
    "    setSeeds(SEED + cur_epoch)\n",
1491
    "    \n",
1492
    "    val_ds = RSNA_DataSet(train_md, mode='valid', bs=bs, fold=fold, dataset=dataset)\n",
1493
    "    \n",
1494
    "    loader_val = D.DataLoader(val_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
1495
    "                              shuffle=True)\n",
1496
    "    print('dataset valid:', len(val_ds), 'loader valid:', len(loader_val))\n",
1497
    "    \n",
1498
    "    #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz, fc_drop_p=0)\n",
1499
    "    model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
1500
    "    \n",
1501
    "    model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n",
1502
    "    if model_file_name is not None:\n",
1503
    "        print('loading model', model_file_name)\n",
1504
    "        state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n",
1505
    "        model.load_state_dict(state_dict)\n",
1506
    "    else:\n",
1507
    "        print('starting from scratch')\n",
1508
    "    \n",
1509
    "    if (not CLOUD) or CLOUD_SINGLE:\n",
1510
    "        model = model.to(device)\n",
1511
    "    else:\n",
1512
    "        model_parallel = dp.DataParallel(model, device_ids=devices)\n",
1513
    "    \n",
1514
    "    loc_data = val_ds.metadata.copy()\n",
1515
    "    series_counts = loc_data.index.value_counts()\n",
1516
    "    \n",
1517
    "    loc_data['orig_idx'] = np.arange(len(loc_data))\n",
1518
    "    loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n",
1519
    "    loc_data['my_order'] = np.arange(len(loc_data))\n",
1520
    "    loc_data = loc_data.sort_values(['orig_idx'])\n",
1521
    "    \n",
1522
    "    preds = []\n",
1523
    "    \n",
1524
    "    for i in range(num_iter):\n",
1525
    "        \n",
1526
    "        val_ds.setFeats(-1, i+100)\n",
1527
    "        \n",
1528
    "        if CLOUD and (not CLOUD_SINGLE):\n",
1529
    "            results = model_parallel(val_loop_fn, loader_val)\n",
1530
    "            predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n",
1531
    "            indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n",
1532
    "            offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n",
1533
    "        else:\n",
1534
    "            predictions, indices, offsets = val_loop_fn(model, loader_val, device)\n",
1535
    "\n",
1536
    "        predictions = predictions[np.argsort(indices)]\n",
1537
    "        offsets = offsets[np.argsort(indices)]\n",
1538
    "        assert len(predictions) == len(loc_data.index.unique())\n",
1539
    "        assert len(predictions) == len(offsets)\n",
1540
    "        assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n",
1541
    "\n",
1542
    "        val_results = []\n",
1543
    "        for k, series in enumerate(np.sort(loc_data.index.unique())):\n",
1544
    "            cnt = series_counts[series]\n",
1545
    "            assert (offsets[k] + cnt) <= 60\n",
1546
    "            val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n",
1547
    "\n",
1548
    "        val_results = np.concatenate(val_results)\n",
1549
    "        assert np.isnan(val_results).sum() == 0\n",
1550
    "        val_results = val_results[loc_data.my_order]\n",
1551
    "        assert np.isnan(val_results).sum() == 0\n",
1552
    "        assert len(val_results) == len(loc_data)\n",
1553
    "        \n",
1554
    "        preds.append(val_results)\n",
1555
    "\n",
1556
    "        lls = [log_loss(loc_data[all_ich[k]].values, val_results[:,k], eps=1e-7, labels=[0,1])\\\n",
1557
    "               for k in range(6)]\n",
1558
    "        ll = (class_weights * np.array(lls)).mean()\n",
1559
    "        cor = np.corrcoef(loc_data.loc[:,all_ich].values.reshape(-1), val_results.reshape(-1))[0,1]\n",
1560
    "        auc = roc_auc_score(loc_data.loc[:,all_ich].values.reshape(-1), val_results.reshape(-1))\n",
1561
    "\n",
1562
    "        print('ver {}, iter {}, fold {}, val ll: {:.4f}, cor: {:.4f}, auc: {:.4f}'\n",
1563
    "              .format(VERSION, i, fold, ll, cor, auc))\n",
1564
    "    \n",
1565
    "    print('total running time', time.time() - st0)\n",
1566
    "    \n",
1567
    "    return np.stack(preds)"
1568
   ]
1569
  },
1570
  {
1571
   "cell_type": "code",
1572
   "execution_count": null,
1573
   "metadata": {},
1574
   "outputs": [],
1575
   "source": [
1576
    "def predBounding(pp, target=None):\n",
1577
    "    if target is not None:\n",
1578
    "        ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n",
1579
    "            * class_weights).mean()\n",
1580
    "        print('initial score', ll)\n",
1581
    "    \n",
1582
    "    print('any too low inconsistencies')\n",
1583
    "    for i in range(1,6):\n",
1584
    "        print(i, 'class:', (pp[...,0] < pp[...,i]).mean())\n",
1585
    "    print('total', (pp[...,0] < pp[...,1:].max(-1)).mean())\n",
1586
    "    \n",
1587
    "    max_vals = pp[...,1:].max(-1)\n",
1588
    "    mask = pp[...,0] < max_vals\n",
1589
    "    pp[mask,0] = max_vals[mask]\n",
1590
    "    #mask_vals = 0.5*(preds_all[:,:,:,0] + max_vals)[mask]\n",
1591
    "    #preds_all[mask,0] = mask_vals\n",
1592
    "    #preds_all[mask] = np.clip(preds_all[mask],0,np.expand_dims(mask_vals,1))\n",
1593
    "\n",
1594
    "    assert (pp[...,0] < pp[...,1:].max(-1)).sum() == 0\n",
1595
    "\n",
1596
    "    if target is not None:\n",
1597
    "        ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n",
1598
    "            * class_weights).mean()\n",
1599
    "        print('any too low corrected score', ll)\n",
1600
    "    \n",
1601
    "    print('any too high inconsistencies')\n",
1602
    "    mask = pp[...,0] > pp[...,1:].sum(-1)\n",
1603
    "    print('total', mask.mean())\n",
1604
    "\n",
1605
    "    mask_val = 0.5*(pp[mask,0] + pp[...,1:].sum(-1)[mask])\n",
1606
    "    scaler = mask_val / pp[...,1:].sum(-1)[mask]\n",
1607
    "    pp[mask,1:] = pp[mask,1:] * np.expand_dims(scaler,1)\n",
1608
    "    pp[mask,0] = mask_val\n",
1609
    "\n",
1610
    "    if target is not None:\n",
1611
    "        ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n",
1612
    "            * class_weights).mean()\n",
1613
    "        print('any too high corrected score', ll)\n",
1614
    "    \n",
1615
    "    return pp"
1616
   ]
1617
  },
1618
  {
1619
   "cell_type": "markdown",
1620
   "metadata": {},
1621
   "source": [
1622
    "## Selecting runs aggregation"
1623
   ]
1624
  },
1625
  {
1626
   "cell_type": "code",
1627
   "execution_count": 26,
1628
   "metadata": {},
1629
   "outputs": [],
1630
   "source": [
1631
    "def scalePreds(x, power = 2, center=0.5):\n",
1632
    "    res = x.copy()\n",
1633
    "    res[x > center] = center + (1 - center) * ((res[x > center] - center)/(1 - center))**power\n",
1634
    "    res[x < center] = center - center * ((center - res[x < center])/center)**power\n",
1635
    "    return res\n",
1636
    "\n",
1637
    "def getPredsOOF(ver,aug=32,datasets=range(6,11),datasets5=range(11,13)):\n",
1638
    "    preds_all = np.zeros((len(datasets) + len(datasets5),aug,len(train_md),6))\n",
1639
    "    \n",
1640
    "    if len(datasets) > 0:\n",
1641
    "        for fold in range(3):\n",
1642
    "            preds = np.stack([pickle.load(open(PATH_DISK/'ensemble/oof_d{}_f{}_v{}'\n",
1643
    "                                               .format(ds, fold, ver),'rb')) for ds in datasets])\n",
1644
    "            preds_all[:len(datasets),:,train_md.fold == fold,:] = preds\n",
1645
    "    \n",
1646
    "    if len(datasets5) > 0:\n",
1647
    "        for fold in range(5):\n",
1648
    "            preds = np.stack([pickle.load(open(PATH_DISK/'ensemble/oof_d{}_f{}_v{}'\n",
1649
    "                                               .format(ds, fold, ver),'rb')) for ds in datasets5])\n",
1650
    "            preds_all[len(datasets):,:,train_md.fold5 == fold,:] = preds\n",
1651
    "    \n",
1652
    "    preds_all = np.clip(preds_all, 1e-15, 1-1e-15)\n",
1653
    "    return preds_all"
1654
   ]
1655
  },
1656
  {
1657
   "cell_type": "code",
1658
   "execution_count": null,
1659
   "metadata": {},
1660
   "outputs": [],
1661
   "source": [
1662
    "def getYuvalOOF(names, train_md, names5=None, aug=32):\n",
1663
    "    \n",
1664
    "    for num_folds in [3,5]:\n",
1665
    "        if ('yuval_idx' not in train_md.columns) or ('yuval_idx5' not in train_md.columns):\n",
1666
    "            print('adding yuval_idx')\n",
1667
    "            fn = 'OOF_validation_image_ids'\n",
1668
    "            if num_folds == 5:\n",
1669
    "                fn += '_5_stage2'\n",
1670
    "            if num_folds == 3:\n",
1671
    "                fn += '_3_stage2'\n",
1672
    "            yuval_oof = pickle.load(open(PATH_DISK/'yuval/OOF_stage2/{}.pkl'.format(fn),'rb'))\n",
1673
    "            \n",
1674
    "            df2 = pd.DataFrame()\n",
1675
    "            col_name = 'yuval_idx'\n",
1676
    "            fold_col = 'fold'\n",
1677
    "            if num_folds == 5:\n",
1678
    "                col_name += '5'\n",
1679
    "                fold_col += '5'\n",
1680
    "            \n",
1681
    "            for fold in range(num_folds):\n",
1682
    "                \n",
1683
    "                assert np.all(train_md.loc[train_md.img_id.isin(yuval_oof[fold])][fold_col].values == fold)\n",
1684
    "                \n",
1685
    "                df = pd.DataFrame(np.arange(len(yuval_oof[fold])), columns=[col_name])\n",
1686
    "                df.index = yuval_oof[fold]\n",
1687
    "                df2 = pd.concat([df2,df],sort=False)\n",
1688
    "            \n",
1689
    "            train_md = train_md.join(df2, on = 'img_id')\n",
1690
    "            assert train_md[col_name].isnull().sum() == 0\n",
1691
    "    \n",
1692
    "    preds_y = np.zeros((len(names),len(train_md),6))\n",
1693
    "    \n",
1694
    "    for fold in range(3):\n",
1695
    "        preds = np.stack([torch.sigmoid(torch.stack(\n",
1696
    "            pickle.load(open(PATH_DISK/'yuval/OOF_stage2'/name.format(fold),'rb')))).numpy() for name in names])\n",
1697
    "        preds = preds[:,:,train_md.loc[train_md.fold == fold, 'yuval_idx']]\n",
1698
    "        \n",
1699
    "        preds_y[:,train_md.fold == fold] = preds.mean(1)\n",
1700
    "        \n",
1701
    "    if names5 is not None:\n",
1702
    "        preds_y5 = np.zeros((len(names5),len(train_md),6))\n",
1703
    "        for fold in range(5):\n",
1704
    "            preds = np.stack([torch.sigmoid(torch.stack(\n",
1705
    "                pickle.load(open(PATH_DISK/'yuval/OOF_stage2'/name.format(fold),'rb')))).numpy() for name in names5])\n",
1706
    "            preds = preds[:,:,train_md.loc[train_md.fold5 == fold, 'yuval_idx5']]\n",
1707
    "\n",
1708
    "            preds_y5[:,train_md.fold5 == fold] = preds.mean(1)\n",
1709
    "        \n",
1710
    "        preds_y = np.concatenate([preds_y, preds_y5], axis=0)\n",
1711
    "    \n",
1712
    "    \n",
1713
    "    preds_y = np.clip(preds_y, 1e-15, 1-1e-15)[:,:,np.array([5,0,1,2,3,4])]\n",
1714
    "    return preds_y"
1715
   ]
1716
  },
1717
  {
1718
   "cell_type": "markdown",
1719
   "metadata": {},
1720
   "source": [
1721
    "# Inference"
1722
   ]
1723
  },
1724
  {
1725
   "cell_type": "code",
1726
   "execution_count": 102,
1727
   "metadata": {},
1728
   "outputs": [],
1729
   "source": [
1730
    "def inference_one(dataset, bs = 100, add_seed = 0, fold = 0, anum = 0):\n",
1731
    "    \n",
1732
    "    st = time.time()\n",
1733
    "    dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
1734
    "\n",
1735
    "    cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n",
1736
    "    if cur_epoch is None: cur_epoch = 0\n",
1737
    "    print('completed epochs:', cur_epoch)\n",
1738
    "\n",
1739
    "    #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
1740
    "    model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
1741
    "    \n",
1742
    "    model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n",
1743
    "    if model_file_name is not None:\n",
1744
    "        print('loading model', model_file_name)\n",
1745
    "        state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n",
1746
    "        model.load_state_dict(state_dict)\n",
1747
    "    \n",
1748
    "    if (not CLOUD) or CLOUD_SINGLE:\n",
1749
    "        model = model.to(device)\n",
1750
    "    else:\n",
1751
    "        model_parallel = dp.DataParallel(model, device_ids=devices)\n",
1752
    "\n",
1753
    "    setSeeds(SEED + cur_epoch + anum + 100*fold)\n",
1754
    "\n",
1755
    "    tst_ds = RSNA_DataSet(test_md, mode='test', bs=bs, fold=fold, dataset=dataset)\n",
1756
    "    loader_tst = D.DataLoader(tst_ds, num_workers=NUM_WORKERS, batch_size=bs, shuffle=False)\n",
1757
    "    print('dataset test:', len(tst_ds), 'loader test:', len(loader_tst), 'anum:', anum)\n",
1758
    "    \n",
1759
    "    tst_ds.setFeats(-1, epoch = anum+100)\n",
1760
    "\n",
1761
    "    loc_data = tst_ds.metadata.copy()\n",
1762
    "    series_counts = loc_data.index.value_counts()\n",
1763
    "\n",
1764
    "    loc_data['orig_idx'] = np.arange(len(loc_data))\n",
1765
    "    loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n",
1766
    "    loc_data['my_order'] = np.arange(len(loc_data))\n",
1767
    "    loc_data = loc_data.sort_values(['orig_idx'])\n",
1768
    "    \n",
1769
    "    if CLOUD and (not CLOUD_SINGLE):\n",
1770
    "        results = model_parallel(test_loop_fn, loader_tst)\n",
1771
    "        predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n",
1772
    "        indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n",
1773
    "        offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n",
1774
    "    else:\n",
1775
    "        predictions, indices, offsets = test_loop_fn(model, loader_tst, device)\n",
1776
    "\n",
1777
    "    predictions = predictions[np.argsort(indices)]\n",
1778
    "    offsets = offsets[np.argsort(indices)]\n",
1779
    "    assert len(predictions) == len(test_md.SeriesInstanceUID.unique())\n",
1780
    "    assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n",
1781
    "    \n",
1782
    "    val_results = []\n",
1783
    "    for k, series in enumerate(np.sort(loc_data.index.unique())):\n",
1784
    "        cnt = series_counts[series]\n",
1785
    "        assert (offsets[k] + cnt) <= 60\n",
1786
    "        val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n",
1787
    "\n",
1788
    "    val_results = np.concatenate(val_results)\n",
1789
    "    assert np.isnan(val_results).sum() == 0\n",
1790
    "    val_results = val_results[loc_data.my_order]\n",
1791
    "    assert len(val_results) == len(loc_data)\n",
1792
    "\n",
1793
    "    print('test processing time:', time.time() - st)\n",
1794
    "    \n",
1795
    "    return val_results"
1796
   ]
1797
  },
1798
  {
1799
   "cell_type": "markdown",
1800
   "metadata": {},
1801
   "source": [
1802
    "# Ensembling"
1803
   ]
1804
  },
1805
  {
1806
   "cell_type": "code",
1807
   "execution_count": 1,
1808
   "metadata": {
1809
    "scrolled": true
1810
   },
1811
   "outputs": [],
1812
   "source": [
1813
    "def getStepX(train_md, preds_all, fold=0, target=0, mode='train'):\n",
1814
    "    \n",
1815
    "    print('my_len', my_len)\n",
1816
    "    X = np.stack([preds_all[:my_len,:,target].mean(0), \n",
1817
    "                  preds_all[my_len:,:,target].mean(0)], axis=0)\n",
1818
    "    \n",
1819
    "    if mode == 'train':\n",
1820
    "        X = X[:,train_md.fold != fold]\n",
1821
    "        y = train_md.loc[train_md.fold != fold, all_ich[target]].values\n",
1822
    "        ww = train_md.loc[train_md.fold != fold, 'weights'].values\n",
1823
    "    elif mode == 'valid':\n",
1824
    "        X = X[:,train_md.fold == fold]\n",
1825
    "        y = train_md.loc[train_md.fold == fold, all_ich[target]].values\n",
1826
    "        ww = train_md.loc[train_md.fold == fold, 'weights'].values\n",
1827
    "    else:\n",
1828
    "        X = X\n",
1829
    "        y = None\n",
1830
    "    \n",
1831
    "    ll = None\n",
1832
    "    llw = None\n",
1833
    "    auc = None\n",
1834
    "    if y is not None:\n",
1835
    "        ll = log_loss(y, X.mean(0), eps=1e-7, labels=[0,1])\n",
1836
    "        llw = log_loss(y, X.mean(0), eps=1e-7, labels=[0,1], sample_weight=ww)\n",
1837
    "        auc = roc_auc_score(y, X.mean(0))\n",
1838
    "    \n",
1839
    "    return X, y, ww, ll, llw, auc\n",
1840
    "\n",
1841
    "\n",
1842
    "def train_ensemble(train_md, preds_all, fold = 0, target = 0, weighted = False):\n",
1843
    "    \n",
1844
    "    print('starting fold',fold,'target',target)\n",
1845
    "    \n",
1846
    "    st = time.time()\n",
1847
    "    \n",
1848
    "    limit_low = 1e-15\n",
1849
    "    limit_high = 1 - 1e-5\n",
1850
    "    \n",
1851
    "    prior = train_md.loc[train_md.fold != fold, all_ich[target]].mean()\n",
1852
    "    \n",
1853
    "    def my_objective(x,preds,vals,ww):\n",
1854
    "        preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n",
1855
    "        res = np.average(- vals * np.log(preds_sum) - (1 - vals) * np.log(1 - preds_sum), weights=ww)\n",
1856
    "        #print('x   ',x, x.sum())\n",
1857
    "        print('obj ',res)\n",
1858
    "        return res\n",
1859
    "\n",
1860
    "    def my_grad(x,preds,vals,ww):\n",
1861
    "        preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n",
1862
    "        res = np.average(- vals * preds / preds_sum + (1 - vals) * preds / (1 - preds_sum), weights=ww, axis=1)\n",
1863
    "        #print('grad',res)\n",
1864
    "        return res\n",
1865
    "\n",
1866
    "    def my_hess(x,preds,vals,ww):\n",
1867
    "        preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n",
1868
    "        res = np.average(preds * np.expand_dims(preds, axis=1) * \n",
1869
    "                         (vals / preds_sum**2 + (1 - vals) / (1 - preds_sum)**2), weights=ww, axis=2)\n",
1870
    "        return res\n",
1871
    "    \n",
1872
    "    X,y,ww,ll_train,llw_train,auc_train =  getStepX(train_md, preds_all, fold=fold, target=target)\n",
1873
    "    \n",
1874
    "    bnds_low = np.zeros(X.shape[0])\n",
1875
    "    bnds_high = np.ones(X.shape[0])\n",
1876
    "    \n",
1877
    "    initial_sol = np.ones(X.shape[0])/X.shape[0]\n",
1878
    "    \n",
1879
    "    bnds = sp.optimize.Bounds(bnds_low, bnds_high)\n",
1880
    "    cons = sp.optimize.LinearConstraint(np.ones((1,X.shape[0])), 0.98, 1.00)\n",
1881
    "    \n",
1882
    "    if weighted:\n",
1883
    "        ww_in = ww\n",
1884
    "    else:\n",
1885
    "        ww_in = np.ones(len(y))\n",
1886
    "    \n",
1887
    "    model = sp.optimize.minimize(my_objective, initial_sol, jac=my_grad, hess=my_hess, args=(X, y, ww_in),\n",
1888
    "                                 bounds=bnds, method='trust-constr', constraints=cons,\n",
1889
    "                                 options={'gtol': 1e-11, 'initial_tr_radius': 0.1, 'initial_barrier_parameter': 0.01})\n",
1890
    "    \n",
1891
    "    pickle.dump(model, open(PATH_DISK/'ensemble'/'model.f{}.t{}.v{}'\n",
1892
    "                            .format(fold,target,VERSION),'wb'))\n",
1893
    "\n",
1894
    "    print('model', model.x, 'sum', model.x.sum())\n",
1895
    "    \n",
1896
    "    train_preds = (X*np.expand_dims(model.x, axis=1)).sum(0)\n",
1897
    "    ll_train2 = log_loss(y, train_preds, eps=1e-7, labels=[0,1])\n",
1898
    "    llw_train2 = log_loss(y, train_preds, eps=1e-7, labels=[0,1], sample_weight=ww)\n",
1899
    "    auc_train2 = roc_auc_score(y, train_preds)\n",
1900
    "    \n",
1901
    "    X,y,ww,ll_val,llw_val,auc_val =  getStepX(train_md, preds_all, fold=fold, target=target, mode='valid')\n",
1902
    "    \n",
1903
    "    val_preds = (X*np.expand_dims(model.x, axis=1)).sum(0)\n",
1904
    "    \n",
1905
    "    ll_val2 = log_loss(y, val_preds, eps=1e-7, labels=[0,1])\n",
1906
    "    llw_val2 = log_loss(y, val_preds, eps=1e-7, labels=[0,1], sample_weight=ww)\n",
1907
    "    auc_val2 = roc_auc_score(y, val_preds)\n",
1908
    "    \n",
1909
    "    print('v{} f{} t{}: original ll {:.4f}/{:.4f}, ensemble ll {:.4f}/{:.4f}'\n",
1910
    "          .format(VERSION,fold,target,ll_val,llw_val,ll_val2,llw_val2))\n",
1911
    "    \n",
1912
    "    run_time = time.time() - st\n",
1913
    "    print('running time', run_time)\n",
1914
    "    \n",
1915
    "    stats = pd.DataFrame([[VERSION,fold,target,\n",
1916
    "                           ll_train,auc_train,ll_train2,auc_train2,\n",
1917
    "                           ll_val,auc_val,ll_val2,auc_val2,\n",
1918
    "                           llw_train,llw_train2,llw_val,llw_val2,\n",
1919
    "                           run_time,weighted]],\n",
1920
    "                           columns = \n",
1921
    "                            ['version','fold','target',\n",
1922
    "                             'train_loss','train_auc','train_loss_ens','train_auc_ens', \n",
1923
    "                             'valid_loss','valid_auc','valid_loss_ens','valid_auc_ens',\n",
1924
    "                             'train_w_loss','train_w_loss_ens','valid_w_loss','valid_w_loss_ens',\n",
1925
    "                             'run_time','weighted'\n",
1926
    "                             ])\n",
1927
    "    \n",
1928
    "    stats_filename = PATH_DISK/'ensemble'/'stats.v{}'.format(VERSION)\n",
1929
    "    if stats_filename.is_file():\n",
1930
    "        stats = pd.concat([pd.read_csv(stats_filename), stats], sort=False)\n",
1931
    "    stats.to_csv(stats_filename, index=False)\n",
1932
    "\n",
1933
    "#model.cols = Xt.columns\n",
1934
    "#predictions[data_filt['fold'] == i] = (Xv*model.x).sum(1)"
1935
   ]
1936
  },
1937
  {
1938
   "cell_type": "code",
1939
   "execution_count": null,
1940
   "metadata": {},
1941
   "outputs": [],
1942
   "source": []
1943
  }
1944
 ],
1945
 "metadata": {
1946
  "kernelspec": {
1947
   "display_name": "Python 3",
1948
   "language": "python",
1949
   "name": "python3"
1950
  },
1951
  "language_info": {
1952
   "codemirror_mode": {
1953
    "name": "ipython",
1954
    "version": 3
1955
   },
1956
   "file_extension": ".py",
1957
   "mimetype": "text/x-python",
1958
   "name": "python",
1959
   "nbconvert_exporter": "python",
1960
   "pygments_lexer": "ipython3",
1961
   "version": "3.7.4"
1962
  }
1963
 },
1964
 "nbformat": 4,
1965
 "nbformat_minor": 2
1966
}