[fc9ccf]: / 3a-L1-train-and-generate-predictions-fastai_v1.ipynb

Download this file

816 lines (815 with data), 27.8 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "stage = \"stage_2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = 'NONAMED'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fastai\n",
    "from fastai.vision import *\n",
    "from fastai.callbacks import *\n",
    "from fastai.distributed import *\n",
    "from fastai.utils.mem import *\n",
    "import re\n",
    "import pydicom\n",
    "import pdb\n",
    "import pickle\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0386_subdural_loss_fold5_of_5.pth\r\n",
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0399_subdural_loss_fold4_of_5.pth\r\n",
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0403_subdural_loss_fold3_of_5.pth\r\n",
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0406_subdural_loss_fold2_of_5.pth\r\n",
      "densenet121-53_train_cpu_window_uint8_sz512_cv0.0413_subdural_loss_fold1_of_5.pth\r\n",
      "densenet121_sz512_cv0.0738_weighted_loss_fold4_of_5.pth\r\n",
      "densenet121_sz512_cv0.0739_weighted_loss_fold5_of_5.pth\r\n",
      "densenet121_sz512_cv0.0743_weighted_loss_fold2_of_5.pth\r\n",
      "densenet121_sz512_cv0.0748_weighted_loss_fold3_of_5.pth\r\n",
      "densenet121_sz512_cv0.0749_weighted_loss_fold1_of_5.pth\r\n",
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0365_subdural_loss_fold5_of_5.pth\r\n",
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0366_subdural_loss_fold4_of_5.pth\r\n",
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0369_subdural_loss_fold2_of_5.pth\r\n",
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0371_subdural_loss_fold3_of_5.pth\r\n",
      "resnet101-53_train_cpu_window_uint8_sz512_cv0.0375_subdural_loss_fold1_of_5.pth\r\n",
      "resnet101_sz512_cv0.0749_weighted_loss_fold5_of_5.pth\r\n",
      "resnet101_sz512_cv0.0751_weighted_loss_fold4_of_5.pth\r\n",
      "resnet101_sz512_cv0.0754_weighted_loss_fold2_of_5.pth\r\n",
      "resnet101_sz512_cv0.0758_weighted_loss_fold1_of_5.pth\r\n",
      "resnet101_sz512_cv0.0761_weighted_loss_fold3_of_5.pth\r\n",
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0522_subdural_loss_fold5_of_5.pth\r\n",
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0528_subdural_loss_fold3_of_5.pth\r\n",
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0530_subdural_loss_fold4_of_5.pth\r\n",
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0550_subdural_loss_fold2_of_5.pth\r\n",
      "resnet18-53_train_cpu_window_uint8_sz512_cv0.0556_subdural_loss_fold1_of_5.pth\r\n",
      "resnet18_sz512_cv0.0762_weighted_loss_fold4_of_5.pth\r\n",
      "resnet18_sz512_cv0.0768_weighted_loss_fold5_of_5.pth\r\n",
      "resnet18_sz512_cv0.0773_weighted_loss_fold2_of_5.pth\r\n",
      "resnet18_sz512_cv0.0786_weighted_loss_fold1_of_5.pth\r\n",
      "resnet18_sz512_cv0.0786_weighted_loss_fold3_of_5.pth\r\n",
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0453_subdural_loss_fold5_of_5.pth\r\n",
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0466_subdural_loss_fold4_of_5.pth\r\n",
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0474_subdural_loss_fold3_of_5.pth\r\n",
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0484_subdural_loss_fold2_of_5.pth\r\n",
      "resnet34-53_train_cpu_window_uint8_sz512_cv0.0496_subdural_loss_fold1_of_5.pth\r\n",
      "resnet34_sz512_cv0.0750_weighted_loss_fold4_of_5.pth\r\n",
      "resnet34_sz512_cv0.0759_weighted_loss_fold5_of_5.pth\r\n",
      "resnet34_sz512_cv0.0763_weighted_loss_fold2_of_5.pth\r\n",
      "resnet34_sz512_cv0.0765_weighted_loss_fold1_of_5.pth\r\n",
      "resnet34_sz512_cv0.0782_weighted_loss_fold3_of_5.pth\r\n",
      "resnet50_sz512_cv0.0751_weighted_loss_fold4_of_5.pth\r\n",
      "resnet50_sz512_cv0.0751_weighted_loss_fold5_of_5.pth\r\n",
      "resnet50_sz512_cv0.0754_weighted_loss_fold2_of_5.pth\r\n",
      "resnet50_sz512_cv0.0761_weighted_loss_fold1_of_5.pth\r\n",
      "resnet50_sz512_cv0.0764_weighted_loss_fold3_of_5.pth\r\n",
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0372_subdural_loss_fold5_of_5.pth\r\n",
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0376_subdural_loss_fold4_of_5.pth\r\n",
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0378_subdural_loss_fold2_of_5.pth\r\n",
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0381_subdural_loss_fold3_of_5.pth\r\n",
      "resnext50_32x4d-53_train_cpu_window_uint8_sz512_cv0.0387_subdural_loss_fold1_of_5.pth\r\n"
     ]
    }
   ],
   "source": [
    "!ls models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "arch = 'resnet18'\n",
    "fold=4\n",
    "lr = 1e-3\n",
    "\n",
    "model_fn = None\n",
    "SZ = 512\n",
    "n_folds=5\n",
    "n_epochs = 15\n",
    "n_tta = 10\n",
    "\n",
    "#model_fn = 'resnet34_sz512_cv0.0821_weighted_loss_fold1_of_5'\n",
    "\n",
    "if model_fn is not None:\n",
    "    model_fn_fold = int(model_fn[-6])-1\n",
    "    assert model_fn_fold == fold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = Path('data/unzip')\n",
    "train_dir = data_dir / f'{stage}_train_images'\n",
    "test_dir = data_dir / f'{stage}_test_images'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'data/stage_2_fn_to_study_ix.pickle'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-6207314e915b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfn_to_study_ix\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_fn_to_study_ix.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mstudy_ix_to_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_study_ix_to_fn.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mfn_to_labels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m  \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_train_fn_to_labels.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mstudy_to_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'data/{stage}_study_to_data.pickle'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/stage_2_fn_to_study_ix.pickle'"
     ]
    }
   ],
   "source": [
    "fn_to_study_ix = pickle.load(open(f'data/{stage}_fn_to_study_ix.pickle', 'rb'))\n",
    "study_ix_to_fn = pickle.load(open(f'data/{stage}_study_ix_to_fn.pickle', 'rb'))\n",
    "fn_to_labels = pickle.load(  open(f'data/{stage}_train_fn_to_labels.pickle', 'rb'))\n",
    "study_to_data = pickle.load (open(f'data/{stage}_study_to_data.pickle', 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gets a pathlib.Path, returns labels\n",
    "labels = [\n",
    "    'any0', 'epidural0', 'intraparenchymal0', 'intraventricular0', 'subarachnoid0', 'subdural0',\n",
    "    'any1', 'epidural1', 'intraparenchymal1', 'intraventricular1', 'subarachnoid1', 'subdural1',\n",
    "    'any2', 'epidural2', 'intraparenchymal2', 'intraventricular2', 'subarachnoid2', 'subdural2',\n",
    "]\n",
    "\n",
    "def get_labels(center_p):\n",
    "\n",
    "    study, center_ix = fn_to_study_ix[center_p.stem]\n",
    "    total = len(study_ix_to_fn[study])\n",
    "    \n",
    "    # Only 1 study has 2 different RescaleIntercept values (0.0 and 1.0) in the same series, so assume they're all the\n",
    "    # same and do everything in the big tensor\n",
    "    \n",
    "    pixels = {}\n",
    "    img_labels = []\n",
    "    ixs = [ max(0, min(ix, total-1)) for ix in range(center_ix-1, center_ix+2) ]\n",
    "    for i, ix in enumerate(ixs):\n",
    "        fn = study_ix_to_fn[study][ix]\n",
    "        img_labels += [ x + str(i) for x in fn_to_labels[fn] ]\n",
    "            \n",
    "    return img_labels\n",
    "\n",
    "get_labels(train_dir / 'ID_ac7c8fe8b.dcm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission = pd.read_csv(f'data/unzip/{stage}_sample_submission.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission['fn'] = sample_submission.ID.apply(lambda x: '_'.join(x.split('_')[:2]) + '.dcm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_fns = sample_submission.fn.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wc = 50\n",
    "ww = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def open_dicom(p_or_fn): # when called by .from_folder it's a Path; when called from .add_test it's a string\n",
    "    #pdb.set_trace()\n",
    "    center_p = Path(p_or_fn)\n",
    "\n",
    "    study, center_ix = fn_to_study_ix[center_p.stem]\n",
    "    total = len(study_ix_to_fn[study])\n",
    "    \n",
    "    # Only 1 study has 2 different RescaleIntercept values (0.0 and 1.0) in the same series, so assume they're all the\n",
    "    # same and do everything in the big tensor\n",
    "    \n",
    "    pixels = {}\n",
    "    ixs = [ max(0, min(ix, total-1)) for ix in range(center_ix-1, center_ix+2) ]\n",
    "    for ix in ixs:\n",
    "        if ix not in pixels:\n",
    "            p = center_p.parent / f'{study_ix_to_fn[study][ix]}.dcm'\n",
    "            dcm = pydicom.dcmread(str(p))\n",
    "            if ix == center_ix:\n",
    "                rescale_slope, rescale_intercept = float(dcm.RescaleSlope), float(dcm.RescaleIntercept)\n",
    "            pixels[ix] = torch.FloatTensor(dcm.pixel_array.astype(np.float))\n",
    "        \n",
    "    t = torch.stack([pixels[ix] for ix in ixs], dim=0) # stack chans together\n",
    "    if (t.shape[1:] != (SZ,SZ)):\n",
    "        t = torch.nn.functional.interpolate(\n",
    "            t.unsqueeze_(0), size=SZ, mode='bilinear', align_corners=True).squeeze_(0) # resize\n",
    "    t = t * rescale_slope + rescale_intercept # rescale\n",
    "    t = torch.clamp(t, wc-ww/2, wc+ww/2) # window\n",
    "    t = (t - (wc-ww/2)) / ww # normalize\n",
    "    \n",
    "    return Image(t)\n",
    "\n",
    "\n",
    "class DicomList(ImageList):\n",
    "    def open(self, fn): return open_dicom(fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_stats = ([0.45, 0.45, 0.45], [0.225, 0.225, 0.225])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OverSamplingCallback(LearnerCallback):\n",
    "    def __init__(self,learn:Learner):\n",
    "        super().__init__(learn)\n",
    "        self.labels = self.learn.data.train_dl.dataset.y.items\n",
    "        _, counts = np.unique(self.labels,return_counts=True)\n",
    "        self.weights = torch.DoubleTensor((1/counts)[self.labels])\n",
    "        self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])\n",
    "        self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))\n",
    "        \n",
    "    def on_train_begin(self, **kwargs):\n",
    "        self.learn.data.train_dl.dl.batch_sampler = BatchSampler(WeightedRandomSampler(weights,self.total_len_oversample), self.learn.data.train_dl.batch_size,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From: https://forums.fast.ai/t/is-there-any-built-in-method-to-oversample-minority-classes/46838/5\n",
    "class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):\n",
    "    def __init__(self, dataset, indices=None, num_samples=None):\n",
    "                \n",
    "        # if indices is not provided, \n",
    "        # all elements in the dataset will be considered\n",
    "        self.indices = list(range(len(dataset))) if indices is None else indices\n",
    "            \n",
    "        # if num_samples is not provided, \n",
    "        # draw `len(indices)` samples in each iteration\n",
    "        self.num_samples = len(self.indices) if num_samples is None else num_samples\n",
    "            \n",
    "        # distribution of classes in the dataset \n",
    "        label_to_count = defaultdict(int)\n",
    "        for idx in self.indices:\n",
    "            label = self._get_label(dataset, idx)\n",
    "            label_to_count[label] += 1\n",
    "                \n",
    "        # weight for each sample\n",
    "        weights = [1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices]\n",
    "        self.weights = torch.DoubleTensor(weights)\n",
    "\n",
    "    def _get_label(self, dataset, idx):\n",
    "        return 'any1' in dataset.y[idx].obj\n",
    "                \n",
    "    def __iter__(self):\n",
    "        return (self.indices[i] for i in torch.multinomial(\n",
    "            self.weights, self.num_samples, replacement=True))\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folds = np.array_split(list({v['fold'] for v in study_to_data.values()}), n_folds)\n",
    "folds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(f\"data/{stage}_train_dicom_diags_norm.csv\")\n",
    "# split train/val\n",
    "np.random.seed(666)\n",
    "studies_in_fold = [k for k,v in study_to_data.items() if np.isin(v['fold'],folds[fold])]\n",
    "study_id_val_set = set(df[ df['SeriesInstanceUID'].isin(studies_in_fold)]['SOPInstanceUID'] + '.dcm')\n",
    "\n",
    "def is_val(p_or_fn):\n",
    "    p = Path(p_or_fn)\n",
    "    return p.name in study_id_val_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = (\n",
    "    [ flip_lr(p=0.5), rotate(degrees=(-180,180), p=1.) ], # train\n",
    "    [] # val\n",
    ")\n",
    "data = (DicomList.from_folder(f'data/unzip/{stage}_train_images/', extensions=['.dcm'], presort=True)\n",
    "    .split_by_valid_func(is_val)\n",
    "    .label_from_func(get_labels, classes=labels)\n",
    "    .transform(tfms, mode='nearest')\n",
    "    .add_test(f'data/unzip/{stage}_test_images/' + test_fns))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db = data.databunch().normalize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "balsampler = ImbalancedDatasetSampler(db.train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "db.train_dl.batch_sampler = balsampler\n",
    "db.train_dl.batch_sampler\n",
    "db"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db.show_batch() # AttributeError: 'list' object has no attribute 'pixel' <-- because using torch datasets or whatever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# yuval reina's loss\n",
    "yuval_weights = FloatTensor([ 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1 ]).cuda() \n",
    "\n",
    "def yuval_loss(y_pred,y_true):\n",
    "    return F.binary_cross_entropy_with_logits(y_pred,\n",
    "                                  y_true,\n",
    "                                  yuval_weights.repeat(y_pred.shape[0],1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_lb_weights = FloatTensor([ 2, 1, 1, 1, 1, 1 ])\n",
    "\n",
    "def real_lb_loss(pred:Tensor, targ:Tensor)->Rank0Tensor:\n",
    "    pred,targ = flatten_check(pred,targ)\n",
    "    tp = pred.view(-1,18)[:,6:12]\n",
    "    tt = targ.view(-1,18)[:,6:12]\n",
    "    return F.binary_cross_entropy_with_logits(tp, tt, real_lb_weights.to(device=pred.device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w_loss = 0.1\n",
    "weighted_lb_weights = FloatTensor([ 2.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, \n",
    "                                    2.,        1.,        1.,        1.,        1.,        1., \n",
    "                                    2.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, 1.*w_loss, ])\n",
    "\n",
    "def weighted_loss(pred:Tensor,targ:Tensor)->Rank0Tensor:\n",
    "    return F.binary_cross_entropy_with_logits(pred,targ,weighted_lb_weights.to(device=pred.device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lb_weights = FloatTensor([ 2, 1, 1, 1, 1, 1 ]).cuda()\n",
    "\n",
    "def lb_loss(pred:Tensor, targ:Tensor)->Rank0Tensor:\n",
    "    pred,targ = flatten_check(pred,targ)\n",
    "    tp = pred.view(-1,6)\n",
    "    tt = targ.view(-1,6)\n",
    "    return torch.nn.functional.binary_cross_entropy_with_logits(tp, tt, lb_weights.to(device=pred.device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(\n",
    "    db, getattr(models, arch), loss_func=weighted_loss, metrics=[real_lb_loss],\n",
    "    pretrained=True, lin_ftrs=[],ps=0.,\n",
    "    model_dir = Path('./models/').resolve())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    learn.load(model_fn,strict=False)\n",
    "    print(f\"Loaded {model_fn}\")\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = learn.to_parallel().to_fp16()\n",
    "\n",
    "arch_to_batch_factor = {\n",
    "    'resnet18' : 4.3,\n",
    "    'resnet34' : 3.4,\n",
    "    'resnet50' : 1.5,\n",
    "    'resnet101' : 1.0,\n",
    "    'resnext50_32x4d': 1.2,\n",
    "    'densenet121': 1.1,\n",
    "    'squeezenet1_0': 2.6,\n",
    "    'vgg16': 1.5\n",
    "}\n",
    "\n",
    "bs = arch_to_batch_factor[arch] * 512*512 * gpu_mem_get()[0] * 1e-3 / (SZ*SZ)\n",
    "bs *= int(int(any([isinstance(cb, MixedPrecision) for cb in learn.callbacks])))+1 # 2x if fp16\n",
    "if any([isinstance(cb, fastai.distributed.ParallelTrainer) for cb in learn.callbacks]):\n",
    "    bs *= torch.cuda.device_count()\n",
    "    bs = (bs // torch.cuda.device_count()) * torch.cuda.device_count()\n",
    "bs = int(bs)\n",
    "learn.data.batch_size = bs\n",
    "bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    if lr is None:\n",
    "        !rm {learn.model_dir}/tmp.pth\n",
    "        learn.lr_find()\n",
    "        learn.recorder.plot(suggestion=True)\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "# set lr "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(n_epochs, lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = learn.validate()\n",
    "cv = float(v[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_fn = f'{arch}_sz{SZ}_cv{cv:0.4f}_{learn.loss_func.__name__}_fold{fold+1}_of_{n_folds}'\n",
    "learn.save(model_fn)\n",
    "model_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.data.test_ds.tfms = learn.data.train_ds.tfms\n",
    "test_preds = []\n",
    "# Fastai WTF: it figures out to run outputs through sigmoid if a standard loss error is used \n",
    "# (see loss_func_name2activ and related stuff) but on custom loss funcs sigmoid must be executed explicitly:\n",
    "# https://forums.fast.ai/t/shouldnt-we-able-to-pass-an-activ-function-to-learner-get-preds/50492\n",
    "for _ in progress_bar(range(n_tta)):\n",
    "    preds, _ = learn.get_preds(DatasetType.Test, activ=torch.sigmoid)\n",
    "    test_preds.append(preds)\n",
    "tta_test_preds = torch.cat([p.unsqueeze(0) for p in test_preds],dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.data.valid_ds.tfms = learn.data.train_ds.tfms\n",
    "valid_preds = []\n",
    "# Fastai WTF: it figures out to run outputs through sigmoid if a standard loss error is used \n",
    "# (see loss_func_name2activ and related stuff) but on custom loss funcs sigmoid must be executed explicitly:\n",
    "# https://forums.fast.ai/t/shouldnt-we-able-to-pass-an-activ-function-to-learner-get-preds/50492\n",
    "for _ in progress_bar(range(n_tta)):\n",
    "    preds, _ = learn.get_preds(DatasetType.Valid, activ=torch.sigmoid)\n",
    "    valid_preds.append(preds)\n",
    "tta_valid_preds = torch.cat([p.unsqueeze(0) for p in valid_preds],dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PREDS_DIR = 'data/predictions'\n",
    "!mkdir -p {PREDS_DIR}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(tta_test_preds,  f'{PREDS_DIR}/{model_fn}_test.pth')\n",
    "torch.save(tta_valid_preds, f'{PREDS_DIR}/{model_fn}_valid.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame({'fn': [p.stem for p in learn.data.valid_ds.x.items]}).to_csv(\n",
    "    f'{PREDS_DIR}/{model_fn}_valid_fns.csv',index=False)\n",
    "pd.DataFrame({'fn': [Path(p).stem for p in learn.data.test_ds.x.items]}).to_csv(\n",
    "    f'{PREDS_DIR}/{model_fn}_test_fns.csv',index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Submit to Kaggle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tta_test_preds_mean = tta_test_preds.mean(dim=0)\n",
    "tta_test_preds_geomean = torch.expm1(torch.log1p(tta_test_preds).mean(dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = []\n",
    "labels = []\n",
    "\n",
    "for fn, pred in zip(test_fns, tta_test_preds_mean):\n",
    "    for i, label in enumerate(db.train_ds.classes):\n",
    "        if label.endswith('1'):\n",
    "            ids.append(f\"{fn.split('.')[0]}_{label.strip('1')}\")\n",
    "            predicted_probability = '{0:1.10f}'.format(pred[i].item())\n",
    "            labels.append(predicted_probability)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mkdir -p data/submissions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_name = experiment_name + \"_\" + model_fn\n",
    "sub_path = f'data/submissions/{sub_name}.csv.zip'\n",
    "sub_name, sub_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame({'ID': ids, 'Label': labels}).to_csv(sub_path, compression='zip', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/submissions/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!echo {sub_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!kaggle competitions submit -c rsna-intracranial-hemorrhage-detection -f {sub_path} -m {model_fn}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}