--- a
+++ b/Production/Post Full Head Models Train-focal features.ipynb
@@ -0,0 +1,20676 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
+      "  return f(*args, **kwds)\n",
+      "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
+      "  return f(*args, **kwds)\n"
+     ]
+    }
+   ],
+   "source": [
+    "from __future__ import absolute_import\n",
+    "from __future__ import division\n",
+    "from __future__ import print_function\n",
+    "\n",
+    "\n",
+    "import numpy as np # linear algebra\n",
+    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
+    "import os\n",
+    "import datetime\n",
+    "import seaborn as sns\n",
+    "\n",
+    "#import pydicom\n",
+    "import time\n",
+    "from functools import partial\n",
+    "import gc\n",
+    "import operator \n",
+    "import matplotlib.pyplot as plt\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.utils.data as D\n",
+    "import torch.nn.functional as F\n",
+    "from sklearn.model_selection import KFold\n",
+    "from tqdm import tqdm, tqdm_notebook\n",
+    "from IPython.core.interactiveshell import InteractiveShell\n",
+    "InteractiveShell.ast_node_interactivity = \"all\"\n",
+    "import warnings\n",
+    "warnings.filterwarnings(action='once')\n",
+    "import pickle\n",
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "%matplotlib inline\n",
+    "from skimage.io import imread,imshow\n",
+    "from helper import *\n",
+    "import helper\n",
+    "import torchvision.models as models\n",
+    "from torch.optim import Adam\n",
+    "from defenitions import *"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SEED = 432\n",
+    "device=device_by_name(\"Tesla\")\n",
+    "#device=device_by_name(\"RTX\")\n",
+    "#device=device_by_name(\"1060\")\n",
+    "torch.cuda.set_device(device)\n",
+    "#device = \"cpu\"\n",
+    "sendmeemail=Email_Progress(my_gmail,my_pass,to_email,'Densenet161-Copy2-2 results')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_submission(test_df,pred):\n",
+    "    epidural_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_epidural','Label':torch.sigmoid(pred[:,0])})\n",
+    "    intraparenchymal_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_intraparenchymal','Label':torch.sigmoid(pred[:,1])})\n",
+    "    intraventricular_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_intraventricular','Label':torch.sigmoid(pred[:,2])})\n",
+    "    subarachnoid_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_subarachnoid','Label':torch.sigmoid(pred[:,3])})\n",
+    "    subdural_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_subdural','Label':torch.sigmoid(pred[:,4])})\n",
+    "    any_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_any','Label':torch.sigmoid(pred[:,5])}) \n",
+    "    return pd.concat([epidural_df,\n",
+    "                        intraparenchymal_df,\n",
+    "                        intraventricular_df,\n",
+    "                        subarachnoid_df,\n",
+    "                        subdural_df,\n",
+    "                        any_df]).sort_values('ID').reset_index(drop=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_submission_ids(image_ids,pred,do_sigmoid=True):\n",
+    "    if do_sigmoid:\n",
+    "        func = lambda x:torch.sigmoid(x)\n",
+    "    else:\n",
+    "        func = lambda x:x\n",
+    "    epidural_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_epidural','Label':func(pred[:,0])})\n",
+    "    intraparenchymal_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_intraparenchymal','Label':func(pred[:,1])})\n",
+    "    intraventricular_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_intraventricular','Label':func(pred[:,2])})\n",
+    "    subarachnoid_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_subarachnoid','Label':func(pred[:,3])})\n",
+    "    subdural_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_subdural','Label':func(pred[:,4])})\n",
+    "    any_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_any','Label':func(pred[:,5])}) \n",
+    "    return pd.concat([epidural_df,\n",
+    "                        intraparenchymal_df,\n",
+    "                        intraventricular_df,\n",
+    "                        subarachnoid_df,\n",
+    "                        subdural_df,\n",
+    "                        any_df]).sort_values('ID').reset_index(drop=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(674510, 15)"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(674252, 15)"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>PatientID</th>\n",
+       "      <th>epidural</th>\n",
+       "      <th>intraparenchymal</th>\n",
+       "      <th>intraventricular</th>\n",
+       "      <th>subarachnoid</th>\n",
+       "      <th>subdural</th>\n",
+       "      <th>any</th>\n",
+       "      <th>PID</th>\n",
+       "      <th>StudyI</th>\n",
+       "      <th>SeriesI</th>\n",
+       "      <th>WindowCenter</th>\n",
+       "      <th>WindowWidth</th>\n",
+       "      <th>ImagePositionZ</th>\n",
+       "      <th>ImagePositionX</th>\n",
+       "      <th>ImagePositionY</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>63eb1e259</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>a449357f</td>\n",
+       "      <td>62d125e5b2</td>\n",
+       "      <td>0be5c0d1b3</td>\n",
+       "      <td>['00036', '00036']</td>\n",
+       "      <td>['00080', '00080']</td>\n",
+       "      <td>180.199951</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-8.000000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>2669954a7</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>363d5865</td>\n",
+       "      <td>a20b80c7bf</td>\n",
+       "      <td>3564d584db</td>\n",
+       "      <td>['00047', '00047']</td>\n",
+       "      <td>['00080', '00080']</td>\n",
+       "      <td>922.530821</td>\n",
+       "      <td>-156.0</td>\n",
+       "      <td>45.572849</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>52c9913b1</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>9c2b4bd7</td>\n",
+       "      <td>3e3634f8cf</td>\n",
+       "      <td>973274ffc9</td>\n",
+       "      <td>40</td>\n",
+       "      <td>150</td>\n",
+       "      <td>4.455000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-115.063000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>4e6ff6126</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>3ae81c2d</td>\n",
+       "      <td>a1390c15c2</td>\n",
+       "      <td>e5ccad8244</td>\n",
+       "      <td>['00036', '00036']</td>\n",
+       "      <td>['00080', '00080']</td>\n",
+       "      <td>100.000000</td>\n",
+       "      <td>-99.5</td>\n",
+       "      <td>28.500000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>7858edd88</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>c1867feb</td>\n",
+       "      <td>c73e81ed3a</td>\n",
+       "      <td>28e0531b3a</td>\n",
+       "      <td>40</td>\n",
+       "      <td>100</td>\n",
+       "      <td>145.793000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-132.190000</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   PatientID  epidural  intraparenchymal  intraventricular  subarachnoid  \\\n",
+       "0  63eb1e259         0                 0                 0             0   \n",
+       "1  2669954a7         0                 0                 0             0   \n",
+       "2  52c9913b1         0                 0                 0             0   \n",
+       "3  4e6ff6126         0                 0                 0             0   \n",
+       "4  7858edd88         0                 0                 0             0   \n",
+       "\n",
+       "   subdural  any       PID      StudyI     SeriesI        WindowCenter  \\\n",
+       "0         0    0  a449357f  62d125e5b2  0be5c0d1b3  ['00036', '00036']   \n",
+       "1         0    0  363d5865  a20b80c7bf  3564d584db  ['00047', '00047']   \n",
+       "2         0    0  9c2b4bd7  3e3634f8cf  973274ffc9                  40   \n",
+       "3         0    0  3ae81c2d  a1390c15c2  e5ccad8244  ['00036', '00036']   \n",
+       "4         0    0  c1867feb  c73e81ed3a  28e0531b3a                  40   \n",
+       "\n",
+       "          WindowWidth  ImagePositionZ  ImagePositionX  ImagePositionY  \n",
+       "0  ['00080', '00080']      180.199951          -125.0       -8.000000  \n",
+       "1  ['00080', '00080']      922.530821          -156.0       45.572849  \n",
+       "2                 150        4.455000          -125.0     -115.063000  \n",
+       "3  ['00080', '00080']      100.000000           -99.5       28.500000  \n",
+       "4                 100      145.793000          -125.0     -132.190000  "
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "train_df = pd.read_csv(data_dir+'train.csv')\n",
+    "train_df.shape\n",
+    "train_df=train_df[~train_df.PatientID.isin(bad_images)].reset_index(drop=True)\n",
+    "train_df=train_df.drop_duplicates().reset_index(drop=True)\n",
+    "train_df.shape\n",
+    "train_df.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>PatientID</th>\n",
+       "      <th>epidural</th>\n",
+       "      <th>intraparenchymal</th>\n",
+       "      <th>intraventricular</th>\n",
+       "      <th>subarachnoid</th>\n",
+       "      <th>subdural</th>\n",
+       "      <th>any</th>\n",
+       "      <th>SeriesI</th>\n",
+       "      <th>PID</th>\n",
+       "      <th>StudyI</th>\n",
+       "      <th>WindowCenter</th>\n",
+       "      <th>WindowWidth</th>\n",
+       "      <th>ImagePositionZ</th>\n",
+       "      <th>ImagePositionX</th>\n",
+       "      <th>ImagePositionY</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>28fbab7eb</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>ebfd7e4506</td>\n",
+       "      <td>cf1b6b11</td>\n",
+       "      <td>93407cadbb</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>158.458000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-135.598000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>877923b8b</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>6d95084e15</td>\n",
+       "      <td>ad8ea58f</td>\n",
+       "      <td>a337baa067</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>138.729050</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-101.797981</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>a591477cb</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>8e06b2c9e0</td>\n",
+       "      <td>ecfb278b</td>\n",
+       "      <td>0cfe838d54</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>60.830002</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-133.300003</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>42217c898</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>e800f419cf</td>\n",
+       "      <td>e96e31f4</td>\n",
+       "      <td>c497ac5bad</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>55.388000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-146.081000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>a130c4d2f</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>faeb7454f3</td>\n",
+       "      <td>69affa42</td>\n",
+       "      <td>854e4fbc01</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>33.516888</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-118.689819</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   PatientID  epidural  intraparenchymal  intraventricular  subarachnoid  \\\n",
+       "0  28fbab7eb       0.5               0.5               0.5           0.5   \n",
+       "1  877923b8b       0.5               0.5               0.5           0.5   \n",
+       "2  a591477cb       0.5               0.5               0.5           0.5   \n",
+       "3  42217c898       0.5               0.5               0.5           0.5   \n",
+       "4  a130c4d2f       0.5               0.5               0.5           0.5   \n",
+       "\n",
+       "   subdural  any     SeriesI       PID      StudyI WindowCenter WindowWidth  \\\n",
+       "0       0.5  0.5  ebfd7e4506  cf1b6b11  93407cadbb           30          80   \n",
+       "1       0.5  0.5  6d95084e15  ad8ea58f  a337baa067           30          80   \n",
+       "2       0.5  0.5  8e06b2c9e0  ecfb278b  0cfe838d54           30          80   \n",
+       "3       0.5  0.5  e800f419cf  e96e31f4  c497ac5bad           30          80   \n",
+       "4       0.5  0.5  faeb7454f3  69affa42  854e4fbc01           30          80   \n",
+       "\n",
+       "   ImagePositionZ  ImagePositionX  ImagePositionY  \n",
+       "0      158.458000          -125.0     -135.598000  \n",
+       "1      138.729050          -125.0     -101.797981  \n",
+       "2       60.830002          -125.0     -133.300003  \n",
+       "3       55.388000          -125.0     -146.081000  \n",
+       "4       33.516888          -125.0     -118.689819  "
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "test_df = pd.read_csv(data_dir+'test.csv')\n",
+    "test_df.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class PosSeriesSampler():\n",
+    "    def __init__(self,df,split,pos_r,neg_r):\n",
+    "        self.df=df[df.SeriesI.isin(split)]\n",
+    "        self.pos_r=max(min(pos_r,1.0),0)\n",
+    "        self.neg_r=max(min(neg_r,1.0),0)\n",
+    "        cols=['SeriesI','any']\n",
+    "        self.series_df=self.df[cols].groupby('SeriesI').sum().reset_index()\n",
+    "        self.series_df['any']=self.series_df['any']>0\n",
+    "        self.sids=self.series_df.SeriesI.values\n",
+    "        self.any_col=self.series_df['any'].values\n",
+    "        self.any_col=self.align(self.any_col,split,self.sids)\n",
+    "        self.pos=np.where(self.any_col)[0]\n",
+    "        self.neg=np.where(~self.any_col)[0]\n",
+    "\n",
+    "    def align(self,arr,index1,index2):\n",
+    "        return arr[np.argsort(index2)[np.argsort(np.argsort(index1))]]\n",
+    "    def __call__(self):\n",
+    "        tmpp=self.pos.copy()\n",
+    "        np.random.shuffle(tmpp)\n",
+    "        tmpn=self.neg.copy()\n",
+    "        np.random.shuffle(tmpn)\n",
+    "        return np.concatenate([tmpp[:int(self.pos_r*tmpp.shape[0])],tmpn[:int(self.neg_r*tmpn.shape[0])]])\n",
+    "        \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "split_sid = train_df.PID.unique()\n",
+    "splits=list(KFold(n_splits=5,shuffle=True, random_state=SEED).split(split_sid))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def my_loss(y_pred,y_true,weights):\n",
+    "    window=(y_true>=0).to(torch.float)\n",
+    "    loss = (F.binary_cross_entropy_with_logits(y_pred,y_true,reduction='none')*window*weights.expand_as(y_true)).mean()/(window.mean()+1e-7)\n",
+    "    return loss"
+   ]
+  },
+  {
+   "cell_type": "raw",
+   "metadata": {},
+   "source": [
+    "def my_loss(y_pred,y_true,weights):\n",
+    "    if len(y_pred.shape)==len(y_true.shape):\n",
+    "        window=(y_true>=0).to(torch.float)\n",
+    "        loss = (F.binary_cross_entropy_with_logits(y_pred,y_true,reduction='none')*window*weights.expand_as(y_true)).mean()/(window.mean()+1e-7)\n",
+    "    else:\n",
+    "        window0=(y_true[...,0]>=0).to(torch.float)\n",
+    "        window1=(y_true[...,1]>=0).to(torch.float)\n",
+    "        loss0 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,0],reduction='none')*window0*weights.expand_as(y_true[...,0])/(window0.mean()+1e-7)\n",
+    "        loss1 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,1],reduction='none')*window1*weights.expand_as(y_true[...,1])/(window1.mean()+1e-7)\n",
+    "        loss = (y_true[...,2]*loss0+(1.0-y_true[...,2])*loss1).mean() \n",
+    "    return loss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Metric():\n",
+    "    def __init__(self,weights,k=0.03):\n",
+    "        self.weights=weights\n",
+    "        self.k=k\n",
+    "        self.zero()\n",
+    "        \n",
+    "    def zero(self):\n",
+    "        self.loss_sum=0.\n",
+    "        self.loss_count=0.\n",
+    "        self.lossf=0.\n",
+    "        \n",
+    "    def calc(self,y_pred,y_true,prefix=\"\"):\n",
+    "        window=(y_true>=0).to(torch.float)\n",
+    "        loss = (F.binary_cross_entropy_with_logits(y_pred,y_true,reduction='none')*window*self.weights.expand_as(y_true)).mean()/(window.mean()+1e-5)\n",
+    "        self.lossf=self.lossf*(1-self.k)+loss*self.k\n",
+    "        self.loss_sum=self.loss_sum+loss*window.sum()\n",
+    "        self.loss_count=self.loss_count+window.sum()\n",
+    "        return({prefix+'mloss':self.lossf})    \n",
+    "        \n",
+    "    def calc_sums(self,prefix=\"\"):\n",
+    "        return({prefix+'mloss_tot':self.loss_sum/self.loss_count})    \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#features=(features-features.mean())/features.std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class SimpleModel(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(SimpleModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (7,in_size), padding=(3,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 128, 5, padding=2)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 3, padding=1)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(64, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = self.relu0(x)\n",
+    "        x = self.conv1d1(x)\n",
+    "        x = self.bn1(x)\n",
+    "        x = self.relu1(x)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x = self.relu2(x)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class SimpleModel2(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(SimpleModel2, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 128, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(64, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = self.relu0(x)\n",
+    "        x = self.conv1d1(x)\n",
+    "        x = self.bn1(x)\n",
+    "        x = self.relu1(x)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x = self.relu2(x)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ClassModel(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ClassModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 128, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(128, 6, 3, padding=1)\n",
+    "        \n",
+    "        self.conv2d1class=torch.nn.Conv2d(1, 128, (9,in_size), padding=(4,0))\n",
+    "        self.bn0class=torch.nn.BatchNorm1d(128)\n",
+    "        self.maxpool1class=torch.nn.MaxPool1d(3)\n",
+    "        self.conv1d1class=torch.nn.Conv1d(128, 128, 3, padding=1)\n",
+    "        self.bn1class=torch.nn.BatchNorm1d(128)\n",
+    "        self.maxpool2class=torch.nn.MaxPool1d(3)\n",
+    "        self.conv1d2class=torch.nn.Conv1d(128, 64, 2, padding=1)\n",
+    "        self.bn2class=torch.nn.BatchNorm1d(64)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        z=x\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = self.relu0(x)\n",
+    "        x = self.conv1d1(x)\n",
+    "        x = self.bn1(x)\n",
+    "        x = self.relu1(x)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x = self.relu2(x)\n",
+    "        z=self.conv2d1class(z.unsqueeze(1)).squeeze(-1)\n",
+    "        z=self.bn0class(z)\n",
+    "        z=self.maxpool1class(z)\n",
+    "        z=self.conv1d1class(z)\n",
+    "        z=self.maxpool2class(z)\n",
+    "        z=self.conv1d2class(z)\n",
+    "        z=self.bn2class(z)\n",
+    "        z=F.max_pool1d(z,kernel_size=z.shape[-1])\n",
+    "        z=z.expand_as(x)\n",
+    "        x=torch.cat([x,z],1)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ResModel(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x=x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ResModelPool(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModelPool, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size),stride=(1,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "#        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x=x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x)\n",
+    "        x=F.max_pool2d(x,kernel_size=(1,x.shape[-1])).squeeze(-1)        \n",
+    "        x0 = self.bn0(x)\n",
+    "#        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ResDropModel(nn.Module):\n",
+    "    def __init__(self,in_size,dropout=0.2):\n",
+    "        super(ResDropModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "        self.dropout=dropout\n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        x0 = self.relu0(x) \n",
+    "        x = self.conv1d1(x0)\n",
+    "#        x = self.bn1(x)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "#        x = self.bn2(x)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fn = partial(torch.clamp,min=0,max=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class GenReLU(nn.Module):\n",
+    "    def __init__(self,leak=0,add=0,clamp=None):\n",
+    "        super(GenReLU, self).__init__()\n",
+    "        self.leak,self.add=leak,add\n",
+    "        if isinstance(clamp,tuple):\n",
+    "            self.clamp = partial(torch.clamp,min=clamp[0],max=clamp[1])\n",
+    "        elif clamp:\n",
+    "            self.clamp = partial(torch.clamp,min=-clamp,max=clamp)\n",
+    "        else:\n",
+    "            self.clamp=None\n",
+    "            \n",
+    "        \n",
+    "    def forward(self,x):\n",
+    "        x = F.leaky_relu(x,self.leak)\n",
+    "        if self.add:\n",
+    "            x=x+self.add\n",
+    "        if self.clamp:\n",
+    "            x = self.clamp(x)\n",
+    "        return x\n",
+    "        \n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "class ResModelIn(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModelIn, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 3, padding=1)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 3, padding=1)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        x = self.conv1d3(x)\n",
+    "        out = x.transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "class ResModelInR(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModelInR, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0= GenReLU(leak=0.01,add=-0.3,clamp=(-1,3))\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 3, padding=1)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=GenReLU(leak=0.01,add=-0.3,clamp=(-1,3))\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 3, padding=1)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=GenReLU(leak=0,add=0,clamp=(0,2))\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        x = self.conv1d3(x)\n",
+    "        out = x.transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class BaseModel(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super(BaseModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (5,2208), padding=(2,0))\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 6, 1)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.relu0(x)\n",
+    "        out = self.conv1d1(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 23,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 23,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f82c9b2e7b0>"
+      ]
+     },
+     "execution_count": 23,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.067308146525303, 'mloss': tensor(0.0673), 'val_loss': 0.07460971067906642, 'val_mloss_tot': tensor(0.0745)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06447429880488798, 'mloss': tensor(0.0628), 'val_loss': 0.07257415679644565, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060776570781693275, 'mloss': tensor(0.0581), 'val_loss': 0.07166337267476686, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06255646532615573, 'mloss': tensor(0.0640), 'val_loss': 0.06990248672664165, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06148951418209343, 'mloss': tensor(0.0646), 'val_loss': 0.0703263329874192, 'val_mloss_tot': tensor(0.0703)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05986658480128078, 'mloss': tensor(0.0607), 'val_loss': 0.070958069392613, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055580468322294796, 'mloss': tensor(0.0550), 'val_loss': 0.0701158703038735, 'val_mloss_tot': tensor(0.0701)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059968966227883905, 'mloss': tensor(0.0623), 'val_loss': 0.07003260549461963, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05783948200028065, 'mloss': tensor(0.0573), 'val_loss': 0.06955109893211296, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05674921379642664, 'mloss': tensor(0.0564), 'val_loss': 0.06888324899827035, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05789361624488136, 'mloss': tensor(0.0612), 'val_loss': 0.06906770919544661, 'val_mloss_tot': tensor(0.0690)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06456862591634674, 'mloss': tensor(0.0735), 'val_loss': 0.06873089746210952, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05913774677925313, 'mloss': tensor(0.0626), 'val_loss': 0.06975699934956371, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056922164049918886, 'mloss': tensor(0.0534), 'val_loss': 0.07042677809427283, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05989361404293323, 'mloss': tensor(0.0580), 'val_loss': 0.06937360586978647, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06046491604754402, 'mloss': tensor(0.0613), 'val_loss': 0.06932684786579743, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05552499091811479, 'mloss': tensor(0.0536), 'val_loss': 0.06890473218862804, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05803742831475861, 'mloss': tensor(0.0587), 'val_loss': 0.06864881016009924, 'val_mloss_tot': tensor(0.0686)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056701664379929524, 'mloss': tensor(0.0559), 'val_loss': 0.06993526953025436, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05597085406286244, 'mloss': tensor(0.0553), 'val_loss': 0.06879630652921541, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05543938604505698, 'mloss': tensor(0.0556), 'val_loss': 0.068675661115546, 'val_mloss_tot': tensor(0.0686)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055983229030143716, 'mloss': tensor(0.0528), 'val_loss': 0.06827947101460732, 'val_mloss_tot': tensor(0.0682)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055272456923741504, 'mloss': tensor(0.0555), 'val_loss': 0.06822937811363716, 'val_mloss_tot': tensor(0.0682)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057854696961629044, 'mloss': tensor(0.0576), 'val_loss': 0.06831360294730687, 'val_mloss_tot': tensor(0.0683)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "(0.06820877346837399, {'val_loss': 0.06820877346837399, 'val_mloss_tot': tensor(0.0682)})\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [0]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fc17d832bb0>"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "151b551e5bd8486ebc8c003ec810c640",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "45b6a07b15bf47dba1862c2a48d54d7c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e11f3ec29b374f748a7252f526c0bf76",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06608332269690445, 'mloss': tensor(0.0675), 'val_loss': 0.07562760766595603, 'val_mloss_tot': tensor(0.0755)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5063b439b2a74c26abd484615f68e9df",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd45279772d8474eb7ad02aaf353bfa5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06423676105824865, 'mloss': tensor(0.0646), 'val_loss': 0.07186336961652463, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "290c1a8698f44e91bd24452ee387a596",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0cda1650e07446fa9b3ef660d95e866b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06675321332334092, 'mloss': tensor(0.0683), 'val_loss': 0.07225884847575799, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1b85340ed3484402aaf24874342715b5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "de0b50181c8d48b5a1eb59e2129668b7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06393128845837232, 'mloss': tensor(0.0677), 'val_loss': 0.07105442780690889, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fd35dd89e50044638b6214a03bca338b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c7487ffd743745acada18c2c38b09b7d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06005979180662095, 'mloss': tensor(0.0626), 'val_loss': 0.06986814809885497, 'val_mloss_tot': tensor(0.0698)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a23a697440494261a9890197b5a11bb3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3514ff8c0c6a4e81b5f9476c56745e0d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06153353077037748, 'mloss': tensor(0.0604), 'val_loss': 0.07069476890998581, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "85ffa59141674f64997b793c1390a409",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "743d3c2b925f46e0a8fda435926012b8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06221874830854907, 'mloss': tensor(0.0633), 'val_loss': 0.07006131876648093, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4f699a110cb44b8cbe6922d6accf8c88",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "24ec4969c0c241be91f0608be2318ea0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05960000938803774, 'mloss': tensor(0.0590), 'val_loss': 0.07034418839029968, 'val_mloss_tot': tensor(0.0702)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "05b4c9fbfc7f4b6cb877b64bfd7320e9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "755861a35d814cffa5d2a7e49f12c856",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06002905661236823, 'mloss': tensor(0.0602), 'val_loss': 0.06938419877551497, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3a899bbee92e48e688fafd435ba5186f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f2298bed472c42878125ed0499f0c325",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06009711871682274, 'mloss': tensor(0.0593), 'val_loss': 0.06914957696959997, 'val_mloss_tot': tensor(0.0690)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "69dcabc65c04485d9d85eb613fe2e0df",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4a536206405e43ae938594fc58cf7808",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060048284353773074, 'mloss': tensor(0.0598), 'val_loss': 0.06921839565814783, 'val_mloss_tot': tensor(0.0691)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1c1d4c5fdd704a12b048205bcbe5d8ac",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3b35c38593f04d7db2f1097503e30580",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05864350766645002, 'mloss': tensor(0.0573), 'val_loss': 0.06952748630816738, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9189cf8d596c4ae19b6b72e2650e3a1d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "14c9215c4436486f86a83c04174b08e0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05971179006537074, 'mloss': tensor(0.0593), 'val_loss': 0.07106425389647483, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "de9e7f11c2514e21999a4963690c3807",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e7d01da554af462f9e28315a0aa027e0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060222424369697256, 'mloss': tensor(0.0577), 'val_loss': 0.070545250926322, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "950ecf0d9ae945918d5d5c1f3514aa78",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c6f85075f2b241eea9d0b0f7feff0d68",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060113432986098475, 'mloss': tensor(0.0549), 'val_loss': 0.07056683046976105, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "619527a3dd954300923ddc96bd59db5d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "31fae09da7c64044a17ddf44f8adbf83",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060484712198020284, 'mloss': tensor(0.0637), 'val_loss': 0.06944521171196054, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a56764337e4c4ccdad9a19e7ca6fc3f0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8e25f7edf76b4b9285a4687da8ead522",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06327848210394939, 'mloss': tensor(0.0642), 'val_loss': 0.07001245824697738, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "153b81663ea64ccdb2bf215226b40e96",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f75f2bc6097b4434a07d97db04b065d9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06144748066222469, 'mloss': tensor(0.0615), 'val_loss': 0.0695960230116422, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "997716f1b28546da8b8223d84689e798",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e4ae0825df934278b6d1433af86e2a72",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0593364982521064, 'mloss': tensor(0.0605), 'val_loss': 0.06964275708111624, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "60944f9053ea40218169c47f60baae0a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "344d63f34f4c409d938f19a185a80773",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059294644648837144, 'mloss': tensor(0.0599), 'val_loss': 0.06983057542626435, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "567af8d8ff6f40b98c228457b1ab11fe",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5325c953b1eb4ba29f1c7afb6803b12b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05712647444272771, 'mloss': tensor(0.0580), 'val_loss': 0.06951607308195283, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f818e26ab12146e4af06caffe4ee32dd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d2c3e0b52dea4e6bad973ef1208d2d86",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06183372723698266, 'mloss': tensor(0.0638), 'val_loss': 0.06949221545364707, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2858c2295ab4404d84be220c8080e898",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7f1ef429112140f9a7dca6736efd608f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058854192375465825, 'mloss': tensor(0.0561), 'val_loss': 0.06922740353426586, 'val_mloss_tot': tensor(0.0691)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "375e3ab8a4b64683baa1bb8a9fcd895f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "36f42b7d319a4c62a5b23c02de74b2a8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058280917904031275, 'mloss': tensor(0.0580), 'val_loss': 0.06973961847058187, 'val_mloss_tot': tensor(0.0696)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1b127b8d9aa0410f8390bb61de978934",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "0.06914957696959997\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [1]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f82c9b2e7b0>"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06811236130900705, 'mloss': tensor(0.0645), 'val_loss': 0.07296392299717495, 'val_mloss_tot': tensor(0.0729)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06484684772493185, 'mloss': tensor(0.0617), 'val_loss': 0.07182952273365578, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06794251192230603, 'mloss': tensor(0.0716), 'val_loss': 0.07096840750788025, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06221648544032464, 'mloss': tensor(0.0616), 'val_loss': 0.07048706298973034, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06506258486547185, 'mloss': tensor(0.0633), 'val_loss': 0.07060831906990363, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06387764076858017, 'mloss': tensor(0.0631), 'val_loss': 0.06990973167272233, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061879498025504805, 'mloss': tensor(0.0605), 'val_loss': 0.07004213785971346, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06117829148018609, 'mloss': tensor(0.0626), 'val_loss': 0.06998244177565159, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06260664986052528, 'mloss': tensor(0.0603), 'val_loss': 0.06963591108679289, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06240710515919824, 'mloss': tensor(0.0639), 'val_loss': 0.06943066220594804, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06785973891998759, 'mloss': tensor(0.0759), 'val_loss': 0.06950196877747834, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061713547811765246, 'mloss': tensor(0.0610), 'val_loss': 0.06934147141119729, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06595809476685713, 'mloss': tensor(0.0685), 'val_loss': 0.07012372287480455, 'val_mloss_tot': tensor(0.0701)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06549197154344466, 'mloss': tensor(0.0669), 'val_loss': 0.0700294252198476, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06389820492204888, 'mloss': tensor(0.0645), 'val_loss': 0.07042309583206227, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06161852050980389, 'mloss': tensor(0.0608), 'val_loss': 0.06996223300710501, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06161852050980389, 'mloss': tensor(0.0608), 'val_loss': 0.06996223300710501, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05957129002817653, 'mloss': tensor(0.0593), 'val_loss': 0.0694331979597749, 'val_mloss_tot': tensor(0.0694)}\n",
+      "{'loss': 0.05957129002817653, 'mloss': tensor(0.0593), 'val_loss': 0.0694331979597749, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06426514238477328, 'mloss': tensor(0.0648), 'val_loss': 0.0692660001555254, 'val_mloss_tot': tensor(0.0692)}\n",
+      "{'loss': 0.06426514238477328, 'mloss': tensor(0.0648), 'val_loss': 0.0692660001555254, 'val_mloss_tot': tensor(0.0692)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05897115044344752, 'mloss': tensor(0.0572), 'val_loss': 0.06915576889100465, 'val_mloss_tot': tensor(0.0691)}\n",
+      "{'loss': 0.05897115044344752, 'mloss': tensor(0.0572), 'val_loss': 0.06915576889100465, 'val_mloss_tot': tensor(0.0691)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06303324305739905, 'mloss': tensor(0.0640), 'val_loss': 0.06928441191454104, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06303324305739905, 'mloss': tensor(0.0640), 'val_loss': 0.06928441191454104, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05982911799751764, 'mloss': tensor(0.0626), 'val_loss': 0.06928381115680764, 'val_mloss_tot': tensor(0.0693)}\n",
+      "{'loss': 0.05982911799751764, 'mloss': tensor(0.0626), 'val_loss': 0.06928381115680764, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06297458258909464, 'mloss': tensor(0.0649), 'val_loss': 0.06922990236401196, 'val_mloss_tot': tensor(0.0692)}\n",
+      "{'loss': 0.06297458258909464, 'mloss': tensor(0.0649), 'val_loss': 0.06922990236401196, 'val_mloss_tot': tensor(0.0692)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06105891367840735, 'mloss': tensor(0.0599), 'val_loss': 0.06914063720296511, 'val_mloss_tot': tensor(0.0691)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06105891367840735, 'mloss': tensor(0.0599), 'val_loss': 0.06914063720296511, 'val_mloss_tot': tensor(0.0691)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "(0.06907161608383602, {'val_loss': 0.06907161608383602, 'val_mloss_tot': tensor(0.0690)})\n",
+      "se_resnext101_32x4d new_splits_focal 3\n",
+      "\n",
+      "(0.06907161608383602, {'val_loss': 0.06907161608383602, 'val_mloss_tot': tensor(0.0690)})\n",
+      "se_resnext101_32x4d new_splits_focal 3\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f82c9b2e7b0>"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f82c9b2e7b0>"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06256559370040231, 'mloss': tensor(0.0611), 'val_loss': 0.07610935135362823, 'val_mloss_tot': tensor(0.0760)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06256559370040231, 'mloss': tensor(0.0611), 'val_loss': 0.07610935135362823, 'val_mloss_tot': tensor(0.0760)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060436919679122895, 'mloss': tensor(0.0599), 'val_loss': 0.07404250497967921, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060436919679122895, 'mloss': tensor(0.0599), 'val_loss': 0.07404250497967921, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05892525780124182, 'mloss': tensor(0.0590), 'val_loss': 0.07430134960602908, 'val_mloss_tot': tensor(0.0742)}\n",
+      "{'loss': 0.05892525780124182, 'mloss': tensor(0.0590), 'val_loss': 0.07430134960602908, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059684565029891966, 'mloss': tensor(0.0573), 'val_loss': 0.07428747899334145, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059684565029891966, 'mloss': tensor(0.0573), 'val_loss': 0.07428747899334145, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05794791039229029, 'mloss': tensor(0.0566), 'val_loss': 0.07421989964710579, 'val_mloss_tot': tensor(0.0742)}\n",
+      "{'loss': 0.05794791039229029, 'mloss': tensor(0.0566), 'val_loss': 0.07421989964710579, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05629156700258014, 'mloss': tensor(0.0548), 'val_loss': 0.07526278103602531, 'val_mloss_tot': tensor(0.0752)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05629156700258014, 'mloss': tensor(0.0548), 'val_loss': 0.07526278103602531, 'val_mloss_tot': tensor(0.0752)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05727883602370432, 'mloss': tensor(0.0588), 'val_loss': 0.07382143096670081, 'val_mloss_tot': tensor(0.0738)}\n",
+      "{'loss': 0.05727883602370432, 'mloss': tensor(0.0588), 'val_loss': 0.07382143096670081, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056193441329850005, 'mloss': tensor(0.0577), 'val_loss': 0.07381229711215577, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056193441329850005, 'mloss': tensor(0.0577), 'val_loss': 0.07381229711215577, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05566271497102168, 'mloss': tensor(0.0557), 'val_loss': 0.07362644695371512, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05566271497102168, 'mloss': tensor(0.0557), 'val_loss': 0.07362644695371512, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05699799664221157, 'mloss': tensor(0.0579), 'val_loss': 0.07320393128778602, 'val_mloss_tot': tensor(0.0731)}\n",
+      "{'loss': 0.05699799664221157, 'mloss': tensor(0.0579), 'val_loss': 0.07320393128778602, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05484745116440305, 'mloss': tensor(0.0546), 'val_loss': 0.073520051290617, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05484745116440305, 'mloss': tensor(0.0546), 'val_loss': 0.073520051290617, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05722927713211164, 'mloss': tensor(0.0565), 'val_loss': 0.07346819397863612, 'val_mloss_tot': tensor(0.0734)}\n",
+      "{'loss': 0.05722927713211164, 'mloss': tensor(0.0565), 'val_loss': 0.07346819397863612, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05590619426693891, 'mloss': tensor(0.0549), 'val_loss': 0.07642889548413699, 'val_mloss_tot': tensor(0.0764)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05590619426693891, 'mloss': tensor(0.0549), 'val_loss': 0.07642889548413699, 'val_mloss_tot': tensor(0.0764)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058468750980138914, 'mloss': tensor(0.0608), 'val_loss': 0.07517170486399202, 'val_mloss_tot': tensor(0.0751)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058468750980138914, 'mloss': tensor(0.0608), 'val_loss': 0.07517170486399202, 'val_mloss_tot': tensor(0.0751)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05905847136177575, 'mloss': tensor(0.0574), 'val_loss': 0.07365074729524003, 'val_mloss_tot': tensor(0.0736)}\n",
+      "{'loss': 0.05905847136177575, 'mloss': tensor(0.0574), 'val_loss': 0.07365074729524003, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05954825335870781, 'mloss': tensor(0.0605), 'val_loss': 0.07438604836924703, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05954825335870781, 'mloss': tensor(0.0605), 'val_loss': 0.07438604836924703, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05907009247290204, 'mloss': tensor(0.0620), 'val_loss': 0.0740199738857718, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05907009247290204, 'mloss': tensor(0.0620), 'val_loss': 0.0740199738857718, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055304533789174985, 'mloss': tensor(0.0548), 'val_loss': 0.07353268412216643, 'val_mloss_tot': tensor(0.0735)}\n",
+      "{'loss': 0.055304533789174985, 'mloss': tensor(0.0548), 'val_loss': 0.07353268412216643, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05557909771779207, 'mloss': tensor(0.0574), 'val_loss': 0.07400266126706646, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05557909771779207, 'mloss': tensor(0.0574), 'val_loss': 0.07400266126706646, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053877604987457785, 'mloss': tensor(0.0521), 'val_loss': 0.07405381055632759, 'val_mloss_tot': tensor(0.0740)}\n",
+      "{'loss': 0.053877604987457785, 'mloss': tensor(0.0521), 'val_loss': 0.07405381055632759, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05681332851133147, 'mloss': tensor(0.0572), 'val_loss': 0.07383451637140186, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05681332851133147, 'mloss': tensor(0.0572), 'val_loss': 0.07383451637140186, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05557124688829585, 'mloss': tensor(0.0576), 'val_loss': 0.07446255968104987, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05557124688829585, 'mloss': tensor(0.0576), 'val_loss': 0.07446255968104987, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054036270286534786, 'mloss': tensor(0.0531), 'val_loss': 0.07390088248486462, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054036270286534786, 'mloss': tensor(0.0531), 'val_loss': 0.07390088248486462, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054714143508980885, 'mloss': tensor(0.0551), 'val_loss': 0.07393516586009474, 'val_mloss_tot': tensor(0.0739)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054714143508980885, 'mloss': tensor(0.0551), 'val_loss': 0.07393516586009474, 'val_mloss_tot': tensor(0.0739)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "0.07320393128778602\n",
+      "\n",
+      "0.07320393128778602\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [2,3]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal 4\n",
+      "se_resnext101_32x4d new_splits_focal 4\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f82c9b2e7b0>"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f82c9b2e7b0>"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eb8fdd39608c4f1799ef40ee6f3bd345",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eb8fdd39608c4f1799ef40ee6f3bd345",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1659724fe1f64386afeb8b2d5b0457b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1659724fe1f64386afeb8b2d5b0457b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "346f25002b1941d693730f9b50bac587",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "346f25002b1941d693730f9b50bac587",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06799117253421788, 'mloss': tensor(0.0699), 'val_loss': 0.07957097498873468, 'val_mloss_tot': tensor(0.0795)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9b1e4a82fe3f42e7b168b92f2bc6e2a7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06799117253421788, 'mloss': tensor(0.0699), 'val_loss': 0.07957097498873468, 'val_mloss_tot': tensor(0.0795)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9b1e4a82fe3f42e7b168b92f2bc6e2a7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "254b1b689f904da69c17f66acece1019",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "254b1b689f904da69c17f66acece1019",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06200766211658852, 'mloss': tensor(0.0633), 'val_loss': 0.07399804443347188, 'val_mloss_tot': tensor(0.0739)}\n",
+      "{'loss': 0.06200766211658852, 'mloss': tensor(0.0633), 'val_loss': 0.07399804443347188, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c24af84c26ce437498deb7faf9bfbc72",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c24af84c26ce437498deb7faf9bfbc72",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f0d642fad3484b4ea255a70a9ba9fd6e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f0d642fad3484b4ea255a70a9ba9fd6e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06278177597006018, 'mloss': tensor(0.0632), 'val_loss': 0.07242505058178655, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "305d96c882a74ab38384376c31141290",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06278177597006018, 'mloss': tensor(0.0632), 'val_loss': 0.07242505058178655, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "305d96c882a74ab38384376c31141290",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e8ff21bd037443dbce1fb9f4d7afdef",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e8ff21bd037443dbce1fb9f4d7afdef",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06102479354704403, 'mloss': tensor(0.0603), 'val_loss': 0.07268230391191205, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ef0737917a0b4f44b47fcb71a6c47c3a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06102479354704403, 'mloss': tensor(0.0603), 'val_loss': 0.07268230391191205, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ef0737917a0b4f44b47fcb71a6c47c3a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "234a67f9efc44744850f6223269d7067",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "234a67f9efc44744850f6223269d7067",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05735043145398115, 'mloss': tensor(0.0560), 'val_loss': 0.07298917406302714, 'val_mloss_tot': tensor(0.0730)}\n",
+      "{'loss': 0.05735043145398115, 'mloss': tensor(0.0560), 'val_loss': 0.07298917406302714, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c76a15ff2bea455282678ca88ea2b9db",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c76a15ff2bea455282678ca88ea2b9db",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4291cc5dc2644a6db94f552eb1288300",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4291cc5dc2644a6db94f552eb1288300",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05788960416695958, 'mloss': tensor(0.0587), 'val_loss': 0.07227114161951581, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "807c608dcfea49deb5bd8e8a892b607b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05788960416695958, 'mloss': tensor(0.0587), 'val_loss': 0.07227114161951581, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "807c608dcfea49deb5bd8e8a892b607b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f42ecc9c4bf94409bc20892da4a6b0b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f42ecc9c4bf94409bc20892da4a6b0b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055881445187220706, 'mloss': tensor(0.0515), 'val_loss': 0.07193080599416002, 'val_mloss_tot': tensor(0.0719)}\n",
+      "{'loss': 0.055881445187220706, 'mloss': tensor(0.0515), 'val_loss': 0.07193080599416002, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c4b38d607734fc8b6641d680e445cd3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c4b38d607734fc8b6641d680e445cd3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1fde8a4a045b4dcf90570a19d074c721",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1fde8a4a045b4dcf90570a19d074c721",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05803228518772075, 'mloss': tensor(0.0589), 'val_loss': 0.07329496418159863, 'val_mloss_tot': tensor(0.0733)}\n",
+      "{'loss': 0.05803228518772075, 'mloss': tensor(0.0589), 'val_loss': 0.07329496418159863, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "abd5d2ec585648a99ae4509c214157d6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "abd5d2ec585648a99ae4509c214157d6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e74dc7a37564a46845f6e5c48183496",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e74dc7a37564a46845f6e5c48183496",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057731436834648917, 'mloss': tensor(0.0571), 'val_loss': 0.07216293245237246, 'val_mloss_tot': tensor(0.0721)}\n",
+      "{'loss': 0.057731436834648917, 'mloss': tensor(0.0571), 'val_loss': 0.07216293245237246, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "da84cd1fa22c45499c264de2bf63d4fc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "da84cd1fa22c45499c264de2bf63d4fc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6a170475b3ac4362912afa6a42f47472",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6a170475b3ac4362912afa6a42f47472",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05696376784139452, 'mloss': tensor(0.0580), 'val_loss': 0.07230398032874592, 'val_mloss_tot': tensor(0.0723)}\n",
+      "{'loss': 0.05696376784139452, 'mloss': tensor(0.0580), 'val_loss': 0.07230398032874592, 'val_mloss_tot': tensor(0.0723)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "70fb01b2d63f43da9514f228fb855c76",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "70fb01b2d63f43da9514f228fb855c76",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "de3e670ac76e400b927bad7e6c071519",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "de3e670ac76e400b927bad7e6c071519",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06086215436857173, 'mloss': tensor(0.0655), 'val_loss': 0.07201771383449013, 'val_mloss_tot': tensor(0.0720)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e9a3fcc718d409385c02db507e80dd0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06086215436857173, 'mloss': tensor(0.0655), 'val_loss': 0.07201771383449013, 'val_mloss_tot': tensor(0.0720)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6e9a3fcc718d409385c02db507e80dd0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cb374774fee74f27829259bc22d1e51c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cb374774fee74f27829259bc22d1e51c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05501155299962578, 'mloss': tensor(0.0547), 'val_loss': 0.07210505726984759, 'val_mloss_tot': tensor(0.0721)}\n",
+      "{'loss': 0.05501155299962578, 'mloss': tensor(0.0547), 'val_loss': 0.07210505726984759, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0aafb47ef06545378025f0189375514c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0aafb47ef06545378025f0189375514c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8993c496d28a4e3780e765b05f63bf92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8993c496d28a4e3780e765b05f63bf92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05803203116808863, 'mloss': tensor(0.0555), 'val_loss': 0.07404777356366879, 'val_mloss_tot': tensor(0.0740)}\n",
+      "{'loss': 0.05803203116808863, 'mloss': tensor(0.0555), 'val_loss': 0.07404777356366879, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b031fdcd980540ab980599419ab2569a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b031fdcd980540ab980599419ab2569a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "15031815cabc44d29988073e0703740c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "15031815cabc44d29988073e0703740c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05975140749898141, 'mloss': tensor(0.0617), 'val_loss': 0.07301857568670501, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "acd6c98428f046c89734bb0a4334e3d3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05975140749898141, 'mloss': tensor(0.0617), 'val_loss': 0.07301857568670501, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "acd6c98428f046c89734bb0a4334e3d3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bb063d6d039d4022915014fa8b2f237a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bb063d6d039d4022915014fa8b2f237a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059001187086733194, 'mloss': tensor(0.0638), 'val_loss': 0.07365117994276618, 'val_mloss_tot': tensor(0.0736)}\n",
+      "{'loss': 0.059001187086733194, 'mloss': tensor(0.0638), 'val_loss': 0.07365117994276618, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "423b3c19d56b4b72a12387a17c334ec6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "423b3c19d56b4b72a12387a17c334ec6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd19a302f52341ddad8ff46d80bbcea6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd19a302f52341ddad8ff46d80bbcea6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05959516691341214, 'mloss': tensor(0.0597), 'val_loss': 0.07345079880976967, 'val_mloss_tot': tensor(0.0734)}\n",
+      "{'loss': 0.05959516691341214, 'mloss': tensor(0.0597), 'val_loss': 0.07345079880976967, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3355570eec5b4c78b4f94005e1b6c702",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3355570eec5b4c78b4f94005e1b6c702",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4fd52799dec2442eb967e50aa7143fa4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4fd52799dec2442eb967e50aa7143fa4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05663001460106719, 'mloss': tensor(0.0569), 'val_loss': 0.07301431294032919, 'val_mloss_tot': tensor(0.0730)}\n",
+      "{'loss': 0.05663001460106719, 'mloss': tensor(0.0569), 'val_loss': 0.07301431294032919, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4096592276214175b98458f38e88b07b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4096592276214175b98458f38e88b07b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f984f719777c4864b64268099847934f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f984f719777c4864b64268099847934f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056310181319974936, 'mloss': tensor(0.0559), 'val_loss': 0.07183515317813467, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fdb0468fa97f4f3f8746578bdb059a52",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056310181319974936, 'mloss': tensor(0.0559), 'val_loss': 0.07183515317813467, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fdb0468fa97f4f3f8746578bdb059a52",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "18f13203215f42939fc3ff7112abc731",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "18f13203215f42939fc3ff7112abc731",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05257013759244393, 'mloss': tensor(0.0482), 'val_loss': 0.07403652549453592, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1698feab15e14156acb9ca56339a486d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05257013759244393, 'mloss': tensor(0.0482), 'val_loss': 0.07403652549453592, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1698feab15e14156acb9ca56339a486d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "825b9647ac7b4ba59d7aaab1cbfbe1ee",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "825b9647ac7b4ba59d7aaab1cbfbe1ee",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05776624110454715, 'mloss': tensor(0.0568), 'val_loss': 0.07294346797110339, 'val_mloss_tot': tensor(0.0729)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e9e697902e4047caa64b50dd64c9bf32",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05776624110454715, 'mloss': tensor(0.0568), 'val_loss': 0.07294346797110339, 'val_mloss_tot': tensor(0.0729)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e9e697902e4047caa64b50dd64c9bf32",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd3f4b58af7c464ca44a9c33a1df1030",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd3f4b58af7c464ca44a9c33a1df1030",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054283337441584284, 'mloss': tensor(0.0521), 'val_loss': 0.07240734455978823, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "28c3fbde8e3242baad95622cb2fe959a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054283337441584284, 'mloss': tensor(0.0521), 'val_loss': 0.07240734455978823, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "28c3fbde8e3242baad95622cb2fe959a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4c086563edf74a11b28e840105078013",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4c086563edf74a11b28e840105078013",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05544162225223156, 'mloss': tensor(0.0521), 'val_loss': 0.0725896455543606, 'val_mloss_tot': tensor(0.0725)}\n",
+      "{'loss': 0.05544162225223156, 'mloss': tensor(0.0521), 'val_loss': 0.0725896455543606, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "62918abef86d4bafa4fbb3df334ad861",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "62918abef86d4bafa4fbb3df334ad861",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a5ca096e4475425e88df69b794055570",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a5ca096e4475425e88df69b794055570",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05750886167725116, 'mloss': tensor(0.0591), 'val_loss': 0.07257047332521634, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4338efa827b5425aad6748cf6823de68",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05750886167725116, 'mloss': tensor(0.0591), 'val_loss': 0.07257047332521634, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4338efa827b5425aad6748cf6823de68",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a382e5885f3f43e08c84ad7ca99d4a8d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a382e5885f3f43e08c84ad7ca99d4a8d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05547570323233913, 'mloss': tensor(0.0540), 'val_loss': 0.0725123643053145, 'val_mloss_tot': tensor(0.0725)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5a50d3e788344b42b9339c1bd5464feb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05547570323233913, 'mloss': tensor(0.0540), 'val_loss': 0.0725123643053145, 'val_mloss_tot': tensor(0.0725)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5a50d3e788344b42b9339c1bd5464feb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "(0.07146647619649105, {'val_loss': 0.07146647619649105, 'val_mloss_tot': tensor(0.0714)})\n",
+      "\n",
+      "(0.07146647619649105, {'val_loss': 0.07146647619649105, 'val_mloss_tot': tensor(0.0714)})\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [4]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f7d8609a790>"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06801814952788739, 'mloss': tensor(0.0687), 'val_loss': 0.07283909177901793, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06713587075366854, 'mloss': tensor(0.0654), 'val_loss': 0.07061037322012137, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06425997749901873, 'mloss': tensor(0.0616), 'val_loss': 0.06953470255153216, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06607530081518949, 'mloss': tensor(0.0674), 'val_loss': 0.06936297067155947, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06531331479201585, 'mloss': tensor(0.0692), 'val_loss': 0.06891423660349481, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06329029732242795, 'mloss': tensor(0.0644), 'val_loss': 0.06916797859784292, 'val_mloss_tot': tensor(0.0691)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05990499216438457, 'mloss': tensor(0.0586), 'val_loss': 0.06880586169355986, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06398531669358798, 'mloss': tensor(0.0665), 'val_loss': 0.06930466414410241, 'val_mloss_tot': tensor(0.0692)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06175779544215915, 'mloss': tensor(0.0617), 'val_loss': 0.06855630030741497, 'val_mloss_tot': tensor(0.0685)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060580648818581675, 'mloss': tensor(0.0596), 'val_loss': 0.06814609903043935, 'val_mloss_tot': tensor(0.0681)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06176182649278729, 'mloss': tensor(0.0647), 'val_loss': 0.06819111633171536, 'val_mloss_tot': tensor(0.0681)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06841749242982949, 'mloss': tensor(0.0779), 'val_loss': 0.06822031848322677, 'val_mloss_tot': tensor(0.0682)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06295308192498555, 'mloss': tensor(0.0662), 'val_loss': 0.06890888082775838, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05993890128465238, 'mloss': tensor(0.0555), 'val_loss': 0.06957911991191154, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06348086497780833, 'mloss': tensor(0.0609), 'val_loss': 0.06896918778478796, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06467006555587114, 'mloss': tensor(0.0655), 'val_loss': 0.06863239274113155, 'val_mloss_tot': tensor(0.0686)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059530082159649944, 'mloss': tensor(0.0577), 'val_loss': 0.06846679936790345, 'val_mloss_tot': tensor(0.0684)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06176354525914924, 'mloss': tensor(0.0622), 'val_loss': 0.06890479189780901, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060365150610968976, 'mloss': tensor(0.0599), 'val_loss': 0.06855630791575021, 'val_mloss_tot': tensor(0.0685)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0592349039480856, 'mloss': tensor(0.0580), 'val_loss': 0.06847410744094119, 'val_mloss_tot': tensor(0.0684)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05948680389709346, 'mloss': tensor(0.0594), 'val_loss': 0.06839205849611638, 'val_mloss_tot': tensor(0.0683)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06009998340068295, 'mloss': tensor(0.0568), 'val_loss': 0.06836220631185844, 'val_mloss_tot': tensor(0.0683)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058883878685634106, 'mloss': tensor(0.0587), 'val_loss': 0.06831089384207616, 'val_mloss_tot': tensor(0.0682)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06160811806986899, 'mloss': tensor(0.0609), 'val_loss': 0.06832941194844185, 'val_mloss_tot': tensor(0.0683)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=245), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.06814609903043935\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [0]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnet101' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f7d8609a790>"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06706305390403391, 'mloss': tensor(0.0679), 'val_loss': 0.0753338036360219, 'val_mloss_tot': tensor(0.0752)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06582605521661761, 'mloss': tensor(0.0664), 'val_loss': 0.07156634539132938, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06860142615312698, 'mloss': tensor(0.0703), 'val_loss': 0.07097313453365738, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06528153796833137, 'mloss': tensor(0.0681), 'val_loss': 0.0703708862963443, 'val_mloss_tot': tensor(0.0703)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06120083370721801, 'mloss': tensor(0.0629), 'val_loss': 0.06973114719536776, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06337633186302069, 'mloss': tensor(0.0623), 'val_loss': 0.07033007722347975, 'val_mloss_tot': tensor(0.0702)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06427232268276552, 'mloss': tensor(0.0658), 'val_loss': 0.06967794288648292, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06179677971778461, 'mloss': tensor(0.0612), 'val_loss': 0.06956295766867697, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06247514259205962, 'mloss': tensor(0.0627), 'val_loss': 0.06899789812741801, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06231222968290985, 'mloss': tensor(0.0618), 'val_loss': 0.0688553408292743, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06183718991156706, 'mloss': tensor(0.0612), 'val_loss': 0.06882287827320396, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060959088810676396, 'mloss': tensor(0.0603), 'val_loss': 0.06888944027402127, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06159196596112574, 'mloss': tensor(0.0612), 'val_loss': 0.0707330879328462, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dfc3c0641a414996980276e5da0036c8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06268981381400358, 'mloss': tensor(0.0606), 'val_loss': 0.06999726186040789, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ce95e8a507ad4ea49967c842ded1a5b1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b54d55d04b5f45c4ab8b7bf9d17cdb90",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062062416725145605, 'mloss': tensor(0.0566), 'val_loss': 0.07001044309775656, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3ed5c778babb4bafb82f90880fcb37fa",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a09e35a5b85c46e18074f82dd047d867",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062336952298358256, 'mloss': tensor(0.0656), 'val_loss': 0.06903052127842481, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "79a6a57bfd3f40599c5c443f7d2a9cdd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3b7019b04e594af9b013d6c0d01f88c9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06487696107365519, 'mloss': tensor(0.0658), 'val_loss': 0.06963334473160405, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c67075d87aba49cbb2b4c32e40b88c0f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2fb644ca91a44946a61924312091c702",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0638985172978266, 'mloss': tensor(0.0644), 'val_loss': 0.06882744966230045, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "913c04c1b22944f68351ba87634c4fc4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1b9b3769c9b54ec69c17f07072fb2f5b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061455932627082495, 'mloss': tensor(0.0627), 'val_loss': 0.06872678914417824, 'val_mloss_tot': tensor(0.0686)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a970501495ff4d06bd9bbad7dd0cb382",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "43e410688446486c8a4f1e23e67de2eb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060658901023823195, 'mloss': tensor(0.0610), 'val_loss': 0.0690681101831918, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4a62d5c3c35a4a38867e88c950197a35",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ca11453944ab45398cb884e9c75f25e7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05841321709674182, 'mloss': tensor(0.0592), 'val_loss': 0.06885637667728588, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1ed2ebf4b2f34d19ac918d8a9a05f938",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1f55db4f2f104a3bb3f89a050790c4bf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06359648862342149, 'mloss': tensor(0.0650), 'val_loss': 0.06877168237309282, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "57abcb5d31e7455a978c42a0431061b4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bfbbde2ea78a4602b3a0f58aeb1d9fd6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060309677044462016, 'mloss': tensor(0.0571), 'val_loss': 0.06850194214687993, 'val_mloss_tot': tensor(0.0684)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8136b1745c9c49f6bc1e908bc24fe388",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=982), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "50f2b949c8ca4226b1752242e30c189a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06016276608588028, 'mloss': tensor(0.0603), 'val_loss': 0.06883820214619239, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "10ba909ed9814f04918af41e7370cab5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=240), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.06850194214687993\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [1]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnet101' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f7d8609a790>"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06642876009938492, 'mloss': tensor(0.0636), 'val_loss': 0.072511745141706, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06366300585780155, 'mloss': tensor(0.0610), 'val_loss': 0.07134100720889655, 'val_mloss_tot': tensor(0.0713)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06673878450100637, 'mloss': tensor(0.0703), 'val_loss': 0.07100178783208977, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060346572511635735, 'mloss': tensor(0.0591), 'val_loss': 0.07054926062130976, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06367121056713963, 'mloss': tensor(0.0621), 'val_loss': 0.07070935500959154, 'val_mloss_tot': tensor(0.0707)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06302450786947127, 'mloss': tensor(0.0626), 'val_loss': 0.06980195504270102, 'val_mloss_tot': tensor(0.0698)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061013097087315483, 'mloss': tensor(0.0599), 'val_loss': 0.06992431138434753, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05968484094438953, 'mloss': tensor(0.0610), 'val_loss': 0.06988293575009837, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06063475550829022, 'mloss': tensor(0.0583), 'val_loss': 0.07002946159421553, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06066792534185903, 'mloss': tensor(0.0620), 'val_loss': 0.06975773179730181, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06653188573258909, 'mloss': tensor(0.0744), 'val_loss': 0.06959410336034501, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060091046433399846, 'mloss': tensor(0.0586), 'val_loss': 0.06941451203937714, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06524702591013883, 'mloss': tensor(0.0680), 'val_loss': 0.07024231006349871, 'val_mloss_tot': tensor(0.0702)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06453835472863666, 'mloss': tensor(0.0663), 'val_loss': 0.07012349858065607, 'val_mloss_tot': tensor(0.0701)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06245861472701878, 'mloss': tensor(0.0631), 'val_loss': 0.0701827489118646, 'val_mloss_tot': tensor(0.0702)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06308804128224449, 'mloss': tensor(0.0616), 'val_loss': 0.07006279917185003, 'val_mloss_tot': tensor(0.0700)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059855410874735414, 'mloss': tensor(0.0587), 'val_loss': 0.07055743700006471, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05801896215887184, 'mloss': tensor(0.0574), 'val_loss': 0.06963524807561264, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06259847604033456, 'mloss': tensor(0.0636), 'val_loss': 0.06957239694968771, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056891917504811235, 'mloss': tensor(0.0550), 'val_loss': 0.06958192928132378, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0608769425321196, 'mloss': tensor(0.0617), 'val_loss': 0.06982114527992873, 'val_mloss_tot': tensor(0.0698)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05789883858367275, 'mloss': tensor(0.0598), 'val_loss': 0.06970915781628144, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06090872075104362, 'mloss': tensor(0.0627), 'val_loss': 0.06969711087598854, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=974), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059125143559363255, 'mloss': tensor(0.0586), 'val_loss': 0.06957406156624739, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.06941451203937714\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [2]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnet101' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal 3\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f7d8609a790>"
+      ]
+     },
+     "execution_count": 27,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06795507064264683, 'mloss': tensor(0.0664), 'val_loss': 0.07484717968255893, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06469026769797935, 'mloss': tensor(0.0645), 'val_loss': 0.07318627575916223, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06352702081886957, 'mloss': tensor(0.0639), 'val_loss': 0.0729283714651695, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06340994861053613, 'mloss': tensor(0.0603), 'val_loss': 0.07237432615022313, 'val_mloss_tot': tensor(0.0723)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062144562581749045, 'mloss': tensor(0.0612), 'val_loss': 0.07279423628284855, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05987148049207686, 'mloss': tensor(0.0579), 'val_loss': 0.07325074485251222, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06129221278543854, 'mloss': tensor(0.0633), 'val_loss': 0.07186072752009466, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05999444709855952, 'mloss': tensor(0.0617), 'val_loss': 0.07201866069465082, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058762111584181605, 'mloss': tensor(0.0587), 'val_loss': 0.07181323329979157, 'val_mloss_tot': tensor(0.0717)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06002566909838612, 'mloss': tensor(0.0604), 'val_loss': 0.07136683014496306, 'val_mloss_tot': tensor(0.0713)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05842553575623436, 'mloss': tensor(0.0578), 'val_loss': 0.07143153050593788, 'val_mloss_tot': tensor(0.0714)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061793036670504314, 'mloss': tensor(0.0611), 'val_loss': 0.07143326022005717, 'val_mloss_tot': tensor(0.0714)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059324955313279804, 'mloss': tensor(0.0585), 'val_loss': 0.07413294031017567, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06170147087446787, 'mloss': tensor(0.0634), 'val_loss': 0.07230502058996162, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06274846450562306, 'mloss': tensor(0.0608), 'val_loss': 0.07213753208395887, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06388790179251697, 'mloss': tensor(0.0654), 'val_loss': 0.07249575732352181, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06288017048053764, 'mloss': tensor(0.0658), 'val_loss': 0.07167686430401489, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05876169492238801, 'mloss': tensor(0.0583), 'val_loss': 0.07158819205494078, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05935071261533107, 'mloss': tensor(0.0614), 'val_loss': 0.07162750098419177, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05741542338425759, 'mloss': tensor(0.0553), 'val_loss': 0.07180500715459529, 'val_mloss_tot': tensor(0.0717)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06110199735701955, 'mloss': tensor(0.0621), 'val_loss': 0.07118308365146522, 'val_mloss_tot': tensor(0.0711)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059016008194910985, 'mloss': tensor(0.0607), 'val_loss': 0.07190553739392122, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05790989192603154, 'mloss': tensor(0.0569), 'val_loss': 0.07156969215835399, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=977), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05835562291724042, 'mloss': tensor(0.0585), 'val_loss': 0.07158829269884917, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=244), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07118308365146522\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [3]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnet101' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal 4\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 44,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 44,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f7d8609a790>"
+      ]
+     },
+     "execution_count": 44,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "10161438203649c1bc03aacc370629f0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f82f5cf8cd2f445b9da55a36ff86454a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a2caa9e2396b436e8695f19747ef5dd0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.07015490436772923, 'mloss': tensor(0.0726), 'val_loss': 0.07512244625160328, 'val_mloss_tot': tensor(0.0751)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8062deb24ede4d89b4ea945b89ca6a3f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9e9451b09cde4aeaa6abf51811bcfe6d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0653958427481258, 'mloss': tensor(0.0676), 'val_loss': 0.07276976575343473, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8efff9ba667e4c92a4aed1fa70c6a08c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d94246e864764b5b84a32cab0e2519db",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06489033201858484, 'mloss': tensor(0.0655), 'val_loss': 0.0710169263221776, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "29e15bf616964f9caf62117c2e71e561",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "08b8473c9549460fa6a7c61ed34fa10d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06400799257209767, 'mloss': tensor(0.0632), 'val_loss': 0.07157295975864296, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "77e89510b1f045ae811ba4ed824bd140",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "79d35682b6b640099b98c5c7cb8c8211",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060340242052691215, 'mloss': tensor(0.0595), 'val_loss': 0.0719826279004455, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8127f4e7261841b7af3758891440ea59",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c95367a5f358497d99a65077514105cd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06047398680038099, 'mloss': tensor(0.0608), 'val_loss': 0.07088910720050938, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1f3e0a9b0ca5463c8fb28348a61279d1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2b814f6309fe46ce88bf790d56a5d4dd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05890102664341058, 'mloss': tensor(0.0550), 'val_loss': 0.07075476742572026, 'val_mloss_tot': tensor(0.0707)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7e186c7b9a8242a492db0f0393cb2734",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "474d6e312d7046f5ac4abf9d8f3d447e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061000964191458826, 'mloss': tensor(0.0623), 'val_loss': 0.07117423560726739, 'val_mloss_tot': tensor(0.0711)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e0368541633742539c5657c11d7d992c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "93ea089de3cb48a19af522d798d8488c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06013221482477172, 'mloss': tensor(0.0591), 'val_loss': 0.07059903822929753, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9bdfd1f0c8354a389fb9b056ba3b0147",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "96b84789a54b44c09eacfb79b6482954",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059454674310153166, 'mloss': tensor(0.0607), 'val_loss': 0.0707068234561426, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3d9213a8b4c84e69b7b531ac8c0885d6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "49d7443855db41c19124e47ed3709c80",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06383416286507557, 'mloss': tensor(0.0687), 'val_loss': 0.07039150315137045, 'val_mloss_tot': tensor(0.0703)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dc90d4ef65c245e6bca6d66a8662f2cd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b47c6e5d59b546afa73e1a0aa152baa1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057146727435353206, 'mloss': tensor(0.0561), 'val_loss': 0.07043666631993857, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7454e6095c4543f9aabea54c7498742d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3516623873cd490eb3c24d979a6f1843",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06096357998306074, 'mloss': tensor(0.0587), 'val_loss': 0.07214031475200224, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bc6c2b06892f484f90568976db2b1acf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1f2c2707261542c88c7f554a11caa1a9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06184994533020218, 'mloss': tensor(0.0633), 'val_loss': 0.07149467728127111, 'val_mloss_tot': tensor(0.0714)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c3960cdfb4e7406a9f0f8b355b09cc52",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fc0e8a851e4a4168b084b417461183ce",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06131159870290688, 'mloss': tensor(0.0657), 'val_loss': 0.07184541405287952, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "784da236e9a342c4868f05cb8cf66178",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7a75c46914124f449150604c3448af57",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0625149542190651, 'mloss': tensor(0.0627), 'val_loss': 0.07091063708594815, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ec831e3c5aa64a5fa6fad0bd21aa4da1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "55da618d867e4423917f3eb992d0bd63",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05930726220701768, 'mloss': tensor(0.0594), 'val_loss': 0.07102426019335083, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7b6578e96373459e9e5de480827c4a8c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d427cde1bfbd43fd87fda6eac7ac538a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05897989988945656, 'mloss': tensor(0.0586), 'val_loss': 0.0702291799007881, 'val_mloss_tot': tensor(0.0702)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f28c4b3617494b798fbd9d1718aa81f4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8c706cc5f96b425a89e34b785174972a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05527941090788124, 'mloss': tensor(0.0510), 'val_loss': 0.07152184016761268, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d643170cc5c5468eb5d0f7a26da42011",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "885f6058282742028589c114427f8107",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060419764565040876, 'mloss': tensor(0.0595), 'val_loss': 0.07105222300660273, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cbeb196011c94a05a0441049a021109f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "52bdf908a5d94614b9dfb046a04996a4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05662640915423674, 'mloss': tensor(0.0543), 'val_loss': 0.07070589078748636, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b61ed5e183834e3ab41e070087fcd03f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "24f5d3becd2c4ed9b7fea1149453471c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05799873370762432, 'mloss': tensor(0.0545), 'val_loss': 0.07046750419384797, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "178be3496dc5435badcb55dcb9052060",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "717f703b10594274b48f932730222435",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05963707518961883, 'mloss': tensor(0.0607), 'val_loss': 0.07042344067215557, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0883df5cff7d467b94448167f5cea054",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=975), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d993bca6f8ff4e9a97cfed9e08297d98",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05862137302747463, 'mloss': tensor(0.0576), 'val_loss': 0.07037128592173943, 'val_mloss_tot': tensor(0.0703)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7f3ab6b3b4344d02af67f5caa71b7a7a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=247), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.06990519024220555, {'val_loss': 0.06990519024220555, 'val_mloss_tot': tensor(0.0698)})\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in [4]: #range(5):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnet101' , 'new_splits_focal'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "#    sampling=PosSeriesSampler(train_df,split_train,1,0.75)\n",
+    "    sample_ratio= 1.0 #1.01*float(sampling().shape[0])/split_train.shape[0]\n",
+    "    num_train_optimization_steps = sample_ratio*num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cd1b5b50d05d441fa0998850a3d9dc40",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal features_train_tta 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "17011529951e47458a21a8dcca7ca7cd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(134734)"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "35210c31a58e478da5b60bad141692df",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6bd321717d5245bab4187400fbe65c4c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8b54375a333c497fabe69842bea07bf1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a1fd41b7c5b34604981130ca3d0768d4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "12cb92f2dfd34db18b58b20ba06dde1f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9f66ae195db74cdba6035e63bc81d5e9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8dc95506c405435ca7767ba4755b0049",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4416498fb70e41f082c015d67f183722",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9310e2a95d204c0b9d88e1b0adee26a5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a97a385930374cec8a1082912480d206",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "92ff4174c891469099073afca0938843",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a521da0b16b44100a17381a46cd1b312",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7156774051f348c1a317c407f97f345e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "631dcbe04bfa4975b6a8f8b5606ac816",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d9f230f246b94d9a8e221048fbf77de2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "641bc6a055224f1091d2b913315bc749",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b51f742aab7b4f45b6e28928bf2a071e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c0e8936c23a46ebbe3a75a368bd3fad",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ced768aa65c84c6b8ad7d5370b7a846d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b599f8523f634dc6a2aad858c78126fb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "53c69b6528534a57a2d9543bb792e36e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5bec92a5d55846b298eaa0aaf46d8955",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2e2bb4d0f8124cec872f0decc08f10ff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cfd6b21a2e1448a69dd175ed27a79092",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f80cd6f27954413486a5618225f642ed",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e7a35900d518465fb5d805416da7e89a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c843eed021b4ea38cdd7598c0ba8dea",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "969fd3ad3d8e40e681bb2c163fc119ff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d254d5b7fda044c597fbe2034ba76074",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8ae89f085e164e23bbfa937267798f7f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e0ee3c24a91f4ed2bdb96a0731544579",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "97b550cff6c642b6b9ea885ce1e5a0a7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal features_train_tta 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ef8f02406395452e8a606810358c42ae",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(131494)"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0f77ea43d5354bb6b01cd9363813734b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "298df3b3a5d243ae8dd8eba03def2bb1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "29b0f5e3050e41ad91b145eed6d92a30",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c807b9fd901d43c59750ac15c40b8097",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "12d9dc52070943e0be65a108daecce82",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "acf19cfff3f44b978acac4ee1fdb75f7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b07394ec4c9a4ad1beec2b5b4a8fc12a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "507ae5dd96864ca58677817f4d7e029a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f4e8696ca3ec46b0ba39a3b960d97671",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cf06a5c4e9ac449a9e44dd1f474ff699",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e1f491ab01f74d5ea0dc212dcc528804",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fa4c7b6f21da4de7901348223330f627",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8ec0e684ba20466bbc8f488926a28a9c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3d26cd7b2d7943afb17eff1c31766ff6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4e547013c2694109ae8dfb9d2040d168",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d43ca27d7ace4c7fb35adb8cc6bfe750",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c7e8182b95d64fa1aabf0f645a60bb30",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c60f603a10244182a338b0c203fa4697",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "10aa5e458d98416790ea38b7d464af65",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2f6103974ed841ef9b2af8c6d41f3546",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f15a88d7f34546fdbe065facbac634eb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dd63b103e8aa4702803684a070f95ca1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e85b06944e764981a0f759c054f749bb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "03cf23c4fec74058a24ed26cd7421ffa",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3adc185f8a594086951834173e623b95",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "be5ce69b3d984a34bfdf816d5a6a7049",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "db4fa8a9cdfd4d3ea7d022d0a996eace",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c0fb2fc3cf0747dd98700cd727f70f8d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "37fd9a8865f94a1f9b6c6a835455c6cb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dd237fd6dc654ea39a705af529dfc724",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a8c7fc1823b54baf9dda842651a9f564",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c4a2c3d605841368e5a2e9ce03c4b16",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal features_train_tta 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "af6ffb6fefcb481582f2e85e852c41a3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(136434)"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3d27d152d4954709b711d5872c578dee",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e34e057b3c574c4384ca4c8caa6eb32d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "ename": "KeyboardInterrupt",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-25-91dc370428ee>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     42\u001b[0m         \u001b[0mwins\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mleave\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mpr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mvalid_dataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdo_apex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m             \u001b[0mpred_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwins\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mpickle_file\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs_dir\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0moutputs_format\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mversion\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'OOF_pred'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_split\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/kaggle/RSNA/Production/helper/mytraining.py\u001b[0m in \u001b[0;36mmodel_run\u001b[0;34m(model, dataset, do_apex, batch_size, num_workers)\u001b[0m\n\u001b[1;32m    183\u001b[0m     \u001b[0mres_list\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    184\u001b[0m     \u001b[0mdata_loader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mD\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0mbatchs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    186\u001b[0m         \u001b[0my_preds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbatchs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatchs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatchs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    187\u001b[0m         \u001b[0mres_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0my_preds\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_preds\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0my_preds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/site-packages/tqdm/_tqdm_notebook.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    219\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    222\u001b[0m                 \u001b[0;31m# return super(tqdm...) will not catch exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m                 \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/site-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    977\u001b[0m \"\"\", fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m    978\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 979\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    980\u001b[0m                 \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    981\u001b[0m                 \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    802\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    803\u001b[0m             \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshutdown\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtasks_outstanding\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 804\u001b[0;31m             \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    805\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtasks_outstanding\u001b[0m \u001b[0;34m-=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    806\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    769\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    770\u001b[0m             \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 771\u001b[0;31m                 \u001b[0msuccess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_get_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    772\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0msuccess\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    773\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    722\u001b[0m         \u001b[0;31m#   (bool: whether successfully get data, any: data if successful else None)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    723\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 724\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_queue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    725\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/queues.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m    102\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m                     \u001b[0mtimeout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeadline\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m                     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_poll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    105\u001b[0m                         \u001b[0;32mraise\u001b[0m \u001b[0mEmpty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m                 \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_poll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\u001b[0m in \u001b[0;36mpoll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    255\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_closed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    256\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_readable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_poll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    259\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__enter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\u001b[0m in \u001b[0;36m_poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    412\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    413\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_poll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 414\u001b[0;31m         \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    415\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    416\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m    909\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    910\u001b[0m             \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 911\u001b[0;31m                 \u001b[0mready\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mselector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    912\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mready\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    913\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfileobj\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevents\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mready\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/anaconda3/envs/RSNA/lib/python3.6/selectors.py\u001b[0m in \u001b[0;36mselect\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    374\u001b[0m             \u001b[0mready\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    375\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 376\u001b[0;31m                 \u001b[0mfd_event_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_poll\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpoll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    377\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mInterruptedError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    378\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mready\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "model_names=['se_resnext101_32x4d']\n",
+    "types=['features_train_tta']\n",
+    "versions=['new_splits_focal']\n",
+    "num_splits=[5]\n",
+    "seeds=[432]\n",
+    "for model_name,type_,version_,n,SEED in zip(model_names,types,versions,num_splits,seeds):\n",
+    "    for num_split in tqdm_notebook(range(n)):\n",
+    "        split_sid = train_df.PID.unique()\n",
+    "        splits=list(KFold(n_splits=n,shuffle=True, random_state=SEED).split(split_sid))\n",
+    "\n",
+    "        pred_list=[]\n",
+    "        print(model_name,version_,type_,num_split) \n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        features=pickle.load(pickle_file)\n",
+    "        pickle_file.close()\n",
+    "        features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "        split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "        model=ResModelPool(features.shape[-1])\n",
+    "        version=version_+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "\n",
+    "        model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "\n",
+    "        valid_dataset=FullHeadDataset(train_df,\n",
+    "                                      split_validate,\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                      multi =3)\n",
+    "\n",
+    "        win_dataset=FullHeadDataset(train_df,\n",
+    "                                      split_validate,\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                       target_columns=hemorrhage_types)\n",
+    "        win_list=[]\n",
+    "        dl = D.DataLoader(win_dataset,batch_size=128,num_workers=16)\n",
+    "        for _,win in tqdm_notebook(dl):\n",
+    "            win_list.append(win.reshape(win.shape[0]*win.shape[1],-1))    \n",
+    "        wins = torch.cat(win_list,0).sum(1)>=0\n",
+    "        wins.sum()\n",
+    "        for i in tqdm_notebook(range(32),leave=False):\n",
+    "            pr = model_run(model,valid_dataset,do_apex=False,batch_size=128)\n",
+    "            pred_list.append(pr.reshape(pr.shape[0]*pr.shape[1],-1)[wins])\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'OOF_pred',num_split),'wb')\n",
+    "        pickle.dump(pred_list,pickle_file,protocol=4)\n",
+    "        pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_train_tta 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(134734)"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_train_tta 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(131494)"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_train_tta 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(136434)"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_train_tta 3\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(135582)"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_train_tta 4\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(136008)"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Traceback (most recent call last):\n",
+      "Traceback (most recent call last):\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/queues.py\", line 230, in _feed\n",
+      "    close()\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 177, in close\n",
+      "    self._close()\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 361, in _close\n",
+      "    _close(self._handle)\n",
+      "OSError: [Errno 9] Bad file descriptor\n",
+      "Traceback (most recent call last):\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/queues.py\", line 230, in _feed\n",
+      "    close()\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 177, in close\n",
+      "    self._close()\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 361, in _close\n",
+      "    _close(self._handle)\n",
+      "OSError: [Errno 9] Bad file descriptor\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/queues.py\", line 240, in _feed\n",
+      "    send_bytes(obj)\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n",
+      "    self._send_bytes(m[offset:offset + size])\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 404, in _send_bytes\n",
+      "    self._send(header + buf)\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n",
+      "    n = write(self._handle, buf)\n",
+      "OSError: [Errno 9] Bad file descriptor\n",
+      "Traceback (most recent call last):\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/queues.py\", line 230, in _feed\n",
+      "    close()\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 177, in close\n",
+      "    self._close()\n",
+      "  File \"/home/reina/anaconda3/envs/RSNA/lib/python3.6/multiprocessing/connection.py\", line 361, in _close\n",
+      "    _close(self._handle)\n",
+      "OSError: [Errno 9] Bad file descriptor\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=31), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "model_names=['se_resnet101']\n",
+    "types=['features_train_tta']\n",
+    "versions=['new_splits_focal']\n",
+    "num_splits=[5]\n",
+    "seeds=[432]\n",
+    "for model_name,type_,version_,n,SEED in zip(model_names,types,versions,num_splits,seeds):\n",
+    "    for num_split in tqdm_notebook(range(n)):\n",
+    "        split_sid = train_df.PID.unique()\n",
+    "        splits=list(KFold(n_splits=n,shuffle=True, random_state=SEED).split(split_sid))\n",
+    "\n",
+    "        pred_list=[]\n",
+    "        print(model_name,version_,type_,num_split) \n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        features=pickle.load(pickle_file)\n",
+    "        pickle_file.close()\n",
+    "        features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "        split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "        model=ResModelPool(features.shape[-1])\n",
+    "        version=version_+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "\n",
+    "        model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "\n",
+    "        valid_dataset=FullHeadDataset(train_df,\n",
+    "                                      split_validate,\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                      multi =3)\n",
+    "\n",
+    "        win_dataset=FullHeadDataset(train_df,\n",
+    "                                      split_validate,\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                       target_columns=hemorrhage_types)\n",
+    "        win_list=[]\n",
+    "        dl = D.DataLoader(win_dataset,batch_size=128,num_workers=16)\n",
+    "        for _,win in tqdm_notebook(dl):\n",
+    "            win_list.append(win.reshape(win.shape[0]*win.shape[1],-1))    \n",
+    "        wins = torch.cat(win_list,0).sum(1)>=0\n",
+    "        wins.sum()\n",
+    "        for i in tqdm_notebook(range(32),leave=False):\n",
+    "            pr = model_run(model,valid_dataset,do_apex=False,batch_size=128)\n",
+    "            pred_list.append(pr.reshape(pr.shape[0]*pr.shape[1],-1)[wins])\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'OOF_pred',num_split),'wb')\n",
+    "        pickle.dump(pred_list,pickle_file,protocol=4)\n",
+    "        pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pickle_file=open(outputs_dir+'OOF_validation_image_ids_5.pkl','rb')\n",
+    "image_id_folds=pickle.load(pickle_file)\n",
+    "pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "targets=res_df.loc[image_id_folds[0]]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(134734, 6)"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "targets.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 102,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0680)"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0683)"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0693)"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0710)"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0698)"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0693)"
+      ]
+     },
+     "execution_count": 102,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "cols=['PatientID']\n",
+    "cols.extend(hemorrhage_types)\n",
+    "res_df=train_df[cols].set_index('PatientID')\n",
+    "model_names=['se_resnet101']\n",
+    "types=['OOF_pred']\n",
+    "versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "num_splits=[5]\n",
+    "#model_names=['se_resnext101_32x4d','se_resnet101']\n",
+    "#types=['test_pred_ensamble','test_pred_ensamble']\n",
+    "#versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3']\n",
+    "#num_splits=[5,5]\n",
+    "#model_names=['se_resnet101']\n",
+    "#types=['OOF_pred']\n",
+    "#versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5]\n",
+    "l=0.\n",
+    "for num_split in range(n):\n",
+    "    pred_list=[]\n",
+    "    for model_name,type_,version_,n in zip(model_names,types,versions,num_splits):\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        pred_list.extend(pickle.load(pickle_file))\n",
+    "        pickle_file.close()\n",
+    "    preds=torch.cat([p[None] for p in pred_list],0).mean(0)\n",
+    "    targets=torch.tensor(res_df.loc[image_id_folds[num_split]][hemorrhage_types].values,dtype=torch.float)\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],dtype=torch.float)\n",
+    "    my_loss(preds,targets,weights)\n",
+    "    l=l+my_loss(preds,targets,weights)\n",
+    "l/5\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 103,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0675)"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0683)"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0690)"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0716)"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0701)"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0693)"
+      ]
+     },
+     "execution_count": 103,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "cols=['PatientID']\n",
+    "cols.extend(hemorrhage_types)\n",
+    "res_df=train_df[cols].set_index('PatientID')\n",
+    "model_names=['se_resnet101']\n",
+    "types=['OOF_pred']\n",
+    "versions=['new_splits_fullhead_resmodel_pool2_3']\n",
+    "num_splits=[5]\n",
+    "#model_names=['se_resnext101_32x4d','se_resnet101']\n",
+    "#types=['test_pred_ensamble','test_pred_ensamble']\n",
+    "#versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3']\n",
+    "#num_splits=[5,5]\n",
+    "#model_names=['se_resnet101']\n",
+    "#types=['OOF_pred']\n",
+    "#versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5]\n",
+    "l=0.\n",
+    "for num_split in range(n):\n",
+    "    pred_list=[]\n",
+    "    for model_name,type_,version_,n in zip(model_names,types,versions,num_splits):\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        pred_list.extend(pickle.load(pickle_file))\n",
+    "        pickle_file.close()\n",
+    "    preds=torch.cat([p[None] for p in pred_list],0).mean(0)\n",
+    "    targets=torch.tensor(res_df.loc[image_id_folds[num_split]][hemorrhage_types].values,dtype=torch.float)\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],dtype=torch.float)\n",
+    "    my_loss(preds,targets,weights)\n",
+    "    l=l+my_loss(preds,targets,weights)\n",
+    "l/5\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 100,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0664)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0672)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0678)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0701)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0687)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0680)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "cols=['PatientID']\n",
+    "cols.extend(hemorrhage_types)\n",
+    "res_df=train_df[cols].set_index('PatientID')\n",
+    "model_names=['se_resnext101_32x4d','se_resnet101','se_resnet101']\n",
+    "types=['OOF_pred','OOF_pred','OOF_pred']\n",
+    "versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3','new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "num_splits=[5,5,5]\n",
+    "#model_names=['se_resnext101_32x4d','se_resnet101']\n",
+    "#types=['test_pred_ensamble','test_pred_ensamble']\n",
+    "#versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3']\n",
+    "#num_splits=[5,5]\n",
+    "#model_names=['se_resnet101']\n",
+    "#types=['OOF_pred']\n",
+    "#versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5]\n",
+    "l=0.\n",
+    "for num_split in range(n):\n",
+    "    pred_list=[]\n",
+    "    for model_name,type_,version_,n in zip(model_names,types,versions,num_splits):\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        pred_list.extend(pickle.load(pickle_file))\n",
+    "        pickle_file.close()\n",
+    "    preds=torch.cat([p[None] for p in pred_list],0).mean(0)\n",
+    "    targets=torch.tensor(res_df.loc[image_id_folds[num_split]][hemorrhage_types].values,dtype=torch.float)\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],dtype=torch.float)\n",
+    "    my_loss(preds,targets,weights)\n",
+    "    l=l+my_loss(preds,targets,weights)\n",
+    "l/5\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0662)"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0672)"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0675)"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0702)"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0687)"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0680)"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "cols=['PatientID']\n",
+    "cols.extend(hemorrhage_types)\n",
+    "res_df=train_df[cols].set_index('PatientID')\n",
+    "model_names=['se_resnext101_32x4d','se_resnext101_32x4d','se_resnet101','se_resnet101']\n",
+    "types=['OOF_pred','OOF_pred','OOF_pred','OOF_pred']\n",
+    "versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_focal_fullhead_resmodel_pool2_over3','new_splits_fullhead_resmodel_pool2_3','new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "num_splits=[5,5,5,5]\n",
+    "#model_names=['se_resnext101_32x4d','se_resnet101']\n",
+    "#types=['test_pred_ensamble','test_pred_ensamble']\n",
+    "#versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3']\n",
+    "#num_splits=[5,5]\n",
+    "#model_names=['se_resnet101']\n",
+    "#types=['OOF_pred']\n",
+    "#versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5]\n",
+    "l=0.\n",
+    "for num_split in range(n):\n",
+    "    pred_list=[]\n",
+    "    for model_name,type_,version_,n in zip(model_names,types,versions,num_splits):\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        pred_list.extend(pickle.load(pickle_file))\n",
+    "        pickle_file.close()\n",
+    "    preds=torch.cat([p[None] for p in pred_list],0).mean(0)\n",
+    "    targets=torch.tensor(res_df.loc[image_id_folds[num_split]][hemorrhage_types].values,dtype=torch.float)\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],dtype=torch.float)\n",
+    "    my_loss(preds,targets,weights)\n",
+    "    l=l+my_loss(preds,targets,weights)\n",
+    "l/5\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 101,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0664)"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0675)"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0678)"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0705)"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0692)"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0683)"
+      ]
+     },
+     "execution_count": 101,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "cols=['PatientID']\n",
+    "cols.extend(hemorrhage_types)\n",
+    "res_df=train_df[cols].set_index('PatientID')\n",
+    "#model_names=['se_resnext101_32x4d','se_resnet101','se_resnet101']\n",
+    "#types=['OOF_pred','OOF_pred','OOF_pred']\n",
+    "#versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3','new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5,5,5]\n",
+    "model_names=['se_resnext101_32x4d','se_resnet101']\n",
+    "types=['OOF_pred','OOF_pred']\n",
+    "versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3']\n",
+    "num_splits=[5,5]\n",
+    "l=0\n",
+    "#model_names=['se_resnet101']\n",
+    "#types=['OOF_pred']\n",
+    "#versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5]\n",
+    "for num_split in range(n):\n",
+    "    pred_list=[]\n",
+    "    for model_name,type_,version_,n in zip(model_names,types,versions,num_splits):\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        pred_list.extend(pickle.load(pickle_file))\n",
+    "        pickle_file.close()\n",
+    "    preds=torch.cat([p[None] for p in pred_list],0).mean(0)\n",
+    "    targets=torch.tensor(res_df.loc[image_id_folds[num_split]][hemorrhage_types].values,dtype=torch.float)\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],dtype=torch.float)\n",
+    "    my_loss(preds,targets,weights)\n",
+    "    l=l+my_loss(preds,targets,weights)\n",
+    "l/5\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 105,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=5), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_test 0\n",
+      "torch.Size([78545, 8, 256])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(78545)"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_test 1\n",
+      "torch.Size([78545, 8, 256])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(78545)"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_test 2\n",
+      "torch.Size([78545, 8, 256])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(78545)"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_test 3\n",
+      "torch.Size([78545, 8, 256])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(78545)"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 new_splits_focal features_test 4\n",
+      "torch.Size([78545, 8, 256])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(78545)"
+      ]
+     },
+     "execution_count": 105,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "model_names=['se_resnet101']\n",
+    "types=['features_test']\n",
+    "versions=['new_splits_focal']\n",
+    "num_splits=[5]\n",
+    "seeds=[432]\n",
+    "for model_name,type_,version_,n,SEED in zip(model_names,types,versions,num_splits,seeds):\n",
+    "    for num_split in tqdm_notebook(range(n)):\n",
+    "        pred_list=[]\n",
+    "        print(model_name,version_,type_,num_split) \n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        features=pickle.load(pickle_file)\n",
+    "        pickle_file.close()\n",
+    "        features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "        print(features.shape)\n",
+    "        model=ResModelPool(features.shape[-1])\n",
+    "        version=version_+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "\n",
+    "        model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "\n",
+    "        valid_dataset=FullHeadDataset(test_df,\n",
+    "                                      test_df.SeriesI.unique(),\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                      multi =4)\n",
+    "\n",
+    "        win_dataset=FullHeadDataset(test_df,\n",
+    "                                      test_df.SeriesI.unique(),\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                       target_columns=hemorrhage_types)\n",
+    "        win_list=[]\n",
+    "        dl = D.DataLoader(win_dataset,batch_size=128,num_workers=16)\n",
+    "        for _,win in tqdm_notebook(dl):\n",
+    "            win_list.append(win.reshape(win.shape[0]*win.shape[1],-1))    \n",
+    "        wins = torch.cat(win_list,0).sum(1)>=0\n",
+    "        wins.sum()\n",
+    "        for i in tqdm_notebook(range(32),leave=False):\n",
+    "            pr = model_run(model,valid_dataset,do_apex=False,batch_size=128)\n",
+    "            pred_list.append(pr.reshape(pr.shape[0]*pr.shape[1],-1)[wins])\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'test_pred_ensamble',num_split),'wb')\n",
+    "        pickle.dump(pred_list,pickle_file,protocol=4)\n",
+    "        pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b2a0549fe0654f979d071740abe95615",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=1), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d new_splits_focal features_test 1\n",
+      "torch.Size([78545, 8, 256])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a12fe4bc588148da9a9b20948e6c860a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor(78545)"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fb54f171d80d459aa73c2a902103853d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b97ba32ae4eb436fa171fe5140cded9e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "76df6d9189cc4591b000f72b89c62445",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5488a1374a924e9993ddd0b1e8da1f3c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fb37357b017041ba86a4a3f2adede09e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6cd4894b9b4c424ebfe606583c14e288",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9effeb034879473495cfe97e0d70db06",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1afcb4da27174e0abce2370b4c355897",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e25b5752e3a24d818682b89adafefd7f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "51859c2e59e64a63ad7ee03bfea1a04e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a3b2886c7afc4425a220df6d6e6f0014",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "043e938e89ec49d78b7bbae31c3655f6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cd3df26e86bc4f0ab88f2067ed46ad7f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "985526e5569744829809f2b5c775a5d4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5f6f7c8d70aa4344a5c2e8051c518336",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "205c96d891924b20b7db4840b38e3db3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "23912a3c5b48492c915fa6fe1e128e85",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4867580e072e4f5fb6ef23670d00f8ec",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "abe6c975a20141cb8133681f19337cd9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "074ded45715b42b7a7c87319a1b9bdd1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7b2110cf63314090b7cc0b91a1f9e10a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ad18f47dd6f0404eb1750dea56f7eec8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0b4db68962f24b33add43afed875f866",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a642ff660c134aa88b50502cd0e478cc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "43541346e088456782bc6f47465eab40",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "140a6fcb3ace4cc08886f9e9783f8e11",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f1a162a742754fa48551b89a0567344a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5027476649424e79876a39829afe6b3b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c16248b690654a28b295ce3b77195eaf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e7c7a1bb0ed841db91d0c4faa7045ef1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "49f00c7afa8b4edc810a7fd20241b1e1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ab060477367b4702854b84d322469a50",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=18), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "multi=3\n",
+    "model_names=['se_resnext101_32x4d']\n",
+    "types=['features_test']\n",
+    "versions=['new_splits_focal']\n",
+    "num_splits=[5]\n",
+    "seeds=[432]\n",
+    "for model_name,type_,version_,n,SEED in zip(model_names,types,versions,num_splits,seeds):\n",
+    "    for num_split in tqdm_notebook(range(5)):\n",
+    "        pred_list=[]\n",
+    "        print(model_name,version_,type_,num_split) \n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        features=pickle.load(pickle_file)\n",
+    "        pickle_file.close()\n",
+    "        features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "        print(features.shape)\n",
+    "        model=ResModelPool(features.shape[-1])\n",
+    "        version=version_+'_fullhead_resmodel_pool2_over{}'.format(multi)\n",
+    "\n",
+    "        model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "\n",
+    "        valid_dataset=FullHeadDataset(test_df,\n",
+    "                                      test_df.SeriesI.unique(),\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                      multi =4)\n",
+    "\n",
+    "        win_dataset=FullHeadDataset(test_df,\n",
+    "                                      test_df.SeriesI.unique(),\n",
+    "                                      features,\n",
+    "                                      'SeriesI',\n",
+    "                                      'ImagePositionZ',\n",
+    "                                       target_columns=hemorrhage_types)\n",
+    "        win_list=[]\n",
+    "        dl = D.DataLoader(win_dataset,batch_size=128,num_workers=16)\n",
+    "        for _,win in tqdm_notebook(dl):\n",
+    "            win_list.append(win.reshape(win.shape[0]*win.shape[1],-1))    \n",
+    "        wins = torch.cat(win_list,0).sum(1)>=0\n",
+    "        wins.sum()\n",
+    "        for i in tqdm_notebook(range(32),leave=False):\n",
+    "            pr = model_run(model,valid_dataset,do_apex=False,batch_size=128)\n",
+    "            pred_list.append(pr.reshape(pr.shape[0]*pr.shape[1],-1)[wins])\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'test_pred_ensamble',num_split),'wb')\n",
+    "        pickle.dump(pred_list,pickle_file,protocol=4)\n",
+    "        pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "multi=3\n",
+    "pred_list=[]\n",
+    "model_names=['se_resnext101_32x4d','se_resnext101_32x4d','se_resnet101','se_resnet101']\n",
+    "types=['test_pred_ensamble','test_pred_ensamble','test_pred_ensamble','test_pred_ensamble']\n",
+    "versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_focal_fullhead_resmodel_pool2_over3','new_splits_fullhead_resmodel_pool2_3','new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "num_splits=[5,5,5,5]\n",
+    "#model_names=['se_resnext101_32x4d','se_resnet101']\n",
+    "#types=['test_pred_ensamble','test_pred_ensamble']\n",
+    "#versions=['new_splits_fullhead_resmodel_pool2_3','new_splits_fullhead_resmodel_pool2_3']\n",
+    "#num_splits=[5,5]\n",
+    "#model_names=['se_resnet101']\n",
+    "#types=['test_pred_ensamble']\n",
+    "#versions=['new_splits_focal_fullhead_resmodel_pool2_over3']\n",
+    "#num_splits=[5]\n",
+    "\n",
+    "for model_name,type_,version_,n in zip(model_names,types,versions,num_splits):\n",
+    "    for num_split in range(n):\n",
+    "        pickle_file=open(outputs_dir+outputs_format.format(model_name,version_,type_,num_split),'rb')\n",
+    "        pred_list.extend(pickle.load(pickle_file))\n",
+    "        pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "preds=torch.sigmoid(torch.cat([p[None] for p in pred_list],0)).mean(0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([78545, 6])"
+      ]
+     },
+     "execution_count": 34,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[3.5711e-05, 2.7395e-05, 5.7499e-06, 7.0699e-05, 7.5931e-05, 1.5110e-04],\n",
+       "        [3.9274e-06, 3.3226e-06, 3.7224e-07, 1.2136e-05, 1.4754e-05, 3.8673e-05],\n",
+       "        [3.3621e-05, 1.0076e-04, 2.5102e-06, 1.9387e-04, 1.3010e-04, 5.5842e-04],\n",
+       "        [8.2830e-05, 2.8591e-04, 5.9830e-06, 2.6872e-04, 2.9643e-04, 8.5356e-04],\n",
+       "        [7.1110e-05, 2.6773e-04, 9.9516e-06, 1.9567e-04, 2.3143e-04, 6.7518e-04],\n",
+       "        [5.4225e-05, 1.7603e-04, 8.6872e-06, 1.5172e-04, 2.3249e-04, 5.0158e-04],\n",
+       "        [4.0195e-05, 1.5081e-04, 7.9949e-06, 1.3344e-04, 2.8025e-04, 5.1623e-04],\n",
+       "        [3.5982e-05, 1.6879e-04, 1.0185e-05, 1.6445e-04, 3.8864e-04, 7.3439e-04],\n",
+       "        [2.9345e-05, 1.2291e-04, 6.0371e-06, 1.4451e-04, 3.6284e-04, 5.9658e-04],\n",
+       "        [3.0346e-05, 2.1097e-04, 8.2453e-06, 2.0855e-04, 3.9730e-04, 8.7336e-04]])"
+      ]
+     },
+     "execution_count": 34,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "preds.shape\n",
+    "preds[:10]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pickle_file=open(outputs_dir+'ensemble_test_image_ids.pkl','rb')\n",
+    "image_ids=pickle.load(pickle_file)\n",
+    "pickle_file.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array(['37874f3c6', '2e16b8f05', '0b758b9c0', 'b8eba0e41', 'acf5e1ac4',\n",
+       "       'ae3c642b0', '7edd07ece', '2cbe4682d', '7ed61df92', 'df94f381c'],\n",
+       "      dtype=object)"
+      ]
+     },
+     "execution_count": 36,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "image_ids[:10]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>ID</th>\n",
+       "      <th>Label</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>ID_000012eaf_any</td>\n",
+       "      <td>0.001188</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>ID_000012eaf_epidural</td>\n",
+       "      <td>0.000104</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>ID_000012eaf_intraparenchymal</td>\n",
+       "      <td>0.000135</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>ID_000012eaf_intraventricular</td>\n",
+       "      <td>0.000024</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>ID_000012eaf_subarachnoid</td>\n",
+       "      <td>0.000148</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>ID_000012eaf_subdural</td>\n",
+       "      <td>0.000900</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>ID_0000ca2f6_any</td>\n",
+       "      <td>0.001744</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>ID_0000ca2f6_epidural</td>\n",
+       "      <td>0.000062</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>ID_0000ca2f6_intraparenchymal</td>\n",
+       "      <td>0.000291</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>ID_0000ca2f6_intraventricular</td>\n",
+       "      <td>0.000027</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>10</th>\n",
+       "      <td>ID_0000ca2f6_subarachnoid</td>\n",
+       "      <td>0.000251</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>11</th>\n",
+       "      <td>ID_0000ca2f6_subdural</td>\n",
+       "      <td>0.001218</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "                               ID     Label\n",
+       "0                ID_000012eaf_any  0.001188\n",
+       "1           ID_000012eaf_epidural  0.000104\n",
+       "2   ID_000012eaf_intraparenchymal  0.000135\n",
+       "3   ID_000012eaf_intraventricular  0.000024\n",
+       "4       ID_000012eaf_subarachnoid  0.000148\n",
+       "5           ID_000012eaf_subdural  0.000900\n",
+       "6                ID_0000ca2f6_any  0.001744\n",
+       "7           ID_0000ca2f6_epidural  0.000062\n",
+       "8   ID_0000ca2f6_intraparenchymal  0.000291\n",
+       "9   ID_0000ca2f6_intraventricular  0.000027\n",
+       "10      ID_0000ca2f6_subarachnoid  0.000251\n",
+       "11          ID_0000ca2f6_subdural  0.001218"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(471270, 2)"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "submission_df=get_submission_ids(image_ids,preds,do_sigmoid=False)\n",
+    "submission_df.head(12)\n",
+    "submission_df.shape\n",
+    "sub_num=58\n",
+    "submission_df.to_csv('/media/hd/notebooks/data/RSNA/submissions/submission{}.csv'.format(sub_num),\n",
+    "                                                                  index=False, columns=['ID','Label'])\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "100%|███████████████████████████████████████| 16.9M/16.9M [01:16<00:00, 206kB/s]\n",
+      "Successfully submitted to RSNA Intracranial Hemorrhage Detection"
+     ]
+    }
+   ],
+   "source": [
+    "#!/home/reina/anaconda3/bin/kaggle competitions submit rsna-intracranial-hemorrhage-detection -f /media/hd/notebooks/data/RSNA/submissions/submission58.csv -m \"all the 5 folds models including focal, mean after sigmoid\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sub=pd.read_csv('/media/hd/notebooks/data/RSNA/submissions/submission55.csv')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sub0=pd.read_csv('/media/hd/notebooks/data/RSNA/submissions/submission54.csv')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0487, dtype=torch.float64)"
+      ]
+     },
+     "execution_count": 66,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "F.binary_cross_entropy(torch.tensor(sub.Label.values),torch.tensor(sub0.Label.values))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.1125)"
+      ]
+     },
+     "execution_count": 70,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "F.binary_cross_entropy(torch.tensor(sub.Label.values,dtype=torch.float),torch.tensor(submission_df.Label.values,dtype=torch.float))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 78,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0484)"
+      ]
+     },
+     "execution_count": 78,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "F.binary_cross_entropy(torch.tensor(sub.Label.values,dtype=torch.float),torch.tensor(submission_df.Label.values,dtype=torch.float))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0486)"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "F.binary_cross_entropy(torch.tensor(sub.Label.values,dtype=torch.float),torch.tensor(submission_df.Label.values,dtype=torch.float))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}