--- a
+++ b/Production/Post Full Head Models Train.ipynb
@@ -0,0 +1,15765 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
+      "  return f(*args, **kwds)\n",
+      "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n",
+      "  return f(*args, **kwds)\n"
+     ]
+    }
+   ],
+   "source": [
+    "from __future__ import absolute_import\n",
+    "from __future__ import division\n",
+    "from __future__ import print_function\n",
+    "\n",
+    "\n",
+    "import numpy as np # linear algebra\n",
+    "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
+    "import os\n",
+    "import datetime\n",
+    "import seaborn as sns\n",
+    "\n",
+    "#import pydicom\n",
+    "import time\n",
+    "from functools import partial\n",
+    "import gc\n",
+    "import operator \n",
+    "import matplotlib.pyplot as plt\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.utils.data as D\n",
+    "import torch.nn.functional as F\n",
+    "from sklearn.model_selection import KFold\n",
+    "from tqdm import tqdm, tqdm_notebook\n",
+    "from IPython.core.interactiveshell import InteractiveShell\n",
+    "InteractiveShell.ast_node_interactivity = \"all\"\n",
+    "import warnings\n",
+    "warnings.filterwarnings(action='once')\n",
+    "import pickle\n",
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "%matplotlib inline\n",
+    "from skimage.io import imread,imshow\n",
+    "from helper import *\n",
+    "import helper\n",
+    "import torchvision.models as models\n",
+    "from torch.optim import Adam\n",
+    "from defenitions import *"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SEED = 8153\n",
+    "#device=device_by_name(\"Tesla\")\n",
+    "device=device_by_name(\"RTX\")\n",
+    "#device=device_by_name(\"1060\")\n",
+    "torch.cuda.set_device(device)\n",
+    "#device = \"cpu\"\n",
+    "sendmeemail=Email_Progress(my_gmail,my_pass,to_email,'Densenet161-Copy2-2 results')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_submission(test_df,pred):\n",
+    "    epidural_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_epidural','Label':torch.sigmoid(pred[:,0])})\n",
+    "    intraparenchymal_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_intraparenchymal','Label':torch.sigmoid(pred[:,1])})\n",
+    "    intraventricular_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_intraventricular','Label':torch.sigmoid(pred[:,2])})\n",
+    "    subarachnoid_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_subarachnoid','Label':torch.sigmoid(pred[:,3])})\n",
+    "    subdural_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_subdural','Label':torch.sigmoid(pred[:,4])})\n",
+    "    any_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_any','Label':torch.sigmoid(pred[:,5])}) \n",
+    "    return pd.concat([epidural_df,\n",
+    "                        intraparenchymal_df,\n",
+    "                        intraventricular_df,\n",
+    "                        subarachnoid_df,\n",
+    "                        subdural_df,\n",
+    "                        any_df]).sort_values('ID').reset_index(drop=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_submission_ids(image_ids,pred,do_sigmoid=True):\n",
+    "    if do_sigmoid:\n",
+    "        func = lambda x:torch.sigmoid(x)\n",
+    "    else:\n",
+    "        func = lambda x:x\n",
+    "    epidural_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_epidural','Label':func(pred[:,0])})\n",
+    "    intraparenchymal_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_intraparenchymal','Label':func(pred[:,1])})\n",
+    "    intraventricular_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_intraventricular','Label':func(pred[:,2])})\n",
+    "    subarachnoid_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_subarachnoid','Label':func(pred[:,3])})\n",
+    "    subdural_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_subdural','Label':func(pred[:,4])})\n",
+    "    any_df=pd.DataFrame(data={'ID':'ID_'+image_ids+'_any','Label':func(pred[:,5])}) \n",
+    "    return pd.concat([epidural_df,\n",
+    "                        intraparenchymal_df,\n",
+    "                        intraventricular_df,\n",
+    "                        subarachnoid_df,\n",
+    "                        subdural_df,\n",
+    "                        any_df]).sort_values('ID').reset_index(drop=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(674510, 15)"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(674252, 15)"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>PatientID</th>\n",
+       "      <th>epidural</th>\n",
+       "      <th>intraparenchymal</th>\n",
+       "      <th>intraventricular</th>\n",
+       "      <th>subarachnoid</th>\n",
+       "      <th>subdural</th>\n",
+       "      <th>any</th>\n",
+       "      <th>PID</th>\n",
+       "      <th>StudyI</th>\n",
+       "      <th>SeriesI</th>\n",
+       "      <th>WindowCenter</th>\n",
+       "      <th>WindowWidth</th>\n",
+       "      <th>ImagePositionZ</th>\n",
+       "      <th>ImagePositionX</th>\n",
+       "      <th>ImagePositionY</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>63eb1e259</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>a449357f</td>\n",
+       "      <td>62d125e5b2</td>\n",
+       "      <td>0be5c0d1b3</td>\n",
+       "      <td>['00036', '00036']</td>\n",
+       "      <td>['00080', '00080']</td>\n",
+       "      <td>180.199951</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-8.000000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>2669954a7</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>363d5865</td>\n",
+       "      <td>a20b80c7bf</td>\n",
+       "      <td>3564d584db</td>\n",
+       "      <td>['00047', '00047']</td>\n",
+       "      <td>['00080', '00080']</td>\n",
+       "      <td>922.530821</td>\n",
+       "      <td>-156.0</td>\n",
+       "      <td>45.572849</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>52c9913b1</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>9c2b4bd7</td>\n",
+       "      <td>3e3634f8cf</td>\n",
+       "      <td>973274ffc9</td>\n",
+       "      <td>40</td>\n",
+       "      <td>150</td>\n",
+       "      <td>4.455000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-115.063000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>4e6ff6126</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>3ae81c2d</td>\n",
+       "      <td>a1390c15c2</td>\n",
+       "      <td>e5ccad8244</td>\n",
+       "      <td>['00036', '00036']</td>\n",
+       "      <td>['00080', '00080']</td>\n",
+       "      <td>100.000000</td>\n",
+       "      <td>-99.5</td>\n",
+       "      <td>28.500000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>7858edd88</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>c1867feb</td>\n",
+       "      <td>c73e81ed3a</td>\n",
+       "      <td>28e0531b3a</td>\n",
+       "      <td>40</td>\n",
+       "      <td>100</td>\n",
+       "      <td>145.793000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-132.190000</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   PatientID  epidural  intraparenchymal  intraventricular  subarachnoid  \\\n",
+       "0  63eb1e259         0                 0                 0             0   \n",
+       "1  2669954a7         0                 0                 0             0   \n",
+       "2  52c9913b1         0                 0                 0             0   \n",
+       "3  4e6ff6126         0                 0                 0             0   \n",
+       "4  7858edd88         0                 0                 0             0   \n",
+       "\n",
+       "   subdural  any       PID      StudyI     SeriesI        WindowCenter  \\\n",
+       "0         0    0  a449357f  62d125e5b2  0be5c0d1b3  ['00036', '00036']   \n",
+       "1         0    0  363d5865  a20b80c7bf  3564d584db  ['00047', '00047']   \n",
+       "2         0    0  9c2b4bd7  3e3634f8cf  973274ffc9                  40   \n",
+       "3         0    0  3ae81c2d  a1390c15c2  e5ccad8244  ['00036', '00036']   \n",
+       "4         0    0  c1867feb  c73e81ed3a  28e0531b3a                  40   \n",
+       "\n",
+       "          WindowWidth  ImagePositionZ  ImagePositionX  ImagePositionY  \n",
+       "0  ['00080', '00080']      180.199951          -125.0       -8.000000  \n",
+       "1  ['00080', '00080']      922.530821          -156.0       45.572849  \n",
+       "2                 150        4.455000          -125.0     -115.063000  \n",
+       "3  ['00080', '00080']      100.000000           -99.5       28.500000  \n",
+       "4                 100      145.793000          -125.0     -132.190000  "
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "train_df = pd.read_csv(data_dir+'train.csv')\n",
+    "train_df.shape\n",
+    "train_df=train_df[~train_df.PatientID.isin(bad_images)].reset_index(drop=True)\n",
+    "train_df=train_df.drop_duplicates().reset_index(drop=True)\n",
+    "train_df.shape\n",
+    "train_df.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>PatientID</th>\n",
+       "      <th>epidural</th>\n",
+       "      <th>intraparenchymal</th>\n",
+       "      <th>intraventricular</th>\n",
+       "      <th>subarachnoid</th>\n",
+       "      <th>subdural</th>\n",
+       "      <th>any</th>\n",
+       "      <th>SeriesI</th>\n",
+       "      <th>PID</th>\n",
+       "      <th>StudyI</th>\n",
+       "      <th>WindowCenter</th>\n",
+       "      <th>WindowWidth</th>\n",
+       "      <th>ImagePositionZ</th>\n",
+       "      <th>ImagePositionX</th>\n",
+       "      <th>ImagePositionY</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>28fbab7eb</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>ebfd7e4506</td>\n",
+       "      <td>cf1b6b11</td>\n",
+       "      <td>93407cadbb</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>158.458000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-135.598000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>877923b8b</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>6d95084e15</td>\n",
+       "      <td>ad8ea58f</td>\n",
+       "      <td>a337baa067</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>138.729050</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-101.797981</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>a591477cb</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>8e06b2c9e0</td>\n",
+       "      <td>ecfb278b</td>\n",
+       "      <td>0cfe838d54</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>60.830002</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-133.300003</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>42217c898</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>e800f419cf</td>\n",
+       "      <td>e96e31f4</td>\n",
+       "      <td>c497ac5bad</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>55.388000</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-146.081000</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>a130c4d2f</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>0.5</td>\n",
+       "      <td>faeb7454f3</td>\n",
+       "      <td>69affa42</td>\n",
+       "      <td>854e4fbc01</td>\n",
+       "      <td>30</td>\n",
+       "      <td>80</td>\n",
+       "      <td>33.516888</td>\n",
+       "      <td>-125.0</td>\n",
+       "      <td>-118.689819</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "   PatientID  epidural  intraparenchymal  intraventricular  subarachnoid  \\\n",
+       "0  28fbab7eb       0.5               0.5               0.5           0.5   \n",
+       "1  877923b8b       0.5               0.5               0.5           0.5   \n",
+       "2  a591477cb       0.5               0.5               0.5           0.5   \n",
+       "3  42217c898       0.5               0.5               0.5           0.5   \n",
+       "4  a130c4d2f       0.5               0.5               0.5           0.5   \n",
+       "\n",
+       "   subdural  any     SeriesI       PID      StudyI WindowCenter WindowWidth  \\\n",
+       "0       0.5  0.5  ebfd7e4506  cf1b6b11  93407cadbb           30          80   \n",
+       "1       0.5  0.5  6d95084e15  ad8ea58f  a337baa067           30          80   \n",
+       "2       0.5  0.5  8e06b2c9e0  ecfb278b  0cfe838d54           30          80   \n",
+       "3       0.5  0.5  e800f419cf  e96e31f4  c497ac5bad           30          80   \n",
+       "4       0.5  0.5  faeb7454f3  69affa42  854e4fbc01           30          80   \n",
+       "\n",
+       "   ImagePositionZ  ImagePositionX  ImagePositionY  \n",
+       "0      158.458000          -125.0     -135.598000  \n",
+       "1      138.729050          -125.0     -101.797981  \n",
+       "2       60.830002          -125.0     -133.300003  \n",
+       "3       55.388000          -125.0     -146.081000  \n",
+       "4       33.516888          -125.0     -118.689819  "
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "test_df = pd.read_csv(data_dir+'test.csv')\n",
+    "test_df.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "split_sid = train_df.PID.unique()\n",
+    "splits=list(KFold(n_splits=3,shuffle=True, random_state=SEED).split(split_sid))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def my_loss(y_pred,y_true,weights):\n",
+    "    window=(y_true>=0).to(torch.float)\n",
+    "    loss = (F.binary_cross_entropy_with_logits(y_pred,y_true,reduction='none')*window*weights.expand_as(y_true)).mean()/(window.mean()+1e-7)\n",
+    "    return loss"
+   ]
+  },
+  {
+   "cell_type": "raw",
+   "metadata": {},
+   "source": [
+    "def my_loss(y_pred,y_true,weights):\n",
+    "    if len(y_pred.shape)==len(y_true.shape):\n",
+    "        window=(y_true>=0).to(torch.float)\n",
+    "        loss = (F.binary_cross_entropy_with_logits(y_pred,y_true,reduction='none')*window*weights.expand_as(y_true)).mean()/(window.mean()+1e-7)\n",
+    "    else:\n",
+    "        window0=(y_true[...,0]>=0).to(torch.float)\n",
+    "        window1=(y_true[...,1]>=0).to(torch.float)\n",
+    "        loss0 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,0],reduction='none')*window0*weights.expand_as(y_true[...,0])/(window0.mean()+1e-7)\n",
+    "        loss1 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,1],reduction='none')*window1*weights.expand_as(y_true[...,1])/(window1.mean()+1e-7)\n",
+    "        loss = (y_true[...,2]*loss0+(1.0-y_true[...,2])*loss1).mean() \n",
+    "    return loss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Metric():\n",
+    "    def __init__(self,weights,k=0.03):\n",
+    "        self.weights=weights\n",
+    "        self.k=k\n",
+    "        self.zero()\n",
+    "        \n",
+    "    def zero(self):\n",
+    "        self.loss_sum=0.\n",
+    "        self.loss_count=0.\n",
+    "        self.lossf=0.\n",
+    "        \n",
+    "    def calc(self,y_pred,y_true,prefix=\"\"):\n",
+    "        window=(y_true>=0).to(torch.float)\n",
+    "        loss = (F.binary_cross_entropy_with_logits(y_pred,y_true,reduction='none')*window*self.weights.expand_as(y_true)).mean()/(window.mean()+1e-5)\n",
+    "        self.lossf=self.lossf*(1-self.k)+loss*self.k\n",
+    "        self.loss_sum=self.loss_sum+loss*window.sum()\n",
+    "        self.loss_count=self.loss_count+window.sum()\n",
+    "        return({prefix+'mloss':self.lossf})    \n",
+    "        \n",
+    "    def calc_sums(self,prefix=\"\"):\n",
+    "        return({prefix+'mloss_tot':self.loss_sum/self.loss_count})    \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#features=(features-features.mean())/features.std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class SimpleModel(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(SimpleModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (7,in_size), padding=(3,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 128, 5, padding=2)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 3, padding=1)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(64, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = self.relu0(x)\n",
+    "        x = self.conv1d1(x)\n",
+    "        x = self.bn1(x)\n",
+    "        x = self.relu1(x)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x = self.relu2(x)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class SimpleModel2(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(SimpleModel2, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 128, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(64, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = self.relu0(x)\n",
+    "        x = self.conv1d1(x)\n",
+    "        x = self.bn1(x)\n",
+    "        x = self.relu1(x)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x = self.relu2(x)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ClassModel(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ClassModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 128, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(128)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(128, 6, 3, padding=1)\n",
+    "        \n",
+    "        self.conv2d1class=torch.nn.Conv2d(1, 128, (9,in_size), padding=(4,0))\n",
+    "        self.bn0class=torch.nn.BatchNorm1d(128)\n",
+    "        self.maxpool1class=torch.nn.MaxPool1d(3)\n",
+    "        self.conv1d1class=torch.nn.Conv1d(128, 128, 3, padding=1)\n",
+    "        self.bn1class=torch.nn.BatchNorm1d(128)\n",
+    "        self.maxpool2class=torch.nn.MaxPool1d(3)\n",
+    "        self.conv1d2class=torch.nn.Conv1d(128, 64, 2, padding=1)\n",
+    "        self.bn2class=torch.nn.BatchNorm1d(64)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        z=x\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = self.relu0(x)\n",
+    "        x = self.conv1d1(x)\n",
+    "        x = self.bn1(x)\n",
+    "        x = self.relu1(x)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x = self.relu2(x)\n",
+    "        z=self.conv2d1class(z.unsqueeze(1)).squeeze(-1)\n",
+    "        z=self.bn0class(z)\n",
+    "        z=self.maxpool1class(z)\n",
+    "        z=self.conv1d1class(z)\n",
+    "        z=self.maxpool2class(z)\n",
+    "        z=self.conv1d2class(z)\n",
+    "        z=self.bn2class(z)\n",
+    "        z=F.max_pool1d(z,kernel_size=z.shape[-1])\n",
+    "        z=z.expand_as(x)\n",
+    "        x=torch.cat([x,z],1)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ResModel(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x=x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ResModelPool(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModelPool, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size),stride=(1,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "#        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x=x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x)\n",
+    "        x=F.max_pool2d(x,kernel_size=(1,x.shape[-1])).squeeze(-1)        \n",
+    "        x0 = self.bn0(x)\n",
+    "#        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class ResDropModel(nn.Module):\n",
+    "    def __init__(self,in_size,dropout=0.2):\n",
+    "        super(ResDropModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "        self.dropout=dropout\n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        x0 = self.relu0(x) \n",
+    "        x = self.conv1d1(x0)\n",
+    "#        x = self.bn1(x)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "#        x = self.bn2(x)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        x = F.dropout(x,self.dropout)\n",
+    "        out = self.conv1d3(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fn = partial(torch.clamp,min=0,max=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class GenReLU(nn.Module):\n",
+    "    def __init__(self,leak=0,add=0,clamp=None):\n",
+    "        super(GenReLU, self).__init__()\n",
+    "        self.leak,self.add=leak,add\n",
+    "        if isinstance(clamp,tuple):\n",
+    "            self.clamp = partial(torch.clamp,min=clamp[0],max=clamp[1])\n",
+    "        elif clamp:\n",
+    "            self.clamp = partial(torch.clamp,min=-clamp,max=clamp)\n",
+    "        else:\n",
+    "            self.clamp=None\n",
+    "            \n",
+    "        \n",
+    "    def forward(self,x):\n",
+    "        x = F.leaky_relu(x,self.leak)\n",
+    "        if self.add:\n",
+    "            x=x+self.add\n",
+    "        if self.clamp:\n",
+    "            x = self.clamp(x)\n",
+    "        return x\n",
+    "        \n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "class ResModelIn(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModelIn, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 3, padding=1)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=torch.nn.ReLU()\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 3, padding=1)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=torch.nn.ReLU()\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        x = self.conv1d3(x)\n",
+    "        out = x.transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "class ResModelInR(nn.Module):\n",
+    "    def __init__(self,in_size):\n",
+    "        super(ResModelInR, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size), padding=(4,0))\n",
+    "        self.bn0=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu0= GenReLU(leak=0.01,add=-0.3,clamp=(-1,3))\n",
+    "        self.conv1d1=torch.nn.Conv1d(64, 64, 3, padding=1)\n",
+    "        self.bn1=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu1=GenReLU(leak=0.01,add=-0.3,clamp=(-1,3))\n",
+    "        self.conv1d2=torch.nn.Conv1d(128, 64, 3, padding=1)\n",
+    "        self.bn2=torch.nn.BatchNorm1d(64)\n",
+    "        self.relu2=GenReLU(leak=0,add=0,clamp=(0,2))\n",
+    "        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = x.unsqueeze(1)\n",
+    "        x = self.conv2d1(x).squeeze(-1)\n",
+    "        x = self.bn0(x)\n",
+    "        x0 = self.relu0(x)\n",
+    "        x = self.conv1d1(x0)\n",
+    "        x = self.bn1(x)\n",
+    "        x1 = self.relu1(x)\n",
+    "        x = torch.cat([x0,x1],1)\n",
+    "        x = self.conv1d2(x)\n",
+    "        x = self.bn2(x)\n",
+    "        x2 = self.relu2(x)\n",
+    "        x = torch.cat([x0,x1,x2],1)\n",
+    "        x = self.conv1d3(x)\n",
+    "        out = x.transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class BaseModel(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super(BaseModel, self).__init__()\n",
+    "        self.dont_do_grad=[]\n",
+    "        self.conv2d1=torch.nn.Conv2d(1, 128, (5,2208), padding=(2,0))\n",
+    "        self.relu0=torch.nn.ReLU()\n",
+    "        self.conv1d1=torch.nn.Conv1d(128, 6, 1)\n",
+    "\n",
+    "        \n",
+    "        \n",
+    "    def forward(self, x):\n",
+    "        x = self.conv2d1(x.unsqueeze(1)).squeeze(-1)\n",
+    "        x = self.relu0(x)\n",
+    "        out = self.conv1d1(x).transpose(-1,-2)\n",
+    "        return out \n",
+    "                    \n",
+    "    def no_grad(self):\n",
+    "        for param in self.parameters():\n",
+    "            param.requires_grad=False\n",
+    "\n",
+    "    def do_grad(self):\n",
+    "        for n,p in self.named_parameters():\n",
+    "            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnet101 classifier_splits 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fd275b627d0>"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06555730362437655, 'mloss': tensor(0.0640), 'val_loss': 0.07757792993691828, 'val_mloss_tot': tensor(0.0774)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06296035056021557, 'mloss': tensor(0.0615), 'val_loss': 0.07537376433794428, 'val_mloss_tot': tensor(0.0752)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06050681098069266, 'mloss': tensor(0.0578), 'val_loss': 0.07501300821869905, 'val_mloss_tot': tensor(0.0749)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06265970817210643, 'mloss': tensor(0.0630), 'val_loss': 0.07554149213751658, 'val_mloss_tot': tensor(0.0754)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062267441321720234, 'mloss': tensor(0.0641), 'val_loss': 0.07397680654273062, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06265147281145228, 'mloss': tensor(0.0617), 'val_loss': 0.07435487403896789, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054534237629796284, 'mloss': tensor(0.0518), 'val_loss': 0.07417201365371753, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060638634779098304, 'mloss': tensor(0.0622), 'val_loss': 0.07360166200468767, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05774219324163882, 'mloss': tensor(0.0562), 'val_loss': 0.07313101735938653, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05876641325094529, 'mloss': tensor(0.0577), 'val_loss': 0.07348821095989219, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05925534488528531, 'mloss': tensor(0.0593), 'val_loss': 0.07329174085965225, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05893109856834247, 'mloss': tensor(0.0602), 'val_loss': 0.07320053409670338, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06035749195134699, 'mloss': tensor(0.0583), 'val_loss': 0.07416119949153528, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06102024913711339, 'mloss': tensor(0.0615), 'val_loss': 0.07507847905218638, 'val_mloss_tot': tensor(0.0750)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05763011426523271, 'mloss': tensor(0.0562), 'val_loss': 0.0745501829909959, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05899938234327748, 'mloss': tensor(0.0591), 'val_loss': 0.07331534775244669, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06051041598786941, 'mloss': tensor(0.0658), 'val_loss': 0.07324502129896815, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06032100399283676, 'mloss': tensor(0.0587), 'val_loss': 0.0731057038019168, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05567247802150111, 'mloss': tensor(0.0584), 'val_loss': 0.0740427240795855, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05695896455363199, 'mloss': tensor(0.0571), 'val_loss': 0.07355904062806165, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05467678778340032, 'mloss': tensor(0.0527), 'val_loss': 0.0733836851853492, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056968818392427734, 'mloss': tensor(0.0548), 'val_loss': 0.07343852231875429, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058187501158805395, 'mloss': tensor(0.0587), 'val_loss': 0.0734037181805232, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057780056129017016, 'mloss': tensor(0.0580), 'val_loss': 0.07378869249113822, 'val_mloss_tot': tensor(0.0737)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "(0.07295001492826406, {'val_loss': 0.07295001492826406, 'val_mloss_tot': tensor(0.0728)})\n",
+      "se_resnet101 classifier_splits 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fd275b627d0>"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061442185502940005, 'mloss': tensor(0.0569), 'val_loss': 0.07534000505579681, 'val_mloss_tot': tensor(0.0753)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06268717308426996, 'mloss': tensor(0.0616), 'val_loss': 0.0740244067160458, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0600582621217358, 'mloss': tensor(0.0608), 'val_loss': 0.07332578248616944, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057682981415727023, 'mloss': tensor(0.0554), 'val_loss': 0.07309666391144075, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057434267426733916, 'mloss': tensor(0.0589), 'val_loss': 0.07305602922172445, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057352926399604924, 'mloss': tensor(0.0554), 'val_loss': 0.07402778348953622, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0559268925074828, 'mloss': tensor(0.0563), 'val_loss': 0.07218927149206582, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05900301626537821, 'mloss': tensor(0.0601), 'val_loss': 0.07129640428508383, 'val_mloss_tot': tensor(0.0713)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05546906743859351, 'mloss': tensor(0.0563), 'val_loss': 0.0717661980437342, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05864951623984079, 'mloss': tensor(0.0593), 'val_loss': 0.07154358268897162, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055549188418039616, 'mloss': tensor(0.0560), 'val_loss': 0.07143378485584768, 'val_mloss_tot': tensor(0.0714)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05482709637848754, 'mloss': tensor(0.0567), 'val_loss': 0.07158134866778444, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060115974892814575, 'mloss': tensor(0.0643), 'val_loss': 0.07221458361075236, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05541222030704629, 'mloss': tensor(0.0526), 'val_loss': 0.07315580130032287, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05617269399288325, 'mloss': tensor(0.0556), 'val_loss': 0.07371268777585611, 'val_mloss_tot': tensor(0.0737)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056898097927381035, 'mloss': tensor(0.0556), 'val_loss': 0.07432223392037175, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06131808400595225, 'mloss': tensor(0.0657), 'val_loss': 0.07209026963881603, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05469984534978693, 'mloss': tensor(0.0542), 'val_loss': 0.07170687715136786, 'val_mloss_tot': tensor(0.0717)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05560388172674825, 'mloss': tensor(0.0561), 'val_loss': 0.07151761781992164, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055510112107333795, 'mloss': tensor(0.0558), 'val_loss': 0.07174188567370904, 'val_mloss_tot': tensor(0.0717)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055448861368395365, 'mloss': tensor(0.0573), 'val_loss': 0.07321729264903541, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05563457613127352, 'mloss': tensor(0.0553), 'val_loss': 0.0718658136930771, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05625772819692903, 'mloss': tensor(0.0579), 'val_loss': 0.0721532035607663, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05539568782743258, 'mloss': tensor(0.0547), 'val_loss': 0.07197411119506308, 'val_mloss_tot': tensor(0.0720)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "0.07129640428508383\n",
+      "se_resnet101 classifier_splits 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fd275b627d0>"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0667317779524751, 'mloss': tensor(0.0663), 'val_loss': 0.0726123980861249, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06425560013165965, 'mloss': tensor(0.0683), 'val_loss': 0.0712925546224518, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06299294143906597, 'mloss': tensor(0.0619), 'val_loss': 0.07263222281486163, 'val_mloss_tot': tensor(0.0726)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0627187252048082, 'mloss': tensor(0.0628), 'val_loss': 0.0701368740195491, 'val_mloss_tot': tensor(0.0701)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06340366102533591, 'mloss': tensor(0.0661), 'val_loss': 0.07123000746629582, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060680815633264, 'mloss': tensor(0.0597), 'val_loss': 0.06930766187222867, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058831156943332005, 'mloss': tensor(0.0557), 'val_loss': 0.06968927842431835, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0616733712603347, 'mloss': tensor(0.0632), 'val_loss': 0.06874918728938391, 'val_mloss_tot': tensor(0.0687)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0608313118593633, 'mloss': tensor(0.0615), 'val_loss': 0.06882584853822668, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061575952143859566, 'mloss': tensor(0.0625), 'val_loss': 0.06886688132628138, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05778591000079733, 'mloss': tensor(0.0565), 'val_loss': 0.06893600475529364, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05847443306806373, 'mloss': tensor(0.0563), 'val_loss': 0.06909122793554452, 'val_mloss_tot': tensor(0.0690)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06244718162474923, 'mloss': tensor(0.0618), 'val_loss': 0.06928921967579667, 'val_mloss_tot': tensor(0.0692)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06321549526102133, 'mloss': tensor(0.0653), 'val_loss': 0.06980274368095868, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06200456986987789, 'mloss': tensor(0.0648), 'val_loss': 0.06964234677212212, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057814249272109416, 'mloss': tensor(0.0575), 'val_loss': 0.06942267575773772, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061415286063921694, 'mloss': tensor(0.0639), 'val_loss': 0.06899209853417694, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05906533155065254, 'mloss': tensor(0.0586), 'val_loss': 0.06911359763396858, 'val_mloss_tot': tensor(0.0691)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05704393460994453, 'mloss': tensor(0.0567), 'val_loss': 0.06939367681576554, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05884749197825831, 'mloss': tensor(0.0580), 'val_loss': 0.06895865832129723, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05948965643986464, 'mloss': tensor(0.0598), 'val_loss': 0.06905965190701957, 'val_mloss_tot': tensor(0.0690)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057317303613035095, 'mloss': tensor(0.0563), 'val_loss': 0.06889244928277038, 'val_mloss_tot': tensor(0.0688)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0570795804969143, 'mloss': tensor(0.0567), 'val_loss': 0.06898040257008939, 'val_mloss_tot': tensor(0.0689)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0564644203909734, 'mloss': tensor(0.0556), 'val_loss': 0.0687852939269666, 'val_mloss_tot': tensor(0.0687)}\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "0.06874918728938391\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in range(3):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnet101' , 'classifier_splits'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "    num_train_optimization_steps = num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "se_resnext101_32x4d classifier_splits 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06104212981418981, 'mloss': tensor(0.0597), 'val_loss': 0.07772794400520733, 'val_mloss_tot': tensor(0.0776)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05878340472343995, 'mloss': tensor(0.0574), 'val_loss': 0.07512017121777086, 'val_mloss_tot': tensor(0.0750)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05655648395510176, 'mloss': tensor(0.0542), 'val_loss': 0.07546429423124443, 'val_mloss_tot': tensor(0.0753)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059097223141805454, 'mloss': tensor(0.0593), 'val_loss': 0.07441392245592339, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05837889894047478, 'mloss': tensor(0.0594), 'val_loss': 0.07449079317664682, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05928927639013696, 'mloss': tensor(0.0591), 'val_loss': 0.07495181764742094, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.050645298009680255, 'mloss': tensor(0.0473), 'val_loss': 0.07507487035842768, 'val_mloss_tot': tensor(0.0749)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05648079982672484, 'mloss': tensor(0.0573), 'val_loss': 0.07394013861469521, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054698379029048796, 'mloss': tensor(0.0532), 'val_loss': 0.07337973213965633, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05531915319713568, 'mloss': tensor(0.0544), 'val_loss': 0.07383127088814496, 'val_mloss_tot': tensor(0.0737)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055614093287722856, 'mloss': tensor(0.0556), 'val_loss': 0.07364389244834044, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05494674375668656, 'mloss': tensor(0.0550), 'val_loss': 0.07360090959812583, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05731814445107101, 'mloss': tensor(0.0557), 'val_loss': 0.07408064657941534, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05787915691263612, 'mloss': tensor(0.0583), 'val_loss': 0.07470676409805496, 'val_mloss_tot': tensor(0.0746)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05448834917186254, 'mloss': tensor(0.0534), 'val_loss': 0.07410673999108301, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05592269570923443, 'mloss': tensor(0.0564), 'val_loss': 0.07380271401357254, 'val_mloss_tot': tensor(0.0737)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0572446217923124, 'mloss': tensor(0.0624), 'val_loss': 0.07315992065498908, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05687102900158551, 'mloss': tensor(0.0556), 'val_loss': 0.07355240966867783, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.052319963189129834, 'mloss': tensor(0.0547), 'val_loss': 0.07404590190131345, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053723292559217185, 'mloss': tensor(0.0539), 'val_loss': 0.07411296213426503, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.051564192577278234, 'mloss': tensor(0.0502), 'val_loss': 0.07405690763494358, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053935549561808335, 'mloss': tensor(0.0519), 'val_loss': 0.07374668115259062, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054132682967620145, 'mloss': tensor(0.0532), 'val_loss': 0.07362750274005342, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053648569763990524, 'mloss': tensor(0.0537), 'val_loss': 0.0743206257259776, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.07315366809013221, {'val_loss': 0.07315366809013221, 'val_mloss_tot': tensor(0.0730)})\n",
+      "se_resnext101_32x4d classifier_splits 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06103116649936224, 'mloss': tensor(0.0566), 'val_loss': 0.07621094500645995, 'val_mloss_tot': tensor(0.0762)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06141884588324707, 'mloss': tensor(0.0599), 'val_loss': 0.07421453151874607, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05937302091545622, 'mloss': tensor(0.0599), 'val_loss': 0.07298415893938665, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05653851122561366, 'mloss': tensor(0.0539), 'val_loss': 0.07359400842361516, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0563969561115315, 'mloss': tensor(0.0568), 'val_loss': 0.07279874784417632, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05658438120329072, 'mloss': tensor(0.0548), 'val_loss': 0.0732334375347397, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05520078283242818, 'mloss': tensor(0.0558), 'val_loss': 0.0726179627928792, 'val_mloss_tot': tensor(0.0726)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0575513950087407, 'mloss': tensor(0.0586), 'val_loss': 0.07155612922364438, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05446073473628828, 'mloss': tensor(0.0553), 'val_loss': 0.07224392154163159, 'val_mloss_tot': tensor(0.0723)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057818598298651984, 'mloss': tensor(0.0585), 'val_loss': 0.0717785329035506, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05476841177750922, 'mloss': tensor(0.0553), 'val_loss': 0.0718532491384483, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05425994278145948, 'mloss': tensor(0.0561), 'val_loss': 0.07197370230470125, 'val_mloss_tot': tensor(0.0720)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05889216983362692, 'mloss': tensor(0.0630), 'val_loss': 0.07261958274218004, 'val_mloss_tot': tensor(0.0726)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05406497755344249, 'mloss': tensor(0.0510), 'val_loss': 0.07325117803383165, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05522595671697223, 'mloss': tensor(0.0548), 'val_loss': 0.07367304334780428, 'val_mloss_tot': tensor(0.0737)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05617977560425682, 'mloss': tensor(0.0551), 'val_loss': 0.0741670508403331, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061041019817783385, 'mloss': tensor(0.0660), 'val_loss': 0.07202263277296612, 'val_mloss_tot': tensor(0.0720)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05373041651438471, 'mloss': tensor(0.0532), 'val_loss': 0.07225684941541857, 'val_mloss_tot': tensor(0.0723)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0542387858677823, 'mloss': tensor(0.0538), 'val_loss': 0.07238784394476835, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05465374149548144, 'mloss': tensor(0.0549), 'val_loss': 0.07244390736038729, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05478347135008695, 'mloss': tensor(0.0572), 'val_loss': 0.07341553525794751, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05454276966923935, 'mloss': tensor(0.0545), 'val_loss': 0.07211499356482996, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05539196863152449, 'mloss': tensor(0.0565), 'val_loss': 0.07266912088138847, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05480816241905843, 'mloss': tensor(0.0541), 'val_loss': 0.07231134713127664, 'val_mloss_tot': tensor(0.0723)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.07124041325140108, {'val_loss': 0.07124041325140108, 'val_mloss_tot': tensor(0.0712)})\n",
+      "se_resnext101_32x4d classifier_splits 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 256])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 256])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06686010671958453, 'mloss': tensor(0.0671), 'val_loss': 0.07256089428926042, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06197450007287932, 'mloss': tensor(0.0649), 'val_loss': 0.07176381888629621, 'val_mloss_tot': tensor(0.0717)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06055403412935137, 'mloss': tensor(0.0595), 'val_loss': 0.07173318491120056, 'val_mloss_tot': tensor(0.0717)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060360971420687644, 'mloss': tensor(0.0610), 'val_loss': 0.07037732802019478, 'val_mloss_tot': tensor(0.0703)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06102236676318482, 'mloss': tensor(0.0634), 'val_loss': 0.0702071183136311, 'val_mloss_tot': tensor(0.0701)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057586846994762284, 'mloss': tensor(0.0570), 'val_loss': 0.06973422538911181, 'val_mloss_tot': tensor(0.0697)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05654831233703238, 'mloss': tensor(0.0543), 'val_loss': 0.0701734111254365, 'val_mloss_tot': tensor(0.0701)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058726070560402376, 'mloss': tensor(0.0602), 'val_loss': 0.0694997781988582, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05784556189496553, 'mloss': tensor(0.0580), 'val_loss': 0.0694955213825801, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05872945495157308, 'mloss': tensor(0.0600), 'val_loss': 0.06933456590500077, 'val_mloss_tot': tensor(0.0693)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0548658382573716, 'mloss': tensor(0.0536), 'val_loss': 0.06967747714278733, 'val_mloss_tot': tensor(0.0696)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05587831730391975, 'mloss': tensor(0.0541), 'val_loss': 0.06954347830623608, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059391494039898726, 'mloss': tensor(0.0586), 'val_loss': 0.06954381119212126, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06010475669847284, 'mloss': tensor(0.0620), 'val_loss': 0.06953739299865559, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05960235242072818, 'mloss': tensor(0.0628), 'val_loss': 0.06947992608786144, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055124038187571354, 'mloss': tensor(0.0543), 'val_loss': 0.06947753373349976, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0586589330860823, 'mloss': tensor(0.0609), 'val_loss': 0.0694346692472733, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05600552231078165, 'mloss': tensor(0.0549), 'val_loss': 0.06912566470945704, 'val_mloss_tot': tensor(0.0691)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05422179659744494, 'mloss': tensor(0.0538), 'val_loss': 0.06992734589251538, 'val_mloss_tot': tensor(0.0699)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05635119022882728, 'mloss': tensor(0.0560), 'val_loss': 0.06953046211440649, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057621935416480106, 'mloss': tensor(0.0589), 'val_loss': 0.06944633296374324, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05468315922469657, 'mloss': tensor(0.0539), 'val_loss': 0.0694805622889931, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0546185314255567, 'mloss': tensor(0.0546), 'val_loss': 0.06947591298906675, 'val_mloss_tot': tensor(0.0694)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05429768024717186, 'mloss': tensor(0.0536), 'val_loss': 0.06954654063937699, 'val_mloss_tot': tensor(0.0695)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.06881347418363605, {'val_loss': 0.06881347418363605, 'val_mloss_tot': tensor(0.0688)})\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in range(3):\n",
+    "    multi=3\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'classifier_splits'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "    num_train_optimization_steps = num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Densenet161_3 classifier_splits 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 552])"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 552])"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "35c2607d583c4f06b31c2099709b55ce",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "21243de8e0d24a22af25df9b1ecb9199",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4f84b2f383864fc794a0a93ef0b5241a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06883620352841167, 'mloss': tensor(0.0687), 'val_loss': 0.0794665542422, 'val_mloss_tot': tensor(0.0793)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b2165078af394902966b067237b8ff55",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0b64635156df4694b352aba2126682a1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06416736874526195, 'mloss': tensor(0.0649), 'val_loss': 0.07710900913529428, 'val_mloss_tot': tensor(0.0769)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "49abd4b2ffc94035bce0cbec633a7d0c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bbfdc3024eb94e1981c30aab94f70623",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06082077204627799, 'mloss': tensor(0.0592), 'val_loss': 0.07777202231671863, 'val_mloss_tot': tensor(0.0776)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f8eafb9d666340da891c9a1679330d4e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8e25f09498ec4bd984fcee5427c4432d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060287136157171836, 'mloss': tensor(0.0600), 'val_loss': 0.07642053298335615, 'val_mloss_tot': tensor(0.0763)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c5925e1f7b6141b58df3204789012320",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5d452f90b0ee49ce81a4828baa28127f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05850181421551266, 'mloss': tensor(0.0573), 'val_loss': 0.07645996998573465, 'val_mloss_tot': tensor(0.0763)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1e507cb0f2494c738564fac5c1bd3913",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6a1380c78cfe45a8ac2be46da4ebde9c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06009318433020893, 'mloss': tensor(0.0621), 'val_loss': 0.07619008562513833, 'val_mloss_tot': tensor(0.0760)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d9f9e5d5b43040bfba87f6ee9f55672e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c2a8064614884ca480a4ecab59e79a48",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05502318868498425, 'mloss': tensor(0.0522), 'val_loss': 0.0770497980591809, 'val_mloss_tot': tensor(0.0769)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "26067cf59883495e9d3544d672afb95d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9128be24fe7248008bc48a7a5aa3395f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05467229249862472, 'mloss': tensor(0.0558), 'val_loss': 0.07456049822004777, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2a6aa786f2fd49cabbbceb5ea3d14521",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6a692fcd68b64c37bd23bce6215ba977",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05836357350776538, 'mloss': tensor(0.0634), 'val_loss': 0.07459651590256121, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3b70c08237b54ee6a63f74df10df6fbd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7eee383eb82442639e14e94f9feebe2e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05315592215511916, 'mloss': tensor(0.0497), 'val_loss': 0.07476671059212184, 'val_mloss_tot': tensor(0.0746)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e1a001f73dbf4085bcedcaee0e85c747",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "327b85dd25d84afba2485d0202b74db1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05203685330720608, 'mloss': tensor(0.0496), 'val_loss': 0.07597205400113659, 'val_mloss_tot': tensor(0.0758)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "047be71545974f96ac77e53a66154169",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cd72d55ed99947aa8f9f960c9bbaf50d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.051829502360524145, 'mloss': tensor(0.0491), 'val_loss': 0.07442193356539889, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4bc6e46a68954c30981f8e5f129c0cc5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "95876f26d3f14459b6ec96dfa271e250",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058099859528654534, 'mloss': tensor(0.0574), 'val_loss': 0.07512590011210167, 'val_mloss_tot': tensor(0.0750)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5dd629b9a38a4c2db9a4408013745ed6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "db0e0ab6e96e4625abbdae909ff44c1f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0592473319786098, 'mloss': tensor(0.0653), 'val_loss': 0.07459551666225558, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d426aa710ef54cb6be7f7405aeaa28c3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "813b68f41da64bedb99b3a775c817eb9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05968444386825604, 'mloss': tensor(0.0610), 'val_loss': 0.07698274446362309, 'val_mloss_tot': tensor(0.0768)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "59eacc15aa1e49e7ad5ebcd3c7800729",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "48ad11a1c7f1434097b2b54789b56944",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05863282249038244, 'mloss': tensor(0.0628), 'val_loss': 0.07422352496997954, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eb6a03b8d6a843c6ab7d4c55ff573fe9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b71336a3f30340d7b5fea2ebd976a40f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05542632519146895, 'mloss': tensor(0.0567), 'val_loss': 0.07491980900694364, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e1f20cd31ede4e8d989816e4eeed9ba9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5267e006936c4ef6b456dd4753d19233",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055169499381633964, 'mloss': tensor(0.0549), 'val_loss': 0.07648079449957779, 'val_mloss_tot': tensor(0.0763)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6dbf9278d2ea4f1499a52a368629d351",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "485a65b517324e8ba508e982f59ff392",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05669092520281455, 'mloss': tensor(0.0572), 'val_loss': 0.07377110467191438, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ed2a7f4bdf6c49c9bc57886c88a07f34",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7017f89378d64e49bd923adc9d6b08a0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055804943687637124, 'mloss': tensor(0.0546), 'val_loss': 0.07434750396845487, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4aea5cd060454cc4957e93d47c3a5dcb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5387a858ff7e4f619a4bf5bdac32424d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054883343126150036, 'mloss': tensor(0.0541), 'val_loss': 0.08067470978431661, 'val_mloss_tot': tensor(0.0805)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "da207c13b4da44feba62fd7369b04882",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f03761ec053a40948c2d1564058c1973",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.052586791399687895, 'mloss': tensor(0.0526), 'val_loss': 0.07430791517540702, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cc2e7e96a19d4c11b3e8a6bf26d31409",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6932b6e1b7544bc9ab008ae1a97ce322",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05461565242117025, 'mloss': tensor(0.0555), 'val_loss': 0.07453368853415211, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3c1170bca9ba47efb274c8a34f7a2a86",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "871d983b2e504dcb88d4ea8444e91970",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.052522394821599636, 'mloss': tensor(0.0508), 'val_loss': 0.07458092422646659, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6f0d3e75eecd497faa960a8f187c8a83",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.07357416845619531, {'val_loss': 0.07357416845619531, 'val_mloss_tot': tensor(0.0734)})\n",
+      "Densenet161_3 classifier_splits 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 552])"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 552])"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b6451bc365f044bdb03f8d95f1441ce9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "605202ef89b74cdcbbeb7fb8adae8ea0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "84ab9c38c8424b768c18ac475b570ea6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06597305889025278, 'mloss': tensor(0.0698), 'val_loss': 0.07734269903200429, 'val_mloss_tot': tensor(0.0773)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "73389d2a93e04c3f8aaa6705ff2eb093",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1e26164af9f34e9fa2c674ec055f01f4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05800519974185489, 'mloss': tensor(0.0546), 'val_loss': 0.07655769771919017, 'val_mloss_tot': tensor(0.0766)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9bed9c3e3d924716a21309934f41ca51",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cc3e863c4df3430793636272f531f384",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05817255458790379, 'mloss': tensor(0.0602), 'val_loss': 0.07479523917342104, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f0cfcc83eccd4facb6baf101c124af70",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3e57864422f44743b079a6d422f998ad",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05827131210172841, 'mloss': tensor(0.0572), 'val_loss': 0.07374378701035933, 'val_mloss_tot': tensor(0.0737)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "36d31308481a446b8ad7ba3a929c35af",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5034b2e4e990457bb83f66a13a5ec3b0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05880998499775261, 'mloss': tensor(0.0561), 'val_loss': 0.07450412283878682, 'val_mloss_tot': tensor(0.0745)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7b3979ada6f34274a7f254c2ae336354",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eb8aa55821b24f83bfec4d56f4d9487c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05557613094818353, 'mloss': tensor(0.0536), 'val_loss': 0.0730999577449771, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4b1231894bb640fc97734010304de76f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fa43963a6aa54cd0bdc961f9ed811b2c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05673775954337136, 'mloss': tensor(0.0602), 'val_loss': 0.07232062728257804, 'val_mloss_tot': tensor(0.0723)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7498bc4249bb4321a77f7ad59669dd9f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b859fc7e2ef24f22ba9610b04805668c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05483517504432824, 'mloss': tensor(0.0540), 'val_loss': 0.07306869049874566, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "556ddd88b3d24c5b94ff0a7e0235b74e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dd773ee14c824ace92cb5f35f0fe562f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05568622114955888, 'mloss': tensor(0.0534), 'val_loss': 0.07420246506950295, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b2ac5834caa24663962fdffc8aa17b99",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6b39ddde087646b0b6b40cee3f8d49ee",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05650527030797528, 'mloss': tensor(0.0573), 'val_loss': 0.07335103402232251, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bebecb877ac84366bd35be9ef0eb9e31",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4a8c260d683b44ca9974a5e90505c6ff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05611171696093718, 'mloss': tensor(0.0573), 'val_loss': 0.07339669959506065, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c7aca6cee8aa45a1a47045cf40bb8b4d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a473f2664dfc451782901112cc248628",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0527351923110712, 'mloss': tensor(0.0519), 'val_loss': 0.07278166824906337, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3f441e42d64b4d409efe20621ac6b9ff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5fa439d893ba4e72a34b8b378edbe2f4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05858057229745605, 'mloss': tensor(0.0620), 'val_loss': 0.07353354573885842, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6fad8f89f8ad4951911e6298f75959d4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a9073cd5f1634e20851ebad1e01cc933",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055208331343125354, 'mloss': tensor(0.0544), 'val_loss': 0.07463426654590521, 'val_mloss_tot': tensor(0.0746)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c1accfe3ee364d75873c00dab2373b46",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6eb7078899e34713b8fe9e756e87b926",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056743860937820176, 'mloss': tensor(0.0557), 'val_loss': 0.07429377118170988, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a4a986d39ea84bba91d011fae6c22e1c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "36a75847182242ce8b13c23d1a9ec7e9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05462875622006605, 'mloss': tensor(0.0537), 'val_loss': 0.07426565117429851, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9b2eb88687344c95a50ada4301b99ee9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e03cb28b02db4fcc87f5e62b531f6ef3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056647627921595514, 'mloss': tensor(0.0588), 'val_loss': 0.0732670552540207, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dfd892cc22ff4a37a12f3e5ed63c77ae",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "137e159b4ec347b890cc986a7221b124",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05815912819217005, 'mloss': tensor(0.0614), 'val_loss': 0.07249528116251273, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "40b4d8d273034bdda4a34fcc0d6be16a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b7ef2cc4579d44e2840ddaec3c7b669f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05462893982440514, 'mloss': tensor(0.0535), 'val_loss': 0.0734268520140975, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7f737177fc1f42d6bd78b3e77705cfbb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ea75b4909f074657bd95cd42f3f19352",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055117786650665655, 'mloss': tensor(0.0551), 'val_loss': 0.07301591209635684, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b0482c62bdaa432698b940931680ae97",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cea6742473cf4793a6ebbc41c3446c86",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05351639222484179, 'mloss': tensor(0.0561), 'val_loss': 0.07357417480202347, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2d7ea61b50bc4bdbb421dce50d76190e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e3ce1b955f2f48578009174863e13571",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053038085350086105, 'mloss': tensor(0.0568), 'val_loss': 0.07385229097633828, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7eb1834b138f4bd3905b6363c4f7222d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8e97013b40ae41cfa7ca0a098b3e164b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05152956944253799, 'mloss': tensor(0.0481), 'val_loss': 0.0734710993744978, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "badf68800076400e9f7138c75a03ede8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "682e739be0f74fb0a1a2049e86019b83",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05369598445909617, 'mloss': tensor(0.0550), 'val_loss': 0.07353208534555827, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "84fba94c276448a7bfa6c523d1cd2fb4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07232062728257804\n",
+      "Densenet161_3 classifier_splits 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 552])"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 552])"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d149b7deb4c74ff68b8354ee6ceb93dd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "41003621879c4845a601706cd9699e5d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3cdab4e1dd40474ab5d30549af416e6e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06396980442396705, 'mloss': tensor(0.0617), 'val_loss': 0.07483250640672195, 'val_mloss_tot': tensor(0.0747)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d611f8adcc23430486145eeed863ce29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "949368907b4c437fa6b85f4687a06a99",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06150023859619457, 'mloss': tensor(0.0568), 'val_loss': 0.07280183078152279, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b84d35cf54464c24acaca997d2e9ddfb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "27411fabcc7b439ab628e0d8416fb771",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06212534932548155, 'mloss': tensor(0.0655), 'val_loss': 0.07222140052131024, 'val_mloss_tot': tensor(0.0721)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a5aa76f7c4304c35a359afef9a170d69",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "23ddb452d5484215847bf7c2bb73106f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05893728310311425, 'mloss': tensor(0.0604), 'val_loss': 0.07128777329722795, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "083e4932b2404148b6831268a41b2a88",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "49f2cbc05aff455e88972f547c71da8b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05935209288500168, 'mloss': tensor(0.0585), 'val_loss': 0.07321278059974386, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6ed3670f9731424e8b583e98147fca1f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bad314f5d3e943dfbe61c01104a8ed21",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05713498165682518, 'mloss': tensor(0.0561), 'val_loss': 0.07160992384085367, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "863a3750d21d427f97177d8863eb4f65",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c55c6c5437114b26bc6cee59e2d147a7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05786703755169134, 'mloss': tensor(0.0576), 'val_loss': 0.07172710372116692, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d1e242bd318b49c4badf4bf2dcb52c47",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f08cc7694cde466cb42102bcc815ade8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05497932628919575, 'mloss': tensor(0.0559), 'val_loss': 0.07069015124436523, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1795b066f61449d7aa528868bd45395d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9c1ee97f438d4e6ea9249773a8455479",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05435640478131613, 'mloss': tensor(0.0516), 'val_loss': 0.07069991626249159, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "31b40edd1aa5435cb5b819acf081b48b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ba9a5f256ef44bbbaf8f137c701002b0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05272874757137747, 'mloss': tensor(0.0500), 'val_loss': 0.07074785957874483, 'val_mloss_tot': tensor(0.0707)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a451e46ef4bd481abc66622984622436",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f751a7c9717841d9a26eeefbd6748091",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055229884896864395, 'mloss': tensor(0.0543), 'val_loss': 0.07041004422841034, 'val_mloss_tot': tensor(0.0703)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "22d94774aed44b53b137c75375dd2a8c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bdc21fe733a54f68a9aa93f46dd24fec",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05809856235926404, 'mloss': tensor(0.0580), 'val_loss': 0.07050350384575686, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "49d352c6cda744e79959c4b8736cfc74",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0f1d9e9c779340be8bd86057b6b6fdd0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060566935267364605, 'mloss': tensor(0.0614), 'val_loss': 0.07163872729684112, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "926e56d0cb684ce0a423bbed0c30c89d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b159b6532ff14af6acdbed1b13ea4a54",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059582604315986304, 'mloss': tensor(0.0602), 'val_loss': 0.07057821930711859, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ae69dfae8677405f9257bdb1195285ba",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3f9b66b18a4e42bdb5e2983d7b16f78f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057584260244539553, 'mloss': tensor(0.0572), 'val_loss': 0.0712421777580172, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "58dca83236b54a62879455e68bbe9324",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c0b02ad0dcd7406d9558f3ec95be1a41",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055009298370109386, 'mloss': tensor(0.0525), 'val_loss': 0.07084355141572984, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f0b709e2b3f34934b50b625bbadd7a20",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a63f7d8c4eb84ce5aeb0d26bb14e427f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0575199417837798, 'mloss': tensor(0.0583), 'val_loss': 0.07171990342548328, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8669c51326b74f0aa9dc267972dd17e1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e7c39607f18641e487eb2d5532c916ca",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05522955483853907, 'mloss': tensor(0.0552), 'val_loss': 0.07078133166304289, 'val_mloss_tot': tensor(0.0707)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ed484c3d978f4bd1a56614fdebf71b64",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9432ca212118419c848d0ef6a9b56f51",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05328754366875874, 'mloss': tensor(0.0510), 'val_loss': 0.0706175214294436, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f6d4738fb6a343b9b1790c56c587e6cb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "afe794eabdf3406788f659a205c44f4b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05372766405214235, 'mloss': tensor(0.0526), 'val_loss': 0.0706223899204963, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d64caa6f9bf64add96cbe740a394db78",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d1a74357406c41bbb6ed01050ccec9ff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05502442548708901, 'mloss': tensor(0.0557), 'val_loss': 0.07082850365839861, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cba9c7a09a034f129b8ae382eeedd77e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "706f4d7628e44b978fb485e9d0cd3d3d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05528873147857635, 'mloss': tensor(0.0540), 'val_loss': 0.07059003036983054, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "383a5bba71ca4edb8e6d17e8be4ffb22",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bee42b5965ab4b2eb4c8b17b199f1481",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05648660532989741, 'mloss': tensor(0.0555), 'val_loss': 0.07073680185506423, 'val_mloss_tot': tensor(0.0707)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "30a6a505fc9e4eb4a2e2df55923b99f2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0d401a289b2d451f9abe1c69ff1fc02d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05495973177602834, 'mloss': tensor(0.0534), 'val_loss': 0.07096810642116029, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a63ee81f24e1443b8aa02aaaa586043c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07041004422841034\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in range(3):\n",
+    "    multi=3\n",
+    "    model_name,version = 'Densenet161_3' , 'classifier_splits'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta2',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "    num_train_optimization_steps = num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Densenet169_3 classifier_splits 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 208])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 208])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9142f460f35e46fe96a0e467ab138703",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9650eee1b33641cfa7e5f3883f6ef14b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cebb196980ff42b3b627eabab0c64d6c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0645742393287307, 'mloss': tensor(0.0658), 'val_loss': 0.07755747350653074, 'val_mloss_tot': tensor(0.0774)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "00f3ed462717426aa48cbba6309e896d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e46a6b3798574b21876b9407ccda8efe",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06093022012470684, 'mloss': tensor(0.0576), 'val_loss': 0.07730007135960747, 'val_mloss_tot': tensor(0.0772)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2769c9c49ff643a1adf9b1bc14216681",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c597620676e14313be97a9d3a4738e1e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06096925592521475, 'mloss': tensor(0.0617), 'val_loss': 0.07525878793816886, 'val_mloss_tot': tensor(0.0751)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "198547160d4a411987c87e9c29d452f9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c27b9177ffb14a2ca14833716b90db5a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06232639657852247, 'mloss': tensor(0.0620), 'val_loss': 0.07516170733291748, 'val_mloss_tot': tensor(0.0750)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7b7f4a1bdaee4132ba8f396c81483672",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8ea64f076ba04ebd9d0b7a8f9ee007af",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0543787243659503, 'mloss': tensor(0.0495), 'val_loss': 0.07635244332813644, 'val_mloss_tot': tensor(0.0762)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "66948c979f0b4dc5a9f87a47345167d2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7ad2f29472a144c3a2a137c9ff93c6e3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05921730220720486, 'mloss': tensor(0.0595), 'val_loss': 0.07481655244499885, 'val_mloss_tot': tensor(0.0747)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "608d4ef3d0064462b41879a27f99a027",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "43c1950720f246e282bd172340255c95",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059993486182169244, 'mloss': tensor(0.0623), 'val_loss': 0.0743512842320181, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "44bfce27669a44fba40afd6166c443dd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8e1237accf0e42d2be83b4744abc1ec2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06037190207026579, 'mloss': tensor(0.0621), 'val_loss': 0.07422575583277738, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "85cb1d55cf644078bf3ceee771da915e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "26e7d2cf2e7b40b0aa287e08e46a7e9e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05532560140138827, 'mloss': tensor(0.0568), 'val_loss': 0.07520591900936269, 'val_mloss_tot': tensor(0.0750)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cea278c2721148afbf51feba7db31f7e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3cf3eb68d04b4e10859f1c05b7d7772b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0595299218338653, 'mloss': tensor(0.0608), 'val_loss': 0.07417779426283107, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ecfe6313ec84426299e6982a74cebef2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4a118b53a2154dfd979235c07894c40d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05790153664021978, 'mloss': tensor(0.0602), 'val_loss': 0.07420345913421528, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a28e8465a01d4063ae2a6fb413e7478c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "82711d47d0d84ef4b7401b7be779d2f5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05484075078884827, 'mloss': tensor(0.0539), 'val_loss': 0.07440526835423128, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "69a3aa328cf24cd6809139eac10b9ebe",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9227d880f6c44aaba89832bacb4d9b6f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06346585310683533, 'mloss': tensor(0.0680), 'val_loss': 0.07456420548145518, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5f0a7c75f943411b8e8560ff911a68e4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bbe2a815dd6345f7b1d45664d98d90dc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06026519884663737, 'mloss': tensor(0.0583), 'val_loss': 0.07494014723965063, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "93918c0cba2342bfabe3e7959111a224",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ba51d3e840a642cd89221d35dd96aa68",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05911734149366065, 'mloss': tensor(0.0563), 'val_loss': 0.07699440680490119, 'val_mloss_tot': tensor(0.0768)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e4db0240beb4431dadd2c204f694fe80",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "89837d1e137d4188aa2f20f10aad46ce",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057378733322550855, 'mloss': tensor(0.0547), 'val_loss': 0.07416969217965609, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b00044cb40a24b6cb6bc5f37932aaa0f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "13760b3619164b2f89e8d999a1837c9e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05760934769717428, 'mloss': tensor(0.0575), 'val_loss': 0.07488289184460022, 'val_mloss_tot': tensor(0.0747)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "07bf9ea65a2740d28a6d0a6fc303a678",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e356611c0e78424f957446d0bc63ed8d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056431262550411566, 'mloss': tensor(0.0557), 'val_loss': 0.07423439399761336, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ec8954d5d89c408896e1693b77956d1d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b07b3c34e0884691bbb194538cb4e903",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057571087611157236, 'mloss': tensor(0.0584), 'val_loss': 0.0742358320869799, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7b96b635e15748d495abaa11ab1ae73a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4eba0e07e9b54d76a5611e95d29300d6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05849957679851013, 'mloss': tensor(0.0569), 'val_loss': 0.0740265137610012, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9588ddaef0e54f2391f4625fbde2c7e7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "415a0aa38d1e4935b853984c1362555f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056775434195440615, 'mloss': tensor(0.0575), 'val_loss': 0.07392963684815455, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dbd4f3864129446a8303e3571b958152",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "facc0fe7d906490eb6326a418ca3625f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055825571514041765, 'mloss': tensor(0.0530), 'val_loss': 0.07438085537192783, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "19283909dfb44bf2bfb6ddb04685c8c5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7a8ab4c6059b4e5f9542ca33c965cfa7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05437830455313461, 'mloss': tensor(0.0543), 'val_loss': 0.07443589000191925, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f09ef1c1afda4dcba2bd5fc67c2a5843",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2bcf7926ab5848d3a78d0e15a8108c8f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058227417368495016, 'mloss': tensor(0.0607), 'val_loss': 0.07416833246131359, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ee55addecf3b41d29ee890371c9f2f1b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.07373551576824978, {'val_loss': 0.07373551576824978, 'val_mloss_tot': tensor(0.0736)})\n",
+      "Densenet169_3 classifier_splits 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 208])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 208])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a05e9bd04b804401a1a612decf891fcd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cf6a3032630740ca9784018eafec26c0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5377a6fe14ad479d98aab1897bd6dfe3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.07105362750797016, 'mloss': tensor(0.0694), 'val_loss': 0.07651247644978688, 'val_mloss_tot': tensor(0.0765)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "afaa9f4fbebb4eefa2f0b07398301f67",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9af77fab59f844898ccb8638dbe18124",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06730027550889715, 'mloss': tensor(0.0670), 'val_loss': 0.07467023952659674, 'val_mloss_tot': tensor(0.0746)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1b1e4f59b3f2455a8c096a8a7d78d3c4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "988ca6779ba9460e94fbc3ee68f9ec92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062364376632078986, 'mloss': tensor(0.0570), 'val_loss': 0.07519180190908473, 'val_mloss_tot': tensor(0.0752)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a875c6be4816461fb101461f4e3a92e6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4edca19e991d4fef9f00831d2de690d7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06356795765666555, 'mloss': tensor(0.0614), 'val_loss': 0.07417154427754079, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "02ac41ae710b4eb5bfce8827ee16d264",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f4511d078def41a78ccafe0139d0c4fd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06607736408630102, 'mloss': tensor(0.0700), 'val_loss': 0.07352933139744691, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ac3493a4dbdf433b8128ea556a9dac5b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a4fc4f73f01a45728a342ae35895858f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06547841501166639, 'mloss': tensor(0.0651), 'val_loss': 0.0729903318987387, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "62dfb3761dbb424cb9c3e94001fad02c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3035d361d69e4494800187a606949930",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06559646791577056, 'mloss': tensor(0.0660), 'val_loss': 0.07289133834221014, 'val_mloss_tot': tensor(0.0729)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f824896e08a6498398d4ecabd927fbee",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8ed3d04d30da442dafd0e958ccd7efad",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060793095184360324, 'mloss': tensor(0.0587), 'val_loss': 0.07347987681718135, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3419e3db88e54cd7a48e8989feec61fd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c0fb9e3f014e48d2a36d94b5f28aa429",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06248185380692323, 'mloss': tensor(0.0626), 'val_loss': 0.0731690767857178, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "894a06e432c449079b6a7c69746d3482",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6630e1988a53440689b531024a030ccd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06359621590530079, 'mloss': tensor(0.0620), 'val_loss': 0.0727032814106745, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c27ab2c9997d4d2499f4cb01680cf59b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e3d64156c8744b19b48a946a8b760d0c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06467542100695214, 'mloss': tensor(0.0685), 'val_loss': 0.07284138554298296, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4e3d8e47aea6405ab8f2c18655f69693",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1c8e338f802044c78f8a2ae94cffa3da",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06017261932859889, 'mloss': tensor(0.0554), 'val_loss': 0.07273818657892506, 'val_mloss_tot': tensor(0.0727)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "69dbb02d04514dfebf12bb7e4c220985",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5c62b758957b4ba0acd01dfa8d82518c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06111985380811217, 'mloss': tensor(0.0588), 'val_loss': 0.07491131747459492, 'val_mloss_tot': tensor(0.0749)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e81d32aed4bc4c09a73129fbad4f07ca",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0f8c70ac553f4173aee539a61d2a3b92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06308568776948033, 'mloss': tensor(0.0618), 'val_loss': 0.07406483579699587, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fa2c21520b8c4a58a4c60f9a9d6bb338",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5147ba03f8fd4bbea1bf0c86c603a115",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06343747297994655, 'mloss': tensor(0.0628), 'val_loss': 0.07385099867540525, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "df852120bc1b417b8fe890a42b025461",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8f02d187b6fa458a84b67985845b7591",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06216585118609157, 'mloss': tensor(0.0580), 'val_loss': 0.07382478498649306, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bce2482b988e4087a72beb6b3a85cd0b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "42887ebb9bbe4d7583fafd8d1deba595",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06340876707310748, 'mloss': tensor(0.0663), 'val_loss': 0.07296157216471506, 'val_mloss_tot': tensor(0.0729)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "316b6021a9174219a96b45644086ca9f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "37e9c6dfba9c44a593e2b720b8f89164",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06183836952827733, 'mloss': tensor(0.0618), 'val_loss': 0.07253359901014625, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5308420208b74999aeee55a92a7bd339",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "47f381c4d46a4540b58989f283feb8b8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062132474005875765, 'mloss': tensor(0.0634), 'val_loss': 0.07294169472212472, 'val_mloss_tot': tensor(0.0729)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ca53502c131046fe92ee6f6383f3c5fd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "189ee481ded44999b76e5f06be9271fd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0624612848134798, 'mloss': tensor(0.0643), 'val_loss': 0.07336700662925112, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d88ae2f8938e488d888fc64a7eed175f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "98d0f1f95aa4448b9591972ed6e9169c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.064248607895717, 'mloss': tensor(0.0654), 'val_loss': 0.07245329433552376, 'val_mloss_tot': tensor(0.0724)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "39f8311c12f44a35b518ec82000bb8a9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd46145387fa497389330a60aaf5a3df",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057301209458834386, 'mloss': tensor(0.0538), 'val_loss': 0.07283733674756637, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3880ef3e9ef64ae5847f3d6fb27a5a8b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c60894793ead46ca8f4aea468af97dae",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061330595898874414, 'mloss': tensor(0.0647), 'val_loss': 0.07254788946905515, 'val_mloss_tot': tensor(0.0725)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a3746d7d5e9941df9b2c97d96796499f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "28ecaa30f4cc48119cdc7effeace16a6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05807849469376063, 'mloss': tensor(0.0546), 'val_loss': 0.07257210482593353, 'val_mloss_tot': tensor(0.0726)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6da8d496c63b4692b2ddb9f0df8a7156",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07245329433552376\n",
+      "Densenet169_3 classifier_splits 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 208])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 208])"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c51542e64ae5431aa3096e7eff48a6bb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2c042e4c9bf94a5cbb4c0b74bf31be85",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "38cba772b35c4005ae4f3725b856b16e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.07344124343409475, 'mloss': tensor(0.0766), 'val_loss': 0.07448363920049787, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e6aae041578546a89e8b696657a1cc4d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9138409cb4f14c04b0e326248ef60b1e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06662308269806132, 'mloss': tensor(0.0670), 'val_loss': 0.07363454110718566, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a33a91cc8064417697eeb2b3a0b5d6b8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "64870529d3b34d579f729680e7d923bd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06420842338318493, 'mloss': tensor(0.0611), 'val_loss': 0.07345911426460508, 'val_mloss_tot': tensor(0.0734)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "38b171aa3ff54e89b8b6286b65fad10f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fb6aac9d4bbc47169b5433135e5934e4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06312026742897688, 'mloss': tensor(0.0620), 'val_loss': 0.07326044370947873, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7505e7b37fc0446a90a98f1594a6790a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7e18a1c99bdb4c0aabd21700f1151e9e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06341106770851614, 'mloss': tensor(0.0641), 'val_loss': 0.07308658373913741, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5f5cfc908b264870a45c15580c90ae12",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1233da7c26414591b6a54c40555a60a1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0639191241871452, 'mloss': tensor(0.0654), 'val_loss': 0.07150505649458085, 'val_mloss_tot': tensor(0.0714)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "181294e4e6264984a01fc5c085817153",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b8fab4ff2a1a43a89955b5c84181d122",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059116884816944555, 'mloss': tensor(0.0584), 'val_loss': 0.07185135577837455, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "98057bfb3c5d4f94b0614eab13038ac0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4f69ddd1b2f141a996ece32cb34c3bf8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059205026079564266, 'mloss': tensor(0.0585), 'val_loss': 0.0720476537479519, 'val_mloss_tot': tensor(0.0720)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "21a0824c8b1245d88ce1ed1ffd37e8de",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "64738792ee314a738098ff02cf2191e2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060264352965870474, 'mloss': tensor(0.0613), 'val_loss': 0.07094660936051915, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d56d67d4494e4f9babcecb61a3423f4e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "daa27cdb5b724661b33e6bde63ba5cf6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05827179999620854, 'mloss': tensor(0.0577), 'val_loss': 0.07089367524027972, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "057f079329474a1a9547a6f6f04e46fb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "356d4f6e9e5e43838fae8b457e727f84",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05913852160548659, 'mloss': tensor(0.0615), 'val_loss': 0.07101326242887548, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6b9d5e6ebc63401f9ad50eee7d112c90",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e3f4b06ec424409cad10bb1564809658",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06186301728520109, 'mloss': tensor(0.0636), 'val_loss': 0.07099969723156374, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d487ff827558431eb4bf10bc8108c8e6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "81871be048d74abc9dfe88aa0e724fd5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059128033597373716, 'mloss': tensor(0.0594), 'val_loss': 0.07200815327685749, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "29777169882a4fa9acefc6dfc32a0e3f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c4daa894d7d0434aa4085d77a5b7e1bb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06222716455325623, 'mloss': tensor(0.0644), 'val_loss': 0.07229821337616356, 'val_mloss_tot': tensor(0.0722)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "44fe24de06124cd69b68f37621a6f45b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "504035e32b4e4fdb80320a3f615e2aae",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06301096527071909, 'mloss': tensor(0.0626), 'val_loss': 0.07194839478493324, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b310b4af107a441f87aeec36c62ab303",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7572d39de12e48c1852ba8d9584ec712",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058587837861350724, 'mloss': tensor(0.0597), 'val_loss': 0.07158679085935057, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0d6cc005b0fb4af8b09b323e83971a92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eed4dd997fa644deae10c94a96f8af7c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0577751585205559, 'mloss': tensor(0.0561), 'val_loss': 0.07104657423852259, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b2c2208e38874fb4ac5f43033caefc93",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1ca031e7924e4c7480bb0868bbc262f2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06172878918529777, 'mloss': tensor(0.0645), 'val_loss': 0.07120306732756972, 'val_mloss_tot': tensor(0.0711)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c30bc057ac7c49d198e86089d6674843",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4b90303727844b9c8aefd0650049609a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062416932213858106, 'mloss': tensor(0.0657), 'val_loss': 0.0708561903439242, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2a56561581724c57987322448992618f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "27b4507392f34adc867c0c677997ecd1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05927560621089667, 'mloss': tensor(0.0571), 'val_loss': 0.07088182524553147, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "faa3fcf1bc2a4bab9a7f0712d1619136",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0feb084e9ff048fe9f45370134122cfc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05872135664114991, 'mloss': tensor(0.0601), 'val_loss': 0.07057941667139163, 'val_mloss_tot': tensor(0.0705)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8ff0f4bca0804ba3846c217698124963",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "64e1ceaf00c246f68d0ca2a2ec8a3098",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05921057431195189, 'mloss': tensor(0.0589), 'val_loss': 0.07125711803705308, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d7f9997744544ae3b5e12092df7bd36d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bf5267949f8541199c5d6c95e769f3bd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05644550363422832, 'mloss': tensor(0.0554), 'val_loss': 0.0708550826389471, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7a0056e7b730450ebbd7d4300614bad6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d9793576e5784a8c99735edcdae6cfa3",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05949881642259972, 'mloss': tensor(0.0622), 'val_loss': 0.07083668288192138, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b5520c31975640e08b6c249582ef3b2c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07057941667139163\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in range(3):\n",
+    "    multi=3\n",
+    "    model_name,version = 'Densenet169_3' , 'classifier_splits'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta2',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "    num_train_optimization_steps = num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Densenet201_3 classifier_splits 0\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 240])"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 240])"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e3782c0d60164186b00e1186f3b857ba",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fbe60d94f0044bf5a86493083ec9a56d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c8f9efc25ef24695b090a95532091467",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06714545603811183, 'mloss': tensor(0.0683), 'val_loss': 0.08121688730711935, 'val_mloss_tot': tensor(0.0811)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9f14d501339243909e0857860d6c374f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "268663b75bc94846bb32c2ca90cd45b4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06639542785040772, 'mloss': tensor(0.0713), 'val_loss': 0.07579567481584737, 'val_mloss_tot': tensor(0.0757)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cbb2b3988c2a4504be78f79d19b754ea",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0636a80ba14649e1b4efdc070ca2549b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06434485560830698, 'mloss': tensor(0.0675), 'val_loss': 0.0756614805644222, 'val_mloss_tot': tensor(0.0755)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "41880dfadaa143349a86b6ffce24570a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "990a0208ce4d43408988b65a16a924f0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06501446233093454, 'mloss': tensor(0.0635), 'val_loss': 0.07803845300862171, 'val_mloss_tot': tensor(0.0779)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a5814a7ac0e24ce0af917127d117dab5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2d3610c1f3964ee1867db0e70d0c9e23",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06126078489662714, 'mloss': tensor(0.0619), 'val_loss': 0.07449463286822469, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ffac3c705b014dc4b50415f1ae32a51c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b8c883efa87547e6bd2b2f9ba385e8e4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05816467110965581, 'mloss': tensor(0.0603), 'val_loss': 0.07454177092191915, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3d22f6c7ea5f47689e5792c88f300ad2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f9274051ec784255aedbf0446bcfcfab",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0565368522008708, 'mloss': tensor(0.0552), 'val_loss': 0.07493635092683026, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fc43d5782cfb4c93b25e06d31b2ceb38",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "782cc809e7dd4a80844c5998f15bff29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05618237833735686, 'mloss': tensor(0.0553), 'val_loss': 0.07491349378541075, 'val_mloss_tot': tensor(0.0748)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "14e9cbb9ef484e619b82d24deb98f434",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "36ae8fd7bcbb4bf280cfe9e3afd02807",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.058813441245742296, 'mloss': tensor(0.0579), 'val_loss': 0.0744921877243891, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2a4e6529fc7a42dd858e103d0af71141",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4e9f1bb4ab644fd683bf0af9b3971696",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05915064492343875, 'mloss': tensor(0.0609), 'val_loss': 0.07393334815141016, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4273a998573b45c58184f08d748690ef",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7da8658a5cbb4fe39ba1f2b18d6120f5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05860633592110833, 'mloss': tensor(0.0621), 'val_loss': 0.07418951835813634, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6335fd1b6145437f986241761250b706",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "44180707a9034f4f9b9833210fa5d6b4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05656608644875216, 'mloss': tensor(0.0558), 'val_loss': 0.07398575787395789, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5d649863095741b0aee9f4d477eda23c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "baba06405c764814ba31a1c4accef238",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05672952532644869, 'mloss': tensor(0.0546), 'val_loss': 0.07657721815768427, 'val_mloss_tot': tensor(0.0764)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7383b13ea9614a5bbb3ec18b058cd328",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "232e31fd05f04769b862b4e765e582bc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060239316070715726, 'mloss': tensor(0.0623), 'val_loss': 0.07445015643466392, 'val_mloss_tot': tensor(0.0743)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c032db7d29dd42b6aadb594d359bde77",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f82c51e8eddb4b729dc8c1250407f709",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05522729377563791, 'mloss': tensor(0.0508), 'val_loss': 0.0774409837763885, 'val_mloss_tot': tensor(0.0773)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ea059da14ad3497b9b3509fd740b9b29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "284ec595d6204ff7a196c4260f2a5c32",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05642039656611328, 'mloss': tensor(0.0546), 'val_loss': 0.07505481231388787, 'val_mloss_tot': tensor(0.0749)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0e0304408d5b45b685ede23ef15e59d1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3534393306ee4d5bb50fc78be636939d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054118776829053424, 'mloss': tensor(0.0526), 'val_loss': 0.07521575793802848, 'val_mloss_tot': tensor(0.0751)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "57b4d3f9571c46f9ad6458f80de7764f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "97c447f982244d8cb7db6541cdba406b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056658329746729554, 'mloss': tensor(0.0577), 'val_loss': 0.07453208685086425, 'val_mloss_tot': tensor(0.0744)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "142a99e26a2847f1bb06bcc84bc47ad4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "61a06c84a0bc43879e8c2142cefd645c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05610415092997552, 'mloss': tensor(0.0531), 'val_loss': 0.074837266302351, 'val_mloss_tot': tensor(0.0747)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e967780a1fa549c4b9bed162faf2648d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0a43cb5564c34a4b8d66c88a385d775a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056013785436985555, 'mloss': tensor(0.0572), 'val_loss': 0.07416649999583295, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "95f209b708084e93b1b5552928346a73",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "affe5fb6a82547e18e12e639c02de536",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056454736975423886, 'mloss': tensor(0.0573), 'val_loss': 0.0743936681053303, 'val_mloss_tot': tensor(0.0742)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d03b9e830d7947489d780bae92187229",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4d46ae2036e14df49a0a39c6633f079b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05601119608737044, 'mloss': tensor(0.0556), 'val_loss': 0.074086194048209, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "95ca72de444e49389e550d2dce28f0af",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5ff7d94454364d6bb376d66b20915478",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054652329200176826, 'mloss': tensor(0.0549), 'val_loss': 0.07469926805354728, 'val_mloss_tot': tensor(0.0745)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "79f0496772064fbd8c51627eda3dd440",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=816), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d60e2a0850134142b1ac3deeedd75307",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05461070586356285, 'mloss': tensor(0.0547), 'val_loss': 0.07415003531837258, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5429aae2b3e849a6ae6642ddf6b5a124",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07393334815141016\n",
+      "Densenet201_3 classifier_splits 1\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 240])"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 240])"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "38248e3e5ad64e98a403963549c96424",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "723c1eafd77d40c4898c08e5538b34cd",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "18c776ba0b33466284dee1c6d38c78c0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06439249240249831, 'mloss': tensor(0.0629), 'val_loss': 0.07567418642432952, 'val_mloss_tot': tensor(0.0756)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e30600c0c7d54780af043cc9c084acc2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "bd8f4c7542da49f2b5d58ce673e06df8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06245851952491162, 'mloss': tensor(0.0626), 'val_loss': 0.07639425312873067, 'val_mloss_tot': tensor(0.0764)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "123b5f6fb3464e45a170f8051c2ac0dc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6d7e4e9f778846919136d3659c32dfed",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05954967619186729, 'mloss': tensor(0.0560), 'val_loss': 0.07668027484153465, 'val_mloss_tot': tensor(0.0766)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3300dc0047d44b68a14edc78da3e29e5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fc74b4d72a314ae7ae26652888fda8e0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055818062697264, 'mloss': tensor(0.0521), 'val_loss': 0.07476054853191827, 'val_mloss_tot': tensor(0.0747)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9498b82369d1486cac4f3e3ad27ec9e6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "027d49b8c0334d17bc5d67c53f3438e2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05982575788926721, 'mloss': tensor(0.0567), 'val_loss': 0.07415224632657155, 'val_mloss_tot': tensor(0.0741)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7759af080be94779b8779c7e68dba9d4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a1c311f4fa0f467e8bbfa90057a22eb7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05983107311968058, 'mloss': tensor(0.0627), 'val_loss': 0.07323539868769485, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2a51a522ed31440baa5708d69fe76015",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4607cc8af4f74e3684711d17fce4bf2e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057032653698383784, 'mloss': tensor(0.0553), 'val_loss': 0.0757853529631819, 'val_mloss_tot': tensor(0.0758)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2787abfefac841c69a1afb26fb526e5d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a67c5444e6684565b233560a53b47e6f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05466835569775774, 'mloss': tensor(0.0524), 'val_loss': 0.07333007520519015, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5f3b800ea2474dd8ab1cf3bd107f23b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3cf0c3b1c8c2460c93aacdf13b2a06ae",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056447752398912296, 'mloss': tensor(0.0567), 'val_loss': 0.07357249872463688, 'val_mloss_tot': tensor(0.0735)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "30448d970e5f44faabf18359f4ee65df",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d0b2877972d04de98dc2733fabcaf90a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.059043416484971474, 'mloss': tensor(0.0611), 'val_loss': 0.07333431949657275, 'val_mloss_tot': tensor(0.0733)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "016ed6ed2cad4454a3bd011998b14e52",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "32075bb860ea4b8788d308d4933511ce",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05655126065467185, 'mloss': tensor(0.0582), 'val_loss': 0.0731954549393821, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eea1dc20284d421ba2757ab5086ad860",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "47c11ad27d464a18ab03c8a260c63a57",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05590760359164867, 'mloss': tensor(0.0546), 'val_loss': 0.07314455393792653, 'val_mloss_tot': tensor(0.0731)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "49b79eca09bc48898851304246ddf6d7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "db48ba5188b94afaa47a63edc984546b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05987552155081188, 'mloss': tensor(0.0615), 'val_loss': 0.0730639233231181, 'val_mloss_tot': tensor(0.0730)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7ac4126686ac41c4bfc172d2a3686b40",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7061daf2b0164765913bec1f9ec2a165",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060542488508901864, 'mloss': tensor(0.0655), 'val_loss': 0.0732335521953135, 'val_mloss_tot': tensor(0.0732)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7322a78fb2b84cf79bb38e6b03aa1386",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5b2b8178f832488d87fcbe2b94701c1c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057100788540008364, 'mloss': tensor(0.0551), 'val_loss': 0.07452967399731278, 'val_mloss_tot': tensor(0.0745)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3da8ac664fc741d7a4aa76a7d5360f47",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4711d2f5469c479c876d15fea6c30c7c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05559522358732739, 'mloss': tensor(0.0558), 'val_loss': 0.07359342729990802, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5e6a7dbba9d44581896c139597512415",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "062e72d449c3405688cab4cc0c407554",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057572012485259005, 'mloss': tensor(0.0589), 'val_loss': 0.07407739510896003, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "76b9741ade7b4e1789c4fa63be62bff6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "02071f08b3b348bfa5ba67c7824a753d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054505876520617315, 'mloss': tensor(0.0531), 'val_loss': 0.07467929229662731, 'val_mloss_tot': tensor(0.0747)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "29333b9470604bc78483cf409bf28e41",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ee3fc72d04e2463895840dcd6aa1e924",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05522564051416929, 'mloss': tensor(0.0522), 'val_loss': 0.07378836331502875, 'val_mloss_tot': tensor(0.0738)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "555d338bc7f84959b8d3241e17b78f4a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "618791befb5242d195a38c4ee5743b9b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053743294800223516, 'mloss': tensor(0.0494), 'val_loss': 0.07376038509012177, 'val_mloss_tot': tensor(0.0737)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a8430dc2b2d7433da20d31b1fb93af78",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e66068ae6caa41148bef02b15065cdee",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057698717112048714, 'mloss': tensor(0.0581), 'val_loss': 0.07365814052045164, 'val_mloss_tot': tensor(0.0736)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6f53433b1115400abfe96e40126b4117",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "86f95df11f2a4c6db5481ae046614cfc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05388140794255837, 'mloss': tensor(0.0541), 'val_loss': 0.07403294937344404, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ae236af88c8d429e90750a97119e8222",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9947af1c989046d988b4f89790a0a9f9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.054101768064043994, 'mloss': tensor(0.0529), 'val_loss': 0.0739659377873489, 'val_mloss_tot': tensor(0.0739)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "cadb9f8384f44c3f93338f5b67872ec4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=812), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "651a40d6776845ef8250a7305fa73379",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053977602862547366, 'mloss': tensor(0.0538), 'val_loss': 0.07404084446521976, 'val_mloss_tot': tensor(0.0740)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9a9ce9fd64c647a5920538fe9687d6f4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=410), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(0.0730125388789286, {'val_loss': 0.0730125388789286, 'val_mloss_tot': tensor(0.0730)})\n",
+      "Densenet201_3 classifier_splits 2\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2697008, 240])"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([674252, 4, 240])"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fab07694750>"
+      ]
+     },
+     "execution_count": 42,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "764277eba34c450e93583b7ee41d1b49",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=24), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a03344b3c1124e55a689554ea7f21a3b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "576b3805a9d9437aa4ec6c455e6ae82f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.062269154649983866, 'mloss': tensor(0.0597), 'val_loss': 0.0757758161695353, 'val_mloss_tot': tensor(0.0757)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "09ef2978a2cb4287a57d3dd1320bd15f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7b6f7afd801a44f29524ea54f6dd154f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05969917109269863, 'mloss': tensor(0.0590), 'val_loss': 0.07204735779153008, 'val_mloss_tot': tensor(0.0720)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "86d446e438d44542b735deef696626f2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f66dfcf8998e4bada7f8ea4087dbf55e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06275483938021034, 'mloss': tensor(0.0638), 'val_loss': 0.07184433003456281, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "aa97510f7bb147f09869e2160a73f4a9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5c667e3048b2411f90b2272da464dd2f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05853381524892932, 'mloss': tensor(0.0555), 'val_loss': 0.0728481894845299, 'val_mloss_tot': tensor(0.0728)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2c697b449faf48c594891ae5df2539e8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "faacfbc8e8424e109a7845909f67c5aa",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.061119686382003054, 'mloss': tensor(0.0634), 'val_loss': 0.07158966145014821, 'val_mloss_tot': tensor(0.0715)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1e24b55627194208a292b928afd26e05",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "34d10ee428824ab69b03e0cb64c3c24c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05808279283155957, 'mloss': tensor(0.0576), 'val_loss': 0.07141277121191872, 'val_mloss_tot': tensor(0.0713)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7073dd73fea64df992577a786ad118c6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7efb71bc09434170b59336e3b1eee143",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05983334281095374, 'mloss': tensor(0.0617), 'val_loss': 0.07198681046709125, 'val_mloss_tot': tensor(0.0719)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b64948ed76714a5b81bb3f2d6c947232",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "38857f5610934ce18738b56430243b79",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057874279286490034, 'mloss': tensor(0.0589), 'val_loss': 0.07095855512665447, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dab76d6a55e24d4da67d010826f46c3b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "357a7ea2b8b24d86b097d52252c71aa8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.056846124484918194, 'mloss': tensor(0.0572), 'val_loss': 0.07051182455690623, 'val_mloss_tot': tensor(0.0704)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6181ccb0e9f24285b038ede6ab468e02",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b291b9c3d0f54628a17b6951a0991bcf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.057616001578078294, 'mloss': tensor(0.0604), 'val_loss': 0.07093305741046817, 'val_mloss_tot': tensor(0.0709)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "d56dcee70d164d31979f3bfb3c673da9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c929bd544074cea801562bce955e578",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05637851310015766, 'mloss': tensor(0.0562), 'val_loss': 0.07083558652715143, 'val_mloss_tot': tensor(0.0708)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "86c1c22718f74b698263cee0e19a0c4c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dc6dbe25e66b4494a062bf428e770dd8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05814679219216533, 'mloss': tensor(0.0569), 'val_loss': 0.0712252360300661, 'val_mloss_tot': tensor(0.0711)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0c089a1b0b6640a689d4ab87869d864d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4e5f364929b047d89729ae510cbd8350",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.06040912248347267, 'mloss': tensor(0.0625), 'val_loss': 0.0716813850386316, 'val_mloss_tot': tensor(0.0716)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "dd4e4ee2675a400c948c3748f2954f4f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "844bf6c0292843568876f013cb3a91d8",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05905177180961919, 'mloss': tensor(0.0603), 'val_loss': 0.07110046382690666, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "06cb26ff44db413ba2b01f5a5b81259f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b427af0b90d94f9eb4c4e730963243a5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.0572230759251108, 'mloss': tensor(0.0555), 'val_loss': 0.07142122062221352, 'val_mloss_tot': tensor(0.0713)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "101e5d27e39743e38e471fbd23a2d71a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "05fe0bc1753d45b189482e5a9c6a36a5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05545345721461106, 'mloss': tensor(0.0534), 'val_loss': 0.0718750449119989, 'val_mloss_tot': tensor(0.0718)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4c9c60d27ebd4e0889c034c0d3663dfa",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c9b49339105a486ca650fb621dc264bf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05736496876334686, 'mloss': tensor(0.0585), 'val_loss': 0.0711389344226881, 'val_mloss_tot': tensor(0.0711)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f4fb0b1a4d974cd8a6e5e02869a03dec",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "13c2049e33c4498287c37920b75d631c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.060103811381215816, 'mloss': tensor(0.0619), 'val_loss': 0.07067023349118423, 'val_mloss_tot': tensor(0.0706)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8910625f93d44de19f905d9f760e5a36",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eb3e9588088f4577aa20d4377bf91fdf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05747771910879523, 'mloss': tensor(0.0562), 'val_loss': 0.071265555529401, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1018f819e31d4bb788015006149852be",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3b529602322f4ceab2de41b8d2eea1e6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05607276650632766, 'mloss': tensor(0.0577), 'val_loss': 0.07131397596496189, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ec621e4dc3d94974b1f54904cd07f184",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b4049ea0fa3b4570b124c8d2a92c8913",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05368155506728136, 'mloss': tensor(0.0517), 'val_loss': 0.0711224768608992, 'val_mloss_tot': tensor(0.0710)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1ca32b1446a945fe97016e0cc269a460",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0f2fb1942290492dbde1a81e41619a56",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.055210776814632045, 'mloss': tensor(0.0557), 'val_loss': 0.07130228077047711, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e339f8345f2d4f08ad773da9bdbcbf67",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c983dec254c54f38934931efd77f9bbc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.053513703152002816, 'mloss': tensor(0.0510), 'val_loss': 0.07124504791412117, 'val_mloss_tot': tensor(0.0712)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0dbdd0c3bfcd4abaae52171de6493b67",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=815), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "17ebd719dafe41c1b88417260c5cd3d1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'loss': 0.05625786708254494, 'mloss': tensor(0.0582), 'val_loss': 0.07147224175197811, 'val_mloss_tot': tensor(0.0714)}\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ba10e62097c24fbdb179a4c67ebb312c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=406), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.07051182455690623\n"
+     ]
+    }
+   ],
+   "source": [
+    "%matplotlib nbagg\n",
+    "for num_split in range(3):\n",
+    "    multi=3\n",
+    "    model_name,version = 'Densenet201_3' , 'classifier_splits'\n",
+    "    print (model_name,version,num_split)\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_train_tta',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features.shape\n",
+    "\n",
+    "    features=features.reshape(features.shape[0]//4,4,-1)\n",
+    "    features.shape\n",
+    "    split_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].SeriesI.unique()\n",
+    "    split_validate =  train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].SeriesI.unique()\n",
+    "\n",
+    "    np.random.seed(SEED+num_split)\n",
+    "    torch.manual_seed(SEED+num_split)\n",
+    "    torch.cuda.manual_seed(SEED+num_split)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "    batch_size=16\n",
+    "    num_workers=18\n",
+    "    num_epochs=24\n",
+    "    klr=1\n",
+    "    weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n",
+    "    train_dataset=FullHeadDataset(train_df,\n",
+    "                                  split_train,\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',\n",
+    "                                  hemorrhage_types,\n",
+    "                                  multi=multi)                \n",
+    "    validate_dataset=FullHeadDataset(train_df,\n",
+    "                                     split_validate,\n",
+    "                                     torch.cat([features[:,i,:] for i in range(4)],-1),\n",
+    "                                     'SeriesI',\n",
+    "                                     'ImagePositionZ',\n",
+    "                                     hemorrhage_types)                \n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_{}'.format(multi)\n",
+    "    _=model.to(device)\n",
+    "    #mixup=Mixup(device=device)\n",
+    "    loss_func=my_loss\n",
+    "    #fig,ax = plt.subplots(figsize=(10,7))\n",
+    "    #gr=loss_graph(fig,ax,num_epochs,len(train_dataset)//batch_size+1,limits=[0.02,0.06])\n",
+    "    num_train_optimization_steps = num_epochs*(len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n",
+    "    sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=2,tau=1)\n",
+    "    optimizer = BertAdam(model.parameters(),lr=klr*1e-3,schedule=sched)\n",
+    "    history,best_model= model_train(model,\n",
+    "                                    optimizer,\n",
+    "                                    train_dataset,\n",
+    "                                    batch_size,\n",
+    "                                    num_epochs,\n",
+    "                                    loss_func,\n",
+    "                                    weights=weights,\n",
+    "                                    do_apex=False,\n",
+    "                                    validate_dataset=validate_dataset,\n",
+    "                                    param_schedualer=None,\n",
+    "                                    weights_data=None,\n",
+    "                                    metric=Metric(torch.tensor([1.,1.,1.,1.,1.,2.])),\n",
+    "                                    return_model=True,\n",
+    "                                    best_average=3,\n",
+    "                                    num_workers=num_workers,\n",
+    "                                    sampler=None,\n",
+    "                                    graph=None)\n",
+    "    torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version,num_split))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    " # Mid Inference for sanity"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "pred_list=[]\n",
+    "pred_list_tmp=[]\n",
+    "for num_split in tqdm_notebook(range(3)):\n",
+    "    model_name,version = 'Densenet169_3' , 'classifier_splits'\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_3'\n",
+    "\n",
+    "    model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "    test_dataset=train_dataset=FullHeadDataset(test_df,\n",
+    "                                  test_df.SeriesI.unique(),\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',multi=4)\n",
+    "    for i in tqdm_notebook(range(32),leave=False):\n",
+    "        pred_list.append(torch.sigmoid(model_run(model,test_dataset,do_apex=False,batch_size=128))[...,None])\n",
+    "        pred_list_tmp.append(pred_list[-1])\n",
+    "pred169=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "pred_list_tmp=[]\n",
+    "for num_split in tqdm_notebook(range(3)):\n",
+    "    model_name,version = 'Densenet161_3' , 'classifier_splits'\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_3'\n",
+    "\n",
+    "    model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "    test_dataset=train_dataset=FullHeadDataset(test_df,\n",
+    "                                  test_df.SeriesI.unique(),\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',multi=4)\n",
+    "    for i in tqdm_notebook(range(32),leave=False):\n",
+    "        pred_list.append(torch.sigmoid(model_run(model,test_dataset,do_apex=False,batch_size=128))[...,None])\n",
+    "        pred_list_tmp.append(pred_list[-1])\n",
+    "pred161=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred_list_tmp=[]\n",
+    "for num_split in tqdm_notebook(range(3)):\n",
+    "    model_name,version = 'Densenet201_3' , 'classifier_splits'\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_3'\n",
+    "\n",
+    "    model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "    test_dataset=train_dataset=FullHeadDataset(test_df,\n",
+    "                                  test_df.SeriesI.unique(),\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',multi=4)\n",
+    "    for i in tqdm_notebook(range(32),leave=False):\n",
+    "        pred_list.append(torch.sigmoid(model_run(model,test_dataset,do_apex=False,batch_size=128))[...,None])\n",
+    "        pred_list_tmp.append(pred_list[-1])\n",
+    "pred201=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred_list_tmp=[]\n",
+    "for num_split in tqdm_notebook(range(3)):\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'classifier_splits'\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_3'\n",
+    "\n",
+    "    model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "    test_dataset=train_dataset=FullHeadDataset(test_df,\n",
+    "                                  test_df.SeriesI.unique(),\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',multi=4)\n",
+    "    for i in tqdm_notebook(range(32),leave=False):\n",
+    "        pred_list.append(torch.sigmoid(model_run(model,test_dataset,do_apex=False,batch_size=128))[...,None])\n",
+    "        pred_list_tmp.append(pred_list[-1])\n",
+    "predse=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "pred_list_tmp=[]\n",
+    "for num_split in tqdm_notebook(range(5)):\n",
+    "    model_name,version = 'se_resnext101_32x4d' , 'new_splits'\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_3'\n",
+    "\n",
+    "    model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "    test_dataset=train_dataset=FullHeadDataset(test_df,\n",
+    "                                  test_df.SeriesI.unique(),\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',multi=4)\n",
+    "    for i in tqdm_notebook(range(32),leave=False):\n",
+    "        pred_list.append(torch.sigmoid(model_run(model,test_dataset,do_apex=False,batch_size=128))[...,None])\n",
+    "        pred_list_tmp.append(pred_list[-1])\n",
+    "predse=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred_list_tmp=[]\n",
+    "for num_split in tqdm_notebook(range(3)):\n",
+    "    model_name,version = 'se_resnet101' , 'classifier_splits'\n",
+    "    pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'rb')\n",
+    "    features=pickle.load(pickle_file)\n",
+    "    pickle_file.close()\n",
+    "    features=features.reshape(features.shape[0]//8,8,-1)\n",
+    "\n",
+    "    model=ResModelPool(features.shape[-1])\n",
+    "    version=version+'_fullhead_resmodel_pool2_3'\n",
+    "\n",
+    "    model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n",
+    "    test_dataset=train_dataset=FullHeadDataset(test_df,\n",
+    "                                  test_df.SeriesI.unique(),\n",
+    "                                  features,\n",
+    "                                  'SeriesI',\n",
+    "                                  'ImagePositionZ',multi=4)\n",
+    "    for i in tqdm_notebook(range(32),leave=False):\n",
+    "        pred_list.append(torch.sigmoid(model_run(model,test_dataset,do_apex=False,batch_size=128))[...,None])\n",
+    "        pred_list_tmp.append(pred_list[-1])\n",
+    "predseres=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred=torch.cat(pred_list_tmp,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2214, 60, 6])"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pred.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "672"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "672"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(pred_list)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2214, 60, 6, 1])"
+      ]
+     },
+     "execution_count": 34,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pred_list[0].shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred=torch.cat(pred_list,-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "53b42cb038484769ad6e1ff8fc16d405",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=2214), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "53b42cb038484769ad6e1ff8fc16d405",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(IntProgress(value=0, max=2214), HTML(value='')))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(78545,)"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(78545,)"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(78545, 6)"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(78545, 6)"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "images_id_list=[]\n",
+    "dummeys=[]\n",
+    "image_arr=test_df.PatientID.values\n",
+    "ref_arr=test_df.SeriesI.values\n",
+    "order_arr=test_df.ImagePositionZ.values\n",
+    "for s in tqdm_notebook(test_df.SeriesI.unique()):\n",
+    "    dumm=np.zeros(60)\n",
+    "    head_idx = np.where(ref_arr==s)[0]\n",
+    "    sorted_head_idx=head_idx[np.argsort(order_arr[head_idx])]\n",
+    "    images_id_list.append(image_arr[sorted_head_idx])\n",
+    "    dumm[0:head_idx.shape[0]]=1\n",
+    "    dummeys.append(dumm)\n",
+    "image_ids=np.concatenate(images_id_list)\n",
+    "preds=pred.reshape(pred.shape[0]*pred.shape[1],6).numpy()[np.concatenate(dummeys)==1]\n",
+    "\n",
+    "image_ids.shape\n",
+    "\n",
+    "preds.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "len(pred_list)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "images_id_list=[]\n",
+    "dummeys=[]\n",
+    "image_arr=test_df.PatientID.values\n",
+    "ref_arr=test_df.SeriesI.values\n",
+    "order_arr=test_df.ImagePositionZ.values\n",
+    "for s in tqdm_notebook(test_df.SeriesI.unique()):\n",
+    "    dumm=np.zeros(60)\n",
+    "    head_idx = np.where(ref_arr==s)[0]\n",
+    "    sorted_head_idx=head_idx[np.argsort(order_arr[head_idx])]\n",
+    "    images_id_list.append(image_arr[sorted_head_idx])\n",
+    "    dumm[0:head_idx.shape[0]]=1\n",
+    "    dummeys.append(dumm)\n",
+    "image_ids=np.concatenate(images_id_list)\n",
+    "preds=pred.reshape(pred.shape[0]*pred.shape[1],6).numpy()[np.concatenate(dummeys)==1]\n",
+    "\n",
+    "image_ids.shape\n",
+    "\n",
+    "preds.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>ID</th>\n",
+       "      <th>Label</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>ID_000012eaf_any</td>\n",
+       "      <td>0.001279</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>ID_000012eaf_epidural</td>\n",
+       "      <td>0.000084</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>ID_000012eaf_intraparenchymal</td>\n",
+       "      <td>0.000194</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>ID_000012eaf_intraventricular</td>\n",
+       "      <td>0.000031</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>ID_000012eaf_subarachnoid</td>\n",
+       "      <td>0.000170</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>ID_000012eaf_subdural</td>\n",
+       "      <td>0.000943</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>ID_0000ca2f6_any</td>\n",
+       "      <td>0.000943</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>ID_0000ca2f6_epidural</td>\n",
+       "      <td>0.000053</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>ID_0000ca2f6_intraparenchymal</td>\n",
+       "      <td>0.000198</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>ID_0000ca2f6_intraventricular</td>\n",
+       "      <td>0.000020</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>10</th>\n",
+       "      <td>ID_0000ca2f6_subarachnoid</td>\n",
+       "      <td>0.000135</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>11</th>\n",
+       "      <td>ID_0000ca2f6_subdural</td>\n",
+       "      <td>0.000603</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "                               ID     Label\n",
+       "0                ID_000012eaf_any  0.001279\n",
+       "1           ID_000012eaf_epidural  0.000084\n",
+       "2   ID_000012eaf_intraparenchymal  0.000194\n",
+       "3   ID_000012eaf_intraventricular  0.000031\n",
+       "4       ID_000012eaf_subarachnoid  0.000170\n",
+       "5           ID_000012eaf_subdural  0.000943\n",
+       "6                ID_0000ca2f6_any  0.000943\n",
+       "7           ID_0000ca2f6_epidural  0.000053\n",
+       "8   ID_0000ca2f6_intraparenchymal  0.000198\n",
+       "9   ID_0000ca2f6_intraventricular  0.000020\n",
+       "10      ID_0000ca2f6_subarachnoid  0.000135\n",
+       "11          ID_0000ca2f6_subdural  0.000603"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>ID</th>\n",
+       "      <th>Label</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>ID_000012eaf_any</td>\n",
+       "      <td>0.001279</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>ID_000012eaf_epidural</td>\n",
+       "      <td>0.000084</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>ID_000012eaf_intraparenchymal</td>\n",
+       "      <td>0.000194</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>ID_000012eaf_intraventricular</td>\n",
+       "      <td>0.000031</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>ID_000012eaf_subarachnoid</td>\n",
+       "      <td>0.000170</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>ID_000012eaf_subdural</td>\n",
+       "      <td>0.000943</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>ID_0000ca2f6_any</td>\n",
+       "      <td>0.000943</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>ID_0000ca2f6_epidural</td>\n",
+       "      <td>0.000053</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>ID_0000ca2f6_intraparenchymal</td>\n",
+       "      <td>0.000198</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>ID_0000ca2f6_intraventricular</td>\n",
+       "      <td>0.000020</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>10</th>\n",
+       "      <td>ID_0000ca2f6_subarachnoid</td>\n",
+       "      <td>0.000135</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>11</th>\n",
+       "      <td>ID_0000ca2f6_subdural</td>\n",
+       "      <td>0.000603</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "                               ID     Label\n",
+       "0                ID_000012eaf_any  0.001279\n",
+       "1           ID_000012eaf_epidural  0.000084\n",
+       "2   ID_000012eaf_intraparenchymal  0.000194\n",
+       "3   ID_000012eaf_intraventricular  0.000031\n",
+       "4       ID_000012eaf_subarachnoid  0.000170\n",
+       "5           ID_000012eaf_subdural  0.000943\n",
+       "6                ID_0000ca2f6_any  0.000943\n",
+       "7           ID_0000ca2f6_epidural  0.000053\n",
+       "8   ID_0000ca2f6_intraparenchymal  0.000198\n",
+       "9   ID_0000ca2f6_intraventricular  0.000020\n",
+       "10      ID_0000ca2f6_subarachnoid  0.000135\n",
+       "11          ID_0000ca2f6_subdural  0.000603"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(471270, 2)"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(471270, 2)"
+      ]
+     },
+     "execution_count": 41,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "submission_df=get_submission_ids(image_ids,torch.tensor(preds),do_sigmoid=False)\n",
+    "submission_df.head(12)\n",
+    "submission_df.shape\n",
+    "sub_num=51\n",
+    "submission_df.to_csv('/media/hd/notebooks/data/RSNA/submissions/submission{}.csv'.format(sub_num),\n",
+    "                                                                  index=False, columns=['ID','Label'])\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "100%|███████████████████████████████████████| 16.9M/16.9M [01:52<00:00, 157kB/s]\n",
+      "100%|███████████████████████████████████████| 16.9M/16.9M [01:52<00:00, 157kB/s]\n",
+      "Successfully submitted to RSNA Intracranial Hemorrhage DetectionSuccessfully submitted to RSNA Intracranial Hemorrhage Detection"
+     ]
+    }
+   ],
+   "source": [
+    "#!/home/reina/anaconda3/bin/kaggle competitions submit rsna-intracranial-hemorrhage-detection -f /media/hd/notebooks/data/RSNA/submissions/submission51.csv -m \"all models, se weight*2, with post pool_3 tta 64, mean after sigmoid\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}