{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n", " return f(*args, **kwds)\n", "/home/reina/anaconda3/envs/RSNA/lib/python3.6/importlib/_bootstrap.py:219: ImportWarning: can't resolve package from __spec__ or __package__, falling back on __name__ and __path__\n", " return f(*args, **kwds)\n" ] } ], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "\n", "import numpy as np # linear algebra\n", "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", "import os\n", "import datetime\n", "import seaborn as sns\n", "import pydicom\n", "import time\n", "import gc\n", "import operator \n", "from apex import amp \n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.utils.data as D\n", "import torch.nn.functional as F\n", "from sklearn.model_selection import KFold\n", "from tqdm import tqdm, tqdm_notebook\n", "from IPython.core.interactiveshell import InteractiveShell\n", "InteractiveShell.ast_node_interactivity = \"all\"\n", "import warnings\n", "warnings.filterwarnings(action='once')\n", "import pickle\n", "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "from skimage.io import imread,imshow\n", "from helper import *\n", "from apex import amp\n", "import helper\n", "import torchvision.models as models\n", "from torch.optim import Adam\n", "from defenitions import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "SEED = 8153\n", "device=device_by_name(\"Tesla\")\n", "#device=device_by_name(\"RTX\")\n", "#device = \"cpu\"\n", "sendmeemail=Email_Progress(my_gmail,my_pass,to_email,'Densenet161-folds0 results')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_submission(test_df,pred):\n", " epidural_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_epidural','Label':torch.sigmoid(pred[:,0])})\n", " intraparenchymal_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_intraparenchymal','Label':torch.sigmoid(pred[:,1])})\n", " intraventricular_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_intraventricular','Label':torch.sigmoid(pred[:,2])})\n", " subarachnoid_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_subarachnoid','Label':torch.sigmoid(pred[:,3])})\n", " subdural_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_subdural','Label':torch.sigmoid(pred[:,4])})\n", " any_df=pd.DataFrame(data={'ID':'ID_'+test_df.PatientID.values+'_any','Label':torch.sigmoid(pred[:,5])}) \n", " return pd.concat([epidural_df,\n", " intraparenchymal_df,\n", " intraventricular_df,\n", " subarachnoid_df,\n", " subdural_df,\n", " any_df]).sort_values('ID').reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(674510, 15)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(674252, 15)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PatientIDepiduralintraparenchymalintraventricularsubarachnoidsubduralanyPIDStudyISeriesIWindowCenterWindowWidthImagePositionZImagePositionXImagePositionY
063eb1e259000000a449357f62d125e5b20be5c0d1b3['00036', '00036']['00080', '00080']180.199951-125.0-8.000000
12669954a7000000363d5865a20b80c7bf3564d584db['00047', '00047']['00080', '00080']922.530821-156.045.572849
252c9913b10000009c2b4bd73e3634f8cf973274ffc9401504.455000-125.0-115.063000
34e6ff61260000003ae81c2da1390c15c2e5ccad8244['00036', '00036']['00080', '00080']100.000000-99.528.500000
47858edd88000000c1867febc73e81ed3a28e0531b3a40100145.793000-125.0-132.190000
\n", "
" ], "text/plain": [ " PatientID epidural intraparenchymal intraventricular subarachnoid \\\n", "0 63eb1e259 0 0 0 0 \n", "1 2669954a7 0 0 0 0 \n", "2 52c9913b1 0 0 0 0 \n", "3 4e6ff6126 0 0 0 0 \n", "4 7858edd88 0 0 0 0 \n", "\n", " subdural any PID StudyI SeriesI WindowCenter \\\n", "0 0 0 a449357f 62d125e5b2 0be5c0d1b3 ['00036', '00036'] \n", "1 0 0 363d5865 a20b80c7bf 3564d584db ['00047', '00047'] \n", "2 0 0 9c2b4bd7 3e3634f8cf 973274ffc9 40 \n", "3 0 0 3ae81c2d a1390c15c2 e5ccad8244 ['00036', '00036'] \n", "4 0 0 c1867feb c73e81ed3a 28e0531b3a 40 \n", "\n", " WindowWidth ImagePositionZ ImagePositionX ImagePositionY \n", "0 ['00080', '00080'] 180.199951 -125.0 -8.000000 \n", "1 ['00080', '00080'] 922.530821 -156.0 45.572849 \n", "2 150 4.455000 -125.0 -115.063000 \n", "3 ['00080', '00080'] 100.000000 -99.5 28.500000 \n", "4 100 145.793000 -125.0 -132.190000 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df = pd.read_csv(data_dir+'train.csv')\n", "train_df.shape\n", "train_df=train_df[~train_df.PatientID.isin(bad_images)].reset_index(drop=True)\n", "train_df=train_df.drop_duplicates().reset_index(drop=True)\n", "train_df.shape\n", "train_df.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PatientIDepiduralintraparenchymalintraventricularsubarachnoidsubduralanySeriesIPIDStudyIWindowCenterWindowWidthImagePositionZImagePositionXImagePositionY
028fbab7eb0.50.50.50.50.50.5ebfd7e4506cf1b6b1193407cadbb3080158.458000-125.0-135.598000
1877923b8b0.50.50.50.50.50.56d95084e15ad8ea58fa337baa0673080138.729050-125.0-101.797981
2a591477cb0.50.50.50.50.50.58e06b2c9e0ecfb278b0cfe838d54308060.830002-125.0-133.300003
342217c8980.50.50.50.50.50.5e800f419cfe96e31f4c497ac5bad308055.388000-125.0-146.081000
4a130c4d2f0.50.50.50.50.50.5faeb7454f369affa42854e4fbc01308033.516888-125.0-118.689819
\n", "
" ], "text/plain": [ " PatientID epidural intraparenchymal intraventricular subarachnoid \\\n", "0 28fbab7eb 0.5 0.5 0.5 0.5 \n", "1 877923b8b 0.5 0.5 0.5 0.5 \n", "2 a591477cb 0.5 0.5 0.5 0.5 \n", "3 42217c898 0.5 0.5 0.5 0.5 \n", "4 a130c4d2f 0.5 0.5 0.5 0.5 \n", "\n", " subdural any SeriesI PID StudyI WindowCenter WindowWidth \\\n", "0 0.5 0.5 ebfd7e4506 cf1b6b11 93407cadbb 30 80 \n", "1 0.5 0.5 6d95084e15 ad8ea58f a337baa067 30 80 \n", "2 0.5 0.5 8e06b2c9e0 ecfb278b 0cfe838d54 30 80 \n", "3 0.5 0.5 e800f419cf e96e31f4 c497ac5bad 30 80 \n", "4 0.5 0.5 faeb7454f3 69affa42 854e4fbc01 30 80 \n", "\n", " ImagePositionZ ImagePositionX ImagePositionY \n", "0 158.458000 -125.0 -135.598000 \n", "1 138.729050 -125.0 -101.797981 \n", "2 60.830002 -125.0 -133.300003 \n", "3 55.388000 -125.0 -146.081000 \n", "4 33.516888 -125.0 -118.689819 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df = pd.read_csv(data_dir+'test.csv')\n", "test_df.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "split_sid = train_df.PID.unique()\n", "splits=list(KFold(n_splits=3,shuffle=True, random_state=SEED).split(split_sid))\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def my_loss(y_pred,y_true,weights):\n", " if len(y_pred.shape)==len(y_true.shape):\n", " loss = F.binary_cross_entropy_with_logits(y_pred,y_true,weights.repeat(y_pred.shape[0],1))\n", " else:\n", " loss0 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,0],weights.repeat(y_pred.shape[0],1),reduction='none')\n", " loss1 = F.binary_cross_entropy_with_logits(y_pred,y_true[...,1],weights.repeat(y_pred.shape[0],1),reduction='none')\n", " loss = (y_true[...,2]*loss0+(1.0-y_true[...,2])*loss1).mean() \n", " return loss" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class parameter_scheduler():\n", " def __init__(self,model,do_first=['classifier'],num_epoch=1):\n", " self.model=model\n", " self.do_first = do_first\n", " self.num_epoch=num_epoch\n", " def __call__(self,epoch):\n", " if epoch>=self.num_epoch:\n", " for n,p in self.model.named_parameters():\n", " p.requires_grad=True\n", " else:\n", " for n,p in self.model.named_parameters():\n", " p.requires_grad= any(nd in n for nd in self.do_first)\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def get_optimizer_parameters(model,klr):\n", " zero_layer=['conv0','norm0','ws_norm']\n", " param_optimizer = list(model.named_parameters())\n", " num_blocks=4\n", " no_decay=['bias']\n", " optimizer_grouped_parameters=[\n", " {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and any(nd in n for nd in zero_layer))], 'lr':klr*2e-5,'weight_decay': 0.01},\n", " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and any(nd in n for nd in zero_layer)], 'lr':klr*2e-5, 'weight_decay': 0.0}\n", " ]\n", " optimizer_grouped_parameters.extend([\n", " {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('wso' in n))], 'lr':klr*1e-5,'weight_decay': 0.01},\n", " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('wso' in n)], 'lr':klr*1e-5, 'weight_decay': 0.0}\n", " ])\n", " optimizer_grouped_parameters.extend([\n", " {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('classifier' in n))], 'lr':klr*1e-3,'weight_decay': 0.01},\n", " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('classifier' in n)], 'lr':klr*1e-3, 'weight_decay': 0.0}\n", " ])\n", " for i in range(num_blocks):\n", " optimizer_grouped_parameters.extend([\n", " {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('denseblock{}'.format(i+1) in n))], 'lr':klr*(2.0**i)*2e-5,'weight_decay': 0.01},\n", " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('denseblock{}'.format(i+1) in n)], 'lr':klr*(2.0**i)*2e-5, 'weight_decay': 0.0}\n", " ])\n", " optimizer_grouped_parameters.extend([\n", " {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('norm5' in n))], 'lr':klr*1e-4,'weight_decay': 0.01},\n", " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and ('norm5' in n)], 'lr':klr*1e-4, 'weight_decay': 0.0}\n", " ])\n", " return(optimizer_grouped_parameters)\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(449982,)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(224270,)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14062), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7009), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.07458001077127371, 'val_loss': 0.08623016369950948}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14062), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7009), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.07500519356493322, 'val_loss': 0.08447535222967352}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14062), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7009), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.06783495590024316, 'val_loss': 0.08535336141405052}\n", "\n", "0.08447535222967352\n" ] } ], "source": [ "%matplotlib nbagg\n", "\n", "num_split=0\n", "np.random.seed(SEED+num_split)\n", "torch.manual_seed(SEED+num_split)\n", "torch.cuda.manual_seed(SEED+num_split)\n", "#torch.backends.cudnn.deterministic = True\n", "idx_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].index.values\n", "idx_validate = train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].index.values\n", "idx_train.shape\n", "idx_validate.shape\n", "\n", "klr=1\n", "batch_size=32\n", "num_workers=12\n", "num_epochs=3\n", "model_name,version = 'Densenet161_3' , 'classifier__splits'\n", "#model1 = MyDenseNet(models.densenet161(pretrained=True),len(hemorrhage_types),num_channels=3)\n", "model = MyDenseNet(models.densenet161(pretrained=True),\n", " len(hemorrhage_types),\n", " num_channels=3,\n", " drop_out=0.2,\n", " wso=((40,80),(80,200),(600,2800)),\n", " strategy='none',\n", " dont_do_grad=[],\n", " extra_pool=4,\n", " pool_type='max'\n", " )\n", "model.load_state_dict(torch.load(models_dir+models_format.format(model_name,'basic_splits',num_split),map_location=torch.device(device)))\n", "_=model.to(device)\n", "weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n", "loss_func=my_loss\n", "targets_dataset=D.TensorDataset(torch.tensor(train_df[hemorrhage_types].values,dtype=torch.float))\n", "transform=MyTransform(flip=True,zoom=0.05,rotate=15,out_size=512,shift=40)\n", "imagedataset = ImageDataset(train_df,transform=transform.random,base_path=train_images_dir,\n", " window_eq=False,equalize=False,rescale=True)\n", "transform_val=MyTransform(out_size=512)\n", "imagedataset_val = ImageDataset(train_df,transform=transform_val.random,base_path=train_images_dir,\n", " window_eq=False,equalize=False,rescale=True)\n", "combined_dataset=DatasetCat([imagedataset,targets_dataset])\n", "combined_dataset_val=DatasetCat([imagedataset_val,targets_dataset])\n", "#param_s=parameter_scheduler(model,num_epoch=0)\n", "pre_proc = Mixup(0.4,device=device)\n", "optimizer_grouped_parameters=get_optimizer_parameters(model,klr)\n", "#sampling=simple_sampler(train_df[hemorrhage_types].values[idx_train],0.25)\n", "#sampling=sampler(train_df[hemorrhage_types].values[idx_train],5,[0,0,0,0,0,2],train_df.SeriesI.values[idx_train])\n", "sample_ratio= 1.0 #1.003*float(sampling().shape[0])/idx_train.shape[0]\n", "train_dataset=D.Subset(combined_dataset,idx_train)\n", "validate_dataset=D.Subset(combined_dataset_val,idx_validate)\n", "num_train_optimization_steps = num_epochs*(sample_ratio*len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n", "fig,ax = plt.subplots(figsize=(10,7))\n", "gr=loss_graph(fig,ax,num_epochs,int(num_train_optimization_steps/num_epochs)+1,limits=(0.05,0.2))\n", "sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=num_epochs,tau=1.5)\n", "optimizer = BertAdam(optimizer_grouped_parameters,lr=klr*1e-3,schedule=sched)\n", "model, optimizer = amp.initialize(model, optimizer, opt_level=\"O1\",verbosity=0)\n", "history,best_model= model_train(model,\n", " optimizer,\n", " train_dataset,\n", " batch_size,\n", " num_epochs,\n", " loss_func,\n", " weights=weights,\n", " do_apex=False,\n", " model_apexed=True,\n", " validate_dataset=validate_dataset,\n", " param_schedualer=None,\n", " weights_data=None,\n", " metric=None,\n", " return_model=True,\n", " num_workers=num_workers,\n", " sampler=None,\n", " pre_process = None,\n", " graph=gr,\n", " call_progress=sendmeemail)\n", "\n", "torch.save(model.state_dict(), models_dir+models_format.format(model_name,version,num_split))\n", "torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version+'_best',num_split))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(449019,)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(225233,)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14032), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7039), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.08000715392810209, 'val_loss': 0.0840423224571725}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14032), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7039), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.07516958098892379, 'val_loss': 0.08315527216200035}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14032), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib nbagg\n", "\n", "num_split=1\n", "np.random.seed(SEED+num_split)\n", "torch.manual_seed(SEED+num_split)\n", "torch.cuda.manual_seed(SEED+num_split)\n", "#torch.backends.cudnn.deterministic = True\n", "idx_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].index.values\n", "idx_validate = train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].index.values\n", "idx_train.shape\n", "idx_validate.shape\n", "\n", "klr=1\n", "batch_size=32\n", "num_workers=12\n", "num_epochs=3\n", "model_name,version = 'Densenet161_3' , 'classifier__splits'\n", "#model1 = MyDenseNet(models.densenet161(pretrained=True),len(hemorrhage_types),num_channels=3)\n", "model = MyDenseNet(models.densenet161(pretrained=True),\n", " len(hemorrhage_types),\n", " num_channels=3,\n", " drop_out=0.2,\n", " wso=((40,80),(80,200),(600,2800)),\n", " strategy='none',\n", " dont_do_grad=[],\n", " extra_pool=4,\n", " pool_type='max'\n", " )\n", "model.load_state_dict(torch.load(models_dir+models_format.format(model_name,'basic_splits',num_split),map_location=torch.device(device)))\n", "_=model.to(device)\n", "weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n", "loss_func=my_loss\n", "targets_dataset=D.TensorDataset(torch.tensor(train_df[hemorrhage_types].values,dtype=torch.float))\n", "transform=MyTransform(flip=True,zoom=0.05,rotate=15,out_size=512,shift=40)\n", "imagedataset = ImageDataset(train_df,transform=transform.random,base_path=train_images_dir,\n", " window_eq=False,equalize=False,rescale=True)\n", "transform_val=MyTransform(out_size=512)\n", "imagedataset_val = ImageDataset(train_df,transform=transform_val.random,base_path=train_images_dir,\n", " window_eq=False,equalize=False,rescale=True)\n", "combined_dataset=DatasetCat([imagedataset,targets_dataset])\n", "combined_dataset_val=DatasetCat([imagedataset_val,targets_dataset])\n", "#param_s=parameter_scheduler(model,num_epoch=0)\n", "pre_proc = Mixup(0.4,device=device)\n", "optimizer_grouped_parameters=get_optimizer_parameters(model,klr)\n", "#sampling=simple_sampler(train_df[hemorrhage_types].values[idx_train],0.25)\n", "#sampling=sampler(train_df[hemorrhage_types].values[idx_train],5,[0,0,0,0,0,2],train_df.SeriesI.values[idx_train])\n", "sample_ratio= 1.0 #1.003*float(sampling().shape[0])/idx_train.shape[0]\n", "train_dataset=D.Subset(combined_dataset,idx_train)\n", "validate_dataset=D.Subset(combined_dataset_val,idx_validate)\n", "num_train_optimization_steps = num_epochs*(sample_ratio*len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n", "fig,ax = plt.subplots(figsize=(10,7))\n", "gr=loss_graph(fig,ax,num_epochs,int(num_train_optimization_steps/num_epochs)+1,limits=(0.05,0.2))\n", "sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=num_epochs,tau=1.5)\n", "optimizer = BertAdam(optimizer_grouped_parameters,lr=klr*1e-3,schedule=sched)\n", "model, optimizer = amp.initialize(model, optimizer, opt_level=\"O1\",verbosity=0)\n", "history,best_model= model_train(model,\n", " optimizer,\n", " train_dataset,\n", " batch_size,\n", " num_epochs,\n", " loss_func,\n", " weights=weights,\n", " do_apex=False,\n", " model_apexed=True,\n", " validate_dataset=validate_dataset,\n", " param_schedualer=None,\n", " weights_data=None,\n", " metric=None,\n", " return_model=True,\n", " num_workers=num_workers,\n", " sampler=None,\n", " pre_process = None,\n", " graph=gr,\n", " call_progress=sendmeemail)\n", "\n", "torch.save(model.state_dict(), models_dir+models_format.format(model_name,version,num_split))\n", "torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version+'_best',num_split))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(449503,)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(224749,)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ee1c929344d84520a1f9739532f6f12f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1fd034336d064c10840656fc46f077a9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14047), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "99cc6bb36f974341abe28b2d480aabf3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7024), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.07909837087169497, 'val_loss': 0.08144222121488978}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ba18a7b165984084b58047ca4cfc6c82", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14047), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f204e8ad53ba45438547fdc7d8de013e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7024), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.07260369182006554, 'val_loss': 0.08076523827359182}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d81f0f8ab034a92aff069495e9fcd7b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=14047), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ea2203b9578c4023b7f76df27dde92b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=7024), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.0691204885223411, 'val_loss': 0.080848135731886}\n", "0.08076523827359182\n" ] } ], "source": [ "%matplotlib nbagg\n", "\n", "num_split=2\n", "np.random.seed(SEED+num_split)\n", "torch.manual_seed(SEED+num_split)\n", "torch.cuda.manual_seed(SEED+num_split)\n", "#torch.backends.cudnn.deterministic = True\n", "idx_train = train_df[train_df.PID.isin(set(split_sid[splits[num_split][0]]))].index.values\n", "idx_validate = train_df[train_df.PID.isin(set(split_sid[splits[num_split][1]]))].index.values\n", "idx_train.shape\n", "idx_validate.shape\n", "\n", "klr=1\n", "batch_size=32\n", "num_workers=12\n", "num_epochs=3\n", "model_name,version = 'Densenet161_3' , 'classifier_splits'\n", "#model1 = MyDenseNet(models.densenet161(pretrained=True),len(hemorrhage_types),num_channels=3)\n", "model = MyDenseNet(models.densenet161(pretrained=True),\n", " len(hemorrhage_types),\n", " num_channels=3,\n", " drop_out=0.2,\n", " wso=((40,80),(80,200),(600,2800)),\n", " strategy='none',\n", " dont_do_grad=[],\n", " extra_pool=4,\n", " pool_type='max'\n", " )\n", "model.load_state_dict(torch.load(models_dir+models_format.format(model_name,'basic_splits',num_split),map_location=torch.device(device)))\n", "_=model.to(device)\n", "weights = torch.tensor([1.,1.,1.,1.,1.,2.],device=device)\n", "loss_func=my_loss\n", "targets_dataset=D.TensorDataset(torch.tensor(train_df[hemorrhage_types].values,dtype=torch.float))\n", "transform=MyTransform(flip=True,zoom=0.05,rotate=15,out_size=512,shift=40)\n", "imagedataset = ImageDataset(train_df,transform=transform.random,base_path=train_images_dir,\n", " window_eq=False,equalize=False,rescale=True)\n", "transform_val=MyTransform(out_size=512)\n", "imagedataset_val = ImageDataset(train_df,transform=transform_val.random,base_path=train_images_dir,\n", " window_eq=False,equalize=False,rescale=True)\n", "combined_dataset=DatasetCat([imagedataset,targets_dataset])\n", "combined_dataset_val=DatasetCat([imagedataset_val,targets_dataset])\n", "#param_s=parameter_scheduler(model,num_epoch=0)\n", "pre_proc = Mixup(0.4,device=device)\n", "optimizer_grouped_parameters=get_optimizer_parameters(model,klr)\n", "#sampling=simple_sampler(train_df[hemorrhage_types].values[idx_train],0.25)\n", "#sampling=sampler(train_df[hemorrhage_types].values[idx_train],5,[0,0,0,0,0,2],train_df.SeriesI.values[idx_train])\n", "sample_ratio= 1.0 #1.003*float(sampling().shape[0])/idx_train.shape[0]\n", "train_dataset=D.Subset(combined_dataset,idx_train)\n", "validate_dataset=D.Subset(combined_dataset_val,idx_validate)\n", "num_train_optimization_steps = num_epochs*(sample_ratio*len(train_dataset)//batch_size+int(len(train_dataset)%batch_size>0))\n", "fig,ax = plt.subplots(figsize=(10,7))\n", "gr=loss_graph(fig,ax,num_epochs,int(num_train_optimization_steps/num_epochs)+1,limits=(0.05,0.2))\n", "sched=WarmupExpCosineWithWarmupRestartsSchedule( t_total=num_train_optimization_steps, cycles=num_epochs,tau=1.5)\n", "optimizer = BertAdam(optimizer_grouped_parameters,lr=klr*1e-3,schedule=sched)\n", "model, optimizer = amp.initialize(model, optimizer, opt_level=\"O1\",verbosity=0)\n", "history,best_model= model_train(model,\n", " optimizer,\n", " train_dataset,\n", " batch_size,\n", " num_epochs,\n", " loss_func,\n", " weights=weights,\n", " do_apex=False,\n", " model_apexed=True,\n", " validate_dataset=validate_dataset,\n", " param_schedualer=None,\n", " weights_data=None,\n", " metric=None,\n", " return_model=True,\n", " num_workers=num_workers,\n", " sampler=None,\n", " pre_process = None,\n", " graph=gr,\n", " call_progress=sendmeemail)\n", "\n", "torch.save(model.state_dict(), models_dir+models_format.format(model_name,version,num_split))\n", "torch.save(best_model.state_dict(), models_dir+models_format.format(model_name,version+'_best',num_split))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b8cf249e3252435a9b9d69dfd788d336", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=819), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "model_name,version, num_split = 'Densenet161_3' , 'classifier_splits',0\n", "model = MyDenseNet(models.densenet161(pretrained=True),\n", " len(hemorrhage_types),\n", " num_channels=3,\n", " drop_out=0.2,\n", " wso=((40,80),(80,200),(600,2800)),\n", " strategy='none',\n", " dont_do_grad=[],\n", " extra_pool=4,\n", " pool_type='max',\n", " return_features=True\n", " )\n", "model.load_state_dict(torch.load(models_dir+models_format.format(model_name,version,num_split),map_location=torch.device(device)))\n", "_=model.to(device)\n", "transform=MyTransform(flip=True,zoom=0.05,rotate=15,out_size=512,shift=40)\n", "transform_val=MyTransform(out_size=512)\n", "indexes=np.arange(test_df.shape[0]).repeat(8)\n", "imagedataset_test=D.Subset(ImageDataset(test_df,transform=transform.random,base_path=test_images_dir,\n", " window_eq=False,equalize=False,rescale=True),indexes)\n", "pred,features = model_run(model,imagedataset_test,do_apex=True,batch_size=96,num_workers=18)\n", "pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'features_test',num_split),'wb')\n", "pickle.dump(features,pickle_file,protocol=4)\n", "pickle_file.close()\n", "pickle_file=open(outputs_dir+outputs_format.format(model_name,version,'predictions_test',num_split),'wb')\n", "pickle.dump(pred,pickle_file,protocol=4)\n", "pickle_file.close()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDLabel
0ID_000012eaf_any0.002726
1ID_000012eaf_epidural0.000023
2ID_000012eaf_intraparenchymal0.000494
3ID_000012eaf_intraventricular0.000140
4ID_000012eaf_subarachnoid0.000619
5ID_000012eaf_subdural0.001694
6ID_0000ca2f6_any0.004521
7ID_0000ca2f6_epidural0.000045
8ID_0000ca2f6_intraparenchymal0.000900
9ID_0000ca2f6_intraventricular0.000160
10ID_0000ca2f6_subarachnoid0.001351
11ID_0000ca2f6_subdural0.001957
\n", "
" ], "text/plain": [ " ID Label\n", "0 ID_000012eaf_any 0.002726\n", "1 ID_000012eaf_epidural 0.000023\n", "2 ID_000012eaf_intraparenchymal 0.000494\n", "3 ID_000012eaf_intraventricular 0.000140\n", "4 ID_000012eaf_subarachnoid 0.000619\n", "5 ID_000012eaf_subdural 0.001694\n", "6 ID_0000ca2f6_any 0.004521\n", "7 ID_0000ca2f6_epidural 0.000045\n", "8 ID_0000ca2f6_intraparenchymal 0.000900\n", "9 ID_0000ca2f6_intraventricular 0.000160\n", "10 ID_0000ca2f6_subarachnoid 0.001351\n", "11 ID_0000ca2f6_subdural 0.001957" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "(471270, 2)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "submission_df=get_submission(test_df,pred[(np.arange(pred.shape[0]).reshape(pred.shape[0]//8,8))].mean(1))\n", "#submission_df=get_submission(test_df,pred)\n", "submission_df.head(12)\n", "submission_df.shape\n", "sub_num=25\n", "submission_df.to_csv('/media/hd/notebooks/data/RSNA/submissions/submission{}.csv'.format(sub_num),\n", " index=False, columns=['ID','Label'])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }