Switch to unified view

a b/3a-L1-train-and-generate-predictions-fastai_v1.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 1,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import os\n",
10
    "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\""
11
   ]
12
  },
13
  {
14
   "cell_type": "code",
15
   "execution_count": 2,
16
   "metadata": {},
17
   "outputs": [],
18
   "source": [
19
    "stage = \"stage_2\""
20
   ]
21
  },
22
  {
23
   "cell_type": "code",
24
   "execution_count": 3,
25
   "metadata": {},
26
   "outputs": [],
27
   "source": [
28
    "experiment_name = 'NONAMED'"
29
   ]
30
  },
31
  {
32
   "cell_type": "code",
33
   "execution_count": 4,
34
   "metadata": {},
35
   "outputs": [],
36
   "source": [
37
    "import fastai\n",
38
    "from fastai.vision import *\n",
39
    "from fastai.callbacks import *\n",
40
    "from fastai.distributed import *\n",
41
    "from fastai.utils.mem import *\n",
42
    "import re\n",
43
    "import pydicom\n",
44
    "import pdb\n",
45
    "import pickle\n",
46
    "import torch\n"
47
   ]
48
  },
49
  {
50
   "cell_type": "code",
51
   "execution_count": 5,
52
   "metadata": {},
53
   "outputs": [
54
    {
55
     "name": "stdout",
56
     "output_type": "stream",
57
     "text": [
58
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0386_subdural_loss_fold5_of_5.pth\r\n",
59
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0399_subdural_loss_fold4_of_5.pth\r\n",
60
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0403_subdural_loss_fold3_of_5.pth\r\n",
61
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0406_subdural_loss_fold2_of_5.pth\r\n",
62
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0413_subdural_loss_fold1_of_5.pth\r\n",
63
      "densenet121_sz512_cv0.0738_weighted_loss_fold4_of_5.pth\r\n",
64
      "densenet121_sz512_cv0.0739_weighted_loss_fold5_of_5.pth\r\n",
65
      "densenet121_sz512_cv0.0743_weighted_loss_fold2_of_5.pth\r\n",
66
      "densenet121_sz512_cv0.0748_weighted_loss_fold3_of_5.pth\r\n",
67
      "densenet121_sz512_cv0.0749_weighted_loss_fold1_of_5.pth\r\n",
68
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0365_subdural_loss_fold5_of_5.pth\r\n",
69
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0366_subdural_loss_fold4_of_5.pth\r\n",
70
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0369_subdural_loss_fold2_of_5.pth\r\n",
71
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0371_subdural_loss_fold3_of_5.pth\r\n",
72
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0375_subdural_loss_fold1_of_5.pth\r\n",
73
      "resnet101_sz512_cv0.0749_weighted_loss_fold5_of_5.pth\r\n",
74
      "resnet101_sz512_cv0.0751_weighted_loss_fold4_of_5.pth\r\n",
75
      "resnet101_sz512_cv0.0754_weighted_loss_fold2_of_5.pth\r\n",
76
      "resnet101_sz512_cv0.0758_weighted_loss_fold1_of_5.pth\r\n",
77
      "resnet101_sz512_cv0.0761_weighted_loss_fold3_of_5.pth\r\n",
78
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0522_subdural_loss_fold5_of_5.pth\r\n",
79
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0528_subdural_loss_fold3_of_5.pth\r\n",
80
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0530_subdural_loss_fold4_of_5.pth\r\n",
81
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0550_subdural_loss_fold2_of_5.pth\r\n",
82
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0556_subdural_loss_fold1_of_5.pth\r\n",
83
      "resnet18_sz512_cv0.0762_weighted_loss_fold4_of_5.pth\r\n",
84
      "resnet18_sz512_cv0.0768_weighted_loss_fold5_of_5.pth\r\n",
85
      "resnet18_sz512_cv0.0773_weighted_loss_fold2_of_5.pth\r\n",
86
      "resnet18_sz512_cv0.0786_weighted_loss_fold1_of_5.pth\r\n",
87
      "resnet18_sz512_cv0.0786_weighted_loss_fold3_of_5.pth\r\n",
88
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0453_subdural_loss_fold5_of_5.pth\r\n",
89
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0466_subdural_loss_fold4_of_5.pth\r\n",
90
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0474_subdural_loss_fold3_of_5.pth\r\n",
91
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0484_subdural_loss_fold2_of_5.pth\r\n",
92
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0496_subdural_loss_fold1_of_5.pth\r\n",
93
      "resnet34_sz512_cv0.0750_weighted_loss_fold4_of_5.pth\r\n",
94
      "resnet34_sz512_cv0.0759_weighted_loss_fold5_of_5.pth\r\n",
95
      "resnet34_sz512_cv0.0763_weighted_loss_fold2_of_5.pth\r\n",
96
      "resnet34_sz512_cv0.0765_weighted_loss_fold1_of_5.pth\r\n",
97
      "resnet34_sz512_cv0.0782_weighted_loss_fold3_of_5.pth\r\n",
98
      "resnet50_sz512_cv0.0751_weighted_loss_fold4_of_5.pth\r\n",
99
      "resnet50_sz512_cv0.0751_weighted_loss_fold5_of_5.pth\r\n",
100
      "resnet50_sz512_cv0.0754_weighted_loss_fold2_of_5.pth\r\n",
101
      "resnet50_sz512_cv0.0761_weighted_loss_fold1_of_5.pth\r\n",
102
      "resnet50_sz512_cv0.0764_weighted_loss_fold3_of_5.pth\r\n",
103
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0372_subdural_loss_fold5_of_5.pth\r\n",
104
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0376_subdural_loss_fold4_of_5.pth\r\n",
105
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0378_subdural_loss_fold2_of_5.pth\r\n",
106
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0381_subdural_loss_fold3_of_5.pth\r\n",
107
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0387_subdural_loss_fold1_of_5.pth\r\n"
108
     ]
109
    }
110
   ],
111
   "source": [
112
    "!ls models"
113
   ]
114
  },
115
  {
116
   "cell_type": "code",
117
   "execution_count": 6,
118
   "metadata": {},
119
   "outputs": [],
120
   "source": [
121
    "arch = 'resnet18'\n",
122
    "fold=4\n",
123
    "lr = 1e-3\n",
124
    "\n",
125
    "model_fn = None\n",
126
    "SZ = 512\n",
127
    "n_folds=5\n",
128
    "n_epochs = 15\n",
129
    "n_tta = 10\n",
130
    "\n",
131
    "#model_fn = 'resnet34_sz512_cv0.0821_weighted_loss_fold1_of_5'\n",
132
    "\n",
133
    "if model_fn is not None:\n",
134
    "    model_fn_fold = int(model_fn[-6])-1\n",
135
    "    assert model_fn_fold == fold"
136
   ]
137
  },
138
  {
139
   "cell_type": "code",
140
   "execution_count": 7,
141
   "metadata": {},
142
   "outputs": [],
143
   "source": [
144
    "data_dir = Path('data/unzip')\n",
145
    "train_dir = data_dir / f'{stage}_train_images'\n",
146
    "test_dir = data_dir / f'{stage}_test_images'"
147
   ]
148
  },
149
  {
150
   "cell_type": "code",
151
   "execution_count": 8,
152
   "metadata": {},
153
   "outputs": [
154
    {
155
     "ename": "FileNotFoundError",
156
     "evalue": "[Errno 2] No such file or directory: 'data/stage_2_fn_to_study_ix.pickle'",
157
     "output_type": "error",
158
     "traceback": [
159
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
160
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
161
      "\u001b[0;32m<ipython-input-8-6207314e915b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfn_to_study_ix\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_fn_to_study_ix.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mstudy_ix_to_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_study_ix_to_fn.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mfn_to_labels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m  \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_train_fn_to_labels.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mstudy_to_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_study_to_data.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
162
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/stage_2_fn_to_study_ix.pickle'"
163
     ]
164
    }
165
   ],
166
   "source": [
167
    "fn_to_study_ix = pickle.load(open(f'data/{stage}_fn_to_study_ix.pickle', 'rb'))\n",
168
    "study_ix_to_fn = pickle.load(open(f'data/{stage}_study_ix_to_fn.pickle', 'rb'))\n",
169
    "fn_to_labels = pickle.load(  open(f'data/{stage}_train_fn_to_labels.pickle', 'rb'))\n",
170
    "study_to_data = pickle.load (open(f'data/{stage}_study_to_data.pickle', 'rb'))\n"
171
   ]
172
  },
173
  {
174
   "cell_type": "code",
175
   "execution_count": null,
176
   "metadata": {},
177
   "outputs": [],
178
   "source": [
179
    "# gets a pathlib.Path, returns labels\n",
180
    "labels = [\n",
181
    "    'any0', 'epidural0', 'intraparenchymal0', 'intraventricular0', 'subarachnoid0', 'subdural0',\n",
182
    "    'any1', 'epidural1', 'intraparenchymal1', 'intraventricular1', 'subarachnoid1', 'subdural1',\n",
183
    "    'any2', 'epidural2', 'intraparenchymal2', 'intraventricular2', 'subarachnoid2', 'subdural2',\n",
184
    "]\n",
185
    "\n",
186
    "def get_labels(center_p):\n",
187
    "\n",
188
    "    study, center_ix = fn_to_study_ix[center_p.stem]\n",
189
    "    total = len(study_ix_to_fn[study])\n",
190
    "    \n",
191
    "    # Only 1 study has 2 different RescaleIntercept values (0.0 and 1.0) in the same series, so assume they're all the\n",
192
    "    # same and do everything in the big tensor\n",
193
    "    \n",
194
    "    pixels = {}\n",
195
    "    img_labels = []\n",
196
    "    ixs = [ max(0, min(ix, total-1)) for ix in range(center_ix-1, center_ix+2) ]\n",
197
    "    for i, ix in enumerate(ixs):\n",
198
    "        fn = study_ix_to_fn[study][ix]\n",
199
    "        img_labels += [ x + str(i) for x in fn_to_labels[fn] ]\n",
200
    "            \n",
201
    "    return img_labels\n",
202
    "\n",
203
    "get_labels(train_dir / 'ID_ac7c8fe8b.dcm')"
204
   ]
205
  },
206
  {
207
   "cell_type": "code",
208
   "execution_count": null,
209
   "metadata": {},
210
   "outputs": [],
211
   "source": [
212
    "sample_submission = pd.read_csv(f'data/unzip/{stage}_sample_submission.csv')"
213
   ]
214
  },
215
  {
216
   "cell_type": "code",
217
   "execution_count": null,
218
   "metadata": {},
219
   "outputs": [],
220
   "source": [
221
    "sample_submission['fn'] = sample_submission.ID.apply(lambda x: '_'.join(x.split('_')[:2]) + '.dcm')"
222
   ]
223
  },
224
  {
225
   "cell_type": "code",
226
   "execution_count": null,
227
   "metadata": {},
228
   "outputs": [],
229
   "source": [
230
    "sample_submission.head()"
231
   ]
232
  },
233
  {
234
   "cell_type": "code",
235
   "execution_count": null,
236
   "metadata": {},
237
   "outputs": [],
238
   "source": [
239
    "test_fns = sample_submission.fn.unique()"
240
   ]
241
  },
242
  {
243
   "cell_type": "code",
244
   "execution_count": null,
245
   "metadata": {},
246
   "outputs": [],
247
   "source": [
248
    "wc = 50\n",
249
    "ww = 100"
250
   ]
251
  },
252
  {
253
   "cell_type": "code",
254
   "execution_count": null,
255
   "metadata": {},
256
   "outputs": [],
257
   "source": [
258
    "def open_dicom(p_or_fn): # when called by .from_folder it's a Path; when called from .add_test it's a string\n",
259
    "    #pdb.set_trace()\n",
260
    "    center_p = Path(p_or_fn)\n",
261
    "\n",
262
    "    study, center_ix = fn_to_study_ix[center_p.stem]\n",
263
    "    total = len(study_ix_to_fn[study])\n",
264
    "    \n",
265
    "    # Only 1 study has 2 different RescaleIntercept values (0.0 and 1.0) in the same series, so assume they're all the\n",
266
    "    # same and do everything in the big tensor\n",
267
    "    \n",
268
    "    pixels = {}\n",
269
    "    ixs = [ max(0, min(ix, total-1)) for ix in range(center_ix-1, center_ix+2) ]\n",
270
    "    for ix in ixs:\n",
271
    "        if ix not in pixels:\n",
272
    "            p = center_p.parent / f'{study_ix_to_fn[study][ix]}.dcm'\n",
273
    "            dcm = pydicom.dcmread(str(p))\n",
274
    "            if ix == center_ix:\n",
275
    "                rescale_slope, rescale_intercept = float(dcm.RescaleSlope), float(dcm.RescaleIntercept)\n",
276
    "            pixels[ix] = torch.FloatTensor(dcm.pixel_array.astype(np.float))\n",
277
    "        \n",
278
    "    t = torch.stack([pixels[ix] for ix in ixs], dim=0) # stack chans together\n",
279
    "    if (t.shape[1:] != (SZ,SZ)):\n",
280
    "        t = torch.nn.functional.interpolate(\n",
281
    "            t.unsqueeze_(0), size=SZ, mode='bilinear', align_corners=True).squeeze_(0) # resize\n",
282
    "    t = t * rescale_slope + rescale_intercept # rescale\n",
283
    "    t = torch.clamp(t, wc-ww/2, wc+ww/2) # window\n",
284
    "    t = (t - (wc-ww/2)) / ww # normalize\n",
285
    "    \n",
286
    "    return Image(t)\n",
287
    "\n",
288
    "\n",
289
    "class DicomList(ImageList):\n",
290
    "    def open(self, fn): return open_dicom(fn)"
291
   ]
292
  },
293
  {
294
   "cell_type": "code",
295
   "execution_count": null,
296
   "metadata": {},
297
   "outputs": [],
298
   "source": [
299
    "my_stats = ([0.45, 0.45, 0.45], [0.225, 0.225, 0.225])"
300
   ]
301
  },
302
  {
303
   "cell_type": "code",
304
   "execution_count": null,
305
   "metadata": {},
306
   "outputs": [],
307
   "source": [
308
    "class OverSamplingCallback(LearnerCallback):\n",
309
    "    def __init__(self,learn:Learner):\n",
310
    "        super().__init__(learn)\n",
311
    "        self.labels = self.learn.data.train_dl.dataset.y.items\n",
312
    "        _, counts = np.unique(self.labels,return_counts=True)\n",
313
    "        self.weights = torch.DoubleTensor((1/counts)[self.labels])\n",
314
    "        self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])\n",
315
    "        self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))\n",
316
    "        \n",
317
    "    def on_train_begin(self, **kwargs):\n",
318
    "        self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)"
319
   ]
320
  },
321
  {
322
   "cell_type": "code",
323
   "execution_count": null,
324
   "metadata": {},
325
   "outputs": [],
326
   "source": [
327
    "# From: https://forums.fast.ai/t/is-there-any-built-in-method-to-oversample-minority-classes/46838/5\n",
328
    "class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):\n",
329
    "    def __init__(self, dataset, indices=None, num_samples=None):\n",
330
    "                \n",
331
    "        # if indices is not provided, \n",
332
    "        # all elements in the dataset will be considered\n",
333
    "        self.indices = list(range(len(dataset))) if indices is None else indices\n",
334
    "            \n",
335
    "        # if num_samples is not provided, \n",
336
    "        # draw `len(indices)` samples in each iteration\n",
337
    "        self.num_samples = len(self.indices) if num_samples is None else num_samples\n",
338
    "            \n",
339
    "        # distribution of classes in the dataset \n",
340
    "        label_to_count = defaultdict(int)\n",
341
    "        for idx in self.indices:\n",
342
    "            label = self._get_label(dataset, idx)\n",
343
    "            label_to_count[label] += 1\n",
344
    "                \n",
345
    "        # weight for each sample\n",
346
    "        weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices]\n",
347
    "        self.weights = torch.DoubleTensor(weights)\n",
348
    "\n",
349
    "    def _get_label(self, dataset, idx):\n",
350
    "        return 'any1' in dataset.y[idx].obj\n",
351
    "                \n",
352
    "    def __iter__(self):\n",
353
    "        return (self.indices[i] for i in torch.multinomial(\n",
354
    "            self.weights, self.num_samples, replacement=True))\n",
355
    "\n",
356
    "    def __len__(self):\n",
357
    "        return self.num_samples"
358
   ]
359
  },
360
  {
361
   "cell_type": "code",
362
   "execution_count": null,
363
   "metadata": {},
364
   "outputs": [],
365
   "source": [
366
    "folds = np.array_split(list({v['fold'] for v in study_to_data.values()}), n_folds)\n",
367
    "folds"
368
   ]
369
  },
370
  {
371
   "cell_type": "code",
372
   "execution_count": null,
373
   "metadata": {},
374
   "outputs": [],
375
   "source": [
376
    "df = pd.read_csv(f\"data/{stage}_train_dicom_diags_norm.csv\")\n",
377
    "# split train/val\n",
378
    "np.random.seed(666)\n",
379
    "studies_in_fold = [k for k,v in study_to_data.items() if np.isin(v['fold'],folds[fold])]\n",
380
    "study_id_val_set = set(df[ df['SeriesInstanceUID'].isin(studies_in_fold)]['SOPInstanceUID'] + '.dcm')\n",
381
    "\n",
382
    "def is_val(p_or_fn):\n",
383
    "    p = Path(p_or_fn)\n",
384
    "    return p.name in study_id_val_set"
385
   ]
386
  },
387
  {
388
   "cell_type": "code",
389
   "execution_count": null,
390
   "metadata": {},
391
   "outputs": [],
392
   "source": [
393
    "tfms = (\n",
394
    "    [ flip_lr(p=0.5), rotate(degrees=(-180,180), p=1.) ], # train\n",
395
    "    [] # val\n",
396
    ")\n",
397
    "data = (DicomList.from_folder(f'data/unzip/{stage}_train_images/', extensions=['.dcm'], presort=True)\n",
398
    "    .split_by_valid_func(is_val)\n",
399
    "    .label_from_func(get_labels, classes=labels)\n",
400
    "    .transform(tfms, mode='nearest')\n",
401
    "    .add_test(f'data/unzip/{stage}_test_images/' + test_fns))\n"
402
   ]
403
  },
404
  {
405
   "cell_type": "code",
406
   "execution_count": null,
407
   "metadata": {},
408
   "outputs": [],
409
   "source": [
410
    "db = data.databunch().normalize()"
411
   ]
412
  },
413
  {
414
   "cell_type": "code",
415
   "execution_count": null,
416
   "metadata": {},
417
   "outputs": [],
418
   "source": [
419
    "balsampler = ImbalancedDatasetSampler(db.train_ds)"
420
   ]
421
  },
422
  {
423
   "cell_type": "code",
424
   "execution_count": null,
425
   "metadata": {
426
    "scrolled": true
427
   },
428
   "outputs": [],
429
   "source": [
430
    "db.train_dl.batch_sampler = balsampler\n",
431
    "db.train_dl.batch_sampler\n",
432
    "db"
433
   ]
434
  },
435
  {
436
   "cell_type": "code",
437
   "execution_count": null,
438
   "metadata": {},
439
   "outputs": [],
440
   "source": [
441
    "db.show_batch() # AttributeError: 'list' object has no attribute 'pixel' <-- because using torch datasets or whatever"
442
   ]
443
  },
444
  {
445
   "cell_type": "code",
446
   "execution_count": null,
447
   "metadata": {},
448
   "outputs": [],
449
   "source": [
450
    "# yuval reina's loss\n",
451
    "yuval_weights = FloatTensor([ 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1 ]).cuda() \n",
452
    "\n",
453
    "def yuval_loss(y_pred,y_true):\n",
454
    "    return F.binary_cross_entropy_with_logits(y_pred,\n",
455
    "                                  y_true,\n",
456
    "                                  yuval_weights.repeat(y_pred.shape[0],1))"
457
   ]
458
  },
459
  {
460
   "cell_type": "code",
461
   "execution_count": null,
462
   "metadata": {},
463
   "outputs": [],
464
   "source": [
465
    "real_lb_weights = FloatTensor([ 2, 1, 1, 1, 1, 1 ])\n",
466
    "\n",
467
    "def real_lb_loss(pred:Tensor, targ:Tensor)->Rank0Tensor:\n",
468
    "    pred,targ = flatten_check(pred,targ)\n",
469
    "    tp = pred.view(-1,18)[:,6:12]\n",
470
    "    tt = targ.view(-1,18)[:,6:12]\n",
471
    "    return F.binary_cross_entropy_with_logits(tp, tt, real_lb_weights.to(device=pred.device))"
472
   ]
473
  },
474
  {
475
   "cell_type": "code",
476
   "execution_count": null,
477
   "metadata": {},
478
   "outputs": [],
479
   "source": [
480
    "w_loss = 0.1\n",
481
    "weighted_lb_weights = FloatTensor([ 2.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, \n",
482
    "                                    2.,        1.,        1.,        1.,        1.,        1., \n",
483
    "                                    2.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, ])\n",
484
    "\n",
485
    "def weighted_loss(pred:Tensor,targ:Tensor)->Rank0Tensor:\n",
486
    "    return F.binary_cross_entropy_with_logits(pred,targ,weighted_lb_weights.to(device=pred.device))"
487
   ]
488
  },
489
  {
490
   "cell_type": "code",
491
   "execution_count": null,
492
   "metadata": {},
493
   "outputs": [],
494
   "source": [
495
    "lb_weights = FloatTensor([ 2, 1, 1, 1, 1, 1 ]).cuda()\n",
496
    "\n",
497
    "def lb_loss(pred:Tensor, targ:Tensor)->Rank0Tensor:\n",
498
    "    pred,targ = flatten_check(pred,targ)\n",
499
    "    tp = pred.view(-1,6)\n",
500
    "    tt = targ.view(-1,6)\n",
501
    "    return torch.nn.functional.binary_cross_entropy_with_logits(tp, tt, lb_weights.to(device=pred.device))"
502
   ]
503
  },
504
  {
505
   "cell_type": "code",
506
   "execution_count": null,
507
   "metadata": {},
508
   "outputs": [],
509
   "source": [
510
    "learn = cnn_learner(\n",
511
    "    db, getattr(models, arch), loss_func=weighted_loss, metrics=[real_lb_loss],\n",
512
    "    pretrained=True, lin_ftrs=[],ps=0.,\n",
513
    "    model_dir = Path('./models/').resolve())"
514
   ]
515
  },
516
  {
517
   "cell_type": "code",
518
   "execution_count": null,
519
   "metadata": {},
520
   "outputs": [],
521
   "source": [
522
    "try:\n",
523
    "    learn.load(model_fn,strict=False)\n",
524
    "    print(f\"Loaded {model_fn}\")\n",
525
    "except Exception as e:\n",
526
    "    print(e)"
527
   ]
528
  },
529
  {
530
   "cell_type": "code",
531
   "execution_count": null,
532
   "metadata": {},
533
   "outputs": [],
534
   "source": [
535
    "learn = learn.to_parallel().to_fp16()\n",
536
    "\n",
537
    "arch_to_batch_factor = {\n",
538
    "    'resnet18' : 4.3,\n",
539
    "    'resnet34' : 3.4,\n",
540
    "    'resnet50' : 1.5,\n",
541
    "    'resnet101' : 1.0,\n",
542
    "    'resnext50_32x4d': 1.2,\n",
543
    "    'densenet121': 1.1,\n",
544
    "    'squeezenet1_0': 2.6,\n",
545
    "    'vgg16': 1.5\n",
546
    "}\n",
547
    "\n",
548
    "bs = arch_to_batch_factor[arch] * 512*512 * gpu_mem_get()[0] * 1e-3 / (SZ*SZ)\n",
549
    "bs *= int(int(any([isinstance(cb, MixedPrecision) for cb in learn.callbacks])))+1 # 2x if fp16\n",
550
    "if any([isinstance(cb, fastai.distributed.ParallelTrainer) for cb in learn.callbacks]):\n",
551
    "    bs *= torch.cuda.device_count()\n",
552
    "    bs = (bs // torch.cuda.device_count()) * torch.cuda.device_count()\n",
553
    "bs = int(bs)\n",
554
    "learn.data.batch_size = bs\n",
555
    "bs"
556
   ]
557
  },
558
  {
559
   "cell_type": "code",
560
   "execution_count": null,
561
   "metadata": {},
562
   "outputs": [],
563
   "source": [
564
    "learn"
565
   ]
566
  },
567
  {
568
   "cell_type": "code",
569
   "execution_count": null,
570
   "metadata": {},
571
   "outputs": [],
572
   "source": [
573
    "learn.unfreeze()\n",
574
    "learn.summary()"
575
   ]
576
  },
577
  {
578
   "cell_type": "code",
579
   "execution_count": null,
580
   "metadata": {},
581
   "outputs": [],
582
   "source": [
583
    "try:\n",
584
    "    if lr is None:\n",
585
    "        !rm {learn.model_dir}/tmp.pth\n",
586
    "        learn.lr_find()\n",
587
    "        learn.recorder.plot(suggestion=True)\n",
588
    "except Exception as e:\n",
589
    "    print(e)\n",
590
    "# set lr "
591
   ]
592
  },
593
  {
594
   "cell_type": "code",
595
   "execution_count": null,
596
   "metadata": {},
597
   "outputs": [],
598
   "source": [
599
    "learn.fit_one_cycle(n_epochs, lr)"
600
   ]
601
  },
602
  {
603
   "cell_type": "markdown",
604
   "metadata": {},
605
   "source": [
606
    "# Save"
607
   ]
608
  },
609
  {
610
   "cell_type": "code",
611
   "execution_count": null,
612
   "metadata": {},
613
   "outputs": [],
614
   "source": [
615
    "v = learn.validate()\n",
616
    "cv = float(v[-1])"
617
   ]
618
  },
619
  {
620
   "cell_type": "code",
621
   "execution_count": null,
622
   "metadata": {},
623
   "outputs": [],
624
   "source": [
625
    "model_fn = f'{arch}_sz{SZ}_cv{cv:0.4f}_{learn.loss_func.__name__}_fold{fold+1}_of_{n_folds}'\n",
626
    "learn.save(model_fn)\n",
627
    "model_fn"
628
   ]
629
  },
630
  {
631
   "cell_type": "markdown",
632
   "metadata": {},
633
   "source": [
634
    "# Predictions"
635
   ]
636
  },
637
  {
638
   "cell_type": "code",
639
   "execution_count": null,
640
   "metadata": {},
641
   "outputs": [],
642
   "source": [
643
    "learn.data.test_ds.tfms = learn.data.train_ds.tfms\n",
644
    "test_preds = []\n",
645
    "# Fastai WTF: it figures out to run outputs through sigmoid if a standard loss error is used \n",
646
    "# (see loss_func_name2activ and related stuff) but on custom loss funcs sigmoid must be executed explicitly:\n",
647
    "# https://forums.fast.ai/t/shouldnt-we-able-to-pass-an-activ-function-to-learner-get-preds/50492\n",
648
    "for _ in progress_bar(range(n_tta)):\n",
649
    "    preds, _ = learn.get_preds(DatasetType.Test, activ=torch.sigmoid)\n",
650
    "    test_preds.append(preds)\n",
651
    "tta_test_preds = torch.cat([p.unsqueeze(0) for p in test_preds],dim=0)"
652
   ]
653
  },
654
  {
655
   "cell_type": "code",
656
   "execution_count": null,
657
   "metadata": {},
658
   "outputs": [],
659
   "source": [
660
    "learn.data.valid_ds.tfms = learn.data.train_ds.tfms\n",
661
    "valid_preds = []\n",
662
    "# Fastai WTF: it figures out to run outputs through sigmoid if a standard loss error is used \n",
663
    "# (see loss_func_name2activ and related stuff) but on custom loss funcs sigmoid must be executed explicitly:\n",
664
    "# https://forums.fast.ai/t/shouldnt-we-able-to-pass-an-activ-function-to-learner-get-preds/50492\n",
665
    "for _ in progress_bar(range(n_tta)):\n",
666
    "    preds, _ = learn.get_preds(DatasetType.Valid, activ=torch.sigmoid)\n",
667
    "    valid_preds.append(preds)\n",
668
    "tta_valid_preds = torch.cat([p.unsqueeze(0) for p in valid_preds],dim=0)"
669
   ]
670
  },
671
  {
672
   "cell_type": "code",
673
   "execution_count": null,
674
   "metadata": {},
675
   "outputs": [],
676
   "source": [
677
    "PREDS_DIR = 'data/predictions'\n",
678
    "!mkdir -p {PREDS_DIR}"
679
   ]
680
  },
681
  {
682
   "cell_type": "code",
683
   "execution_count": null,
684
   "metadata": {},
685
   "outputs": [],
686
   "source": [
687
    "torch.save(tta_test_preds,  f'{PREDS_DIR}/{model_fn}_test.pth')\n",
688
    "torch.save(tta_valid_preds, f'{PREDS_DIR}/{model_fn}_valid.pth')"
689
   ]
690
  },
691
  {
692
   "cell_type": "code",
693
   "execution_count": null,
694
   "metadata": {},
695
   "outputs": [],
696
   "source": [
697
    "pd.DataFrame({'fn': [p.stem for p in learn.data.valid_ds.x.items]}).to_csv(\n",
698
    "    f'{PREDS_DIR}/{model_fn}_valid_fns.csv',index=False)\n",
699
    "pd.DataFrame({'fn': [Path(p).stem for p in learn.data.test_ds.x.items]}).to_csv(\n",
700
    "    f'{PREDS_DIR}/{model_fn}_test_fns.csv',index=False)"
701
   ]
702
  },
703
  {
704
   "cell_type": "markdown",
705
   "metadata": {},
706
   "source": [
707
    "# Submit to Kaggle"
708
   ]
709
  },
710
  {
711
   "cell_type": "code",
712
   "execution_count": null,
713
   "metadata": {},
714
   "outputs": [],
715
   "source": [
716
    "tta_test_preds_mean = tta_test_preds.mean(dim=0)\n",
717
    "tta_test_preds_geomean = torch.expm1(torch.log1p(tta_test_preds).mean(dim=0))"
718
   ]
719
  },
720
  {
721
   "cell_type": "code",
722
   "execution_count": null,
723
   "metadata": {},
724
   "outputs": [],
725
   "source": [
726
    "ids = []\n",
727
    "labels = []\n",
728
    "\n",
729
    "for fn, pred in zip(test_fns, tta_test_preds_mean):\n",
730
    "    for i, label in enumerate(db.train_ds.classes):\n",
731
    "        if label.endswith('1'):\n",
732
    "            ids.append(f\"{fn.split('.')[0]}_{label.strip('1')}\")\n",
733
    "            predicted_probability = '{0:1.10f}'.format(pred[i].item())\n",
734
    "            labels.append(predicted_probability)"
735
   ]
736
  },
737
  {
738
   "cell_type": "code",
739
   "execution_count": null,
740
   "metadata": {},
741
   "outputs": [],
742
   "source": [
743
    "mkdir -p data/submissions"
744
   ]
745
  },
746
  {
747
   "cell_type": "code",
748
   "execution_count": null,
749
   "metadata": {},
750
   "outputs": [],
751
   "source": [
752
    "sub_name = experiment_name + \"_\" + model_fn\n",
753
    "sub_path = f'data/submissions/{sub_name}.csv.zip'\n",
754
    "sub_name, sub_path"
755
   ]
756
  },
757
  {
758
   "cell_type": "code",
759
   "execution_count": null,
760
   "metadata": {},
761
   "outputs": [],
762
   "source": [
763
    "pd.DataFrame({'ID': ids, 'Label': labels}).to_csv(sub_path, compression='zip', index=False)"
764
   ]
765
  },
766
  {
767
   "cell_type": "code",
768
   "execution_count": null,
769
   "metadata": {},
770
   "outputs": [],
771
   "source": [
772
    "!ls data/submissions/"
773
   ]
774
  },
775
  {
776
   "cell_type": "code",
777
   "execution_count": null,
778
   "metadata": {},
779
   "outputs": [],
780
   "source": [
781
    "!echo {sub_path}"
782
   ]
783
  },
784
  {
785
   "cell_type": "code",
786
   "execution_count": null,
787
   "metadata": {},
788
   "outputs": [],
789
   "source": [
790
    "!kaggle competitions submit -c rsna-intracranial-hemorrhage-detection -f {sub_path} -m {model_fn}"
791
   ]
792
  }
793
 ],
794
 "metadata": {
795
  "kernelspec": {
796
   "display_name": "Python 3",
797
   "language": "python",
798
   "name": "python3"
799
  },
800
  "language_info": {
801
   "codemirror_mode": {
802
    "name": "ipython",
803
    "version": 3
804
   },
805
   "file_extension": ".py",
806
   "mimetype": "text/x-python",
807
   "name": "python",
808
   "nbconvert_exporter": "python",
809
   "pygments_lexer": "ipython3",
810
   "version": "3.7.3"
811
  }
812
 },
813
 "nbformat": 4,
814
 "nbformat_minor": 2
815
}