--- a +++ b/Serialized/train_base_models.ipynb @@ -0,0 +1,984 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n", + " return f(*args, **kwds)\n", + "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n", + " return f(*args, **kwds)\n" + ] + } + ], + "source": [ + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "\n", + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "import os\n", + "import datetime\n", + "import seaborn as sns\n", + "import pydicom\n", + "import time\n", + "import gc\n", + "import operator \n", + "from apex import amp \n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.utils.data as D\n", + "import torch.nn.functional as F\n", + "from sklearn.model_selection import KFold\n", + "from tqdm import tqdm, tqdm_notebook\n", + "from IPython.core.interactiveshell import InteractiveShell\n", + "InteractiveShell.ast_node_interactivity = \"all\"\n", + "import warnings\n", + "warnings.filterwarnings(action='once')\n", + "import pickle\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "from skimage.io import imread,imshow\n", + "from helper import *\n", + "from apex import amp\n", + "import helper\n", + "import torchvision.models as models\n", + "import pretrainedmodels\n", + "from torch.optim import Adam\n", + "from functools import partial\n", + "from defenitions import *" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set parameters below" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# here you should set which model parameters you want to choose (see definitions.py) and what GPU to use\n", + "params=parameters['se_resnet101_5'] # se_resnet101_5, se_resnext101_32x4d_3, se_resnext101_32x4d_5\n", + "\n", + "device=device_by_name(\"Tesla\") # RTX , cpu\n", + "torch.cuda.set_device(device)\n", + "sendmeemail=Email_Progress(my_gmail,my_pass,to_email,'{} results'.format(params['model_name']))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'model_name': 'se_resnet101',\n", + " 'SEED': 432,\n", + " 'n_splits': 5,\n", + " 'Pre_version': None,\n", + " 'focal': False,\n", + " 'version': 'new_splits',\n", + " 'train_prediction': 'predictions_train_tta',\n", + " 'train_features': 'features_train_tta',\n", + " 'test_prediction': 'predictions_test',\n", + " 'test_features': 'features_test',\n", + " 'num_epochs': 5,\n", + " 'num_pool': 8}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "SEED = params['SEED']\n", + "n_splits=params['n_splits']" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(674252, 15)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "(674252, 15)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>PatientID</th>\n", + " <th>epidural</th>\n", + " <th>intraparenchymal</th>\n", + " <th>intraventricular</th>\n", + " <th>subarachnoid</th>\n", + " <th>subdural</th>\n", + " <th>any</th>\n", + " <th>PID</th>\n", + " <th>StudyI</th>\n", + " <th>SeriesI</th>\n", + " <th>WindowCenter</th>\n", + " <th>WindowWidth</th>\n", + " <th>ImagePositionZ</th>\n", + " <th>ImagePositionX</th>\n", + " <th>ImagePositionY</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>63eb1e259</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>a449357f</td>\n", + " <td>62d125e5b2</td>\n", + " <td>0be5c0d1b3</td>\n", + " <td>['00036', '00036']</td>\n", + " <td>['00080', '00080']</td>\n", + " <td>180.199951</td>\n", + " <td>-125.0</td>\n", + " <td>-8.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2669954a7</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>363d5865</td>\n", + " <td>a20b80c7bf</td>\n", + " <td>3564d584db</td>\n", + " <td>['00047', '00047']</td>\n", + " <td>['00080', '00080']</td>\n", + " <td>922.530821</td>\n", + " <td>-156.0</td>\n", + " <td>45.572849</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>52c9913b1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>9c2b4bd7</td>\n", + " <td>3e3634f8cf</td>\n", + " <td>973274ffc9</td>\n", + " <td>40</td>\n", + " <td>150</td>\n", + " <td>4.455000</td>\n", + " <td>-125.0</td>\n", + " <td>-115.063000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>4e6ff6126</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>3ae81c2d</td>\n", + " <td>a1390c15c2</td>\n", + " <td>e5ccad8244</td>\n", + " <td>['00036', '00036']</td>\n", + " <td>['00080', '00080']</td>\n", + " <td>100.000000</td>\n", + " <td>-99.5</td>\n", + " <td>28.500000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>7858edd88</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>c1867feb</td>\n", + " <td>c73e81ed3a</td>\n", + " <td>28e0531b3a</td>\n", + " <td>40</td>\n", + " <td>100</td>\n", + " <td>145.793000</td>\n", + " <td>-125.0</td>\n", + " <td>-132.190000</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " PatientID epidural intraparenchymal intraventricular subarachnoid \\\n", + "0 63eb1e259 0 0 0 0 \n", + "1 2669954a7 0 0 0 0 \n", + "2 52c9913b1 0 0 0 0 \n", + "3 4e6ff6126 0 0 0 0 \n", + "4 7858edd88 0 0 0 0 \n", + "\n", + " subdural any PID StudyI SeriesI WindowCenter \\\n", + "0 0 0 a449357f 62d125e5b2 0be5c0d1b3 ['00036', '00036'] \n", + "1 0 0 363d5865 a20b80c7bf 3564d584db ['00047', '00047'] \n", + "2 0 0 9c2b4bd7 3e3634f8cf 973274ffc9 40 \n", + "3 0 0 3ae81c2d a1390c15c2 e5ccad8244 ['00036', '00036'] \n", + "4 0 0 c1867feb c73e81ed3a 28e0531b3a 40 \n", + "\n", + " WindowWidth ImagePositionZ ImagePositionX ImagePositionY \n", + "0 ['00080', '00080'] 180.199951 -125.0 -8.000000 \n", + "1 ['00080', '00080'] 922.530821 -156.0 45.572849 \n", + "2 150 4.455000 -125.0 -115.063000 \n", + "3 ['00080', '00080'] 100.000000 -99.5 28.500000 \n", + "4 100 145.793000 -125.0 -132.190000 " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_df = pd.read_csv(data_dir+'train.csv')\n", + "train_df.shape\n", + "train_df=train_df[~train_df.PatientID.isin(bad_images)].reset_index(drop=True)\n", + "train_df=train_df.drop_duplicates().reset_index(drop=True)\n", + "train_df.shape\n", + "train_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>PatientID</th>\n", + " <th>epidural</th>\n", + " <th>intraparenchymal</th>\n", + " <th>intraventricular</th>\n", + " <th>subarachnoid</th>\n", + " <th>subdural</th>\n", + " <th>any</th>\n", + " <th>SeriesI</th>\n", + " <th>PID</th>\n", + " <th>StudyI</th>\n", + " <th>WindowCenter</th>\n", + " <th>WindowWidth</th>\n", + " <th>ImagePositionZ</th>\n", + " <th>ImagePositionX</th>\n", + " <th>ImagePositionY</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>28fbab7eb</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>ebfd7e4506</td>\n", + " <td>cf1b6b11</td>\n", + " <td>93407cadbb</td>\n", + " <td>30</td>\n", + " <td>80</td>\n", + " <td>158.458000</td>\n", + " <td>-125.0</td>\n", + " <td>-135.598000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>877923b8b</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>6d95084e15</td>\n", + " <td>ad8ea58f</td>\n", + " <td>a337baa067</td>\n", + " <td>30</td>\n", + " <td>80</td>\n", + " <td>138.729050</td>\n", + " <td>-125.0</td>\n", + " <td>-101.797981</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>a591477cb</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>8e06b2c9e0</td>\n", + " <td>ecfb278b</td>\n", + " <td>0cfe838d54</td>\n", + " <td>30</td>\n", + " <td>80</td>\n", + " <td>60.830002</td>\n", + " <td>-125.0</td>\n", + " <td>-133.300003</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>42217c898</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>e800f419cf</td>\n", + " <td>e96e31f4</td>\n", + " <td>c497ac5bad</td>\n", + " <td>30</td>\n", + " <td>80</td>\n", + " <td>55.388000</td>\n", + " <td>-125.0</td>\n", + " <td>-146.081000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>a130c4d2f</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>faeb7454f3</td>\n", + " <td>69affa42</td>\n", + " <td>854e4fbc01</td>\n", + " <td>30</td>\n", + " <td>80</td>\n", + " <td>33.516888</td>\n", + " <td>-125.0</td>\n", + " <td>-118.689819</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " PatientID epidural intraparenchymal intraventricular subarachnoid \\\n", + "0 28fbab7eb 0.5 0.5 0.5 0.5 \n", + "1 877923b8b 0.5 0.5 0.5 0.5 \n", + "2 a591477cb 0.5 0.5 0.5 0.5 \n", + "3 42217c898 0.5 0.5 0.5 0.5 \n", + "4 a130c4d2f 0.5 0.5 0.5 0.5 \n", + "\n", + " subdural any SeriesI PID StudyI WindowCenter WindowWidth \\\n", + "0 0.5 0.5 ebfd7e4506 cf1b6b11 93407cadbb 30 80 \n", + "1 0.5 0.5 6d95084e15 ad8ea58f a337baa067 30 80 \n", + "2 0.5 0.5 8e06b2c9e0 ecfb278b 0cfe838d54 30 80 \n", + "3 0.5 0.5 e800f419cf e96e31f4 c497ac5bad 30 80 \n", + "4 0.5 0.5 faeb7454f3 69affa42 854e4fbc01 30 80 \n", + "\n", + " ImagePositionZ ImagePositionX ImagePositionY \n", + "0 158.458000 -125.0 -135.598000 \n", + "1 138.729050 -125.0 -101.797981 \n", + "2 60.830002 -125.0 -133.300003 \n", + "3 55.388000 -125.0 -146.081000 \n", + "4 33.516888 -125.0 -118.689819 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df = pd.read_csv(data_dir+'test.csv')\n", + "test_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "split_sid = train_df.PID.unique()\n", + "splits=list(KFold(n_splits=n_splits,shuffle=True, random_state=SEED).split(split_sid))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "pickle_file=open(outputs_dir+\"PID_splits_{}.pkl\".format(n_splits),'wb')\n", + "pickle.dump((split_sid,splits),pickle_file,protocol=4)\n", + "pickle_file.close()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def my_loss(y_pred,y_true,weights):\n", + " if len(y_pred.shape)==len(y_true.shape): \n", + " # Normal loss\n", + " loss = F.binary_cross_entropy_with_logits(y_pred,y_true,weights.expand_as(y_pred))\n", + " else:\n", + " # Mixup loss (not used here)\n", + " loss0 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,0],weights.repeat(y_pred.shape[0],1),reduction='none')\n", + " loss1 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,1],weights.repeat(y_pred.shape[0],1),reduction='none')\n", + " loss = (y_true[...,2]*loss0+(1.0-y_true[...,2])*loss1).mean() \n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class FocalLoss(nn.Module):\n", + " def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):\n", + " super(FocalLoss, self).__init__()\n", + " self.alpha = alpha\n", + " self.gamma = gamma\n", + " self.logits = logits\n", + " self.reduce = reduce\n", + "\n", + " def forward(self, y_pred,y_true,weights):\n", + " if self.logits:\n", + " BCE_loss = F.binary_cross_entropy_with_logits(y_pred,y_true,weights.expand_as(y_pred), reduction='none')\n", + " else:\n", + " BCE_loss = F.binary_cross_entropy(y_pred,y_true,weights.expand_as(y_pred), reduction='none')\n", + " pt = torch.exp(-BCE_loss)\n", + " F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss\n", + "\n", + " if self.reduce:\n", + " return torch.mean(F_loss)\n", + " else:\n", + " return F_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class parameter_scheduler():\n", + " def __init__(self,model,do_first=['classifier'],num_epoch=1):\n", + " self.model=model\n", + " self.do_first = do_first\n", + " self.num_epoch=num_epoch\n", + " def __call__(self,epoch):\n", + " if epoch>=self.num_epoch:\n", + " for n,p in self.model.named_parameters():\n", + " p.requires_grad=True\n", + " else:\n", + " for n,p in self.model.named_parameters():\n", + " p.requires_grad= any(nd in n for nd in self.do_first)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def get_model(model_name):\n", + " if params['model_name'].startswith('se'):\n", + " return MySENet, pretrainedmodels.__dict__[params['model_name']](num_classes=1000, pretrained='imagenet')\n", + " elif 'Densenet161' in params['model_name']:\n", + " return partial(MyDenseNet, strategy='none'),models.densenet161(pretrained=True)\n", + " elif 'Densenet169' in params['model_name']:\n", + " return partial(MyDenseNet, strategy='none'),models.densenet169(pretrained=True)\n", + " else:\n", + " raise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "%matplotlib nbagg\n", + "for num_split in range(params['n_splits']):\n", + " np.random.seed(SEED+num_split)\n", + " torch.manual_seed(SEED+num_split)\n", + " torch.cuda.manual_seed(SEED+num_split)\n", + " #torch.backends.cudnn.deterministic = True\n", + " idx_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].index.values\n", + " idx_validate = train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].index.values\n", + " idx_train.shape\n", + " idx_validate.shape\n", + "\n", + " klr=1\n", + " batch_size=32\n", + " num_workers=12\n", + " num_epochs=params['num_epochs']\n", + " model_name,version = params['model_name'] , params['version']\n", + " new_model,base_model=get_model(params['model_name'])\n", + " model = new_model(base_model,\n", + " len(hemorrhage_types),\n", + " num_channels=3,\n", + " dropout=0.2,\n", + " wso=((40,80),(80,200),(40,400)),\n", + " dont_do_grad=[],\n", + " extra_pool=params['num_pool'],\n", + " )\n", + " if params['Pre_version'] is not None:\n", + " model.load_state_dict(torch.load(models_dir+models_format.format(model_name,params['Pre_version'],\n", + " num_split),map_location=torch.device(device)))\n", + "\n", + " _=model.to(device)\n", + " weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n", + " loss_func=my_loss if not params['focal'] else FocalLoss()\n", + " targets_dataset=D.TensorDataset(torch.tensor(train_df[hemorrhage_types].values,dtype=torch.float))\n", + " transform=MyTransform(mean_change=15,\n", + " std_change=0,\n", + " flip=True,\n", + " zoom=(0.2,0.2),\n", + " rotate=30,\n", + " out_size=512,\n", + " shift=10,\n", + " normal=False)\n", + " imagedataset = ImageDataset(train_df,transform=transform.random,base_path=train_images_dir,\n", + " window_eq=False,equalize=False,rescale=True)\n", + " transform_val=MyTransform(out_size=512)\n", + " imagedataset_val = ImageDataset(train_df,transform=transform_val.random,base_path=train_images_dir,\n", + " window_eq=False,equalize=False,rescale=True)\n", + " combined_dataset=DatasetCat([imagedataset,targets_dataset])\n", + " combined_dataset_val=DatasetCat([imagedataset_val,targets_dataset])\n", + " optimizer_grouped_parameters=model.get_optimizer_parameters(klr)\n", + " sampling=sampler(train_df[hemorrhage_types].values[idx_train],0.5,[0,0,0,0,0,1])\n", + " sample_ratio=1.02*float(sampling().shape[0])/idx_train.shape[0]\n", + " train_dataset=D.Subset(combined_dataset,idx_train)\n", + " validate_dataset=D.Subset(combined_dataset_val,idx_validate)\n", + " num_train_optimization_steps = num_epochs*(sample_ratio*len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n", + " fig,ax = plt.subplots(figsize=(10,7))\n", + " gr=loss_graph(fig,ax,num_epochs,int(num_train_optimization_steps/num_epochs)+1,limits=(0.05,0.2))\n", + " sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=num_epochs,tau=1)\n", + " optimizer = BertAdam(optimizer_grouped_parameters,lr=klr*1e-3,schedule=sched)\n", + " model, optimizer = amp.initialize(model, optimizer, opt_level=\"O1\",verbosity=0)\n", + " history,best_model= model_train(model,\n", + " optimizer,\n", + " train_dataset,\n", + " batch_size,\n", + " num_epochs,\n", + " loss_func,\n", + " weights=weights,\n", + " do_apex=False,\n", + " model_apexed=True,\n", + " validate_dataset=validate_dataset,\n", + " param_schedualer=None,\n", + " weights_data=None,\n", + " metric=None,\n", + " return_model=True,\n", + " num_workers=num_workers,\n", + " sampler=None,\n", + " pre_process = None,\n", + " graph=gr,\n", + " call_progress=sendmeemail)\n", + "\n", + " torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for num_split in range(params['n_splits']):\n", + " idx_validate = train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].index.values\n", + " model_name,version =params['model_name'] , params['version']\n", + " new_model,base_model=get_model(params['model_name'])\n", + " model = new_model(base_model,\n", + " len(hemorrhage_types),\n", + " num_channels=3,\n", + " dropout=0.2,\n", + " wso=((40,80),(80,200),(40,400)),\n", + " dont_do_grad=[],\n", + " extra_pool=params['num_pool'],\n", + " )\n", + " model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n", + " _=model.to(device)\n", + " transform=MyTransform(mean_change=15,\n", + " std_change=0,\n", + " flip=True,\n", + " zoom=(0.2,0.2),\n", + " rotate=30,\n", + " out_size=512,\n", + " shift=0,\n", + " normal=False)\n", + " indexes=np.arange(train_df.shape[0]).repeat(4)\n", + " train_dataset=D.Subset(ImageDataset(train_df,transform=transform.random,base_path=train_images_dir,\n", + " window_eq=False,equalize=False,rescale=True),indexes)\n", + " pred,features = model_run(model,train_dataset,do_apex=True,batch_size=96,num_workers=14)\n", + "\n", + " pickle_file=open(outputs_dir+outputs_format.format(model_name,version,params['train_features'],num_split),'wb')\n", + " pickle.dump(features,pickle_file,protocol=4)\n", + " pickle_file.close()\n", + "\n", + " pickle_file=open(outputs_dir+outputs_format.format(model_name,version,params['train_prediction'],num_split),'wb')\n", + " pickle.dump(pred,pickle_file,protocol=4)\n", + " pickle_file.close()\n", + "\n", + "\n", + " my_loss(pred[(idx_validate*4+np.arange(4)[:,None]).transpose(1,0)].mean(1),\n", + " torch.tensor(train_df[hemorrhage_types].values[idx_validate],dtype=torch.float),\n", + " torch.tensor([1.,1.,1.,1.,1.,2.]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for num_split in range(params['n_splits']):\n", + " idx_validate = train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].index.values\n", + " model_name,version =params['model_name'] , params['version']\n", + " new_model,base_model=get_model(params['model_name'])\n", + " model = new_model(base_model,\n", + " len(hemorrhage_types),\n", + " num_channels=3,\n", + " dropout=0.2,\n", + " wso=((40,80),(80,200),(40,400)),\n", + " dont_do_grad=[],\n", + " extra_pool=params['num_pool'],\n", + " )\n", + " model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n", + " _=model.to(device)\n", + " transform=MyTransform(mean_change=15,\n", + " std_change=0,\n", + " flip=True,\n", + " zoom=(0.2,0.2),\n", + " rotate=30,\n", + " out_size=512,\n", + " shift=0,\n", + " normal=False)\n", + " indexes=np.arange(test_df.shape[0]).repeat(8)\n", + " imagedataset_test=D.Subset(ImageDataset(test_df,transform=transform.random,base_path=test_images_dir,\n", + " window_eq=False,equalize=False,rescale=True),indexes)\n", + " pred,features = model_run(model,imagedataset_test,do_apex=True,batch_size=96,num_workers=18)\n", + " pickle_file=open(outputs_dir+outputs_format.format(model_name,version,params['test_features'],num_split),'wb')\n", + " pickle.dump(features,pickle_file,protocol=4)\n", + " pickle_file.close()\n", + "\n", + " pickle_file=open(outputs_dir+outputs_format.format(model_name,version,params['test_prediction'],num_split),'wb')\n", + " pickle.dump(pred,pickle_file,protocol=4)\n", + " pickle_file.close()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## create submission file - for reference" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3ae2e5d2ec4d4919b47ab1bf3482807c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, max=3), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "torch.Size([78545, 24, 6])" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preds=[]\n", + "for i in tqdm_notebook(range(params['n_splits'])):\n", + " model_name,version, num_split = params['model_name'] , params['version'],i\n", + " pickle_file=open(outputs_dir+outputs_format.format(model_name,version,params['test_prediction'],num_split),'rb')\n", + " pred=pickle.load(pickle_file)\n", + " pickle_file.close()\n", + " preds.append(pred[(np.arange(pred.shape[0]).reshape(pred.shape[0]//8,8))])\n", + "predss = torch.cat(preds,1)\n", + "predss.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>ID</th>\n", + " <th>Label</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>ID_000012eaf_any</td>\n", + " <td>0.012404</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>ID_000012eaf_epidural</td>\n", + " <td>0.000464</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>ID_000012eaf_intraparenchymal</td>\n", + " <td>0.001818</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>ID_000012eaf_intraventricular</td>\n", + " <td>0.000576</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>ID_000012eaf_subarachnoid</td>\n", + " <td>0.001655</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>ID_000012eaf_subdural</td>\n", + " <td>0.010707</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>ID_0000ca2f6_any</td>\n", + " <td>0.002507</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>ID_0000ca2f6_epidural</td>\n", + " <td>0.000038</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>ID_0000ca2f6_intraparenchymal</td>\n", + " <td>0.000540</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>ID_0000ca2f6_intraventricular</td>\n", + " <td>0.000080</td>\n", + " </tr>\n", + " <tr>\n", + " <th>10</th>\n", + " <td>ID_0000ca2f6_subarachnoid</td>\n", + " <td>0.000490</td>\n", + " </tr>\n", + " <tr>\n", + " <th>11</th>\n", + " <td>ID_0000ca2f6_subdural</td>\n", + " <td>0.001157</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " ID Label\n", + "0 ID_000012eaf_any 0.012404\n", + "1 ID_000012eaf_epidural 0.000464\n", + "2 ID_000012eaf_intraparenchymal 0.001818\n", + "3 ID_000012eaf_intraventricular 0.000576\n", + "4 ID_000012eaf_subarachnoid 0.001655\n", + "5 ID_000012eaf_subdural 0.010707\n", + "6 ID_0000ca2f6_any 0.002507\n", + "7 ID_0000ca2f6_epidural 0.000038\n", + "8 ID_0000ca2f6_intraparenchymal 0.000540\n", + "9 ID_0000ca2f6_intraventricular 0.000080\n", + "10 ID_0000ca2f6_subarachnoid 0.000490\n", + "11 ID_0000ca2f6_subdural 0.001157" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "(471270, 2)" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submission_df=get_submission(test_df,torch.sigmoid(predss).mean(1),False)\n", + "submission_df.head(12)\n", + "submission_df.shape\n", + "sub_num=999\n", + "submission_df.to_csv('/media/hd/notebooks/data/RSNA/submissions/submission{}.csv'.format(sub_num),\n", + " index=False, columns=['ID','Label'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}