--- a +++ b/notebooks/Code.ipynb @@ -0,0 +1,1966 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "#conda activate torch-xla-nightly\n", + "#export XRT_TPU_CONFIG=\"tpu_worker;0;$10.0.101.2:8470\"\n", + "#git init\n", + "#git remote add origin https://github.com/nosound2/RSNA-Hemorrhage\n", + "#git config remote.origin.push HEAD\n", + "#git config credential.helper store\n", + "#git pull origin master\n", + "#git config --global branch.master.remote origin\n", + "#git config --global branch.master.merge refs/heads/master\n", + "#gcloud config set compute/zone europe-west4-a\n", + "#gcloud config set compute/zone us-central1-c\n", + "#gcloud auth login\n", + "#gcloud config set project endless-empire-239015\n", + "#pip install kaggle --user\n", + "#PATH=$PATH:~/.local/bin\n", + "#mkdir .kaggle\n", + "#gsutil cp gs://recursion-double-strand/kaggle-keys/kaggle.json ~/.kaggle\n", + "#chmod 600 /home/zahar_chikishev/.kaggle/kaggle.json\n", + "#kaggle competitions download rsna-intracranial-hemorrhage-detection -f stage_2_train.csv\n", + "#sudo apt install unzip\n", + "#unzip stage_1_train.csv.zip\n", + "#kaggle kernels output xhlulu/rsna-generate-metadata-csvs -p .\n", + "#kaggle kernels output zaharch/dicom-test-metadata-to-csv -p .\n", + "#kaggle kernels output zaharch/dicom-metadata-to-csv -p .\n", + "#gsutil -m cp gs://rsna-hemorrhage/yuvals/* .\n", + "\n", + "#export XRT_TPU_CONFIG=\"tpu_worker;0;10.0.101.2:8470\"; conda activate torch-xla-nightly; jupyter notebook\n", + "\n", + "# 35.188.114.109\n", + "\n", + "# lsblk\n", + "# sudo mkfs.ext4 -m 0 -F -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb\n", + "# sudo mkdir /mnt/edisk\n", + "# sudo mount -o discard,defaults /dev/sdb /mnt/edisk\n", + "# sudo chmod a+w /mnt/edisk\n", + "# df -H\n", + "# echo UUID=`sudo blkid -s UUID -o value /dev/sdb` /mnt/edisk ext4 discard,defaults,nofail 0 2 | sudo tee -a /etc/fstab\n", + "# cat /etc/fstab\n", + "\n", + "#jupyter notebook --generate-config\n", + "#vi ~/.jupyter/jupyter_notebook_config.py\n", + "\n", + "#c = get_config()\n", + "#c.NotebookApp.ip = '*'\n", + "#c.NotebookApp.open_browser = False\n", + "#c.NotebookApp.port = 5000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "from pathlib import Path\n", + "from PIL import ImageDraw, ImageFont, Image\n", + "from matplotlib import patches, patheffects\n", + "import time\n", + "from random import randint\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pickle\n", + "\n", + "from sklearn.model_selection import KFold, StratifiedKFold, GroupKFold\n", + "from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.metrics import mean_squared_error,log_loss,roc_auc_score\n", + "from scipy.stats import ks_2samp\n", + "\n", + "import pdb\n", + "\n", + "import scipy as sp\n", + "from tqdm import tqdm, tqdm_notebook\n", + "\n", + "import os\n", + "import glob\n", + "\n", + "import torch\n", + "\n", + "#CLOUD = not torch.cuda.is_available()\n", + "CLOUD = (torch.cuda.get_device_name(0) in ['Tesla P4', 'Tesla K80', 'Tesla P100-PCIE-16GB'])\n", + "CLOUD_SINGLE = True\n", + "TPU = False\n", + "\n", + "if not CLOUD:\n", + " torch.cuda.current_device()\n", + "\n", + "import torch.nn as nn\n", + "import torch.utils.data as D\n", + "import torch.nn.functional as F\n", + "import torch.utils as U\n", + "\n", + "import torchvision\n", + "from torchvision import transforms as T\n", + "from torchvision import models as M\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "if CLOUD:\n", + " PATH = Path('/home/zahar_chikishev/Hemorrhage')\n", + " PATH_WORK = Path('/mnt/edisk/running')\n", + " PATH_DISK = Path('/mnt/edisk/running')\n", + "else:\n", + " PATH = Path('C:/StudioProjects/Hemorrhage')\n", + " PATH_WORK = Path('C:/StudioProjects/Hemorrhage/running')\n", + " PATH_DISK = PATH_WORK\n", + "\n", + "from collections import defaultdict, Counter\n", + "import random\n", + "import seaborn as sn\n", + "\n", + "pd.set_option(\"display.max_columns\", 100)\n", + "\n", + "all_ich = ['any','epidural','intraparenchymal','intraventricular','subarachnoid','subdural']\n", + "class_weights = 6.0*np.array([2,1,1,1,1,1])/7.0\n", + "\n", + "if CLOUD and TPU:\n", + " import torch_xla\n", + " import torch_xla.distributed.data_parallel as dp\n", + " import torch_xla.utils as xu\n", + " import torch_xla.core.xla_model as xm\n", + "\n", + "from typing import Collection" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "all_black = '006d4432e'\n", + "all_black = '00bd6c59c'\n", + "\n", + "if CLOUD:\n", + " if TPU:\n", + " device = xm.xla_device()\n", + " else:\n", + " device = 'cuda'\n", + " #device = 'cpu'\n", + " MAX_DEVICES = 1 if CLOUD_SINGLE else 8\n", + " NUM_WORKERS = 16\n", + " bs = 32\n", + "else:\n", + " device = 'cuda'\n", + " #device = 'cpu'\n", + " MAX_DEVICES = 1\n", + " NUM_WORKERS = 0\n", + " bs = 16\n", + "\n", + "#if CLOUD and (not CLOUD_SINGLE):\n", + "# devices = xm.get_xla_supported_devices(max_devices=MAX_DEVICES)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "SEED = 2351\n", + "\n", + "def setSeeds(seed):\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + "\n", + "setSeeds(SEED)\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def getDSName(dataset):\n", + " if dataset == 6:\n", + " return 'Densenet201_F3'\n", + " elif dataset == 7:\n", + " return 'Densenet161_F3'\n", + " elif dataset == 8:\n", + " return 'Densenet169_F3'\n", + " elif dataset == 9:\n", + " return 'se_resnext101_32x4d_F3'\n", + " elif dataset == 10:\n", + " return 'se_resnet101_F3'\n", + " elif dataset == 11:\n", + " return 'se_resnext101_32x4d_F5'\n", + " elif dataset == 12:\n", + " return 'se_resnet101_F5'\n", + " elif dataset == 13:\n", + " return 'se_resnet101_focal_F5'\n", + " elif dataset == 14:\n", + " return 'se_resnet101_F5'\n", + " else: assert False" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def getDSParams(dataset):\n", + " if dataset == 6:\n", + " return 'Densenet201','_3','_v5',240,'','','classifier',''\n", + " elif dataset == 7:\n", + " return 'Densenet161','_3','_v4',552,'2','','classifier',''\n", + " elif dataset == 8:\n", + " return 'Densenet169','_3','_v3',208,'2','','classifier',''\n", + " elif dataset == 9:\n", + " return 'se_resnext101_32x4d','','',256,'','_tta','classifier',''\n", + " elif dataset == 10:\n", + " return 'se_resnet101','','',256,'','','classifier',''\n", + " elif dataset == 11:\n", + " return 'se_resnext101_32x4d_5','','',256,'','','new',''\n", + " elif dataset == 12:\n", + " return 'se_resnet101_5','','',256,'','','new',''\n", + " elif dataset == 13:\n", + " return 'se_resnet101_5f','','',256,'','','new','focal_'\n", + " elif dataset == 14:\n", + " return 'se_resnet101_5n','','',256,'','','new','stage2_'\n", + " else: assert False\n", + "\n", + "def getNFolds(dataset):\n", + " if dataset <= 10:\n", + " return 3\n", + " elif dataset <= 14:\n", + " return 5\n", + " else: assert False\n", + "\n", + "my_datasets3 = [7,9]\n", + "my_datasets5 = [11,12,13]\n", + "my_len = len(my_datasets3) + len(my_datasets5)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "cols_le,cols_float,cols_bool = pickle.load(open(PATH_WORK/'covs','rb'))\n", + "meta_cols = cols_bool + cols_float\n", + "\n", + "#meta_cols = [c for c in meta_cols if c not in ['SeriesPP']]\n", + "#cols_le = cols_le[:-1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pre-processing" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def loadMetadata(rerun=False):\n", + " if rerun:\n", + " train_dedup = pd.read_csv(PATH_WORK/'yuval'/'train_dedup.csv')\n", + " pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n", + "\n", + " assert len(pids) == 17079\n", + " assert len(np.unique(pids)) == 17079\n", + "\n", + " for fol in folding:\n", + " assert len(fol[0]) + len(fol[1]) == 17079\n", + "\n", + " assert len(folding[0][1]) + len(folding[1][1]) + len(folding[2][1]) == 17079\n", + "\n", + " assert len(train_dedup.PID.unique()) == 17079\n", + "\n", + " train_dedup['fold'] = np.nan\n", + "\n", + " for fold in range(3):\n", + " train_dedup.loc[train_dedup.PID.isin(pids[folding[fold][1]]),'fold'] = fold\n", + "\n", + " assert train_dedup.fold.isnull().sum() == 0\n", + " \n", + " oof5 = pickle.load(open(PATH_DISK/'yuval'/'OOF_validation_image_ids_5.pkl','rb'))\n", + " for fold in range(5):\n", + " train_dedup.loc[train_dedup.PatientID.isin(oof5[fold]),'fold5'] = fold\n", + " assert train_dedup.fold5.isnull().sum() == 0\n", + " \n", + " train_md = pd.read_csv(PATH_WORK/'train_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " train_md['img_id'] = train_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n", + "\n", + " ids_df = train_dedup[['fold','PatientID','fold5']]\n", + " ids_df.columns = ['fold','img_id','fold5']\n", + "\n", + " train_md = ids_df.join(train_md.set_index('img_id'), on = 'img_id')\n", + "\n", + " pickle.dump(train_md, open(PATH_WORK/'train.ids.df','wb'))\n", + "\n", + " #test_md = pickle.load(open(PATH_WORK/'test.post.processed.1','rb'))\n", + " test_md = pd.read_csv(PATH_WORK/'test_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n", + "\n", + " filename = PATH_WORK/'yuval'/'test_indexes.pkl'\n", + " test_ids = pickle.load(open(filename,'rb'))\n", + "\n", + " test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n", + " test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n", + " \n", + " assert len(test_md.SeriesInstanceUID.unique()) == 2214\n", + " \n", + " #train_md = pd.concat([train_md, test_md], axis = 0, sort=False).reset_index(drop=True)\n", + " \n", + " #pickle.dump(train_md, open(PATH_WORK/'train.ids.df','wb'))\n", + " pickle.dump(test_md, open(PATH_WORK/'test.ids.df','wb'))\n", + " else:\n", + " train_md = pickle.load(open(PATH_WORK/'train.ids.df','rb'))\n", + " test_md = pickle.load(open(PATH_WORK/'test.ids.df','rb'))\n", + " \n", + " return train_md, test_md" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def loadMetadata2(rerun=False):\n", + " if rerun:\n", + " train_dedup2 = pd.read_csv(PATH_WORK/'yuval'/'train_stage2.csv')\n", + "\n", + " assert train_dedup2.PatientID.nunique() == 752797\n", + " assert train_dedup2.PID.nunique() == 18938\n", + "\n", + " pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n", + " pids3, folding3 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_3_stage_2.pkl','rb'))\n", + " pids5, folding5 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_5_stage_2.pkl','rb'))\n", + "\n", + " assert np.array([len(folding3[i][1]) for i in range(3)]).sum() == 18938\n", + " assert np.array([len(folding3[i][0]) for i in range(3)]).sum() == 2*18938\n", + " \n", + " for i in range(3):\n", + " assert set(pids[folding[i][1]]).issubset(set(pids3[folding3[i][1]]))\n", + "\n", + " for i in range(3):\n", + " train_dedup2.loc[train_dedup2.PID.isin(pids3[folding3[i][1]]),'fold'] = i\n", + " assert train_dedup2.fold.isnull().sum() == 0\n", + "\n", + " for i in range(5):\n", + " train_dedup2.loc[train_dedup2.PID.isin(pids5[folding5[i][1]]),'fold5'] = i\n", + " assert train_dedup2.fold5.isnull().sum() == 0\n", + " \n", + " oof5 = pickle.load(open(PATH_DISK/'yuval'/'OOF_validation_image_ids_5.pkl','rb'))\n", + " for i in range(5):\n", + " assert set(oof5[i]).issubset(set(train_dedup2.loc[train_dedup2.fold5==i,'PatientID']))\n", + " \n", + " train_csv = pd.read_csv(PATH/'stage_2_train.csv')\n", + " train_csv = train_csv.loc[~train_csv.ID.duplicated()].sort_values('ID').reset_index(drop=True) \n", + " all_sop_ids = train_csv.ID.str.split('_').apply(lambda x: x[0]+'_'+x[1]).unique()\n", + " train_df = pd.DataFrame(train_csv.Label.values.reshape((-1,6)), columns = all_ich)\n", + " train_df['sop_id'] = all_sop_ids\n", + " \n", + " old_test = pd.read_csv(PATH_WORK/'test_md.csv')\n", + " old_test = old_test.drop(all_ich,axis=1)\n", + " old_test = old_test.join(train_df.set_index('sop_id'), on = 'SOPInstanceUID')\n", + " \n", + " train_md2 = pd.concat([pd.read_csv(PATH_WORK/'train_md.csv'), old_test],\\\n", + " axis=0,sort=False)\n", + " train_md2['img_id'] = train_md2.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n", + " \n", + " ids_df = train_dedup2[['fold','PatientID','fold5']]\n", + " ids_df.columns = ['fold','img_id','fold5']\n", + "\n", + " train_md2 = ids_df.join(train_md2.set_index('img_id'), on = 'img_id')\n", + " assert train_md2.test.sum() == 78545\n", + " \n", + " pickle.dump(train_md2, open(PATH_WORK/'train2.ids.df','wb'))\n", + " \n", + " test_md = pd.read_csv(PATH_WORK/'test2_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n", + "\n", + " filename = PATH_WORK/'yuval/test_indexes_stage2.pkl'\n", + " test_ids = pickle.load(open(filename,'rb'))\n", + "\n", + " test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n", + " test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n", + "\n", + " assert len(test_md.SeriesInstanceUID.unique()) == 3518\n", + "\n", + " pickle.dump(test_md, open(PATH_WORK/'test2.ids.df','wb'))\n", + "\n", + " else:\n", + " train_md2 = pickle.load(open(PATH_WORK/'train2.ids.df','rb'))\n", + " test_md = pickle.load(open(PATH_WORK/'test2.ids.df','rb'))\n", + " \n", + " return train_md2, test_md" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def loadMetadata3(rerun=False):\n", + " if rerun:\n", + " train_md = pickle.load(open(PATH_WORK/'train.ids.df','rb'))\n", + " test_md = pickle.load(open(PATH_WORK/'test.ids.df','rb'))\n", + " train_md = pd.concat([train_md, test_md], axis=0,sort=False).reset_index(drop=True)\n", + " \n", + " train_md.fold = np.nan\n", + " train_md['PID'] = train_md.PatientID.str.split('_').apply(lambda x: x[1])\n", + " \n", + " pids, folding = pickle.load(open(PATH_WORK/'yuval'/'PID_splits.pkl','rb'))\n", + " pids3, folding3 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_3_stage_2.pkl','rb'))\n", + " pids5, folding5 = pickle.load(open(PATH_WORK/'yuval'/'PID_splits_5_stage_2.pkl','rb'))\n", + "\n", + " for i in range(3):\n", + " train_md.loc[train_md.PID.isin(pids3[folding3[i][1]]),'fold'] = i\n", + " assert train_md.fold.isnull().sum() == 0\n", + "\n", + " for i in range(5):\n", + " train_md.loc[train_md.PID.isin(pids5[folding5[i][1]]),'fold5'] = i\n", + " assert train_md.fold5.isnull().sum() == 0\n", + " \n", + " train_md.weights = train_md.weights.fillna(1)\n", + " \n", + " pickle.dump(train_md, open(PATH_WORK/'train3.ids.df','wb'))\n", + " \n", + " test_md = pd.read_csv(PATH_WORK/'test2_md.csv').sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " test_md['img_id'] = test_md.SOPInstanceUID.str.split('_').apply(lambda x: x[1])\n", + "\n", + " filename = PATH_WORK/'yuval/test_indexes_stage2.pkl'\n", + " test_ids = pickle.load(open(filename,'rb'))\n", + "\n", + " test_ids_df = pd.DataFrame(test_ids, columns = ['img_id'])\n", + " test_md = test_ids_df.join(test_md.set_index('img_id'), on = 'img_id')\n", + "\n", + " assert len(test_md.SeriesInstanceUID.unique()) == 3518\n", + "\n", + " pickle.dump(test_md, open(PATH_WORK/'test3.ids.df','wb'))\n", + " else:\n", + " train_md = pickle.load(open(PATH_WORK/'train3.ids.df','rb'))\n", + " test_md = pickle.load(open(PATH_WORK/'test3.ids.df','rb'))\n", + " \n", + " return train_md, test_md" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "#train_md, test_md = loadMetadata(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "#train_md = loadMetadata2(True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocessedData(dataset, folds = range(3), fold_col=None, do_test=True, do_test2=True, do_train=True):\n", + " assert dataset >= 6\n", + " \n", + " dataset_name, filename_add, filename_add2, feat_sz, ds_num, test_fix, dsft, focal = getDSParams(dataset)\n", + " \n", + " #if 'train_md' not in globals() or 'test_md' not in globals():\n", + " # train_md, test_md = loadMetadata(False)\n", + " \n", + " PATH_DS = PATH_DISK/'features/{}{}'.format(dataset_name,filename_add2)\n", + " if not PATH_DS.is_dir():\n", + " PATH_DS.mkdir()\n", + " if not (PATH_DS/'train').is_dir():\n", + " (PATH_DS/'train').mkdir()\n", + " if not (PATH_DS/'test').is_dir():\n", + " (PATH_DS/'test').mkdir()\n", + " if not (PATH_DS/'test2').is_dir():\n", + " (PATH_DS/'test2').mkdir()\n", + " \n", + " if fold_col is None:\n", + " fold_col = 'fold'\n", + " if len(folds) == 5:\n", + " fold_col = 'fold5'\n", + " \n", + " for fold in folds:\n", + " \n", + " if do_train:\n", + " filename = PATH_DISK/'yuval'/\\\n", + " 'model_{}{}_version_{}_splits_{}type_features_train_tta{}_split_{}.pkl'\\\n", + " .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''), \\\n", + " filename_add, dsft, focal, ds_num, fold)\n", + " feats = pickle.load(open(filename,'rb'))\n", + "\n", + " print('dataset',dataset,'fold',fold,'feats size', feats.shape)\n", + " if dataset <= 13:\n", + " train_md_loc = train_md.loc[~train_md.test]\n", + " else:\n", + " train_md_loc = train_md\n", + " \n", + " assert len(feats) == 4*len(train_md_loc)\n", + " assert feats.shape[1] == feat_sz\n", + " means = feats.mean(0,keepdim=True)\n", + " stds = feats.std(0,keepdim=True)\n", + "\n", + " feats = feats - means\n", + " feats = torch.where(stds > 0, feats/stds, feats)\n", + "\n", + " for i in range(4):\n", + " feats_sub1 = feats[torch.BoolTensor(np.arange(len(feats))%4 == i)]\n", + " feats_sub2 = feats_sub1[torch.BoolTensor(train_md_loc[fold_col] != fold)]\n", + " pickle.dump(feats_sub2, open(PATH_DS/'train/train.f{}.a{}'.format(fold,i),'wb'))\n", + "\n", + " feats_sub2 = feats_sub1[torch.BoolTensor(train_md_loc[fold_col] == fold)]\n", + " pickle.dump(feats_sub2, open(PATH_DS/'train/valid.f{}.a{}'.format(fold,i),'wb'))\n", + "\n", + " if i==0:\n", + " black_feats = feats_sub1[torch.BoolTensor(train_md_loc.img_id == all_black)].squeeze()\n", + " pickle.dump(black_feats, open(PATH_DS/'train/black.f{}'.format(fold),'wb'))\n", + " \n", + " if do_test:\n", + " filename = PATH_DISK/'yuval'/\\\n", + " 'model_{}{}_version_{}_splits_{}type_features_test{}{}_split_{}.pkl'\\\n", + " .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''),\n", + " filename_add,dsft,focal,ds_num,test_fix,fold)\n", + " feats = pickle.load(open(filename,'rb'))\n", + "\n", + " assert len(feats) == 8*len(test_md)\n", + " assert feats.shape[1] == feat_sz\n", + "\n", + " feats = feats - means\n", + " feats = torch.where(stds > 0, feats/stds, feats)\n", + "\n", + " for i in range(8):\n", + " feats_sub = feats[torch.BoolTensor(np.arange(len(feats))%8 == i)]\n", + " pickle.dump(feats_sub, open(PATH_DS/'test/test.f{}.a{}'.format(fold,i),'wb'))\n", + " assert len(feats_sub) == len(test_md)\n", + " \n", + " if do_test2:\n", + " filename = PATH_DISK/'yuval'/\\\n", + " 'model_{}{}_version_{}_splits_{}type_features_test_stage2_split_{}.pkl'\\\n", + " .format(dataset_name.replace('_5n','').replace('_5f','').replace('_5',''),\n", + " filename_add,dsft,focal,fold)\n", + " feats = pickle.load(open(filename,'rb'))\n", + "\n", + " assert len(feats) == 8*len(test_md)\n", + " assert feats.shape[1] == feat_sz\n", + "\n", + " feats = feats - means\n", + " feats = torch.where(stds > 0, feats/stds, feats)\n", + "\n", + " for i in range(8):\n", + " feats_sub = feats[torch.BoolTensor(np.arange(len(feats))%8 == i)]\n", + " pickle.dump(feats_sub, open(PATH_DS/'test2/test.f{}.a{}'.format(fold,i),'wb'))\n", + " assert len(feats_sub) == len(test_md)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [], + "source": [ + "if False:\n", + " path = PATH_WORK/'features/densenet161_v3/train/ID_992b567eb6'\n", + " black_feats = pickle.load(open(path,'rb'))[41]" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "class RSNA_DataSet(D.Dataset):\n", + " def __init__(self, metadata, dataset, mode='train', bs=None, fold=0):\n", + " \n", + " super(RSNA_DataSet, self).__init__()\n", + " \n", + " dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n", + " num_folds = getNFolds(dataset)\n", + " \n", + " self.dataset_name = dataset_name\n", + " self.filename_add2 = filename_add2\n", + " \n", + " folds_col = 'fold'\n", + " if num_folds == 5:\n", + " folds_col = 'fold5'\n", + " \n", + " if (dataset <= 13) and (mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n", + " if mode == 'train':\n", + " self.test_mask = torch.BoolTensor((metadata.loc[metadata.test, folds_col] != fold).values)\n", + " else:\n", + " self.test_mask = torch.BoolTensor((metadata.loc[metadata.test, folds_col] == fold).values)\n", + " \n", + " if mode == 'train':\n", + " md = metadata.loc[metadata[folds_col] != fold].copy().reset_index(drop=True)\n", + " elif mode == 'valid':\n", + " md = metadata.loc[metadata[folds_col] == fold].copy().reset_index(drop=True)\n", + " else:\n", + " md = metadata.copy().reset_index(drop=True)\n", + " \n", + " series = np.sort(md.SeriesInstanceUID.unique())\n", + " md = md.set_index('SeriesInstanceUID', drop=True)\n", + " \n", + " samples_add = 0\n", + " if (mode != 'train') and not DATA_SMALL:\n", + " batch_num = -((-len(series))//(bs*MAX_DEVICES))\n", + " samples_add = batch_num*bs*MAX_DEVICES - len(series)\n", + " print('adding dummy serieses', samples_add)\n", + " \n", + " #self.records = df.to_records(index=False)\n", + " self.mode = mode\n", + " self.real = np.concatenate([np.repeat(True,len(series)),np.repeat(False,samples_add)])\n", + " self.series = np.concatenate([series, random.sample(list(series),samples_add)])\n", + " self.metadata = md\n", + " self.dataset = dataset\n", + " self.fold = fold\n", + " \n", + " print('DataSet', dataset, mode, 'size', len(self.series), 'fold', fold)\n", + " \n", + " path = PATH_DISK/'features/{}{}/train/black.f{}'.format(self.dataset_name, self.filename_add2, fold)\n", + " self.black_feats = pickle.load(open(path,'rb')).squeeze()\n", + " \n", + " if WEIGHTED and (self.mode == 'train'):\n", + " tt = self.metadata['weights'].groupby(self.metadata.index).mean()\n", + " self.weights = tt.loc[self.series].values\n", + " print(pd.value_counts(self.weights))\n", + " \n", + " def setFeats(self, anum, epoch=0):\n", + " def getAPath(an,mode):\n", + " folder = 'test2' if mode == 'test' else 'train'\n", + " return PATH_DISK/'features/{}{}/{}/{}.f{}.a{}'\\\n", + " .format(self.dataset_name,self.filename_add2,folder,mode,self.fold,an)\n", + " \n", + " def getAPathFeats(mode, sz, mask=None):\n", + " max_a = 8 if mode == 'test' else 4\n", + " feats2 = torch.stack([pickle.load(open(getAPath(an,mode),'rb')) for an in range(max_a)])\n", + " if mask is not None:\n", + " feats2 = feats2[:,mask]\n", + " assert feats2.shape[1] == sz\n", + " feats = feats2[torch.randint(max_a,(sz,)), torch.arange(sz, dtype=torch.long)].squeeze()\n", + " return feats\n", + " \n", + " if self.dataset == 1: return\n", + " print('setFeats, augmentation', anum)\n", + " self.anum = anum\n", + " sz = len(self.metadata)\n", + " \n", + " assert anum <= 0\n", + " \n", + " if anum == -1:\n", + " if (self.dataset <= 13) and (self.mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n", + " feats = torch.cat([getAPathFeats(self.mode, sz - self.metadata.test.sum()), \n", + " getAPathFeats('test', self.metadata.test.sum(), self.test_mask)] ,axis=0)\n", + " else:\n", + " feats = getAPathFeats(self.mode, sz)\n", + " else:\n", + " if (self.dataset <= 13) and (self.mode in ['train','valid']) and TRAIN_ON_STAGE_1:\n", + " feats = torch.cat([pickle.load(open(getAPath(anum,self.mode),'rb')),\n", + " pickle.load(open(getAPath(anum,'test'),'rb'))[self.test_mask]], axis=0)\n", + " else:\n", + " feats = pickle.load(open(getAPath(anum,self.mode),'rb'))\n", + " \n", + " self.feats = feats\n", + " self.epoch = epoch\n", + " assert len(feats) == sz\n", + " \n", + " def __getitem__(self, index):\n", + " \n", + " series_id = self.series[index]\n", + " #df = self.metadata.loc[self.metadata.SeriesInstanceUID == series_id].reset_index(drop=True)\n", + " df = self.metadata.loc[series_id].reset_index(drop=True)\n", + " \n", + " if self.dataset >= 6:\n", + " feats = self.feats[torch.BoolTensor(self.metadata.index.values == series_id)]\n", + " else: assert False\n", + " \n", + " order = np.argsort(df.pos_idx1.values)\n", + " df = df.sort_values(['pos_idx1'])\n", + " feats = feats[torch.LongTensor(order)]\n", + " \n", + " #if WEIGHTED:\n", + " # non_black = torch.Tensor(df['weights'])\n", + " #else:\n", + " non_black = torch.ones(len(feats))\n", + " \n", + " feats = torch.cat([feats, torch.Tensor(df[meta_cols].values)], dim=1)\n", + " feats_le = torch.LongTensor(df[cols_le].values)\n", + " \n", + " target = torch.Tensor(df[all_ich].values)\n", + " \n", + " PAD = 12\n", + " \n", + " np.random.seed(1234 + index + 10000*self.epoch)\n", + " offset = np.random.randint(0, 61 - feats.shape[0])\n", + " \n", + " #offset = 0\n", + " bk_add = 0\n", + " top_pad = PAD + offset\n", + " if top_pad > 0:\n", + " dummy_row = torch.cat([self.black_feats, torch.Tensor(df.head(1)[meta_cols].values).squeeze()])\n", + " feats = torch.cat([dummy_row.repeat(top_pad,1), feats], dim=0)\n", + " feats_le = torch.cat([torch.LongTensor(df.head(1)[cols_le].values).squeeze().repeat(top_pad,1), feats_le])\n", + " if offset > 0:\n", + " non_black = torch.cat([bk_add + torch.zeros(offset), non_black])\n", + " target = torch.cat([torch.zeros((offset, len(all_ich))), target], dim=0)\n", + " bot_pad = 60 - len(df) - offset + PAD\n", + " if bot_pad > 0:\n", + " dummy_row = torch.cat([self.black_feats, torch.Tensor(df.tail(1)[meta_cols].values).squeeze()])\n", + " feats = torch.cat([feats, dummy_row.repeat(bot_pad,1)], dim=0)\n", + " feats_le = torch.cat([feats_le, torch.LongTensor(df.tail(1)[cols_le].values).squeeze().repeat(bot_pad,1)])\n", + " if (60 - len(df) - offset) > 0:\n", + " non_black = torch.cat([non_black, bk_add + torch.zeros(60 - len(df) - offset)])\n", + " target = torch.cat([target, torch.zeros((60 - len(df) - offset, len(all_ich)))], dim=0)\n", + " \n", + " assert feats_le.shape[0] == (60 + 2*PAD)\n", + " assert feats.shape[0] == (60 + 2*PAD)\n", + " assert target.shape[0] == 60\n", + " \n", + " idx = index\n", + " if not self.real[index]: idx = -1\n", + " \n", + " if self.mode == 'train':\n", + " return feats, feats_le, target, non_black\n", + " else:\n", + " return feats, feats_le, target, idx, offset\n", + " \n", + " def __len__(self):\n", + " return len(self.series) if not DATA_SMALL else int(0.01*len(self.series))" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "def getCurrentBatch(dataset, fold=0, ver=VERSION):\n", + " sel_batch = None\n", + " for filename in os.listdir(PATH_DISK/'models'):\n", + " splits = filename.split('.')\n", + " if int(splits[2][1]) != fold: continue\n", + " if int(splits[3][1:]) != dataset: continue\n", + " if int(splits[4][1:]) != ver: continue\n", + " if sel_batch is None:\n", + " sel_batch = int(splits[1][1:])\n", + " else:\n", + " sel_batch = max(sel_batch, int(splits[1][1:]))\n", + " return sel_batch\n", + "\n", + "def modelFileName(dataset, fold=0, batch = 1, return_last = False, return_next = False, ver=VERSION):\n", + " sel_batch = batch\n", + " if return_last or return_next:\n", + " sel_batch = getCurrentBatch(fold=fold, dataset=dataset, ver=ver)\n", + " if return_last and sel_batch is None:\n", + " return None\n", + " if return_next:\n", + " if sel_batch is None: sel_batch = 1\n", + " else: sel_batch += 1\n", + " \n", + " return 'model.b{}.f{}.d{}.v{}'.format(sel_batch, fold, dataset, ver)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class BCEWithLogitsLoss(nn.Module):\n", + " def __init__(self, weight=None):\n", + " super().__init__()\n", + " self.weight = weight\n", + " \n", + " def forward(self, input, target, batch_weights = None, focal=0):\n", + " if focal == 0:\n", + " loss = (torch.log(1+torch.exp(input)) - target*input)*self.weight\n", + " else:\n", + " loss = torch.pow(1+torch.exp(input), -focal) * \\\n", + " (((1-target)*torch.exp(focal*input) + target) * torch.log(1+torch.exp(input)) - target*input)\n", + " loss = loss*self.weight\n", + " if batch_weights is not None:\n", + " loss = batch_weights*loss\n", + " return loss.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None):\n", + " \"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`.\"\n", + " layers = [nn.BatchNorm1d(n_in)] if bn else []\n", + " if p != 0: layers.append(nn.Dropout(p))\n", + " layers.append(nn.Linear(n_in, n_out))\n", + " if actn is not None: layers.append(actn)\n", + " return layers" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "def noop(x): return x\n", + "act_fun = nn.ReLU(inplace=True)\n", + "\n", + "def conv_layer(ni, nf, ks=3, act=True):\n", + " bn = nn.BatchNorm1d(nf)\n", + " layers = [nn.Conv1d(ni, nf, ks), bn]\n", + " if act: layers.append(act_fun)\n", + " return nn.Sequential(*layers)\n", + "\n", + "class ResBlock(nn.Module):\n", + " def __init__(self, ni, nh):\n", + " super().__init__()\n", + " layers = [conv_layer(ni, nh, 1),\n", + " conv_layer(nh, nh, 5, act=False)]\n", + " self.convs = nn.Sequential(*layers)\n", + " self.idconv = noop if (ni == nh) else conv_layer(ni, nh, 1, act=False)\n", + " \n", + " def forward(self, x): return act_fun(self.convs(x) + self.idconv(x[:,:,2:-2]))" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "class ResNetModel(nn.Module):\n", + " def __init__(self, n_cont:int, feat_sz=2208):\n", + " super().__init__()\n", + " \n", + " self.le_sz = 10\n", + " assert self.le_sz == (len(cols_le))\n", + " le_in_sizes = np.array([5,5,7,4,4,11,4,6,3,3])\n", + " le_out_sizes = np.array([3,3,4,2,2,6,2,4,2,2])\n", + " le_out_sz = le_out_sizes.sum()\n", + " self.embeddings = nn.ModuleList([embedding(le_in_sizes[i], le_out_sizes[i]) for i in range(self.le_sz)])\n", + " \n", + " self.feat_sz = feat_sz\n", + " \n", + " self.n_cont = n_cont\n", + " scale = 4\n", + " \n", + " self.conv2D = nn.Conv2d(1,scale*16,(feat_sz + n_cont + le_out_sz,1))\n", + " self.bn1 = nn.BatchNorm1d(scale*16)\n", + " \n", + " self.res1 = ResBlock(scale*16,scale*16)\n", + " self.res2 = ResBlock(scale*16,scale*8)\n", + " \n", + " self.res3 = ResBlock(scale*24,scale*16)\n", + " self.res4 = ResBlock(scale*16,scale*8)\n", + " \n", + " self.res5 = ResBlock(scale*32,scale*16)\n", + " self.res6 = ResBlock(scale*16,scale*8)\n", + " \n", + " self.conv1D = nn.Conv1d(scale*40,6,1)\n", + " \n", + " def forward(self, x, x_le, x_le_mix = None, lambd = None) -> torch.Tensor:\n", + " x_le = [e(x_le[:,:,i]) for i,e in enumerate(self.embeddings)]\n", + " x_le = torch.cat(x_le, 2)\n", + " \n", + " x = torch.cat([x, x_le], 2)\n", + " x = x.transpose(1,2)\n", + " \n", + " #x = torch.cat([x[:,:self.feat_sz],self.bn_cont(x[:,self.feat_sz:])], dim=1)\n", + " x = x.reshape(x.shape[0],1,x.shape[1],x.shape[2])\n", + " \n", + " x = self.conv2D(x).squeeze()\n", + " x = self.bn1(x)\n", + " x = act_fun(x)\n", + " \n", + " x2 = self.res1(x)\n", + " x2 = self.res2(x2)\n", + " \n", + " x3 = torch.cat([x[:,:,4:-4], x2], 1)\n", + " \n", + " x3 = self.res3(x3)\n", + " x3 = self.res4(x3)\n", + " \n", + " x4 = torch.cat([x[:,:,8:-8], x2[:,:,4:-4], x3], 1)\n", + " \n", + " x4 = self.res5(x4)\n", + " x4 = self.res6(x4)\n", + " \n", + " x5 = torch.cat([x[:,:,12:-12], x2[:,:,8:-8], x3[:,:,4:-4], x4], 1)\n", + " x5 = self.conv1D(x5)\n", + " x5 = x5.transpose(1,2)\n", + " \n", + " return x5" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [], + "source": [ + "def trunc_normal_(x, mean:float=0., std:float=1.):\n", + " \"Truncated normal initialization.\"\n", + " # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12\n", + " return x.normal_().fmod_(2).mul_(std).add_(mean)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "def embedding(ni:int,nf:int) -> nn.Module:\n", + " \"Create an embedding layer.\"\n", + " emb = nn.Embedding(ni, nf)\n", + " # See https://arxiv.org/abs/1711.09160\n", + " with torch.no_grad(): trunc_normal_(emb.weight, std=0.01)\n", + " return emb" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "class TabularModel(nn.Module):\n", + " \"Basic model for tabular data.\"\n", + " def __init__(self, n_cont:int, feat_sz=2208, fc_drop_p=0.3):\n", + " super().__init__()\n", + " self.le_sz = 10\n", + " assert len(cols_le) == self.le_sz\n", + " le_in_sizes = np.array([5,5,7,4,4,11,4,6,3,3])\n", + " le_out_sizes = np.array([3,3,4,2,2,6,2,4,2,2])\n", + " le_out_sz = le_out_sizes.sum()\n", + " self.embeddings = nn.ModuleList([embedding(le_in_sizes[i], le_out_sizes[i]) for i in range(self.le_sz)])\n", + " #self.bn_cont = nn.BatchNorm1d(feat_sz + n_cont)\n", + " \n", + " self.feat_sz = feat_sz\n", + " self.bn_cont = nn.BatchNorm1d(n_cont + le_out_sz)\n", + " self.n_cont = n_cont\n", + " self.fc_drop = nn.Dropout(p=fc_drop_p)\n", + " self.relu = nn.ReLU(inplace=True)\n", + " \n", + " scale = 4\n", + " \n", + " self.conv2D_1 = nn.Conv2d(1,16*scale,(feat_sz + n_cont + le_out_sz,1))\n", + " self.conv2D_2 = nn.Conv2d(1,16*scale,(feat_sz + n_cont + le_out_sz,5))\n", + " self.bn_cont1 = nn.BatchNorm1d(32*scale)\n", + " self.conv1D_1 = nn.Conv1d(32*scale,16*scale,3)\n", + " self.conv1D_3 = nn.Conv1d(32*scale,16*scale,5,dilation=5)\n", + " self.conv1D_2 = nn.Conv1d(32*scale,6,3)\n", + " self.bn_cont2 = nn.BatchNorm1d(32*scale)\n", + " self.bn_cont3 = nn.BatchNorm1d(6)\n", + "\n", + " self.conv1D_4 = nn.Conv1d(32*scale,32*scale,3)\n", + " self.bn_cont4 = nn.BatchNorm1d(32*scale)\n", + "\n", + " def forward(self, x, x_le, x_le_mix = None, lambd = None) -> torch.Tensor:\n", + " x_le = [e(x_le[:,:,i]) for i,e in enumerate(self.embeddings)]\n", + " x_le = torch.cat(x_le, 2)\n", + " \n", + " if MIXUP and x_le_mix is not None:\n", + " x_le_mix = [e(x_le_mix[:,:,i]) for i,e in enumerate(self.embeddings)]\n", + " x_le_mix = torch.cat(x_le_mix, 2)\n", + " x_le = lambd * x_le + (1-lambd) * x_le_mix\n", + " \n", + " #assert torch.isnan(x_le).any().cpu() == False\n", + " x = torch.cat([x, x_le], 2)\n", + " x = x.transpose(1,2)\n", + " \n", + " #x = torch.cat([x[:,:self.feat_sz],self.bn_cont(x[:,self.feat_sz:])], dim=1)\n", + " \n", + " x = x.reshape(x.shape[0],1,x.shape[1],x.shape[2])\n", + " x = self.fc_drop(x)\n", + " x = torch.cat([self.conv2D_1(x[:,:,:,2:(-2)]).squeeze(), \n", + " self.conv2D_2(x).squeeze()], dim=1)\n", + " x = self.relu(x)\n", + " x = self.bn_cont1(x)\n", + " x = self.fc_drop(x)\n", + " \n", + " x = torch.cat([self.conv1D_1(x[:,:,9:(-9)]),\n", + " self.conv1D_3(x)], dim=1)\n", + " x = self.relu(x)\n", + " x = self.bn_cont2(x)\n", + " x = self.fc_drop(x)\n", + " \n", + " x = self.conv1D_4(x)\n", + " x = self.relu(x)\n", + " x = self.bn_cont4(x)\n", + " \n", + " x = self.conv1D_2(x)\n", + " x = x.transpose(1,2)\n", + " \n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [], + "source": [ + "def train_loop_fn(model, loader, device, context = None):\n", + " \n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " tlen = len(loader._loader._loader)\n", + " OUT_LOSS = 1000\n", + " OUT_TIME = 4\n", + " generator = loader\n", + " device_num = int(str(device)[-1])\n", + " dataset = loader._loader._loader.dataset\n", + " else:\n", + " tlen = len(loader)\n", + " OUT_LOSS = 50\n", + " OUT_TIME = 50\n", + " generator = enumerate(loader)\n", + " device_num = 1\n", + " dataset = loader.dataset\n", + " \n", + " #print('Start training {}'.format(device), 'batches', tlen)\n", + " \n", + " criterion = BCEWithLogitsLoss(weight = torch.Tensor(class_weights).to(device))\n", + " optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.99))\n", + " \n", + " model.train()\n", + " \n", + " if CLOUD and TPU:\n", + " tracker = xm.RateTracker()\n", + "\n", + " tloss = 0\n", + " tloss_count = 0\n", + " \n", + " st = time.time()\n", + " mixup_collected = False\n", + " x_le_mix = None\n", + " lambd = None\n", + " for i, (x, x_le, y, non_black) in generator:\n", + " if (not CLOUD) or CLOUD_SINGLE:\n", + " x = x.to(device)\n", + " x_le = x_le.to(device)\n", + " y = y.to(device)\n", + " non_black = non_black.to(device)\n", + " \n", + " if MIXUP:\n", + " if mixup_collected:\n", + " lambd = np.random.beta(0.4, 0.4, y.size(0))\n", + " lambd = torch.Tensor(lambd).to(device)[:,None,None]\n", + " #shuffle = torch.randperm(y.size(0)).to(device)\n", + " x = lambd * x + (1-lambd) * x_mix #x[shuffle]\n", + " #x_le = lambd * x_le + (1-lambd) * x_le_mix #x[shuffle]\n", + " mixup_collected = False\n", + " else:\n", + " x_mix = x\n", + " x_le_mix = x_le\n", + " y_mix = y\n", + " mixup_collected = True\n", + " continue\n", + " \n", + " optimizer.zero_grad()\n", + " output = model(x, x_le, x_le_mix, lambd)\n", + " \n", + " if MIXUP:\n", + " if NO_BLACK_LOSS:\n", + " loss = criterion(output, y, lambd*non_black[:,:,None]) \\\n", + " + criterion(output, y_mix, (1-lambd)*non_black[:,:,None])\n", + " else:\n", + " loss = criterion(output, y, lambd) + criterion(output, y_mix, 1-lambd)\n", + " del x_mix, y_mix\n", + " else:\n", + " if NO_BLACK_LOSS:\n", + " loss = criterion(output, y, non_black[:,:,None])\n", + " else:\n", + " loss = criterion(output, y)\n", + " \n", + " loss.backward()\n", + " \n", + " tloss += len(y)*loss.cpu().detach().item()\n", + " tloss_count += len(y)\n", + " \n", + " if (CLOUD or CLOUD_SINGLE) and TPU:\n", + " xm.optimizer_step(optimizer)\n", + " if CLOUD_SINGLE:\n", + " xm.mark_step()\n", + " else:\n", + " optimizer.step()\n", + " \n", + " if CLOUD and TPU:\n", + " tracker.add(len(y))\n", + " \n", + " st_passed = time.time() - st\n", + " if (i+1)%OUT_TIME == 0 and device_num == 1:\n", + " #print(torch_xla._XLAC._xla_metrics_report())\n", + " print('Batch {} device: {} time passed: {:.3f} time per batch: {:.3f}'\n", + " .format(i+1, device, st_passed, st_passed/(i+1)))\n", + " \n", + " del loss, output, y, x, x_le\n", + " \n", + " return tloss, tloss_count" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "@torch.no_grad()\n", + "def val_loop_fn(model, loader, device, context = None):\n", + " \n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " tlen = len(loader._loader._loader)\n", + " OUT_LOSS = 1000\n", + " OUT_TIME = 4\n", + " generator = loader\n", + " device_num = int(str(device)[-1])\n", + " else:\n", + " tlen = len(loader)\n", + " OUT_LOSS = 1000\n", + " OUT_TIME = 50\n", + " generator = enumerate(loader)\n", + " device_num = 1\n", + " \n", + " #print('Start validating {}'.format(device), 'batches', tlen)\n", + " \n", + " st = time.time()\n", + " model.eval()\n", + " \n", + " results = []\n", + " indices = []\n", + " offsets = []\n", + " \n", + " for i, (x, x_le, y, idx, offset) in generator:\n", + " \n", + " if (not CLOUD) or CLOUD_SINGLE:\n", + " x = x.to(device)\n", + " x_le = x_le.to(device)\n", + " \n", + " output = model(x, x_le)\n", + " assert torch.isnan(output).any().cpu() == False\n", + " output = torch.sigmoid(output)\n", + " assert torch.isnan(output).any().cpu() == False\n", + " \n", + " mask = (idx >= 0)\n", + " results.append(output[mask].cpu().detach().numpy())\n", + " indices.append(idx[mask].cpu().detach().numpy())\n", + " offsets.append(offset[mask].cpu().detach().numpy())\n", + " \n", + " st_passed = time.time() - st\n", + " if (i+1)%OUT_TIME == 0 and device_num == 1:\n", + " print('Batch {} device: {} time passed: {:.3f} time per batch: {:.3f}'\n", + " .format(i+1, device, st_passed, st_passed/(i+1)))\n", + " \n", + " del output, y, x, x_le, idx, offset\n", + " \n", + " results = np.concatenate(results)\n", + " indices = np.concatenate(indices)\n", + " offsets = np.concatenate(offsets)\n", + " \n", + " return results, indices, offsets" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "@torch.no_grad()\n", + "def test_loop_fn(model, loader, device, context = None):\n", + " \n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " tlen = len(loader._loader._loader)\n", + " OUT_LOSS = 1000\n", + " OUT_TIME = 100\n", + " generator = loader\n", + " device_num = int(str(device)[-1])\n", + " else:\n", + " tlen = len(loader)\n", + " OUT_LOSS = 1000\n", + " OUT_TIME = 10\n", + " generator = enumerate(loader)\n", + " device_num = 1\n", + " \n", + " #print('Start testing {}'.format(device), 'batches', tlen)\n", + " \n", + " st = time.time()\n", + " model.eval()\n", + " \n", + " results = []\n", + " indices = []\n", + " offsets = []\n", + " \n", + " for i, (x, x_le, y, idx, offset) in generator:\n", + " \n", + " if (not CLOUD) or CLOUD_SINGLE:\n", + " x = x.to(device)\n", + " x_le = x_le.to(device)\n", + " \n", + " output = torch.sigmoid(model(x, x_le))\n", + " \n", + " mask = (idx >= 0)\n", + " results.append(output[mask].cpu().detach().numpy())\n", + " indices.append(idx[mask].cpu().detach().numpy())\n", + " offsets.append(offset[mask].cpu().detach().numpy())\n", + " \n", + " st_passed = time.time() - st\n", + " if (i+1)%OUT_TIME == 0 and device_num == 1:\n", + " print('B{} -> time passed: {:.3f} time per batch: {:.3f}'.format(i+1, st_passed, st_passed/(i+1)))\n", + " \n", + " del output, x, y, idx, offset\n", + " \n", + " return np.concatenate(results), np.concatenate(indices), np.concatenate(offsets)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "def train_one(dataset, weight=None, load_model=True, epochs=1, bs=100, fold=0, init_ver=None):\n", + " \n", + " st0 = time.time()\n", + " dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n", + " \n", + " cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n", + " if cur_epoch is None: cur_epoch = 0\n", + " print('completed epochs:', cur_epoch, 'starting now:', epochs)\n", + " \n", + " setSeeds(SEED + cur_epoch)\n", + " \n", + " assert train_md.weights.isnull().sum() == 0\n", + " \n", + " if dataset >= 6:\n", + " trn_ds = RSNA_DataSet(train_md, mode='train', bs=bs, fold=fold, dataset=dataset)\n", + " val_ds = RSNA_DataSet(train_md, mode='valid', bs=bs, fold=fold, dataset=dataset)\n", + " else: assert False\n", + " val_ds.setFeats(0)\n", + " \n", + " if WEIGHTED:\n", + " print('WeightedRandomSampler')\n", + " sampler = D.sampler.WeightedRandomSampler(trn_ds.weights, len(trn_ds))\n", + " loader = D.DataLoader(trn_ds, num_workers=NUM_WORKERS, batch_size=bs, \n", + " shuffle=False, drop_last=True, sampler=sampler)\n", + " else:\n", + " loader = D.DataLoader(trn_ds, num_workers=NUM_WORKERS, batch_size=bs, \n", + " shuffle=True, drop_last=True)\n", + " loader_val = D.DataLoader(val_ds, num_workers=NUM_WORKERS, batch_size=bs, \n", + " shuffle=True)\n", + " print('dataset train:', len(trn_ds), 'valid:', len(val_ds), 'loader train:', len(loader), 'valid:', len(loader_val))\n", + " \n", + " #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz, fc_drop_p=0)\n", + " model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n", + " \n", + " model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n", + " \n", + " if (model_file_name is None) and (init_ver is not None):\n", + " model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset, ver=init_ver)\n", + " \n", + " if model_file_name is not None:\n", + " print('loading model', model_file_name)\n", + " state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n", + " model.load_state_dict(state_dict)\n", + " else:\n", + " print('starting from scratch')\n", + " \n", + " if (not CLOUD) or CLOUD_SINGLE:\n", + " model = model.to(device)\n", + " else:\n", + " model_parallel = dp.DataParallel(model, device_ids=devices)\n", + " \n", + " loc_data = val_ds.metadata.copy()\n", + " \n", + " if DATA_SMALL:\n", + " val_sz = len(val_ds)\n", + " val_series = val_ds.series[:val_sz]\n", + " loc_data = loc_data.loc[loc_data.index.isin(val_series)]\n", + " \n", + " series_counts = loc_data.index.value_counts()\n", + " \n", + " loc_data['orig_idx'] = np.arange(len(loc_data))\n", + " loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " loc_data['my_order'] = np.arange(len(loc_data))\n", + " loc_data = loc_data.sort_values(['orig_idx'])\n", + " \n", + " ww_val = loc_data.weights\n", + " \n", + " for i in range(cur_epoch+1, cur_epoch+epochs+1):\n", + " st = time.time()\n", + " \n", + " #trn_ds.setFeats((i-1) % 4)\n", + " trn_ds.setFeats(-1, epoch=i)\n", + " \n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " results = model_parallel(train_loop_fn, loader)\n", + " tloss, tloss_count = np.stack(results).sum(0)\n", + " state_dict = model_parallel._models[0].state_dict()\n", + " else:\n", + " tloss, tloss_count = train_loop_fn(model, loader, device)\n", + " state_dict = model.state_dict()\n", + " \n", + " state_dict = {k:v.to('cpu') for k,v in state_dict.items()}\n", + " tr_ll = tloss / tloss_count\n", + " \n", + " train_time = time.time()-st\n", + " \n", + " model_file_name = modelFileName(return_next=True, fold=fold, dataset=dataset)\n", + " if not DATA_SMALL:\n", + " torch.save(state_dict, PATH_DISK/'models'/model_file_name)\n", + " \n", + " st = time.time()\n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " results = model_parallel(val_loop_fn, loader_val)\n", + " predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n", + " indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n", + " offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n", + " else:\n", + " predictions, indices, offsets = val_loop_fn(model, loader_val, device)\n", + " \n", + " predictions = predictions[np.argsort(indices)]\n", + " offsets = offsets[np.argsort(indices)]\n", + " assert len(predictions) == len(loc_data.index.unique())\n", + " assert len(predictions) == len(offsets)\n", + " assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n", + " \n", + " #val_results = np.zeros((len(loc_data),6))\n", + " val_results = []\n", + " for k, series in enumerate(np.sort(loc_data.index.unique())):\n", + " cnt = series_counts[series]\n", + " #mask = loc_data.SeriesInstanceUID == series\n", + " assert (offsets[k] + cnt) <= 60\n", + " #val_results[mask] = predictions[k,offsets[k]:(offsets[k] + cnt)]\n", + " val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n", + " \n", + " val_results = np.concatenate(val_results)\n", + " assert np.isnan(val_results).sum() == 0\n", + " val_results = val_results[loc_data.my_order]\n", + " assert np.isnan(val_results).sum() == 0\n", + " assert len(val_results) == len(loc_data)\n", + " \n", + " lls = [log_loss(loc_data.loc[~loc_data.test, all_ich[k]].values, val_results[~loc_data.test,k], \\\n", + " eps=1e-7, labels=[0,1]) for k in range(6)]\n", + " ll = (class_weights * np.array(lls)).mean()\n", + " cor = np.corrcoef(loc_data.loc[~loc_data.test,all_ich].values.reshape(-1), \\\n", + " val_results[~loc_data.test].reshape(-1))[0,1]\n", + " auc = roc_auc_score(loc_data.loc[~loc_data.test, all_ich].values.reshape(-1), \\\n", + " val_results[~loc_data.test].reshape(-1))\n", + " \n", + " lls_w = [log_loss(loc_data.loc[~loc_data.test, all_ich[k]].values, val_results[~loc_data.test,k], eps=1e-7, \n", + " labels=[0,1], sample_weight=ww_val[~loc_data.test]) for k in range(6)]\n", + " ll_w = (class_weights * np.array(lls_w)).mean()\n", + " \n", + " if TRAIN_ON_STAGE_1:\n", + " lls = [log_loss(loc_data.loc[loc_data.test, all_ich[k]].values, val_results[loc_data.test,k], \\\n", + " eps=1e-7, labels=[0,1]) for k in range(6)]\n", + " ll2 = (class_weights * np.array(lls)).mean()\n", + "\n", + " lls_w = [log_loss(loc_data.loc[loc_data.test, all_ich[k]].values, val_results[loc_data.test,k], eps=1e-7, \n", + " labels=[0,1], sample_weight=ww_val[loc_data.test]) for k in range(6)]\n", + " ll2_w = (class_weights * np.array(lls_w)).mean()\n", + " else:\n", + " ll2 = 0\n", + " ll2_w = 0\n", + " \n", + " print('v{}, d{}, e{}, f{}, trn ll: {:.4f}, val ll: {:.4f}, ll_w: {:.4f}, cor: {:.4f}, auc: {:.4f}, lr: {}'\\\n", + " .format(VERSION, dataset, i, fold, tr_ll, ll, ll_w, cor, auc, learning_rate))\n", + " valid_time = time.time()-st\n", + " \n", + " epoch_stats = pd.DataFrame([[VERSION, dataset, i, fold, tr_ll, ll, ll_w, ll2, ll2_w, cor, \n", + " lls[0], lls[1], lls[2], lls[3], \n", + " lls[4], lls[5], len(trn_ds), len(val_ds), bs, train_time, valid_time,\n", + " learning_rate, weight_decay]],\n", + " columns = \n", + " ['ver','dataset','epoch','fold','train_loss','val_loss','val_w_loss',\n", + " 'val_loss2','val_w_loss2','cor',\n", + " 'any','epidural','intraparenchymal','intraventricular','subarachnoid','subdural',\n", + " 'train_sz','val_sz','bs','train_time','valid_time','lr','wd'\n", + " ])\n", + "\n", + " stats_filename = PATH_WORK/'stats.f{}.v{}'.format(fold,VERSION)\n", + " if stats_filename.is_file():\n", + " epoch_stats = pd.concat([pd.read_csv(stats_filename), epoch_stats], sort=False)\n", + " #if not DATA_SMALL:\n", + " epoch_stats.to_csv(stats_filename, index=False)\n", + " \n", + " print('total running time', time.time() - st0)\n", + " \n", + " return model, predictions, val_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# OOF" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "def oof_one(dataset, num_iter=1, bs=100, fold=0):\n", + " \n", + " st0 = time.time()\n", + " dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n", + " \n", + " cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n", + " if cur_epoch is None: cur_epoch = 0\n", + " print('completed epochs:', cur_epoch, 'iters starting now:', num_iter)\n", + " \n", + " setSeeds(SEED + cur_epoch)\n", + " \n", + " val_ds = RSNA_DataSet(train_md, mode='valid', bs=bs, fold=fold, dataset=dataset)\n", + " \n", + " loader_val = D.DataLoader(val_ds, num_workers=NUM_WORKERS, batch_size=bs, \n", + " shuffle=True)\n", + " print('dataset valid:', len(val_ds), 'loader valid:', len(loader_val))\n", + " \n", + " #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz, fc_drop_p=0)\n", + " model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n", + " \n", + " model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n", + " if model_file_name is not None:\n", + " print('loading model', model_file_name)\n", + " state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n", + " model.load_state_dict(state_dict)\n", + " else:\n", + " print('starting from scratch')\n", + " \n", + " if (not CLOUD) or CLOUD_SINGLE:\n", + " model = model.to(device)\n", + " else:\n", + " model_parallel = dp.DataParallel(model, device_ids=devices)\n", + " \n", + " loc_data = val_ds.metadata.copy()\n", + " series_counts = loc_data.index.value_counts()\n", + " \n", + " loc_data['orig_idx'] = np.arange(len(loc_data))\n", + " loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " loc_data['my_order'] = np.arange(len(loc_data))\n", + " loc_data = loc_data.sort_values(['orig_idx'])\n", + " \n", + " preds = []\n", + " \n", + " for i in range(num_iter):\n", + " \n", + " val_ds.setFeats(-1, i+100)\n", + " \n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " results = model_parallel(val_loop_fn, loader_val)\n", + " predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n", + " indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n", + " offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n", + " else:\n", + " predictions, indices, offsets = val_loop_fn(model, loader_val, device)\n", + "\n", + " predictions = predictions[np.argsort(indices)]\n", + " offsets = offsets[np.argsort(indices)]\n", + " assert len(predictions) == len(loc_data.index.unique())\n", + " assert len(predictions) == len(offsets)\n", + " assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n", + "\n", + " val_results = []\n", + " for k, series in enumerate(np.sort(loc_data.index.unique())):\n", + " cnt = series_counts[series]\n", + " assert (offsets[k] + cnt) <= 60\n", + " val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n", + "\n", + " val_results = np.concatenate(val_results)\n", + " assert np.isnan(val_results).sum() == 0\n", + " val_results = val_results[loc_data.my_order]\n", + " assert np.isnan(val_results).sum() == 0\n", + " assert len(val_results) == len(loc_data)\n", + " \n", + " preds.append(val_results)\n", + "\n", + " lls = [log_loss(loc_data[all_ich[k]].values, val_results[:,k], eps=1e-7, labels=[0,1])\\\n", + " for k in range(6)]\n", + " ll = (class_weights * np.array(lls)).mean()\n", + " cor = np.corrcoef(loc_data.loc[:,all_ich].values.reshape(-1), val_results.reshape(-1))[0,1]\n", + " auc = roc_auc_score(loc_data.loc[:,all_ich].values.reshape(-1), val_results.reshape(-1))\n", + "\n", + " print('ver {}, iter {}, fold {}, val ll: {:.4f}, cor: {:.4f}, auc: {:.4f}'\n", + " .format(VERSION, i, fold, ll, cor, auc))\n", + " \n", + " print('total running time', time.time() - st0)\n", + " \n", + " return np.stack(preds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def predBounding(pp, target=None):\n", + " if target is not None:\n", + " ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n", + " * class_weights).mean()\n", + " print('initial score', ll)\n", + " \n", + " print('any too low inconsistencies')\n", + " for i in range(1,6):\n", + " print(i, 'class:', (pp[...,0] < pp[...,i]).mean())\n", + " print('total', (pp[...,0] < pp[...,1:].max(-1)).mean())\n", + " \n", + " max_vals = pp[...,1:].max(-1)\n", + " mask = pp[...,0] < max_vals\n", + " pp[mask,0] = max_vals[mask]\n", + " #mask_vals = 0.5*(preds_all[:,:,:,0] + max_vals)[mask]\n", + " #preds_all[mask,0] = mask_vals\n", + " #preds_all[mask] = np.clip(preds_all[mask],0,np.expand_dims(mask_vals,1))\n", + "\n", + " assert (pp[...,0] < pp[...,1:].max(-1)).sum() == 0\n", + "\n", + " if target is not None:\n", + " ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n", + " * class_weights).mean()\n", + " print('any too low corrected score', ll)\n", + " \n", + " print('any too high inconsistencies')\n", + " mask = pp[...,0] > pp[...,1:].sum(-1)\n", + " print('total', mask.mean())\n", + "\n", + " mask_val = 0.5*(pp[mask,0] + pp[...,1:].sum(-1)[mask])\n", + " scaler = mask_val / pp[...,1:].sum(-1)[mask]\n", + " pp[mask,1:] = pp[mask,1:] * np.expand_dims(scaler,1)\n", + " pp[mask,0] = mask_val\n", + "\n", + " if target is not None:\n", + " ll = ((- target * np.log(pp.mean(0)) - (1 - target) * np.log(np.clip(1 - pp.mean(0),1e-15,1-1e-15)))\n", + " * class_weights).mean()\n", + " print('any too high corrected score', ll)\n", + " \n", + " return pp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Selecting runs aggregation" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def scalePreds(x, power = 2, center=0.5):\n", + " res = x.copy()\n", + " res[x > center] = center + (1 - center) * ((res[x > center] - center)/(1 - center))**power\n", + " res[x < center] = center - center * ((center - res[x < center])/center)**power\n", + " return res\n", + "\n", + "def getPredsOOF(ver,aug=32,datasets=range(6,11),datasets5=range(11,13)):\n", + " preds_all = np.zeros((len(datasets) + len(datasets5),aug,len(train_md),6))\n", + " \n", + " if len(datasets) > 0:\n", + " for fold in range(3):\n", + " preds = np.stack([pickle.load(open(PATH_DISK/'ensemble/oof_d{}_f{}_v{}'\n", + " .format(ds, fold, ver),'rb')) for ds in datasets])\n", + " preds_all[:len(datasets),:,train_md.fold == fold,:] = preds\n", + " \n", + " if len(datasets5) > 0:\n", + " for fold in range(5):\n", + " preds = np.stack([pickle.load(open(PATH_DISK/'ensemble/oof_d{}_f{}_v{}'\n", + " .format(ds, fold, ver),'rb')) for ds in datasets5])\n", + " preds_all[len(datasets):,:,train_md.fold5 == fold,:] = preds\n", + " \n", + " preds_all = np.clip(preds_all, 1e-15, 1-1e-15)\n", + " return preds_all" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def getYuvalOOF(names, train_md, names5=None, aug=32):\n", + " \n", + " for num_folds in [3,5]:\n", + " if ('yuval_idx' not in train_md.columns) or ('yuval_idx5' not in train_md.columns):\n", + " print('adding yuval_idx')\n", + " fn = 'OOF_validation_image_ids'\n", + " if num_folds == 5:\n", + " fn += '_5_stage2'\n", + " if num_folds == 3:\n", + " fn += '_3_stage2'\n", + " yuval_oof = pickle.load(open(PATH_DISK/'yuval/OOF_stage2/{}.pkl'.format(fn),'rb'))\n", + " \n", + " df2 = pd.DataFrame()\n", + " col_name = 'yuval_idx'\n", + " fold_col = 'fold'\n", + " if num_folds == 5:\n", + " col_name += '5'\n", + " fold_col += '5'\n", + " \n", + " for fold in range(num_folds):\n", + " \n", + " assert np.all(train_md.loc[train_md.img_id.isin(yuval_oof[fold])][fold_col].values == fold)\n", + " \n", + " df = pd.DataFrame(np.arange(len(yuval_oof[fold])), columns=[col_name])\n", + " df.index = yuval_oof[fold]\n", + " df2 = pd.concat([df2,df],sort=False)\n", + " \n", + " train_md = train_md.join(df2, on = 'img_id')\n", + " assert train_md[col_name].isnull().sum() == 0\n", + " \n", + " preds_y = np.zeros((len(names),len(train_md),6))\n", + " \n", + " for fold in range(3):\n", + " preds = np.stack([torch.sigmoid(torch.stack(\n", + " pickle.load(open(PATH_DISK/'yuval/OOF_stage2'/name.format(fold),'rb')))).numpy() for name in names])\n", + " preds = preds[:,:,train_md.loc[train_md.fold == fold, 'yuval_idx']]\n", + " \n", + " preds_y[:,train_md.fold == fold] = preds.mean(1)\n", + " \n", + " if names5 is not None:\n", + " preds_y5 = np.zeros((len(names5),len(train_md),6))\n", + " for fold in range(5):\n", + " preds = np.stack([torch.sigmoid(torch.stack(\n", + " pickle.load(open(PATH_DISK/'yuval/OOF_stage2'/name.format(fold),'rb')))).numpy() for name in names5])\n", + " preds = preds[:,:,train_md.loc[train_md.fold5 == fold, 'yuval_idx5']]\n", + "\n", + " preds_y5[:,train_md.fold5 == fold] = preds.mean(1)\n", + " \n", + " preds_y = np.concatenate([preds_y, preds_y5], axis=0)\n", + " \n", + " \n", + " preds_y = np.clip(preds_y, 1e-15, 1-1e-15)[:,:,np.array([5,0,1,2,3,4])]\n", + " return preds_y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "def inference_one(dataset, bs = 100, add_seed = 0, fold = 0, anum = 0):\n", + " \n", + " st = time.time()\n", + " dataset_name, filename_add, filename_add2, feat_sz,_,_,_,_ = getDSParams(dataset)\n", + "\n", + " cur_epoch = getCurrentBatch(fold=fold, dataset=dataset)\n", + " if cur_epoch is None: cur_epoch = 0\n", + " print('completed epochs:', cur_epoch)\n", + "\n", + " #model = TabularModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n", + " model = ResNetModel(n_cont = len(meta_cols), feat_sz=feat_sz)\n", + " \n", + " model_file_name = modelFileName(return_last=True, fold=fold, dataset=dataset)\n", + " if model_file_name is not None:\n", + " print('loading model', model_file_name)\n", + " state_dict = torch.load(PATH_DISK/'models'/model_file_name)\n", + " model.load_state_dict(state_dict)\n", + " \n", + " if (not CLOUD) or CLOUD_SINGLE:\n", + " model = model.to(device)\n", + " else:\n", + " model_parallel = dp.DataParallel(model, device_ids=devices)\n", + "\n", + " setSeeds(SEED + cur_epoch + anum + 100*fold)\n", + "\n", + " tst_ds = RSNA_DataSet(test_md, mode='test', bs=bs, fold=fold, dataset=dataset)\n", + " loader_tst = D.DataLoader(tst_ds, num_workers=NUM_WORKERS, batch_size=bs, shuffle=False)\n", + " print('dataset test:', len(tst_ds), 'loader test:', len(loader_tst), 'anum:', anum)\n", + " \n", + " tst_ds.setFeats(-1, epoch = anum+100)\n", + "\n", + " loc_data = tst_ds.metadata.copy()\n", + " series_counts = loc_data.index.value_counts()\n", + "\n", + " loc_data['orig_idx'] = np.arange(len(loc_data))\n", + " loc_data = loc_data.sort_values(['SeriesInstanceUID','pos_idx1'])\n", + " loc_data['my_order'] = np.arange(len(loc_data))\n", + " loc_data = loc_data.sort_values(['orig_idx'])\n", + " \n", + " if CLOUD and (not CLOUD_SINGLE):\n", + " results = model_parallel(test_loop_fn, loader_tst)\n", + " predictions = np.concatenate([results[i][0] for i in range(MAX_DEVICES)])\n", + " indices = np.concatenate([results[i][1] for i in range(MAX_DEVICES)])\n", + " offsets = np.concatenate([results[i][2] for i in range(MAX_DEVICES)])\n", + " else:\n", + " predictions, indices, offsets = test_loop_fn(model, loader_tst, device)\n", + "\n", + " predictions = predictions[np.argsort(indices)]\n", + " offsets = offsets[np.argsort(indices)]\n", + " assert len(predictions) == len(test_md.SeriesInstanceUID.unique())\n", + " assert np.all(indices[np.argsort(indices)] == np.array(range(len(predictions))))\n", + " \n", + " val_results = []\n", + " for k, series in enumerate(np.sort(loc_data.index.unique())):\n", + " cnt = series_counts[series]\n", + " assert (offsets[k] + cnt) <= 60\n", + " val_results.append(predictions[k,offsets[k]:(offsets[k] + cnt)])\n", + "\n", + " val_results = np.concatenate(val_results)\n", + " assert np.isnan(val_results).sum() == 0\n", + " val_results = val_results[loc_data.my_order]\n", + " assert len(val_results) == len(loc_data)\n", + "\n", + " print('test processing time:', time.time() - st)\n", + " \n", + " return val_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ensembling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def getStepX(train_md, preds_all, fold=0, target=0, mode='train'):\n", + " \n", + " print('my_len', my_len)\n", + " X = np.stack([preds_all[:my_len,:,target].mean(0), \n", + " preds_all[my_len:,:,target].mean(0)], axis=0)\n", + " \n", + " if mode == 'train':\n", + " X = X[:,train_md.fold != fold]\n", + " y = train_md.loc[train_md.fold != fold, all_ich[target]].values\n", + " ww = train_md.loc[train_md.fold != fold, 'weights'].values\n", + " elif mode == 'valid':\n", + " X = X[:,train_md.fold == fold]\n", + " y = train_md.loc[train_md.fold == fold, all_ich[target]].values\n", + " ww = train_md.loc[train_md.fold == fold, 'weights'].values\n", + " else:\n", + " X = X\n", + " y = None\n", + " \n", + " ll = None\n", + " llw = None\n", + " auc = None\n", + " if y is not None:\n", + " ll = log_loss(y, X.mean(0), eps=1e-7, labels=[0,1])\n", + " llw = log_loss(y, X.mean(0), eps=1e-7, labels=[0,1], sample_weight=ww)\n", + " auc = roc_auc_score(y, X.mean(0))\n", + " \n", + " return X, y, ww, ll, llw, auc\n", + "\n", + "\n", + "def train_ensemble(train_md, preds_all, fold = 0, target = 0, weighted = False):\n", + " \n", + " print('starting fold',fold,'target',target)\n", + " \n", + " st = time.time()\n", + " \n", + " limit_low = 1e-15\n", + " limit_high = 1 - 1e-5\n", + " \n", + " prior = train_md.loc[train_md.fold != fold, all_ich[target]].mean()\n", + " \n", + " def my_objective(x,preds,vals,ww):\n", + " preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n", + " res = np.average(- vals * np.log(preds_sum) - (1 - vals) * np.log(1 - preds_sum), weights=ww)\n", + " #print('x ',x, x.sum())\n", + " print('obj ',res)\n", + " return res\n", + "\n", + " def my_grad(x,preds,vals,ww):\n", + " preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n", + " res = np.average(- vals * preds / preds_sum + (1 - vals) * preds / (1 - preds_sum), weights=ww, axis=1)\n", + " #print('grad',res)\n", + " return res\n", + "\n", + " def my_hess(x,preds,vals,ww):\n", + " preds_sum = np.clip((preds * x[:,None]).sum(0), limit_low, limit_high)\n", + " res = np.average(preds * np.expand_dims(preds, axis=1) * \n", + " (vals / preds_sum**2 + (1 - vals) / (1 - preds_sum)**2), weights=ww, axis=2)\n", + " return res\n", + " \n", + " X,y,ww,ll_train,llw_train,auc_train = getStepX(train_md, preds_all, fold=fold, target=target)\n", + " \n", + " bnds_low = np.zeros(X.shape[0])\n", + " bnds_high = np.ones(X.shape[0])\n", + " \n", + " initial_sol = np.ones(X.shape[0])/X.shape[0]\n", + " \n", + " bnds = sp.optimize.Bounds(bnds_low, bnds_high)\n", + " cons = sp.optimize.LinearConstraint(np.ones((1,X.shape[0])), 0.98, 1.00)\n", + " \n", + " if weighted:\n", + " ww_in = ww\n", + " else:\n", + " ww_in = np.ones(len(y))\n", + " \n", + " model = sp.optimize.minimize(my_objective, initial_sol, jac=my_grad, hess=my_hess, args=(X, y, ww_in),\n", + " bounds=bnds, method='trust-constr', constraints=cons,\n", + " options={'gtol': 1e-11, 'initial_tr_radius': 0.1, 'initial_barrier_parameter': 0.01})\n", + " \n", + " pickle.dump(model, open(PATH_DISK/'ensemble'/'model.f{}.t{}.v{}'\n", + " .format(fold,target,VERSION),'wb'))\n", + "\n", + " print('model', model.x, 'sum', model.x.sum())\n", + " \n", + " train_preds = (X*np.expand_dims(model.x, axis=1)).sum(0)\n", + " ll_train2 = log_loss(y, train_preds, eps=1e-7, labels=[0,1])\n", + " llw_train2 = log_loss(y, train_preds, eps=1e-7, labels=[0,1], sample_weight=ww)\n", + " auc_train2 = roc_auc_score(y, train_preds)\n", + " \n", + " X,y,ww,ll_val,llw_val,auc_val = getStepX(train_md, preds_all, fold=fold, target=target, mode='valid')\n", + " \n", + " val_preds = (X*np.expand_dims(model.x, axis=1)).sum(0)\n", + " \n", + " ll_val2 = log_loss(y, val_preds, eps=1e-7, labels=[0,1])\n", + " llw_val2 = log_loss(y, val_preds, eps=1e-7, labels=[0,1], sample_weight=ww)\n", + " auc_val2 = roc_auc_score(y, val_preds)\n", + " \n", + " print('v{} f{} t{}: original ll {:.4f}/{:.4f}, ensemble ll {:.4f}/{:.4f}'\n", + " .format(VERSION,fold,target,ll_val,llw_val,ll_val2,llw_val2))\n", + " \n", + " run_time = time.time() - st\n", + " print('running time', run_time)\n", + " \n", + " stats = pd.DataFrame([[VERSION,fold,target,\n", + " ll_train,auc_train,ll_train2,auc_train2,\n", + " ll_val,auc_val,ll_val2,auc_val2,\n", + " llw_train,llw_train2,llw_val,llw_val2,\n", + " run_time,weighted]],\n", + " columns = \n", + " ['version','fold','target',\n", + " 'train_loss','train_auc','train_loss_ens','train_auc_ens', \n", + " 'valid_loss','valid_auc','valid_loss_ens','valid_auc_ens',\n", + " 'train_w_loss','train_w_loss_ens','valid_w_loss','valid_w_loss_ens',\n", + " 'run_time','weighted'\n", + " ])\n", + " \n", + " stats_filename = PATH_DISK/'ensemble'/'stats.v{}'.format(VERSION)\n", + " if stats_filename.is_file():\n", + " stats = pd.concat([pd.read_csv(stats_filename), stats], sort=False)\n", + " stats.to_csv(stats_filename, index=False)\n", + "\n", + "#model.cols = Xt.columns\n", + "#predictions[data_filt['fold'] == i] = (Xv*model.x).sum(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}