[53d15f]: / notebooks / Code.ipynb

Download this file

1967 lines (1966 with data), 83.0 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#conda activate torch-xla-nightly\n",
    "#export XRT_TPU_CONFIG=\"tpu_worker;0;$10.0.101.2:8470\"\n",
    "#git init\n",
    "#git remote add origin https://github.com/nosound2/RSNA-Hemorrhage\n",
    "#git config remote.origin.push HEAD\n",
    "#git config credential.helper store\n",
    "#git pull origin master\n",
    "#git config --global branch.master.remote origin\n",
    "#git config --global branch.master.merge refs/heads/master\n",
    "#gcloud config set compute/zone europe-west4-a\n",
    "#gcloud config set compute/zone us-central1-c\n",
    "#gcloud auth login\n",
    "#gcloud config set project endless-empire-239015\n",
    "#pip install kaggle --user\n",
    "#PATH=$PATH:~/.local/bin\n",
    "#mkdir .kaggle\n",
    "#gsutil cp gs://recursion-double-strand/kaggle-keys/kaggle.json ~/.kaggle\n",
    "#chmod 600 /home/zahar_chikishev/.kaggle/kaggle.json\n",
    "#kaggle competitions download rsna-intracranial-hemorrhage-detection -f stage_2_train.csv\n",
    "#sudo apt install unzip\n",
    "#unzip stage_1_train.csv.zip\n",
    "#kaggle kernels output xhlulu/rsna-generate-metadata-csvs -p .\n",
    "#kaggle kernels output zaharch/dicom-test-metadata-to-csv -p .\n",
    "#kaggle kernels output zaharch/dicom-metadata-to-csv -p .\n",
    "#gsutil -m cp gs://rsna-hemorrhage/yuvals/* .\n",
    "\n",
    "#export XRT_TPU_CONFIG=\"tpu_worker;0;10.0.101.2:8470\"; conda activate torch-xla-nightly; jupyter notebook\n",
    "\n",
    "# 35.188.114.109\n",
    "\n",
    "# lsblk\n",
    "# sudo mkfs.ext4 -m 0 -F -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb\n",
    "# sudo mkdir /mnt/edisk\n",
    "# sudo mount -o discard,defaults /dev/sdb /mnt/edisk\n",
    "# sudo chmod a+w /mnt/edisk\n",
    "# df -H\n",
    "# echo UUID=`sudo blkid -s UUID -o value /dev/sdb` /mnt/edisk ext4 discard,defaults,nofail 0 2 | sudo tee -a /etc/fstab\n",
    "# cat /etc/fstab\n",
    "\n",
    "#jupyter notebook --generate-config\n",
    "#vi ~/.jupyter/jupyter_notebook_config.py\n",
    "\n",
    "#c = get_config()\n",
    "#c.NotebookApp.ip = '*'\n",
    "#c.NotebookApp.open_browser = False\n",
    "#c.NotebookApp.port = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "from pathlib import Path\n",
    "from PIL import ImageDraw, ImageFont, Image\n",
    "from matplotlib import patches, patheffects\n",
    "import time\n",
    "from random import randint\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "\n",
    "from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold\n",
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import mean_squared_error,log_loss,roc_auc_score\n",
    "from scipy.stats import ks_2samp\n",
    "\n",
    "import pdb\n",
    "\n",
    "import scipy as sp\n",
    "from tqdm import tqdm, tqdm_notebook\n",
    "\n",
    "import os\n",
    "import glob\n",
    "\n",
    "import torch\n",
    "\n",
    "#CLOUD = not torch.cuda.is_available()\n",
    "CLOUD = (torch.cuda.get_device_name(0) in ['Tesla P4', 'Tesla K80', 'Tesla P100-PCIE-16GB'])\n",
    "CLOUD_SINGLE = True\n",
    "TPU = False\n",
    "\n",
    "if not CLOUD:\n",
    "    torch.cuda.current_device()\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as D\n",
    "import torch.nn.functional as F\n",
    "import torch.utils as U\n",
    "\n",
    "import torchvision\n",
    "from torchvision import transforms as T\n",
    "from torchvision import models as M\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "if CLOUD:\n",
    "    PATH = Path('/home/zahar_chikishev/Hemorrhage')\n",
    "    PATH_WORK = Path('/mnt/edisk/running')\n",
    "    PATH_DISK = Path('/mnt/edisk/running')\n",
    "else:\n",
    "    PATH = Path('C:/StudioProjects/Hemorrhage')\n",
    "    PATH_WORK = Path('C:/StudioProjects/Hemorrhage/running')\n",
    "    PATH_DISK = PATH_WORK\n",
    "\n",
    "from collections import defaultdict, Counter\n",
    "import random\n",
    "import seaborn as sn\n",
    "\n",
    "pd.set_option(\"display.max_columns\", 100)\n",
    "\n",
    "all_ich = ['any','epidural','intraparenchymal','intraventricular','subarachnoid','subdural']\n",
    "class_weights = 6.0*np.array([2,1,1,1,1,1])/7.0\n",
    "\n",
    "if CLOUD and TPU:\n",
    "    import torch_xla\n",
    "    import torch_xla.distributed.data_parallel as dp\n",
    "    import torch_xla.utils as xu\n",
    "    import torch_xla.core.xla_model as xm\n",
    "\n",
    "from typing import Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_black = '006d4432e'\n",
    "all_black = '00bd6c59c'\n",
    "\n",
    "if CLOUD:\n",
    "    if TPU:\n",
    "        device = xm.xla_device()\n",
    "    else:\n",
    "        device = 'cuda'\n",
    "    #device = 'cpu'\n",
    "    MAX_DEVICES = 1 if CLOUD_SINGLE else 8\n",
    "    NUM_WORKERS = 16\n",
    "    bs = 32\n",
    "else:\n",
    "    device = 'cuda'\n",
    "    #device = 'cpu'\n",
    "    MAX_DEVICES = 1\n",
    "    NUM_WORKERS = 0\n",
    "    bs = 16\n",
    "\n",
    "#if CLOUD and (not CLOUD_SINGLE):\n",
    "#    devices = xm.get_xla_supported_devices(max_devices=MAX_DEVICES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2351\n",
    "\n",
    "def setSeeds(seed):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "\n",
    "setSeeds(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getDSName(dataset):\n",
    "    if dataset == 6:\n",
    "        return 'Densenet201_F3'\n",
    "    elif dataset == 7:\n",
    "        return 'Densenet161_F3'\n",
    "    elif dataset == 8:\n",
    "        return 'Densenet169_F3'\n",
    "    elif dataset == 9:\n",
    "        return 'se_resnext101_32x4d_F3'\n",
    "    elif dataset == 10:\n",
    "        return 'se_resnet101_F3'\n",
    "    elif dataset == 11:\n",
    "        return 'se_resnext101_32x4d_F5'\n",
    "    elif dataset == 12:\n",
    "        return 'se_resnet101_F5'\n",
    "    elif dataset == 13:\n",
    "        return 'se_resnet101_focal_F5'\n",
    "    elif dataset == 14:\n",
    "        return 'se_resnet101_F5'\n",
    "    else: assert False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getDSParams(dataset):\n",
    "    if dataset == 6:\n",
    "        return 'Densenet201','_3','_v5',240,'','','classifier',''\n",
    "    elif dataset == 7:\n",
    "        return 'Densenet161','_3','_v4',552,'2','','classifier',''\n",
    "    elif dataset == 8:\n",
    "        return 'Densenet169','_3','_v3',208,'2','','classifier',''\n",
    "    elif dataset == 9:\n",
    "        return 'se_resnext101_32x4d','','',256,'','_tta','classifier',''\n",
    "    elif dataset == 10:\n",
    "        return 'se_resnet101','','',256,'','','classifier',''\n",
    "    elif dataset == 11:\n",
    "        return 'se_resnext101_32x4d_5','','',256,'','','new',''\n",
    "    elif dataset == 12:\n",
    "        return 'se_resnet101_5','','',256,'','','new',''\n",
    "    elif dataset == 13:\n",
    "        return 'se_resnet101_5f','','',256,'','','new','focal_'\n",
    "    elif dataset == 14:\n",
    "        return 'se_resnet101_5n','','',256,'','','new','stage2_'\n",
    "    else: assert False\n",
    "\n",
    "def getNFolds(dataset):\n",
    "    if dataset <= 10:\n",
    "        return 3\n",
    "    elif dataset <= 14:\n",
    "        return 5\n",
    "    else: assert False\n",
    "\n",
    "my_datasets3 = [7,9]\n",
    "my_datasets5 = [11,12,13]\n",
    "my_len = len(my_datasets3) + len(my_datasets5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols_le,cols_float,cols_bool = pickle.load(open(PATH_WORK/'covs','rb'))\n",
    "meta_cols = cols_bool + cols_float\n",
    "\n",
    "#meta_cols = [c for c in meta_cols if c not in ['SeriesPP']]\n",
    "#cols_le = cols_le[:-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadMetadata(rerun=False):\n",
    "    if rerun:\n",
    "        train_dedup = pd.read_csv(PATH_WORK/'yuval'/'train_dedup.csv')\n",
    "        pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n",
    "\n",
    "        assert len(pids) == 17079\n",
    "        assert len(np.unique(pids)) == 17079\n",
    "\n",
    "        for fol in folding:\n",
    "            assert len(fol[0]) + len(fol[1]) == 17079\n",
    "\n",
    "        assert len(folding[0][1]) + len(folding[1][1]) + len(folding[2][1]) == 17079\n",
    "\n",
    "        assert len(train_dedup.PID.unique()) == 17079\n",
    "\n",
    "        train_dedup['fold'] = np.nan\n",
    "\n",
    "        for fold in range(3):\n",
    "            train_dedup.loc[train_dedup.PID.isin(pids[folding[fold][1]]),'fold'] = fold\n",
    "\n",
    "        assert train_dedup.fold.isnull().sum() == 0\n",
    "        \n",
    "        oof5 = pickle.load(open(PATH_DISK/'yuval'/'OOF_validation_image_ids_5.pkl','rb'))\n",
    "        for fold in range(5):\n",
    "            train_dedup.loc[train_dedup.PatientID.isin(oof5[fold]),'fold5'] = fold\n",
    "        assert train_dedup.fold5.isnull().sum() == 0\n",
    "        \n",
    "        train_md = pd.read_csv(PATH_WORK/'train_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "        train_md['img_id'] = train_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
    "\n",
    "        ids_df = train_dedup[['fold','PatientID','fold5']]\n",
    "        ids_df.columns = ['fold','img_id','fold5']\n",
    "\n",
    "        train_md = ids_df.join(train_md.set_index('img_id'), on = 'img_id')\n",
    "\n",
    "        pickle.dump(train_md, open(PATH_WORK/'train.ids.df','wb'))\n",
    "\n",
    "        #test_md = pickle.load(open(PATH_WORK/'test.post.processed.1','rb'))\n",
    "        test_md = pd.read_csv(PATH_WORK/'test_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "        test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
    "\n",
    "        filename = PATH_WORK/'yuval'/'test_indexes.pkl'\n",
    "        test_ids = pickle.load(open(filename,'rb'))\n",
    "\n",
    "        test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n",
    "        test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n",
    "        \n",
    "        assert len(test_md.SeriesInstanceUID.unique()) == 2214\n",
    "        \n",
    "        #train_md = pd.concat([train_md, test_md], axis = 0, sort=False).reset_index(drop=True)\n",
    "        \n",
    "        #pickle.dump(train_md, open(PATH_WORK/'train.ids.df','wb'))\n",
    "        pickle.dump(test_md, open(PATH_WORK/'test.ids.df','wb'))\n",
    "    else:\n",
    "        train_md = pickle.load(open(PATH_WORK/'train.ids.df','rb'))\n",
    "        test_md = pickle.load(open(PATH_WORK/'test.ids.df','rb'))\n",
    "    \n",
    "    return train_md, test_md"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadMetadata2(rerun=False):\n",
    "    if rerun:\n",
    "        train_dedup2 = pd.read_csv(PATH_WORK/'yuval'/'train_stage2.csv')\n",
    "\n",
    "        assert train_dedup2.PatientID.nunique() == 752797\n",
    "        assert train_dedup2.PID.nunique() == 18938\n",
    "\n",
    "        pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n",
    "        pids3, folding3 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_3_stage_2.pkl','rb'))\n",
    "        pids5, folding5 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_5_stage_2.pkl','rb'))\n",
    "\n",
    "        assert np.array([len(folding3[i][1]) for i in range(3)]).sum() == 18938\n",
    "        assert np.array([len(folding3[i][0]) for i in range(3)]).sum() == 2*18938\n",
    "        \n",
    "        for i in range(3):\n",
    "            assert set(pids[folding[i][1]]).issubset(set(pids3[folding3[i][1]]))\n",
    "\n",
    "        for i in range(3):\n",
    "            train_dedup2.loc[train_dedup2.PID.isin(pids3[folding3[i][1]]),'fold'] = i\n",
    "        assert train_dedup2.fold.isnull().sum() == 0\n",
    "\n",
    "        for i in range(5):\n",
    "            train_dedup2.loc[train_dedup2.PID.isin(pids5[folding5[i][1]]),'fold5'] = i\n",
    "        assert train_dedup2.fold5.isnull().sum() == 0\n",
    "        \n",
    "        oof5 = pickle.load(open(PATH_DISK/'yuval'/'OOF_validation_image_ids_5.pkl','rb'))\n",
    "        for i in range(5):\n",
    "            assert set(oof5[i]).issubset(set(train_dedup2.loc[train_dedup2.fold5==i,'PatientID']))\n",
    "        \n",
    "        train_csv = pd.read_csv(PATH/'stage_2_train.csv')\n",
    "        train_csv = train_csv.loc[~train_csv.ID.duplicated()].sort_values('ID').reset_index(drop=True)        \n",
    "        all_sop_ids = train_csv.ID.str.split('_').apply(lambda x: x[0]+'_'+x[1]).unique()\n",
    "        train_df = pd.DataFrame(train_csv.Label.values.reshape((-1,6)), columns = all_ich)\n",
    "        train_df['sop_id'] = all_sop_ids\n",
    "        \n",
    "        old_test = pd.read_csv(PATH_WORK/'test_md.csv')\n",
    "        old_test = old_test.drop(all_ich,axis=1)\n",
    "        old_test = old_test.join(train_df.set_index('sop_id'), on = 'SOPInstanceUID')\n",
    "        \n",
    "        train_md2 = pd.concat([pd.read_csv(PATH_WORK/'train_md.csv'), old_test],\\\n",
    "                              axis=0,sort=False)\n",
    "        train_md2['img_id'] = train_md2.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
    "        \n",
    "        ids_df = train_dedup2[['fold','PatientID','fold5']]\n",
    "        ids_df.columns = ['fold','img_id','fold5']\n",
    "\n",
    "        train_md2 = ids_df.join(train_md2.set_index('img_id'), on = 'img_id')\n",
    "        assert train_md2.test.sum() == 78545\n",
    "        \n",
    "        pickle.dump(train_md2, open(PATH_WORK/'train2.ids.df','wb'))\n",
    "        \n",
    "        test_md = pd.read_csv(PATH_WORK/'test2_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "        test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
    "\n",
    "        filename = PATH_WORK/'yuval/test_indexes_stage2.pkl'\n",
    "        test_ids = pickle.load(open(filename,'rb'))\n",
    "\n",
    "        test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n",
    "        test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n",
    "\n",
    "        assert len(test_md.SeriesInstanceUID.unique()) == 3518\n",
    "\n",
    "        pickle.dump(test_md, open(PATH_WORK/'test2.ids.df','wb'))\n",
    "\n",
    "    else:\n",
    "        train_md2 = pickle.load(open(PATH_WORK/'train2.ids.df','rb'))\n",
    "        test_md = pickle.load(open(PATH_WORK/'test2.ids.df','rb'))\n",
    "    \n",
    "    return train_md2, test_md"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadMetadata3(rerun=False):\n",
    "    if rerun:\n",
    "        train_md = pickle.load(open(PATH_WORK/'train.ids.df','rb'))\n",
    "        test_md = pickle.load(open(PATH_WORK/'test.ids.df','rb'))\n",
    "        train_md = pd.concat([train_md, test_md], axis=0,sort=False).reset_index(drop=True)\n",
    "        \n",
    "        train_md.fold = np.nan\n",
    "        train_md['PID'] = train_md.PatientID.str.split('_').apply(lambda x: x[1])\n",
    "        \n",
    "        pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n",
    "        pids3, folding3 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_3_stage_2.pkl','rb'))\n",
    "        pids5, folding5 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_5_stage_2.pkl','rb'))\n",
    "\n",
    "        for i in range(3):\n",
    "            train_md.loc[train_md.PID.isin(pids3[folding3[i][1]]),'fold'] = i\n",
    "        assert train_md.fold.isnull().sum() == 0\n",
    "\n",
    "        for i in range(5):\n",
    "            train_md.loc[train_md.PID.isin(pids5[folding5[i][1]]),'fold5'] = i\n",
    "        assert train_md.fold5.isnull().sum() == 0\n",
    "        \n",
    "        train_md.weights = train_md.weights.fillna(1)\n",
    "        \n",
    "        pickle.dump(train_md, open(PATH_WORK/'train3.ids.df','wb'))\n",
    "        \n",
    "        test_md = pd.read_csv(PATH_WORK/'test2_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "        test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n",
    "\n",
    "        filename = PATH_WORK/'yuval/test_indexes_stage2.pkl'\n",
    "        test_ids = pickle.load(open(filename,'rb'))\n",
    "\n",
    "        test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n",
    "        test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n",
    "\n",
    "        assert len(test_md.SeriesInstanceUID.unique()) == 3518\n",
    "\n",
    "        pickle.dump(test_md, open(PATH_WORK/'test3.ids.df','wb'))\n",
    "    else:\n",
    "        train_md = pickle.load(open(PATH_WORK/'train3.ids.df','rb'))\n",
    "        test_md = pickle.load(open(PATH_WORK/'test3.ids.df','rb'))\n",
    "    \n",
    "    return train_md, test_md"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_md, test_md = loadMetadata(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_md = loadMetadata2(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocessedData(dataset, folds = range(3), fold_col=None, do_test=True, do_test2=True, do_train=True):\n",
    "    assert dataset >= 6\n",
    "    \n",
    "    dataset_name, filename_add, filename_add2, feat_sz, ds_num, test_fix, dsft, focal = getDSParams(dataset)\n",
    "    \n",
    "    #if 'train_md' not in globals() or 'test_md' not in globals():\n",
    "    #    train_md, test_md = loadMetadata(False)\n",
    "    \n",
    "    PATH_DS = PATH_DISK/'features/{}{}'.format(dataset_name,filename_add2)\n",
    "    if not PATH_DS.is_dir():\n",
    "        PATH_DS.mkdir()\n",
    "    if not (PATH_DS/'train').is_dir():\n",
    "        (PATH_DS/'train').mkdir()\n",
    "    if not (PATH_DS/'test').is_dir():\n",
    "        (PATH_DS/'test').mkdir()\n",
    "    if not (PATH_DS/'test2').is_dir():\n",
    "        (PATH_DS/'test2').mkdir()\n",
    "    \n",
    "    if fold_col is None:\n",
    "        fold_col = 'fold'\n",
    "        if len(folds) == 5:\n",
    "            fold_col = 'fold5'\n",
    "    \n",
    "    for fold in folds:\n",
    "        \n",
    "        if do_train:\n",
    "            filename = PATH_DISK/'yuval'/\\\n",
    "                'model_{}{}_version_{}_splits_{}type_features_train_tta{}_split_{}.pkl'\\\n",
    "                .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''), \\\n",
    "                        filename_add, dsft, focal, ds_num, fold)\n",
    "            feats = pickle.load(open(filename,'rb'))\n",
    "\n",
    "            print('dataset',dataset,'fold',fold,'feats size', feats.shape)\n",
    "            if dataset <= 13:\n",
    "                train_md_loc = train_md.loc[~train_md.test]\n",
    "            else:\n",
    "                train_md_loc = train_md\n",
    "            \n",
    "            assert len(feats) == 4*len(train_md_loc)\n",
    "            assert feats.shape[1] == feat_sz\n",
    "            means = feats.mean(0,keepdim=True)\n",
    "            stds = feats.std(0,keepdim=True)\n",
    "\n",
    "            feats = feats - means\n",
    "            feats = torch.where(stds > 0, feats/stds, feats)\n",
    "\n",
    "            for i in range(4):\n",
    "                feats_sub1 = feats[torch.BoolTensor(np.arange(len(feats))%4 == i)]\n",
    "                feats_sub2 = feats_sub1[torch.BoolTensor(train_md_loc[fold_col] != fold)]\n",
    "                pickle.dump(feats_sub2, open(PATH_DS/'train/train.f{}.a{}'.format(fold,i),'wb'))\n",
    "\n",
    "                feats_sub2 = feats_sub1[torch.BoolTensor(train_md_loc[fold_col] == fold)]\n",
    "                pickle.dump(feats_sub2, open(PATH_DS/'train/valid.f{}.a{}'.format(fold,i),'wb'))\n",
    "\n",
    "                if i==0:\n",
    "                    black_feats = feats_sub1[torch.BoolTensor(train_md_loc.img_id == all_black)].squeeze()\n",
    "                    pickle.dump(black_feats, open(PATH_DS/'train/black.f{}'.format(fold),'wb'))\n",
    "        \n",
    "        if do_test:\n",
    "            filename = PATH_DISK/'yuval'/\\\n",
    "                'model_{}{}_version_{}_splits_{}type_features_test{}{}_split_{}.pkl'\\\n",
    "                .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''),\n",
    "                        filename_add,dsft,focal,ds_num,test_fix,fold)\n",
    "            feats = pickle.load(open(filename,'rb'))\n",
    "\n",
    "            assert len(feats) == 8*len(test_md)\n",
    "            assert feats.shape[1] == feat_sz\n",
    "\n",
    "            feats = feats - means\n",
    "            feats = torch.where(stds > 0, feats/stds, feats)\n",
    "\n",
    "            for i in range(8):\n",
    "                feats_sub = feats[torch.BoolTensor(np.arange(len(feats))%8 == i)]\n",
    "                pickle.dump(feats_sub, open(PATH_DS/'test/test.f{}.a{}'.format(fold,i),'wb'))\n",
    "                assert len(feats_sub) == len(test_md)\n",
    "        \n",
    "        if do_test2:\n",
    "            filename = PATH_DISK/'yuval'/\\\n",
    "                'model_{}{}_version_{}_splits_{}type_features_test_stage2_split_{}.pkl'\\\n",
    "                .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''),\n",
    "                        filename_add,dsft,focal,fold)\n",
    "            feats = pickle.load(open(filename,'rb'))\n",
    "\n",
    "            assert len(feats) == 8*len(test_md)\n",
    "            assert feats.shape[1] == feat_sz\n",
    "\n",
    "            feats = feats - means\n",
    "            feats = torch.where(stds > 0, feats/stds, feats)\n",
    "\n",
    "            for i in range(8):\n",
    "                feats_sub = feats[torch.BoolTensor(np.arange(len(feats))%8 == i)]\n",
    "                pickle.dump(feats_sub, open(PATH_DS/'test2/test.f{}.a{}'.format(fold,i),'wb'))\n",
    "                assert len(feats_sub) == len(test_md)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    path = PATH_WORK/'features/densenet161_v3/train/ID_992b567eb6'\n",
    "    black_feats = pickle.load(open(path,'rb'))[41]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RSNA_DataSet(D.Dataset):\n",
    "    def __init__(self, metadata, dataset, mode='train', bs=None, fold=0):\n",
    "        \n",
    "        super(RSNA_DataSet, self).__init__()\n",
    "        \n",
    "        dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
    "        num_folds = getNFolds(dataset)\n",
    "        \n",
    "        self.dataset_name = dataset_name\n",
    "        self.filename_add2 = filename_add2\n",
    "        \n",
    "        folds_col = 'fold'\n",
    "        if num_folds == 5:\n",
    "            folds_col = 'fold5'\n",
    "        \n",
    "        if (dataset <= 13) and (mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n",
    "            if mode == 'train':\n",
    "                self.test_mask = torch.BoolTensor((metadata.loc[metadata.test, folds_col] != fold).values)\n",
    "            else:\n",
    "                self.test_mask = torch.BoolTensor((metadata.loc[metadata.test, folds_col] == fold).values)\n",
    "        \n",
    "        if mode == 'train':\n",
    "            md = metadata.loc[metadata[folds_col] != fold].copy().reset_index(drop=True)\n",
    "        elif mode == 'valid':\n",
    "            md = metadata.loc[metadata[folds_col] == fold].copy().reset_index(drop=True)\n",
    "        else:\n",
    "            md = metadata.copy().reset_index(drop=True)\n",
    "        \n",
    "        series = np.sort(md.SeriesInstanceUID.unique())\n",
    "        md = md.set_index('SeriesInstanceUID', drop=True)\n",
    "        \n",
    "        samples_add = 0\n",
    "        if (mode != 'train') and not DATA_SMALL:\n",
    "            batch_num = -((-len(series))//(bs*MAX_DEVICES))\n",
    "            samples_add = batch_num*bs*MAX_DEVICES - len(series)\n",
    "            print('adding dummy serieses', samples_add)\n",
    "        \n",
    "        #self.records = df.to_records(index=False)\n",
    "        self.mode = mode\n",
    "        self.real = np.concatenate([np.repeat(True,len(series)),np.repeat(False,samples_add)])\n",
    "        self.series = np.concatenate([series, random.sample(list(series),samples_add)])\n",
    "        self.metadata = md\n",
    "        self.dataset = dataset\n",
    "        self.fold = fold\n",
    "        \n",
    "        print('DataSet', dataset, mode, 'size', len(self.series), 'fold', fold)\n",
    "        \n",
    "        path = PATH_DISK/'features/{}{}/train/black.f{}'.format(self.dataset_name, self.filename_add2, fold)\n",
    "        self.black_feats = pickle.load(open(path,'rb')).squeeze()\n",
    "        \n",
    "        if WEIGHTED and (self.mode == 'train'):\n",
    "            tt = self.metadata['weights'].groupby(self.metadata.index).mean()\n",
    "            self.weights = tt.loc[self.series].values\n",
    "            print(pd.value_counts(self.weights))\n",
    "    \n",
    "    def setFeats(self, anum, epoch=0):\n",
    "        def getAPath(an,mode):\n",
    "            folder = 'test2' if mode == 'test' else 'train'\n",
    "            return PATH_DISK/'features/{}{}/{}/{}.f{}.a{}'\\\n",
    "                .format(self.dataset_name,self.filename_add2,folder,mode,self.fold,an)\n",
    "        \n",
    "        def getAPathFeats(mode, sz, mask=None):\n",
    "            max_a = 8 if mode == 'test' else 4\n",
    "            feats2 = torch.stack([pickle.load(open(getAPath(an,mode),'rb')) for an in range(max_a)])\n",
    "            if mask is not None:\n",
    "                feats2 = feats2[:,mask]\n",
    "            assert feats2.shape[1] == sz\n",
    "            feats = feats2[torch.randint(max_a,(sz,)), torch.arange(sz, dtype=torch.long)].squeeze()\n",
    "            return feats\n",
    "        \n",
    "        if self.dataset == 1: return\n",
    "        print('setFeats, augmentation', anum)\n",
    "        self.anum = anum\n",
    "        sz = len(self.metadata)\n",
    "        \n",
    "        assert anum <= 0\n",
    "        \n",
    "        if anum == -1:\n",
    "            if (self.dataset <= 13) and (self.mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n",
    "                feats = torch.cat([getAPathFeats(self.mode, sz - self.metadata.test.sum()), \n",
    "                                   getAPathFeats('test', self.metadata.test.sum(), self.test_mask)] ,axis=0)\n",
    "            else:\n",
    "                feats = getAPathFeats(self.mode, sz)\n",
    "        else:\n",
    "            if (self.dataset <= 13) and (self.mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n",
    "                feats = torch.cat([pickle.load(open(getAPath(anum,self.mode),'rb')),\n",
    "                                   pickle.load(open(getAPath(anum,'test'),'rb'))[self.test_mask]], axis=0)\n",
    "            else:\n",
    "                feats = pickle.load(open(getAPath(anum,self.mode),'rb'))\n",
    "        \n",
    "        self.feats = feats\n",
    "        self.epoch = epoch\n",
    "        assert len(feats) == sz\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        \n",
    "        series_id = self.series[index]\n",
    "        #df = self.metadata.loc[self.metadata.SeriesInstanceUID == series_id].reset_index(drop=True)\n",
    "        df = self.metadata.loc[series_id].reset_index(drop=True)\n",
    "        \n",
    "        if self.dataset >= 6:\n",
    "            feats = self.feats[torch.BoolTensor(self.metadata.index.values == series_id)]\n",
    "        else: assert False\n",
    "        \n",
    "        order = np.argsort(df.pos_idx1.values)\n",
    "        df = df.sort_values(['pos_idx1'])\n",
    "        feats = feats[torch.LongTensor(order)]\n",
    "        \n",
    "        #if WEIGHTED:\n",
    "        #    non_black = torch.Tensor(df['weights'])\n",
    "        #else:\n",
    "        non_black = torch.ones(len(feats))\n",
    "        \n",
    "        feats = torch.cat([feats, torch.Tensor(df[meta_cols].values)], dim=1)\n",
    "        feats_le = torch.LongTensor(df[cols_le].values)\n",
    "        \n",
    "        target = torch.Tensor(df[all_ich].values)\n",
    "        \n",
    "        PAD = 12\n",
    "        \n",
    "        np.random.seed(1234 + index + 10000*self.epoch)\n",
    "        offset = np.random.randint(0, 61 - feats.shape[0])\n",
    "        \n",
    "        #offset = 0\n",
    "        bk_add = 0\n",
    "        top_pad = PAD + offset\n",
    "        if top_pad > 0:\n",
    "            dummy_row = torch.cat([self.black_feats, torch.Tensor(df.head(1)[meta_cols].values).squeeze()])\n",
    "            feats = torch.cat([dummy_row.repeat(top_pad,1), feats], dim=0)\n",
    "            feats_le = torch.cat([torch.LongTensor(df.head(1)[cols_le].values).squeeze().repeat(top_pad,1), feats_le])\n",
    "            if offset > 0:\n",
    "                non_black = torch.cat([bk_add + torch.zeros(offset), non_black])\n",
    "                target = torch.cat([torch.zeros((offset, len(all_ich))), target], dim=0)\n",
    "        bot_pad = 60 - len(df) - offset + PAD\n",
    "        if bot_pad > 0:\n",
    "            dummy_row = torch.cat([self.black_feats, torch.Tensor(df.tail(1)[meta_cols].values).squeeze()])\n",
    "            feats = torch.cat([feats, dummy_row.repeat(bot_pad,1)], dim=0)\n",
    "            feats_le = torch.cat([feats_le, torch.LongTensor(df.tail(1)[cols_le].values).squeeze().repeat(bot_pad,1)])\n",
    "            if (60 - len(df) - offset) > 0:\n",
    "                non_black = torch.cat([non_black, bk_add + torch.zeros(60 - len(df) - offset)])\n",
    "                target = torch.cat([target, torch.zeros((60 - len(df) - offset, len(all_ich)))], dim=0)\n",
    "        \n",
    "        assert feats_le.shape[0] == (60 + 2*PAD)\n",
    "        assert feats.shape[0] == (60 + 2*PAD)\n",
    "        assert target.shape[0] == 60\n",
    "        \n",
    "        idx = index\n",
    "        if not self.real[index]: idx = -1\n",
    "        \n",
    "        if self.mode == 'train':\n",
    "            return feats, feats_le, target, non_black\n",
    "        else:\n",
    "            return feats, feats_le, target, idx, offset\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.series) if not DATA_SMALL else int(0.01*len(self.series))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getCurrentBatch(dataset, fold=0, ver=VERSION):\n",
    "    sel_batch = None\n",
    "    for filename in os.listdir(PATH_DISK/'models'):\n",
    "        splits = filename.split('.')\n",
    "        if int(splits[2][1]) != fold: continue\n",
    "        if int(splits[3][1:]) != dataset: continue\n",
    "        if int(splits[4][1:]) != ver: continue\n",
    "        if sel_batch is None:\n",
    "            sel_batch = int(splits[1][1:])\n",
    "        else:\n",
    "            sel_batch = max(sel_batch, int(splits[1][1:]))\n",
    "    return sel_batch\n",
    "\n",
    "def modelFileName(dataset, fold=0, batch = 1, return_last = False, return_next = False, ver=VERSION):\n",
    "    sel_batch = batch\n",
    "    if return_last or return_next:\n",
    "        sel_batch = getCurrentBatch(fold=fold, dataset=dataset, ver=ver)\n",
    "        if return_last and sel_batch is None:\n",
    "            return None\n",
    "        if return_next:\n",
    "            if sel_batch is None: sel_batch = 1\n",
    "            else: sel_batch += 1\n",
    "    \n",
    "    return 'model.b{}.f{}.d{}.v{}'.format(sel_batch, fold, dataset, ver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BCEWithLogitsLoss(nn.Module):\n",
    "    def __init__(self, weight=None):\n",
    "        super().__init__()\n",
    "        self.weight = weight\n",
    "    \n",
    "    def forward(self, input, target, batch_weights = None, focal=0):\n",
    "        if focal == 0:\n",
    "            loss = (torch.log(1+torch.exp(input)) - target*input)*self.weight\n",
    "        else:\n",
    "            loss = torch.pow(1+torch.exp(input), -focal) * \\\n",
    "                (((1-target)*torch.exp(focal*input) + target) * torch.log(1+torch.exp(input)) - target*input)\n",
    "            loss = loss*self.weight\n",
    "        if batch_weights is not None:\n",
    "            loss = batch_weights*loss\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None):\n",
    "    \"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`.\"\n",
    "    layers = [nn.BatchNorm1d(n_in)] if bn else []\n",
    "    if p != 0: layers.append(nn.Dropout(p))\n",
    "    layers.append(nn.Linear(n_in, n_out))\n",
    "    if actn is not None: layers.append(actn)\n",
    "    return layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "def noop(x): return x\n",
    "act_fun = nn.ReLU(inplace=True)\n",
    "\n",
    "def conv_layer(ni, nf, ks=3, act=True):\n",
    "    bn = nn.BatchNorm1d(nf)\n",
    "    layers = [nn.Conv1d(ni, nf, ks), bn]\n",
    "    if act: layers.append(act_fun)\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, ni, nh):\n",
    "        super().__init__()\n",
    "        layers  = [conv_layer(ni, nh, 1),\n",
    "                   conv_layer(nh, nh, 5, act=False)]\n",
    "        self.convs = nn.Sequential(*layers)\n",
    "        self.idconv = noop if (ni == nh) else conv_layer(ni, nh, 1, act=False)\n",
    "    \n",
    "    def forward(self, x): return act_fun(self.convs(x) + self.idconv(x[:,:,2:-2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNetModel(nn.Module):\n",
    "    def __init__(self, n_cont:int, feat_sz=2208):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.le_sz = 10\n",
    "        assert self.le_sz == (len(cols_le))\n",
    "        le_in_sizes = np.array([5,5,7,4,4,11,4,6,3,3])\n",
    "        le_out_sizes = np.array([3,3,4,2,2,6,2,4,2,2])\n",
    "        le_out_sz = le_out_sizes.sum()\n",
    "        self.embeddings = nn.ModuleList([embedding(le_in_sizes[i], le_out_sizes[i]) for i in range(self.le_sz)])\n",
    "        \n",
    "        self.feat_sz = feat_sz\n",
    "        \n",
    "        self.n_cont = n_cont\n",
    "        scale = 4\n",
    "        \n",
    "        self.conv2D = nn.Conv2d(1,scale*16,(feat_sz + n_cont + le_out_sz,1))\n",
    "        self.bn1 = nn.BatchNorm1d(scale*16)\n",
    "        \n",
    "        self.res1 = ResBlock(scale*16,scale*16)\n",
    "        self.res2 = ResBlock(scale*16,scale*8)\n",
    "        \n",
    "        self.res3 = ResBlock(scale*24,scale*16)\n",
    "        self.res4 = ResBlock(scale*16,scale*8)\n",
    "        \n",
    "        self.res5 = ResBlock(scale*32,scale*16)\n",
    "        self.res6 = ResBlock(scale*16,scale*8)\n",
    "        \n",
    "        self.conv1D = nn.Conv1d(scale*40,6,1)\n",
    "    \n",
    "    def forward(self, x, x_le, x_le_mix = None, lambd = None) -> torch.Tensor:\n",
    "        x_le = [e(x_le[:,:,i]) for i,e in enumerate(self.embeddings)]\n",
    "        x_le = torch.cat(x_le, 2)\n",
    "        \n",
    "        x = torch.cat([x, x_le], 2)\n",
    "        x = x.transpose(1,2)\n",
    "        \n",
    "        #x = torch.cat([x[:,:self.feat_sz],self.bn_cont(x[:,self.feat_sz:])], dim=1)\n",
    "        x = x.reshape(x.shape[0],1,x.shape[1],x.shape[2])\n",
    "        \n",
    "        x = self.conv2D(x).squeeze()\n",
    "        x = self.bn1(x)\n",
    "        x = act_fun(x)\n",
    "        \n",
    "        x2 = self.res1(x)\n",
    "        x2 = self.res2(x2)\n",
    "        \n",
    "        x3 = torch.cat([x[:,:,4:-4], x2], 1)\n",
    "        \n",
    "        x3 = self.res3(x3)\n",
    "        x3 = self.res4(x3)\n",
    "        \n",
    "        x4 = torch.cat([x[:,:,8:-8], x2[:,:,4:-4], x3], 1)\n",
    "        \n",
    "        x4 = self.res5(x4)\n",
    "        x4 = self.res6(x4)\n",
    "        \n",
    "        x5 = torch.cat([x[:,:,12:-12], x2[:,:,8:-8], x3[:,:,4:-4], x4], 1)\n",
    "        x5 = self.conv1D(x5)\n",
    "        x5 = x5.transpose(1,2)\n",
    "        \n",
    "        return x5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trunc_normal_(x, mean:float=0., std:float=1.):\n",
    "    \"Truncated normal initialization.\"\n",
    "    # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12\n",
    "    return x.normal_().fmod_(2).mul_(std).add_(mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embedding(ni:int,nf:int) -> nn.Module:\n",
    "    \"Create an embedding layer.\"\n",
    "    emb = nn.Embedding(ni, nf)\n",
    "    # See https://arxiv.org/abs/1711.09160\n",
    "    with torch.no_grad(): trunc_normal_(emb.weight, std=0.01)\n",
    "    return emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TabularModel(nn.Module):\n",
    "    \"Basic model for tabular data.\"\n",
    "    def __init__(self, n_cont:int, feat_sz=2208, fc_drop_p=0.3):\n",
    "        super().__init__()\n",
    "        self.le_sz = 10\n",
    "        assert len(cols_le) == self.le_sz\n",
    "        le_in_sizes = np.array([5,5,7,4,4,11,4,6,3,3])\n",
    "        le_out_sizes = np.array([3,3,4,2,2,6,2,4,2,2])\n",
    "        le_out_sz = le_out_sizes.sum()\n",
    "        self.embeddings = nn.ModuleList([embedding(le_in_sizes[i], le_out_sizes[i]) for i in range(self.le_sz)])\n",
    "        #self.bn_cont = nn.BatchNorm1d(feat_sz + n_cont)\n",
    "        \n",
    "        self.feat_sz = feat_sz\n",
    "        self.bn_cont = nn.BatchNorm1d(n_cont + le_out_sz)\n",
    "        self.n_cont = n_cont\n",
    "        self.fc_drop = nn.Dropout(p=fc_drop_p)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        \n",
    "        scale = 4\n",
    "        \n",
    "        self.conv2D_1 = nn.Conv2d(1,16*scale,(feat_sz + n_cont + le_out_sz,1))\n",
    "        self.conv2D_2 = nn.Conv2d(1,16*scale,(feat_sz + n_cont + le_out_sz,5))\n",
    "        self.bn_cont1 = nn.BatchNorm1d(32*scale)\n",
    "        self.conv1D_1 = nn.Conv1d(32*scale,16*scale,3)\n",
    "        self.conv1D_3 = nn.Conv1d(32*scale,16*scale,5,dilation=5)\n",
    "        self.conv1D_2 = nn.Conv1d(32*scale,6,3)\n",
    "        self.bn_cont2 = nn.BatchNorm1d(32*scale)\n",
    "        self.bn_cont3 = nn.BatchNorm1d(6)\n",
    "\n",
    "        self.conv1D_4 = nn.Conv1d(32*scale,32*scale,3)\n",
    "        self.bn_cont4 = nn.BatchNorm1d(32*scale)\n",
    "\n",
    "    def forward(self, x, x_le, x_le_mix = None, lambd = None) -> torch.Tensor:\n",
    "        x_le = [e(x_le[:,:,i]) for i,e in enumerate(self.embeddings)]\n",
    "        x_le = torch.cat(x_le, 2)\n",
    "        \n",
    "        if MIXUP and x_le_mix is not None:\n",
    "            x_le_mix = [e(x_le_mix[:,:,i]) for i,e in enumerate(self.embeddings)]\n",
    "            x_le_mix = torch.cat(x_le_mix, 2)\n",
    "            x_le = lambd * x_le + (1-lambd) * x_le_mix\n",
    "        \n",
    "        #assert torch.isnan(x_le).any().cpu() == False\n",
    "        x = torch.cat([x, x_le], 2)\n",
    "        x = x.transpose(1,2)\n",
    "        \n",
    "        #x = torch.cat([x[:,:self.feat_sz],self.bn_cont(x[:,self.feat_sz:])], dim=1)\n",
    "        \n",
    "        x = x.reshape(x.shape[0],1,x.shape[1],x.shape[2])\n",
    "        x = self.fc_drop(x)\n",
    "        x = torch.cat([self.conv2D_1(x[:,:,:,2:(-2)]).squeeze(), \n",
    "                       self.conv2D_2(x).squeeze()], dim=1)\n",
    "        x = self.relu(x)\n",
    "        x = self.bn_cont1(x)\n",
    "        x = self.fc_drop(x)\n",
    "        \n",
    "        x = torch.cat([self.conv1D_1(x[:,:,9:(-9)]),\n",
    "                       self.conv1D_3(x)], dim=1)\n",
    "        x = self.relu(x)\n",
    "        x = self.bn_cont2(x)\n",
    "        x = self.fc_drop(x)\n",
    "        \n",
    "        x = self.conv1D_4(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.bn_cont4(x)\n",
    "        \n",
    "        x = self.conv1D_2(x)\n",
    "        x = x.transpose(1,2)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_loop_fn(model, loader, device, context = None):\n",
    "    \n",
    "    if CLOUD and (not CLOUD_SINGLE):\n",
    "        tlen = len(loader._loader._loader)\n",
    "        OUT_LOSS = 1000\n",
    "        OUT_TIME = 4\n",
    "        generator = loader\n",
    "        device_num = int(str(device)[-1])\n",
    "        dataset = loader._loader._loader.dataset\n",
    "    else:\n",
    "        tlen = len(loader)\n",
    "        OUT_LOSS = 50\n",
    "        OUT_TIME = 50\n",
    "        generator = enumerate(loader)\n",
    "        device_num = 1\n",
    "        dataset = loader.dataset\n",
    "    \n",
    "    #print('Start training {}'.format(device), 'batches', tlen)\n",
    "    \n",
    "    criterion = BCEWithLogitsLoss(weight = torch.Tensor(class_weights).to(device))\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.99))\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    if CLOUD and TPU:\n",
    "        tracker = xm.RateTracker()\n",
    "\n",
    "    tloss = 0\n",
    "    tloss_count = 0\n",
    "    \n",
    "    st = time.time()\n",
    "    mixup_collected = False\n",
    "    x_le_mix = None\n",
    "    lambd = None\n",
    "    for i, (x, x_le, y, non_black) in generator:\n",
    "        if (not CLOUD) or CLOUD_SINGLE:\n",
    "            x = x.to(device)\n",
    "            x_le = x_le.to(device)\n",
    "            y = y.to(device)\n",
    "            non_black = non_black.to(device)\n",
    "        \n",
    "        if MIXUP:\n",
    "            if mixup_collected:\n",
    "                lambd = np.random.beta(0.4, 0.4, y.size(0))\n",
    "                lambd = torch.Tensor(lambd).to(device)[:,None,None]\n",
    "                #shuffle = torch.randperm(y.size(0)).to(device)\n",
    "                x = lambd * x + (1-lambd) * x_mix #x[shuffle]\n",
    "                #x_le = lambd * x_le + (1-lambd) * x_le_mix #x[shuffle]\n",
    "                mixup_collected = False\n",
    "            else:\n",
    "                x_mix = x\n",
    "                x_le_mix = x_le\n",
    "                y_mix = y\n",
    "                mixup_collected = True\n",
    "                continue\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        output = model(x, x_le, x_le_mix, lambd)\n",
    "        \n",
    "        if MIXUP:\n",
    "            if NO_BLACK_LOSS:\n",
    "                loss = criterion(output, y, lambd*non_black[:,:,None]) \\\n",
    "                     + criterion(output, y_mix, (1-lambd)*non_black[:,:,None])\n",
    "            else:\n",
    "                loss = criterion(output, y, lambd) + criterion(output, y_mix, 1-lambd)\n",
    "            del x_mix, y_mix\n",
    "        else:\n",
    "            if NO_BLACK_LOSS:\n",
    "                loss = criterion(output, y, non_black[:,:,None])\n",
    "            else:\n",
    "                loss = criterion(output, y)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        tloss += len(y)*loss.cpu().detach().item()\n",
    "        tloss_count += len(y)\n",
    "        \n",
    "        if (CLOUD or CLOUD_SINGLE) and TPU:\n",
    "            xm.optimizer_step(optimizer)\n",
    "            if CLOUD_SINGLE:\n",
    "                xm.mark_step()\n",
    "        else:\n",
    "            optimizer.step()\n",
    "        \n",
    "        if CLOUD and TPU:\n",
    "            tracker.add(len(y))\n",
    "        \n",
    "        st_passed = time.time() - st\n",
    "        if (i+1)%OUT_TIME == 0 and device_num == 1:\n",
    "            #print(torch_xla._XLAC._xla_metrics_report())\n",
    "            print('Batch {} device: {} time passed: {:.3f} time per batch: {:.3f}'\n",
    "                .format(i+1, device, st_passed, st_passed/(i+1)))\n",
    "        \n",
    "        del loss, output, y, x, x_le\n",
    "    \n",
    "    return tloss, tloss_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def val_loop_fn(model, loader, device, context = None):\n",
    "    \n",
    "    if CLOUD and (not CLOUD_SINGLE):\n",
    "        tlen = len(loader._loader._loader)\n",
    "        OUT_LOSS = 1000\n",
    "        OUT_TIME = 4\n",
    "        generator = loader\n",
    "        device_num = int(str(device)[-1])\n",
    "    else:\n",
    "        tlen = len(loader)\n",
    "        OUT_LOSS = 1000\n",
    "        OUT_TIME = 50\n",
    "        generator = enumerate(loader)\n",
    "        device_num = 1\n",
    "    \n",
    "    #print('Start validating {}'.format(device), 'batches', tlen)\n",
    "    \n",
    "    st = time.time()\n",
    "    model.eval()\n",
    "    \n",
    "    results = []\n",
    "    indices = []\n",
    "    offsets = []\n",
    "    \n",
    "    for i, (x, x_le, y, idx, offset) in generator:\n",
    "        \n",
    "        if (not CLOUD) or CLOUD_SINGLE:\n",
    "            x = x.to(device)\n",
    "            x_le = x_le.to(device)\n",
    "        \n",
    "        output = model(x, x_le)\n",
    "        assert torch.isnan(output).any().cpu() == False\n",
    "        output = torch.sigmoid(output)\n",
    "        assert torch.isnan(output).any().cpu() == False\n",
    "        \n",
    "        mask = (idx >= 0)\n",
    "        results.append(output[mask].cpu().detach().numpy())\n",
    "        indices.append(idx[mask].cpu().detach().numpy())\n",
    "        offsets.append(offset[mask].cpu().detach().numpy())\n",
    "        \n",
    "        st_passed = time.time() - st\n",
    "        if (i+1)%OUT_TIME == 0 and device_num == 1:\n",
    "            print('Batch {} device: {} time passed: {:.3f} time per batch: {:.3f}'\n",
    "                  .format(i+1, device, st_passed, st_passed/(i+1)))\n",
    "        \n",
    "        del output, y, x, x_le, idx, offset\n",
    "    \n",
    "    results = np.concatenate(results)\n",
    "    indices = np.concatenate(indices)\n",
    "    offsets = np.concatenate(offsets)\n",
    "    \n",
    "    return results, indices, offsets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def test_loop_fn(model, loader, device, context = None):\n",
    "    \n",
    "    if CLOUD and (not CLOUD_SINGLE):\n",
    "        tlen = len(loader._loader._loader)\n",
    "        OUT_LOSS = 1000\n",
    "        OUT_TIME = 100\n",
    "        generator = loader\n",
    "        device_num = int(str(device)[-1])\n",
    "    else:\n",
    "        tlen = len(loader)\n",
    "        OUT_LOSS = 1000\n",
    "        OUT_TIME = 10\n",
    "        generator = enumerate(loader)\n",
    "        device_num = 1\n",
    "    \n",
    "    #print('Start testing {}'.format(device), 'batches', tlen)\n",
    "    \n",
    "    st = time.time()\n",
    "    model.eval()\n",
    "    \n",
    "    results = []\n",
    "    indices = []\n",
    "    offsets = []\n",
    "    \n",
    "    for i, (x, x_le, y, idx, offset) in generator:\n",
    "        \n",
    "        if (not CLOUD) or CLOUD_SINGLE:\n",
    "            x = x.to(device)\n",
    "            x_le = x_le.to(device)\n",
    "        \n",
    "        output = torch.sigmoid(model(x, x_le))\n",
    "        \n",
    "        mask = (idx >= 0)\n",
    "        results.append(output[mask].cpu().detach().numpy())\n",
    "        indices.append(idx[mask].cpu().detach().numpy())\n",
    "        offsets.append(offset[mask].cpu().detach().numpy())\n",
    "        \n",
    "        st_passed = time.time() - st\n",
    "        if (i+1)%OUT_TIME == 0 and device_num == 1:\n",
    "            print('B{} -> time passed: {:.3f} time per batch: {:.3f}'.format(i+1, st_passed, st_passed/(i+1)))\n",
    "        \n",
    "        del output, x, y, idx, offset\n",
    "    \n",
    "    return np.concatenate(results), np.concatenate(indices), np.concatenate(offsets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_one(dataset, weight=None, load_model=True, epochs=1, bs=100, fold=0, init_ver=None):\n",
    "    \n",
    "    st0 = time.time()\n",
    "    dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
    "    \n",
    "    cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n",
    "    if cur_epoch is None: cur_epoch = 0\n",
    "    print('completed epochs:', cur_epoch, 'starting now:', epochs)\n",
    "    \n",
    "    setSeeds(SEED + cur_epoch)\n",
    "    \n",
    "    assert train_md.weights.isnull().sum() == 0\n",
    "    \n",
    "    if dataset >= 6:\n",
    "        trn_ds = RSNA_DataSet(train_md, mode='train', bs=bs, fold=fold, dataset=dataset)\n",
    "        val_ds = RSNA_DataSet(train_md, mode='valid', bs=bs, fold=fold, dataset=dataset)\n",
    "    else: assert False\n",
    "    val_ds.setFeats(0)\n",
    "    \n",
    "    if WEIGHTED:\n",
    "        print('WeightedRandomSampler')\n",
    "        sampler = D.sampler.WeightedRandomSampler(trn_ds.weights, len(trn_ds))\n",
    "        loader = D.DataLoader(trn_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
    "                              shuffle=False, drop_last=True, sampler=sampler)\n",
    "    else:\n",
    "        loader = D.DataLoader(trn_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
    "                              shuffle=True, drop_last=True)\n",
    "    loader_val = D.DataLoader(val_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
    "                              shuffle=True)\n",
    "    print('dataset train:', len(trn_ds), 'valid:', len(val_ds), 'loader train:', len(loader), 'valid:', len(loader_val))\n",
    "    \n",
    "    #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz, fc_drop_p=0)\n",
    "    model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
    "    \n",
    "    model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n",
    "    \n",
    "    if (model_file_name is None) and (init_ver is not None):\n",
    "        model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset, ver=init_ver)\n",
    "    \n",
    "    if model_file_name is not None:\n",
    "        print('loading model', model_file_name)\n",
    "        state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n",
    "        model.load_state_dict(state_dict)\n",
    "    else:\n",
    "        print('starting from scratch')\n",
    "    \n",
    "    if (not CLOUD) or CLOUD_SINGLE:\n",
    "        model = model.to(device)\n",
    "    else:\n",
    "        model_parallel = dp.DataParallel(model, device_ids=devices)\n",
    "    \n",
    "    loc_data = val_ds.metadata.copy()\n",
    "    \n",
    "    if DATA_SMALL:\n",
    "        val_sz = len(val_ds)\n",
    "        val_series = val_ds.series[:val_sz]\n",
    "        loc_data = loc_data.loc[loc_data.index.isin(val_series)]\n",
    "    \n",
    "    series_counts = loc_data.index.value_counts()\n",
    "    \n",
    "    loc_data['orig_idx'] = np.arange(len(loc_data))\n",
    "    loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "    loc_data['my_order'] = np.arange(len(loc_data))\n",
    "    loc_data = loc_data.sort_values(['orig_idx'])\n",
    "    \n",
    "    ww_val = loc_data.weights\n",
    "    \n",
    "    for i in range(cur_epoch+1, cur_epoch+epochs+1):\n",
    "        st = time.time()\n",
    "        \n",
    "        #trn_ds.setFeats((i-1) % 4)\n",
    "        trn_ds.setFeats(-1, epoch=i)\n",
    "        \n",
    "        if CLOUD and (not CLOUD_SINGLE):\n",
    "            results = model_parallel(train_loop_fn, loader)\n",
    "            tloss, tloss_count = np.stack(results).sum(0)\n",
    "            state_dict = model_parallel._models[0].state_dict()\n",
    "        else:\n",
    "            tloss, tloss_count = train_loop_fn(model, loader, device)\n",
    "            state_dict = model.state_dict()\n",
    "        \n",
    "        state_dict = {k:v.to('cpu') for k,v in state_dict.items()}\n",
    "        tr_ll = tloss / tloss_count\n",
    "        \n",
    "        train_time = time.time()-st\n",
    "        \n",
    "        model_file_name = modelFileName(return_next=True, fold=fold, dataset=dataset)\n",
    "        if not DATA_SMALL:\n",
    "            torch.save(state_dict, PATH_DISK/'models'/model_file_name)\n",
    "        \n",
    "        st = time.time()\n",
    "        if CLOUD and (not CLOUD_SINGLE):\n",
    "            results = model_parallel(val_loop_fn, loader_val)\n",
    "            predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n",
    "            indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n",
    "            offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n",
    "        else:\n",
    "            predictions, indices, offsets = val_loop_fn(model, loader_val, device)\n",
    "        \n",
    "        predictions = predictions[np.argsort(indices)]\n",
    "        offsets = offsets[np.argsort(indices)]\n",
    "        assert len(predictions) == len(loc_data.index.unique())\n",
    "        assert len(predictions) == len(offsets)\n",
    "        assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n",
    "        \n",
    "        #val_results = np.zeros((len(loc_data),6))\n",
    "        val_results = []\n",
    "        for k, series in enumerate(np.sort(loc_data.index.unique())):\n",
    "            cnt = series_counts[series]\n",
    "            #mask = loc_data.SeriesInstanceUID == series\n",
    "            assert (offsets[k] + cnt) <= 60\n",
    "            #val_results[mask] = predictions[k,offsets[k]:(offsets[k] + cnt)]\n",
    "            val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n",
    "        \n",
    "        val_results = np.concatenate(val_results)\n",
    "        assert np.isnan(val_results).sum() == 0\n",
    "        val_results = val_results[loc_data.my_order]\n",
    "        assert np.isnan(val_results).sum() == 0\n",
    "        assert len(val_results) == len(loc_data)\n",
    "        \n",
    "        lls = [log_loss(loc_data.loc[~loc_data.test, all_ich[k]].values, val_results[~loc_data.test,k], \\\n",
    "                        eps=1e-7, labels=[0,1]) for k in range(6)]\n",
    "        ll = (class_weights * np.array(lls)).mean()\n",
    "        cor = np.corrcoef(loc_data.loc[~loc_data.test,all_ich].values.reshape(-1), \\\n",
    "                          val_results[~loc_data.test].reshape(-1))[0,1]\n",
    "        auc = roc_auc_score(loc_data.loc[~loc_data.test, all_ich].values.reshape(-1), \\\n",
    "                            val_results[~loc_data.test].reshape(-1))\n",
    "        \n",
    "        lls_w = [log_loss(loc_data.loc[~loc_data.test, all_ich[k]].values, val_results[~loc_data.test,k], eps=1e-7, \n",
    "                          labels=[0,1], sample_weight=ww_val[~loc_data.test]) for k in range(6)]\n",
    "        ll_w = (class_weights * np.array(lls_w)).mean()\n",
    "        \n",
    "        if TRAIN_ON_STAGE_1:\n",
    "            lls = [log_loss(loc_data.loc[loc_data.test, all_ich[k]].values, val_results[loc_data.test,k], \\\n",
    "                            eps=1e-7, labels=[0,1]) for k in range(6)]\n",
    "            ll2 = (class_weights * np.array(lls)).mean()\n",
    "\n",
    "            lls_w = [log_loss(loc_data.loc[loc_data.test, all_ich[k]].values, val_results[loc_data.test,k], eps=1e-7, \n",
    "                              labels=[0,1], sample_weight=ww_val[loc_data.test]) for k in range(6)]\n",
    "            ll2_w = (class_weights * np.array(lls_w)).mean()\n",
    "        else:\n",
    "            ll2 = 0\n",
    "            ll2_w = 0\n",
    "        \n",
    "        print('v{}, d{}, e{}, f{}, trn ll: {:.4f}, val ll: {:.4f}, ll_w: {:.4f}, cor: {:.4f}, auc: {:.4f}, lr: {}'\\\n",
    "              .format(VERSION, dataset, i, fold, tr_ll, ll, ll_w, cor, auc, learning_rate))\n",
    "        valid_time = time.time()-st\n",
    "        \n",
    "        epoch_stats = pd.DataFrame([[VERSION, dataset, i, fold, tr_ll, ll, ll_w, ll2, ll2_w, cor, \n",
    "                                     lls[0], lls[1], lls[2], lls[3], \n",
    "                                     lls[4], lls[5], len(trn_ds), len(val_ds), bs, train_time, valid_time,\n",
    "                                     learning_rate, weight_decay]],\n",
    "                                   columns = \n",
    "                                    ['ver','dataset','epoch','fold','train_loss','val_loss','val_w_loss',\n",
    "                                     'val_loss2','val_w_loss2','cor',\n",
    "                                     'any','epidural','intraparenchymal','intraventricular','subarachnoid','subdural',\n",
    "                                     'train_sz','val_sz','bs','train_time','valid_time','lr','wd'\n",
    "                                     ])\n",
    "\n",
    "        stats_filename = PATH_WORK/'stats.f{}.v{}'.format(fold,VERSION)\n",
    "        if stats_filename.is_file():\n",
    "            epoch_stats = pd.concat([pd.read_csv(stats_filename), epoch_stats], sort=False)\n",
    "        #if not DATA_SMALL:\n",
    "        epoch_stats.to_csv(stats_filename, index=False)\n",
    "    \n",
    "    print('total running time', time.time() - st0)\n",
    "    \n",
    "    return model, predictions, val_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OOF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "def oof_one(dataset, num_iter=1, bs=100, fold=0):\n",
    "    \n",
    "    st0 = time.time()\n",
    "    dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
    "    \n",
    "    cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n",
    "    if cur_epoch is None: cur_epoch = 0\n",
    "    print('completed epochs:', cur_epoch, 'iters starting now:', num_iter)\n",
    "    \n",
    "    setSeeds(SEED + cur_epoch)\n",
    "    \n",
    "    val_ds = RSNA_DataSet(train_md, mode='valid', bs=bs, fold=fold, dataset=dataset)\n",
    "    \n",
    "    loader_val = D.DataLoader(val_ds, num_workers=NUM_WORKERS, batch_size=bs, \n",
    "                              shuffle=True)\n",
    "    print('dataset valid:', len(val_ds), 'loader valid:', len(loader_val))\n",
    "    \n",
    "    #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz, fc_drop_p=0)\n",
    "    model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
    "    \n",
    "    model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n",
    "    if model_file_name is not None:\n",
    "        print('loading model', model_file_name)\n",
    "        state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n",
    "        model.load_state_dict(state_dict)\n",
    "    else:\n",
    "        print('starting from scratch')\n",
    "    \n",
    "    if (not CLOUD) or CLOUD_SINGLE:\n",
    "        model = model.to(device)\n",
    "    else:\n",
    "        model_parallel = dp.DataParallel(model, device_ids=devices)\n",
    "    \n",
    "    loc_data = val_ds.metadata.copy()\n",
    "    series_counts = loc_data.index.value_counts()\n",
    "    \n",
    "    loc_data['orig_idx'] = np.arange(len(loc_data))\n",
    "    loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "    loc_data['my_order'] = np.arange(len(loc_data))\n",
    "    loc_data = loc_data.sort_values(['orig_idx'])\n",
    "    \n",
    "    preds = []\n",
    "    \n",
    "    for i in range(num_iter):\n",
    "        \n",
    "        val_ds.setFeats(-1, i+100)\n",
    "        \n",
    "        if CLOUD and (not CLOUD_SINGLE):\n",
    "            results = model_parallel(val_loop_fn, loader_val)\n",
    "            predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n",
    "            indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n",
    "            offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n",
    "        else:\n",
    "            predictions, indices, offsets = val_loop_fn(model, loader_val, device)\n",
    "\n",
    "        predictions = predictions[np.argsort(indices)]\n",
    "        offsets = offsets[np.argsort(indices)]\n",
    "        assert len(predictions) == len(loc_data.index.unique())\n",
    "        assert len(predictions) == len(offsets)\n",
    "        assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n",
    "\n",
    "        val_results = []\n",
    "        for k, series in enumerate(np.sort(loc_data.index.unique())):\n",
    "            cnt = series_counts[series]\n",
    "            assert (offsets[k] + cnt) <= 60\n",
    "            val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n",
    "\n",
    "        val_results = np.concatenate(val_results)\n",
    "        assert np.isnan(val_results).sum() == 0\n",
    "        val_results = val_results[loc_data.my_order]\n",
    "        assert np.isnan(val_results).sum() == 0\n",
    "        assert len(val_results) == len(loc_data)\n",
    "        \n",
    "        preds.append(val_results)\n",
    "\n",
    "        lls = [log_loss(loc_data[all_ich[k]].values, val_results[:,k], eps=1e-7, labels=[0,1])\\\n",
    "               for k in range(6)]\n",
    "        ll = (class_weights * np.array(lls)).mean()\n",
    "        cor = np.corrcoef(loc_data.loc[:,all_ich].values.reshape(-1), val_results.reshape(-1))[0,1]\n",
    "        auc = roc_auc_score(loc_data.loc[:,all_ich].values.reshape(-1), val_results.reshape(-1))\n",
    "\n",
    "        print('ver {}, iter {}, fold {}, val ll: {:.4f}, cor: {:.4f}, auc: {:.4f}'\n",
    "              .format(VERSION, i, fold, ll, cor, auc))\n",
    "    \n",
    "    print('total running time', time.time() - st0)\n",
    "    \n",
    "    return np.stack(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predBounding(pp, target=None):\n",
    "    if target is not None:\n",
    "        ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n",
    "            * class_weights).mean()\n",
    "        print('initial score', ll)\n",
    "    \n",
    "    print('any too low inconsistencies')\n",
    "    for i in range(1,6):\n",
    "        print(i, 'class:', (pp[...,0] < pp[...,i]).mean())\n",
    "    print('total', (pp[...,0] < pp[...,1:].max(-1)).mean())\n",
    "    \n",
    "    max_vals = pp[...,1:].max(-1)\n",
    "    mask = pp[...,0] < max_vals\n",
    "    pp[mask,0] = max_vals[mask]\n",
    "    #mask_vals = 0.5*(preds_all[:,:,:,0] + max_vals)[mask]\n",
    "    #preds_all[mask,0] = mask_vals\n",
    "    #preds_all[mask] = np.clip(preds_all[mask],0,np.expand_dims(mask_vals,1))\n",
    "\n",
    "    assert (pp[...,0] < pp[...,1:].max(-1)).sum() == 0\n",
    "\n",
    "    if target is not None:\n",
    "        ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n",
    "            * class_weights).mean()\n",
    "        print('any too low corrected score', ll)\n",
    "    \n",
    "    print('any too high inconsistencies')\n",
    "    mask = pp[...,0] > pp[...,1:].sum(-1)\n",
    "    print('total', mask.mean())\n",
    "\n",
    "    mask_val = 0.5*(pp[mask,0] + pp[...,1:].sum(-1)[mask])\n",
    "    scaler = mask_val / pp[...,1:].sum(-1)[mask]\n",
    "    pp[mask,1:] = pp[mask,1:] * np.expand_dims(scaler,1)\n",
    "    pp[mask,0] = mask_val\n",
    "\n",
    "    if target is not None:\n",
    "        ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n",
    "            * class_weights).mean()\n",
    "        print('any too high corrected score', ll)\n",
    "    \n",
    "    return pp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selecting runs aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scalePreds(x, power = 2, center=0.5):\n",
    "    res = x.copy()\n",
    "    res[x > center] = center + (1 - center) * ((res[x > center] - center)/(1 - center))**power\n",
    "    res[x < center] = center - center * ((center - res[x < center])/center)**power\n",
    "    return res\n",
    "\n",
    "def getPredsOOF(ver,aug=32,datasets=range(6,11),datasets5=range(11,13)):\n",
    "    preds_all = np.zeros((len(datasets) + len(datasets5),aug,len(train_md),6))\n",
    "    \n",
    "    if len(datasets) > 0:\n",
    "        for fold in range(3):\n",
    "            preds = np.stack([pickle.load(open(PATH_DISK/'ensemble/oof_d{}_f{}_v{}'\n",
    "                                               .format(ds, fold, ver),'rb')) for ds in datasets])\n",
    "            preds_all[:len(datasets),:,train_md.fold == fold,:] = preds\n",
    "    \n",
    "    if len(datasets5) > 0:\n",
    "        for fold in range(5):\n",
    "            preds = np.stack([pickle.load(open(PATH_DISK/'ensemble/oof_d{}_f{}_v{}'\n",
    "                                               .format(ds, fold, ver),'rb')) for ds in datasets5])\n",
    "            preds_all[len(datasets):,:,train_md.fold5 == fold,:] = preds\n",
    "    \n",
    "    preds_all = np.clip(preds_all, 1e-15, 1-1e-15)\n",
    "    return preds_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getYuvalOOF(names, train_md, names5=None, aug=32):\n",
    "    \n",
    "    for num_folds in [3,5]:\n",
    "        if ('yuval_idx' not in train_md.columns) or ('yuval_idx5' not in train_md.columns):\n",
    "            print('adding yuval_idx')\n",
    "            fn = 'OOF_validation_image_ids'\n",
    "            if num_folds == 5:\n",
    "                fn += '_5_stage2'\n",
    "            if num_folds == 3:\n",
    "                fn += '_3_stage2'\n",
    "            yuval_oof = pickle.load(open(PATH_DISK/'yuval/OOF_stage2/{}.pkl'.format(fn),'rb'))\n",
    "            \n",
    "            df2 = pd.DataFrame()\n",
    "            col_name = 'yuval_idx'\n",
    "            fold_col = 'fold'\n",
    "            if num_folds == 5:\n",
    "                col_name += '5'\n",
    "                fold_col += '5'\n",
    "            \n",
    "            for fold in range(num_folds):\n",
    "                \n",
    "                assert np.all(train_md.loc[train_md.img_id.isin(yuval_oof[fold])][fold_col].values == fold)\n",
    "                \n",
    "                df = pd.DataFrame(np.arange(len(yuval_oof[fold])), columns=[col_name])\n",
    "                df.index = yuval_oof[fold]\n",
    "                df2 = pd.concat([df2,df],sort=False)\n",
    "            \n",
    "            train_md = train_md.join(df2, on = 'img_id')\n",
    "            assert train_md[col_name].isnull().sum() == 0\n",
    "    \n",
    "    preds_y = np.zeros((len(names),len(train_md),6))\n",
    "    \n",
    "    for fold in range(3):\n",
    "        preds = np.stack([torch.sigmoid(torch.stack(\n",
    "            pickle.load(open(PATH_DISK/'yuval/OOF_stage2'/name.format(fold),'rb')))).numpy() for name in names])\n",
    "        preds = preds[:,:,train_md.loc[train_md.fold == fold, 'yuval_idx']]\n",
    "        \n",
    "        preds_y[:,train_md.fold == fold] = preds.mean(1)\n",
    "        \n",
    "    if names5 is not None:\n",
    "        preds_y5 = np.zeros((len(names5),len(train_md),6))\n",
    "        for fold in range(5):\n",
    "            preds = np.stack([torch.sigmoid(torch.stack(\n",
    "                pickle.load(open(PATH_DISK/'yuval/OOF_stage2'/name.format(fold),'rb')))).numpy() for name in names5])\n",
    "            preds = preds[:,:,train_md.loc[train_md.fold5 == fold, 'yuval_idx5']]\n",
    "\n",
    "            preds_y5[:,train_md.fold5 == fold] = preds.mean(1)\n",
    "        \n",
    "        preds_y = np.concatenate([preds_y, preds_y5], axis=0)\n",
    "    \n",
    "    \n",
    "    preds_y = np.clip(preds_y, 1e-15, 1-1e-15)[:,:,np.array([5,0,1,2,3,4])]\n",
    "    return preds_y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_one(dataset, bs = 100, add_seed = 0, fold = 0, anum = 0):\n",
    "    \n",
    "    st = time.time()\n",
    "    dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n",
    "\n",
    "    cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n",
    "    if cur_epoch is None: cur_epoch = 0\n",
    "    print('completed epochs:', cur_epoch)\n",
    "\n",
    "    #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
    "    model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n",
    "    \n",
    "    model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n",
    "    if model_file_name is not None:\n",
    "        print('loading model', model_file_name)\n",
    "        state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n",
    "        model.load_state_dict(state_dict)\n",
    "    \n",
    "    if (not CLOUD) or CLOUD_SINGLE:\n",
    "        model = model.to(device)\n",
    "    else:\n",
    "        model_parallel = dp.DataParallel(model, device_ids=devices)\n",
    "\n",
    "    setSeeds(SEED + cur_epoch + anum + 100*fold)\n",
    "\n",
    "    tst_ds = RSNA_DataSet(test_md, mode='test', bs=bs, fold=fold, dataset=dataset)\n",
    "    loader_tst = D.DataLoader(tst_ds, num_workers=NUM_WORKERS, batch_size=bs, shuffle=False)\n",
    "    print('dataset test:', len(tst_ds), 'loader test:', len(loader_tst), 'anum:', anum)\n",
    "    \n",
    "    tst_ds.setFeats(-1, epoch = anum+100)\n",
    "\n",
    "    loc_data = tst_ds.metadata.copy()\n",
    "    series_counts = loc_data.index.value_counts()\n",
    "\n",
    "    loc_data['orig_idx'] = np.arange(len(loc_data))\n",
    "    loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n",
    "    loc_data['my_order'] = np.arange(len(loc_data))\n",
    "    loc_data = loc_data.sort_values(['orig_idx'])\n",
    "    \n",
    "    if CLOUD and (not CLOUD_SINGLE):\n",
    "        results = model_parallel(test_loop_fn, loader_tst)\n",
    "        predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n",
    "        indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n",
    "        offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n",
    "    else:\n",
    "        predictions, indices, offsets = test_loop_fn(model, loader_tst, device)\n",
    "\n",
    "    predictions = predictions[np.argsort(indices)]\n",
    "    offsets = offsets[np.argsort(indices)]\n",
    "    assert len(predictions) == len(test_md.SeriesInstanceUID.unique())\n",
    "    assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n",
    "    \n",
    "    val_results = []\n",
    "    for k, series in enumerate(np.sort(loc_data.index.unique())):\n",
    "        cnt = series_counts[series]\n",
    "        assert (offsets[k] + cnt) <= 60\n",
    "        val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n",
    "\n",
    "    val_results = np.concatenate(val_results)\n",
    "    assert np.isnan(val_results).sum() == 0\n",
    "    val_results = val_results[loc_data.my_order]\n",
    "    assert len(val_results) == len(loc_data)\n",
    "\n",
    "    print('test processing time:', time.time() - st)\n",
    "    \n",
    "    return val_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ensembling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def getStepX(train_md, preds_all, fold=0, target=0, mode='train'):\n",
    "    \n",
    "    print('my_len', my_len)\n",
    "    X = np.stack([preds_all[:my_len,:,target].mean(0), \n",
    "                  preds_all[my_len:,:,target].mean(0)], axis=0)\n",
    "    \n",
    "    if mode == 'train':\n",
    "        X = X[:,train_md.fold != fold]\n",
    "        y = train_md.loc[train_md.fold != fold, all_ich[target]].values\n",
    "        ww = train_md.loc[train_md.fold != fold, 'weights'].values\n",
    "    elif mode == 'valid':\n",
    "        X = X[:,train_md.fold == fold]\n",
    "        y = train_md.loc[train_md.fold == fold, all_ich[target]].values\n",
    "        ww = train_md.loc[train_md.fold == fold, 'weights'].values\n",
    "    else:\n",
    "        X = X\n",
    "        y = None\n",
    "    \n",
    "    ll = None\n",
    "    llw = None\n",
    "    auc = None\n",
    "    if y is not None:\n",
    "        ll = log_loss(y, X.mean(0), eps=1e-7, labels=[0,1])\n",
    "        llw = log_loss(y, X.mean(0), eps=1e-7, labels=[0,1], sample_weight=ww)\n",
    "        auc = roc_auc_score(y, X.mean(0))\n",
    "    \n",
    "    return X, y, ww, ll, llw, auc\n",
    "\n",
    "\n",
    "def train_ensemble(train_md, preds_all, fold = 0, target = 0, weighted = False):\n",
    "    \n",
    "    print('starting fold',fold,'target',target)\n",
    "    \n",
    "    st = time.time()\n",
    "    \n",
    "    limit_low = 1e-15\n",
    "    limit_high = 1 - 1e-5\n",
    "    \n",
    "    prior = train_md.loc[train_md.fold != fold, all_ich[target]].mean()\n",
    "    \n",
    "    def my_objective(x,preds,vals,ww):\n",
    "        preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n",
    "        res = np.average(- vals * np.log(preds_sum) - (1 - vals) * np.log(1 - preds_sum), weights=ww)\n",
    "        #print('x   ',x, x.sum())\n",
    "        print('obj ',res)\n",
    "        return res\n",
    "\n",
    "    def my_grad(x,preds,vals,ww):\n",
    "        preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n",
    "        res = np.average(- vals * preds / preds_sum + (1 - vals) * preds / (1 - preds_sum), weights=ww, axis=1)\n",
    "        #print('grad',res)\n",
    "        return res\n",
    "\n",
    "    def my_hess(x,preds,vals,ww):\n",
    "        preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n",
    "        res = np.average(preds * np.expand_dims(preds, axis=1) * \n",
    "                         (vals / preds_sum**2 + (1 - vals) / (1 - preds_sum)**2), weights=ww, axis=2)\n",
    "        return res\n",
    "    \n",
    "    X,y,ww,ll_train,llw_train,auc_train =  getStepX(train_md, preds_all, fold=fold, target=target)\n",
    "    \n",
    "    bnds_low = np.zeros(X.shape[0])\n",
    "    bnds_high = np.ones(X.shape[0])\n",
    "    \n",
    "    initial_sol = np.ones(X.shape[0])/X.shape[0]\n",
    "    \n",
    "    bnds = sp.optimize.Bounds(bnds_low, bnds_high)\n",
    "    cons = sp.optimize.LinearConstraint(np.ones((1,X.shape[0])), 0.98, 1.00)\n",
    "    \n",
    "    if weighted:\n",
    "        ww_in = ww\n",
    "    else:\n",
    "        ww_in = np.ones(len(y))\n",
    "    \n",
    "    model = sp.optimize.minimize(my_objective, initial_sol, jac=my_grad, hess=my_hess, args=(X, y, ww_in),\n",
    "                                 bounds=bnds, method='trust-constr', constraints=cons,\n",
    "                                 options={'gtol': 1e-11, 'initial_tr_radius': 0.1, 'initial_barrier_parameter': 0.01})\n",
    "    \n",
    "    pickle.dump(model, open(PATH_DISK/'ensemble'/'model.f{}.t{}.v{}'\n",
    "                            .format(fold,target,VERSION),'wb'))\n",
    "\n",
    "    print('model', model.x, 'sum', model.x.sum())\n",
    "    \n",
    "    train_preds = (X*np.expand_dims(model.x, axis=1)).sum(0)\n",
    "    ll_train2 = log_loss(y, train_preds, eps=1e-7, labels=[0,1])\n",
    "    llw_train2 = log_loss(y, train_preds, eps=1e-7, labels=[0,1], sample_weight=ww)\n",
    "    auc_train2 = roc_auc_score(y, train_preds)\n",
    "    \n",
    "    X,y,ww,ll_val,llw_val,auc_val =  getStepX(train_md, preds_all, fold=fold, target=target, mode='valid')\n",
    "    \n",
    "    val_preds = (X*np.expand_dims(model.x, axis=1)).sum(0)\n",
    "    \n",
    "    ll_val2 = log_loss(y, val_preds, eps=1e-7, labels=[0,1])\n",
    "    llw_val2 = log_loss(y, val_preds, eps=1e-7, labels=[0,1], sample_weight=ww)\n",
    "    auc_val2 = roc_auc_score(y, val_preds)\n",
    "    \n",
    "    print('v{} f{} t{}: original ll {:.4f}/{:.4f}, ensemble ll {:.4f}/{:.4f}'\n",
    "          .format(VERSION,fold,target,ll_val,llw_val,ll_val2,llw_val2))\n",
    "    \n",
    "    run_time = time.time() - st\n",
    "    print('running time', run_time)\n",
    "    \n",
    "    stats = pd.DataFrame([[VERSION,fold,target,\n",
    "                           ll_train,auc_train,ll_train2,auc_train2,\n",
    "                           ll_val,auc_val,ll_val2,auc_val2,\n",
    "                           llw_train,llw_train2,llw_val,llw_val2,\n",
    "                           run_time,weighted]],\n",
    "                           columns = \n",
    "                            ['version','fold','target',\n",
    "                             'train_loss','train_auc','train_loss_ens','train_auc_ens', \n",
    "                             'valid_loss','valid_auc','valid_loss_ens','valid_auc_ens',\n",
    "                             'train_w_loss','train_w_loss_ens','valid_w_loss','valid_w_loss_ens',\n",
    "                             'run_time','weighted'\n",
    "                             ])\n",
    "    \n",
    "    stats_filename = PATH_DISK/'ensemble'/'stats.v{}'.format(VERSION)\n",
    "    if stats_filename.is_file():\n",
    "        stats = pd.concat([pd.read_csv(stats_filename), stats], sort=False)\n",
    "    stats.to_csv(stats_filename, index=False)\n",
    "\n",
    "#model.cols = Xt.columns\n",
    "#predictions[data_filt['fold'] == i] = (Xv*model.x).sum(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}