[de45a9]: / DL_Genomics_v8_resnet-fastai.ipynb

Download this file

1470 lines (1469 with data), 119.8 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep learning in genomics - Custom ResNet model with PyTorch and fastai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is based on the [jupyter notebook](https://nbviewer.jupyter.org/github/abidlabs/deep-learning-genomics-primer/blob/master/A_Primer_on_Deep_Learning_in_Genomics_Public.ipynb) from the publication [\"A primer on deep learning in genomics\"](https://www.nature.com/articles/s41588-018-0295-5) but uses the [fastai](https://www.fast.ai) library based on [PyTorch](https://pytorch.org)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.38.dev0'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# fastai version\n",
    "__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loading data from the web, generating dthe ataframe, and saving it to disk is carried out in [Basic model with PyTorch jupyter notebook](https://nbviewer.jupyter.org/github/MicPie/genomics/blob/master/DL_Genomics_v8_basic-pytorch.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data frame setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_df = pd.read_csv('seq_df.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add column NoTarget which is not(Target)\n",
    "seq_df['NotTarget'] = seq_df['Target'].apply(lambda x: int(not(bool(x))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "      <th>NotTarget</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Sequences  Target  NotTarget\n",
       "0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...       0          1\n",
       "1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...       0          1\n",
       "2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...       0          1\n",
       "3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...       1          0\n",
       "4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...       1          0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fastai data object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Data encoding test (incorporated into \"open_seq_image\" function in the next section)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup class instance to encode the four different bases to integer values (1D)\n",
    "int_enc = LabelEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "one_hot_enc = OneHotEncoder(categories=[range(4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_enc = []\n",
    "\n",
    "for s in seq:\n",
    "    enc = int_enc.fit_transform(list(s)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    seq_enc.append(enc.toarray()) # export sparse matrix to np array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_enc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 462,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., 1., ..., 0., 1., 0., 0.],\n",
       "        [1., 1., 0., 0., ..., 1., 0., 1., 1.],\n",
       "        [0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., ..., 0., 0., 0., 0.]]), (4, 50))"
      ]
     },
     "execution_count": 462,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_enc[0].T, seq_enc[0].T.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAMAAAD7wpwzAAADAFBMVEUAAAABAQECAgIDAwMEBAQFBQUGBgYHBwcICAgJCQkKCgoLCwsMDAwNDQ0ODg4PDw8QEBARERESEhITExMUFBQVFRUWFhYXFxcYGBgZGRkaGhobGxscHBwdHR0eHh4fHx8gICAhISEiIiIjIyMkJCQlJSUmJiYnJycoKCgpKSkqKiorKyssLCwtLS0uLi4vLy8wMDAxMTEyMjIzMzM0NDQ1NTU2NjY3Nzc4ODg5OTk6Ojo7Ozs8PDw9PT0+Pj4/Pz9AQEBBQUFCQkJDQ0NERERFRUVGRkZHR0dISEhJSUlKSkpLS0tMTExNTU1OTk5PT09QUFBRUVFSUlJTU1NUVFRVVVVWVlZXV1dYWFhZWVlaWlpbW1tcXFxdXV1eXl5fX19gYGBhYWFiYmJjY2NkZGRlZWVmZmZnZ2doaGhpaWlqampra2tsbGxtbW1ubm5vb29wcHBxcXFycnJzc3N0dHR1dXV2dnZ3d3d4eHh5eXl6enp7e3t8fHx9fX1+fn5/f3+AgICBgYGCgoKDg4OEhISFhYWGhoaHh4eIiIiJiYmKioqLi4uMjIyNjY2Ojo6Pj4+QkJCRkZGSkpKTk5OUlJSVlZWWlpaXl5eYmJiZmZmampqbm5ucnJydnZ2enp6fn5+goKChoaGioqKjo6OkpKSlpaWmpqanp6eoqKipqamqqqqrq6usrKytra2urq6vr6+wsLCxsbGysrKzs7O0tLS1tbW2tra3t7e4uLi5ubm6urq7u7u8vLy9vb2+vr6/v7/AwMDBwcHCwsLDw8PExMTFxcXGxsbHx8fIyMjJycnKysrLy8vMzMzNzc3Ozs7Pz8/Q0NDR0dHS0tLT09PU1NTV1dXW1tbX19fY2NjZ2dna2trb29vc3Nzd3d3e3t7f39/g4ODh4eHi4uLj4+Pk5OTl5eXm5ubn5+fo6Ojp6enq6urr6+vs7Ozt7e3u7u7v7+/w8PDx8fHy8vLz8/P09PT19fX29vb39/f4+Pj5+fn6+vr7+/v8/Pz9/f3+/v7////isF19AAAAO0lEQVR4nGWNSQoAMAgD8/9PT6loIq0HMauShOR9DwphIkO5eALbEI1xp5DUoxYnNtf3t1rci9mGy30A2+0xz9q2+b0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=P size=50x4 at 0x1A20462668>"
      ]
     },
     "execution_count": 311,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PIL.Image.fromarray(seq_enc[0].T.astype('uint8')*255).convert('P')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup custom fastai data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open sequence image function\n",
    "def open_seq_image(seq:str, cls:type=Image)->Image:\n",
    "    \"Return `Image` object created from sequence string `seq`.\"\n",
    "    \n",
    "    int_enc = LabelEncoder() # setup class instance to encode the four different bases to integer values (1D)\n",
    "    one_hot_enc = OneHotEncoder(categories=[range(4)]) # setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "    \n",
    "    enc = int_enc.fit_transform(list(seq)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    enc = enc.toarray().T # export sparse matrix to np array\n",
    "    \n",
    "    # https://stackoverflow.com/questions/22902040/convert-black-and-white-array-into-an-image-in-python\n",
    "    x = PIL.Image.fromarray(enc.astype('uint8')).convert('P')\n",
    "    x = pil2tensor(x,np.float32)\n",
    "    \n",
    "    # optional functions not needed\n",
    "    #x = x.view(4,-1) # remove first dimension\n",
    "    #x = x.expand(3, 4, 50) # expand to 3 channel image\n",
    "    \n",
    "    return cls(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD80/8AhrrUr/8A0rxT8B/h1q95B/pGltNotzaWGm6iflfUYdIsrmDSvPkjjsopUeze3nTTrfzYZGadpvuj4Jf8EwP2bf2gf2wL39kD4j6n4uuYLDwj4g1u38e/8JEz68IdG8WXnhax0rMqvZ/YlsbG3faLUTCVcJMkAWBSivm/HbGYzgrIq2IyKrLDTjhcTUThKSanTdBQkuZztbnleKXLK95xqOMXD9sPnT9kmOD9rf8A4Wbb6rcav4Ls/gx8Ita+Inwp0/wV4s1VYfCuqaf5LmKyF/dXRhguriVbqcg+eZreIxTQqHR+R+CXxp8V/tC+Mr34X/FXTdI1fwnpnhHxBrWlaFqWlx3s2nWmjaPearZaJaapd+bqtnpYkso4TbwXkbeTLOFkWSZ5SUV9lXwmGhmfENJRVsJDCOj3pSq0HOpKEvjUpz9+cudylJuTbbk2Hrn7DnwE+DH7cXg3xd4t1X4aaR8PovBvi7wX4fsNJ8DWguIbqPxRrA0fUbm4l1v+0Ll50tdptsTLHazJ58UayvI7/On/AA2R480r/ibfDvwF4R8F+IW+SXxP4N0uXT7mWE/vXiaCOb7H/wAfuL+OUW4mtLmK3+ySW0VpaQwFFHCuFoZnx5n2W4xOrQw06KpQnKUoxU6VKUk+ab51KUm2qjqpt+SSA/4ba+Mn/QmfCL/xH/wf/wDKqiiiv0n/AFT4Y/6A6f3S/wDlgH//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHNJREFUKJFjdGEM+c8ABTufXWBwlzJgQAe4xNHVMDAwoKiD6cMmRwlANhdmJuPf5yr/cVlCjAdIdQAMELIP3bHobkEPJEZYjKBLYLMMV6gji+GykJDHcHkOm9uwqWNETlqEDEF2LD7LiUlK+JIxTC8pKQIA7jhk1J8wofAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test open sequence image function\n",
    "open_seq_image('CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGACACC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SeqItemList(ImageItemList):\n",
    "    \"Sequence Item List\"\n",
    "    _bunch,_square_show = ImageDataBunch,True\n",
    "    def __post_init__(self):\n",
    "        super().__post_init__()\n",
    "        self.sizes={}\n",
    "    \n",
    "    def open(self, seq): return open_seq_image(seq)\n",
    "    \n",
    "    def get(self, i):\n",
    "        seq = self.items[i][0]\n",
    "        res = self.open(seq)\n",
    "        return res\n",
    "    \n",
    "    @classmethod\n",
    "    def import_from_df(cls, df:DataFrame, cols:IntsOrStrs=0, **kwargs)->'ItemList':\n",
    "        \"Get the sequences in `col` of `df` and return cls and df.\"\n",
    "        return cls(items=df[cols].values, xtra=df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (SeqItemList.import_from_df(seq_df, ['Sequences'])\n",
    "        .random_split_by_pct(valid_pct=0.25)\n",
    "        #.split_by_idxs(range(1500), range(1500,2000))\n",
    "        .label_from_df(['Target', 'NotTarget'])\n",
    "        .databunch(bs=bs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verify data object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "\n",
       "Train: LabelList\n",
       "y: MultiCategoryList (1500 items)\n",
       "[MultiCategory NotTarget, MultiCategory NotTarget, MultiCategory Target, MultiCategory Target, MultiCategory NotTarget]...\n",
       "Path: .\n",
       "x: SeqItemList (1500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "\n",
       "Valid: LabelList\n",
       "y: MultiCategoryList (500 items)\n",
       "[MultiCategory Target, MultiCategory NotTarget, MultiCategory NotTarget, MultiCategory NotTarget, MultiCategory Target]...\n",
       "Path: .\n",
       "x: SeqItemList (500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "\n",
       "Test: None"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, ['Target', 'NotTarget'])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check classes\n",
    "data.c, data.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_dl.batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check data points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4C+Nvxp8V/s9eMrL4bfCrTdI0vw5P4R8P+JtP0y00uOCbTdU1PR7PWIrqK/h2X8k+n3d7J9huJ7maeCENA0ksM91HOfBL45fET46eMr34fa/caRpeh2/hHxB4pu/D+ieGNPj0e+1rRtHvNUtL9tJkgk06CdjYWlnM9tbQme0SWJ9xuLl5iivpv7Jyz/UX+0XRi6/subnavLm57czbbvK32mua/vcylq/2w679kmSf9t7/AIWbY/tCW+kX2lfCr4Ra18QfD2i6J4T0rSIZLvTfJ8vTGksbWKe30uU3E8k1paS26vNPLcApPI8zeRf8NrfGeT/iY3K6QuuTfNqfi3TbA6ZrGqSR/vbOe7vLB4Jbie3vcagJnYyXF3HBJeNdi2tkhKKMlynLMZxhnOArUYyo4f6p7KFvdh7XD1KlXlV7fvJxjOd+ZSlFNq6uB1/xJ+NPiv4I+Dfh/wCJPgbpukeFtN+IfhGbxHfeFk0uPVbDSbtdY1LSLiOwbVvtdzbwXlrpdut5EZmW8Rnhn8y28u3TkP8AhtH4rP8A6be+HPCOpah/Dd+IvC8Gr20Gfkf7Ppt+J9Nst0EOnWo+zWsXlW2kWUMXlIJhMUV7OQcPZHjssVbEYeM5udVOTTcmo168I80uZSlaMIxvKTbUVzOTu5B1/wC0B4y+BP7PXx58bfALRv2Kvh1r1n4H8Xal4ftdc8Qa14o+36jHZXUlstzc/ZdZgg8+QRh38mGKPezbI0XCgooo4cyHLsfw7gsViPaSqVKNGcn7fEaynRpTk9MQlrKUnoktdElZIP/Z\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHRJREFUKJFjdGEM+c+ABHY+u8DgLmWAwYbxGRgYUMSw6SMkDhPDZj42u3GpRwaMMI8gOxLdEHTHIxsMA8h8bA7HZjY2gK4Wn3nIcox/n6v8JyYkCTkAl6MIhTi+WMZmFq4YZnRhDPlPDUeiyxHjOGLNI0YfAGvwc9QnBcXaAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 2\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory Target"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4w/ZJl1L9tn/hZumfHvxFq+o6H8MPhFrXxHs/Dy6zctDrutaf5LH+0Z5pJLuaC4mu9QuJI454jFNqd89q1qbqXcfbvFcf7eX/AA7us/Her2/gKX4u/wDCtdR8maMPqHh9NU/s+3imtCh05p7cvdXsVwtosov764u2ZpmV0KK5MZy0uMOIsBBJUMJgo1qMLR5aVV4erUdSEeWym5xhLmd3eENUoRUf2w674o/Aj4U/C/wb8UviP/wiv/CSeIfhH4u8T+HfDeq+ML6fUftVtoGseHNI01ryCR/s8+y28RYMPlLbN/Y2mp5Ii+2xXvkf7P8A8ZNZ+P3x58E/s6SeDPCPhDw94/8AF2m+FPELeDfCVnb3Muh3d1HbpZNJPHN532fznnjuJxLdPcpb3FxNcS2Vm9uUUcKc2acD5pmGMbqVqFOUqcpOT5GsBCsnGPMoK1Zuol7NpSs0lZJByH/DZHjyw/e+FvAXhHTZ/wDUvPc6XLq/2iwHyx6Zcpq013He2UMUOmxwW9ysqW39kWksHlTm5muOv0X40+K/E/wG8SfHHxfpuka/F4S8XaJ4T8P+EvE+lx6tYWPh/VLXVbiXS4Li983UbaCEaPbRWrQ3kctqk100MiS3EkpKK/Ss74eyTB0KEqNCKlKtQg3rzcs6jjJczk5K8dLxcXbRNLQDkP8Aht747T/vtZ034daveP8ANdat4g+DHhfUb+9kP3p7m7utOknup3OWeaZ3kkdmZ2ZmJJRRXsf6p8Mf9AdP7mvwVRJeiSS2SSSSD//Z\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHNJREFUKJFj/Ptc5b+7lAHDzmcXGJBpdIAsz8DAgKIGlx5cAF09MXx0O9HVMrowhvwn5CBs4sQ4Hpvn8QUYshi+AEOX2/nsAsQjxHoAn+PRHYjNwfgcj8uDxNjJwMCAGiPICmA+JmQQesgQ0oPLIcQCXHoAnhaF1FMzixMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 3\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory Target"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAETCAYAAAA79nyeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADwxJREFUeJzt3WuM7HV9x/HPF08ajCjgpRDwkmgsRh9UaxoTRUxqqPUClVC8P/DBMdSEtmKtMa3FaLRpWlOtwYREHzSpjVFSa4xojLaxNqkXitBokyKlXqgHVBQOpSKC59cH8ycOyyJzZmdnhu95vZLN2Z3/zP+yu/M77/nPb2ZrjBEAgE6O2/QOAACsmsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwuI+qun3u40hV3TH39avXvC/HV9Woqseuc7vA/jPWsJ8ObHoH2D5jjBPu+byqvpXk4Bjjc8usq6oOjDHuXtW+AX0Ya9hPzuBw1KrqOVX15ao6XFWHquo9VXVgWnbPo6DXV9X1Sb4+Xf7iqrquqm6tqvdW1Zeq6jVz67ywqq6tqh9V1RVVdfq06AvTv9dOj+peutaDBTbGWMNeCByWcVeSi5I8Mslzk5yT5OCO67wkyTOTPKOqTk3ykSQXJ3lMkkPTsiRJVb0iyRum9ZyS5OokH5oWnzX9e8YY44Qxxsf344CArWSsYWkCh6M2xvjKGOPKMcbPxhjXJ/lgkuftuNq7xhi3jjHuSHJukivHGJ8cY9yV5N1Jbpm77oVJ3jnG+Ma0/O1JzqyqU9ZwOMCWMtawFwKHo1ZVT62qT1fV96rqtiSXJHn0jqvdMPf5afNfjzGOJPnu3PInJLlsOqV8a5IfJLk7icl+cAwz1rAXAodlfCDJV5M8aYzxiCTvSFI7rjP/Z+pvzNwAUlXHJTl9bvkNSV47xjhp7uOhY4yrdqwHOLYYa1iawGEZD09yeIxxe1U9LcnrHuD6n0jyrKp60TRB8I1JTp5bflmSt1bVGUlSVSdX1flJMsa4M8nhJE9c9UEAW89Yw9IEDsu4OMnBqro9yfszm9R3v8YYNyZ5ZZL3Jbk5s0dYX0ty57T8w0kuTfKx6TT0NUnOnlvFJUkun04rn7viYwG2l7GGpdUYzsqxXtMjq5uSnDPG+OKm9wfoyVhzbHMGh7WoqhdW1YlVdXyStyX5cZKrNrxbQDPGGu4hcFiXs5J8M8n3kzw/yXljjJ9udpeAhow1JPEUFQDQkDM4AEA7AgcAaGdjf038yE1PXui5sRec9vT7XPaZQ9csdL3d7LztordbZF33Zy/b2IS9fH/3so3dLPtz3e22XX9ei1r0+Be12/fps0cu3/kmbBt39nEXbO3z8Ov4nVz1/XmRfV7HeLGOsXs3y253U+PqJn531nGsi4w1zuAAAO0IHACgHYEDALSzsZeJr/p58VU+V7ibdTyPud/7ssq5S/d3203MrVn0tts0j2i/v0+bmkdkDs7PbeL3by/3q93s91hjXtz2HP8qf2brGN/NwQEAjkkCBwBoR+AAAO0IHACgna2fZLyOibHrXtd+WHbi17Yf1yZsy5tl3d9td7PImxquY5Lrcadet3WTjFf9pqKL2sT9aJt+dzdhW/Z3U29M2OFns+j1TDIGAI5JAgcAaEfgAADtCBwAoJ2tn2Tc1bZMnt6LTe3HJra7Ld/zbfdgfifj/f4Zd7jP78WmJkB3+Gvqe9F1gvIiL2hwBgcAaEfgAADtCBwAoB2BAwC0s7FJxgAA+8UZHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKAdgQMAtCNwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7Aof7qKrb5z6OVNUdc1+/es37cnxVjap67Dq3Cyyvqr5VVd+rqofNXXawqj6/wG0/X1UHp88fv2M8GlX1f3NfP3cfD2O3fXtKVd29zm2yvAOb3gG2zxjjhHs+r6pvJTk4xvjcMuuqqgNjDAMCHHsOJPmDJH+27ArGGN9JMj8ejSS/Osb4r2XWZzw6tjiDw1GrqudU1Zer6nBVHaqq91TVgWnZPWdcXl9V1yf5+nT5i6vquqq6tareW1VfqqrXzK3zwqq6tqp+VFVXVNXp06IvTP9eOz1ie+laDxZY1l8meVNVnbRzQVU9u6qunMaQK6vq2dPl70ry3CSXTvf3Sx9oI1V1XlX9e1XdVlXfrqo/nlv2lKq6u6peV1U3JPnUdPnBqvpOVf2gqt5cVTdV1ZnTsodU1Z9W1X9X1c1V9Xdzx/CFJA+ZO4P0jL1+k9g/Aodl3JXkoiSPzGwwOifJwR3XeUmSZyZ5RlWdmuQjSS5O8pgkh6ZlSZKqekWSN0zrOSXJ1Uk+NC0+a/r3jDHGCWOMj+/HAQEr929JPp/kTfMXVtUjk1yR5H1JHpXkr5JcUVWPGmP8SZJ/SXLRdH+/aIHt3JbkVUlOSnJeZlH1W3PLH5LkWUnOSPLbVfX0aZsvS/LY6ePRc9f/oyS/meTMadldSd4zLTsryc+mfTthjHH1It8INkPgcNTGGF8ZY1w5xvjZGOP6JB9M8rwdV3vXGOPWMcYdSc5NcuUY45NjjLuSvDvJLXPXvTDJO8cY35iWvz3JmVV1yhoOB9g/lyT5vap6zNxlL05y3Rjjb8cYd48xPpzkPzN7gHPUxhj/OMb4jzHGkTHGV5N8NPcdjy4ZY/x4Go9eluTvxxhfGmPcmeStuff/hRcmecsY49AY4yeZjUcvr6paZv/YHIHDUauqp1bVp6dJhLdlNog9esfVbpj7/LT5r8cYR5J8d275E5JcNj19dWuSHyS5O7NHT8CD1Bjj60k+meQtcxefluTbO6767SSnZwnTU+b/PD3ddDjJa3Pv8ejIGOPQju3Pj0e3JTk8rauSPC7Jp+bGo6sz+7/yUcvsH5sjcFjGB5J8NcmTxhiPSPKOJDsf3Yy5z2/MXKxU1XG592B2Q5LXjjFOmvt46Bjjqh3rAR583pbkdfn5ff5QZg9q5j0+P3/Qc7T3+Y9m9hT448YYJyb5m9x7PNq5vp3j0SOSnJgkY4wx7cdv7BiPjh9j3LzEvrFBAodlPDzJ4THG7VX1tMwGr1/kE0meVVUvmiYjvzHJyXPLL0vy1qo6I0mq6uSqOj9JplPIh5M8cdUHAey/6RVPH0ny+9NFn0ryK1X1qqo6UFUvT/LUzM70JMn3suD9fTrjckKSH44xfjJNVr7gAW720STnV9WvV9UvZfYA7cjc8suS/HlVPW7axi9X1T1Pn30/s0nGj19k/9gsgcMyLk5ysKpuT/L+zAav+zXGuDHJKzObVHhzZo+evpbkzmn5h5NcmuRj01Ne1yQ5e24VlyS5fDplfO6KjwXYf+9I8rAkGWP8MLMXIfxhkh8meXOSl0xnSJLkr5P8TlXdUlXv+0Urnc64/G6Sd1fV/07ruvwBbnN1ZhOJ/yGzszU3ZvYg6s7pKn+R5HNJ/mla578m+bXptrdMy6+axqOnL/wdYO1q9vsB6zOdxbkpyTljjC9uen+AY1dVnZzkR0lOmx6M0YQzOKxFVb2wqk6squMze07+x0mu2vBuAcegqjq3qh5aVSdk9pLxL4ubfgQO63JWkm9m9hz285OcN8b46WZ3CThGXZDZWeT/yWzy81r/BA3r4SkqAKAdZ3AAgHY29sc2j9z05PucOnrBafs/If0zh655wG3uvM79Xa+r3Y5/Uav+fq7yZ7HocW3Tz3ovP4tFrPpYP3vk8q17t9dFx5ptud8vuh+b2t9FtrvN+3Y011t2u3vZ5jq+d6sckze1b4uMNc7gAADtCBwAoB2BAwC0I3AAgHY29jLxs4+7YKUb3pYJgova70lue13ffu/Hqice72aVk5EXPYZV7se22+34jzv1uq2bZLyXsWZbJmOu+n61LePFoh5s+7tNtuVFHnux7FjjDA4A0I7AAQDaETgAQDsCBwBop80k42Vt0wSsTby75jZx/NsxoW83q3x30XXzgob9nzS/ifvatr/j737b5n3bzar31zsZAwDHJIEDALQjcACAdgQOANDOVk0yXnSS2272+906N/Xutqtc37H+zrt7sc3vvLwXq963B8sk4028QGCbfw/WYdXj+162u9/vAr2pd57exItmVjk5fdFtJiYZAwDHKIEDALQjcACAdgQOANDOxiYZAwDsF2dwAIB2BA4A0I7AAQDaETgAQDsCBwBoR+AAAO0IHACgHYEDALQjcACAdgQOANCOwAEA2hE4AEA7AgcAaEfgAADtCBwAoB2BAwC0I3AAgHYEDgDQjsABANoROABAOwIHAGhH4AAA7QgcAKCd/wc6xV0unK5QSAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 576x576 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom fastai ResNet model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_dummy = create_cnn(data, models.resnet18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BCEWithLogitsLoss()"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn_dummy.loss_func.func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#learn_dummy.model[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): Lambda()\n",
       "  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25)\n",
       "  (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (5): ReLU(inplace)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn_dummy.model[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup custom input stage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define function to create 3 channel image from 1 channel image.\n",
    "def ExpandInput(): return Lambda(lambda x: x.expand(-1, 3, 4, 50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "EI = ExpandInput()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test ExpandInput layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1, 4, 50])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt = torch.rand((64,1,4,50)); tt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 4, 50])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt.expand(-1, 3, 4, 50).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDnviXJ4M8PfFfQvhDffCLwjqWn+AvG3xP8MaXdSaItlPfWXhuw8N6/owvRYG3ju/Jubx43jkQw3EUUJuIppoxPX1l+2v8ABDwH+zl8Yvh1ofwXgvdCn8W+Itd8K6hqtrfyGeLR7fQovELW0asTCTPPZ2lrLM8bzNZ2sMKyJ5auCigDzv8Aae+MHjz4W/Df4j+LZr+z8R3Pw7+G/hnWvDX9vaTbK0F/rT+Io9Ru1mtI4J4rqWa0t7z7XDJHdLcq8gmAmmWR97+zB4O+JH7VPwX+Amj+M/FXg3Q/iF8G/Fl7qV14M1s2mpWZtLRNMhihv3WS5kj8uFJmFzJOzztNIzH7ROJCigDxzxh+0f8AF/wXa/sx6rceLrjX7zxuPB1x4g1DxRK9/eXEmq3Wsyy5upGNwyQxyXdtDE0jRpb6heRlH89jXX/8Eofj94s/bL8P/GL4XfGnSrS603wfpUOih0vL2WbV/wC0dU1m+kvbt7m4lzdx3OkQypLD5WWnm8xZf3PlFFAG38P/AI93yeAtES8+Efw6uphpFsJbqbwHYq8zeUuXYRxqoJPJCqq88ADiiiigD//Z\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAstJREFUKJEFwU8oswEcB/DvuylPmSfNiGVTiOI6FzWxi7KLspY1mpVdVnoeJRntoNwUicRzsFL+tDZao205SE9abezJwZo8KHGYhSzbaKvf+/n8a2trI6PRiM7OTiSTSYiiCEEQMD09DZPJhJqaGvT19SEQCGBmZgajo6O4ublBPp/H9fU1HA4HhoaGEA6HIcsyzs7OkEgk0NraCp7nsb29DbVajVKphEwmA5VKhe7ublitVvj9fjgcDuRyOdTX12NgYACJRAIGgwHhcBiiKMJsNoPneSwtLSGTySCXy2FiYgIulwssy+Lr6wstLS34x3Eczc3NQalUIhKJ4Pv7G4Ig4OHhAZIk4fT0FIVCAVNTU7Barfj5+YHP50MoFIJWq0WlUoHb7UY6ncbv7y/i8TicTicUCgUGBwchSRKen5/h8Xiws7ODsbExMAyDv78/RCIRDA8P4+7uDpubm9Dr9ZAkCYVCAclkEhzHYXZ2FizLwu124+TkBDqdDhzHIRgM4uLiApOTk9BoNFAolUro9XoQEcxmM9bX13F0dIR4PI7b21vYbDY0NDQgm83C4/Hg4OAA6XQau7u7KJVKuLy8xP7+Pvb29sDzPCwWC8bHx/H5+Qmv1wtZlvHy8gKlUonFxUUwDIP7+3vY7XZUV1eDZVn09PRArVbj+PgYq6urKJfLYBgG7+/vaGxshMlkgk6nw/z8PKLRKLLZLIxGIwwGAxYWFiDLMvD4+Egul4vq6uqoUqmQSqWi/v5+isViZLPZ6OPjg/L5PAUCASqXy7S2tka1tbUkiiIJgkA+n4+urq7o/Pycmpub6fX1lYLBIDmdTmpvb6eNjQ06PDykYrFIW1tbpNVqyev1UrFYpKenJ+ro6KBQKESpVIpGRkZoZWWFmpqaqKuri/x+P1ksFtJoNNTb20tVVVX09vZGsViMUqkURaNRstvttLy8TP8B0X5r1+LxSH4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (3, 4, 50)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image(EI(tt)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Insert custom input stage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_custom_resnet = nn.Sequential(ExpandInput(), learn_dummy.model) # insert ExpandInput layer at the beginning of the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 4, 50])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# run dummy data through custom input stage to test it\n",
    "net_custom_resnet[0](tt).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Tensorboard logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# From https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\n",
    "\"\"\"Simple example on how to log scalars and images to tensorboard without tensor ops.\n",
    "License: Copyleft\n",
    "\"\"\"\n",
    "#__author__ = \"Michael Gygli\"\n",
    "\n",
    "#import tensorflow as tf\n",
    "#from StringIO import StringIO\n",
    "#import matplotlib.pyplot as plt\n",
    "#import numpy as np\n",
    "\n",
    "class Logger(object):\n",
    "    \"\"\"Logging in tensorboard without tensorflow ops.\"\"\"\n",
    "\n",
    "    def __init__(self, log_dir):\n",
    "        \"\"\"Creates a summary writer logging to log_dir.\"\"\"\n",
    "        self.writer = tf.summary.FileWriter(log_dir)\n",
    "\n",
    "    def log_scalar(self, tag, value, step):\n",
    "        \"\"\"Log a scalar variable.\n",
    "        Parameter\n",
    "        ----------\n",
    "        tag : basestring\n",
    "            Name of the scalar\n",
    "        value\n",
    "        step : int\n",
    "            training iteration\n",
    "        \"\"\"\n",
    "        summary = tf.Summary(value=[tf.Summary.Value(tag=tag,\n",
    "                                                     simple_value=value)])\n",
    "        self.writer.add_summary(summary, step)\n",
    "\n",
    "    def log_images(self, tag, images, step):\n",
    "        \"\"\"Logs a list of images.\"\"\"\n",
    "\n",
    "        im_summaries = []\n",
    "        for nr, img in enumerate(images):\n",
    "            # Write the image to a string\n",
    "            s = StringIO()\n",
    "            plt.imsave(s, img, format='png')\n",
    "\n",
    "            # Create an Image object\n",
    "            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),\n",
    "                                       height=img.shape[0],\n",
    "                                       width=img.shape[1])\n",
    "            # Create a Summary value\n",
    "            im_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, nr),\n",
    "                                                 image=img_sum))\n",
    "\n",
    "        # Create and write Summary\n",
    "        summary = tf.Summary(value=im_summaries)\n",
    "        self.writer.add_summary(summary, step)\n",
    "        \n",
    "\n",
    "    def log_histogram(self, tag, values, step, bins=1000):\n",
    "        \"\"\"Logs the histogram of a list/vector of values.\"\"\"\n",
    "        # Convert to a numpy array\n",
    "        values = np.array(values)\n",
    "        \n",
    "        # Create histogram using numpy        \n",
    "        counts, bin_edges = np.histogram(values, bins=bins)\n",
    "\n",
    "        # Fill fields of histogram proto\n",
    "        hist = tf.HistogramProto()\n",
    "        hist.min = float(np.min(values))\n",
    "        hist.max = float(np.max(values))\n",
    "        hist.num = int(np.prod(values.shape))\n",
    "        hist.sum = float(np.sum(values))\n",
    "        hist.sum_squares = float(np.sum(values**2))\n",
    "\n",
    "        # Requires equal number as bins, where the first goes from -DBL_MAX to bin_edges[1]\n",
    "        # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto#L30\n",
    "        # Thus, we drop the start of the first bin\n",
    "        bin_edges = bin_edges[1:]\n",
    "\n",
    "        # Add bin edges and counts\n",
    "        for edge in bin_edges:\n",
    "            hist.bucket_limit.append(edge)\n",
    "        for c in counts:\n",
    "            hist.bucket.append(c)\n",
    "\n",
    "        # Create and write Summary\n",
    "        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])\n",
    "        self.writer.add_summary(summary, step)\n",
    "        self.writer.flush()\n",
    "        \n",
    "\"A `Callback` that saves tracked metrics into a log file for Tensorboard.\"\n",
    "# Based on https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\n",
    "# and devforfu: https://nbviewer.jupyter.org/gist/devforfu/ea0b3fcfe194dad323c3762492b05cae\n",
    "# Contribution from MicPie\n",
    "\n",
    "#from ..torch_core import *\n",
    "#from ..basic_data import DataBunch\n",
    "#from ..callback import *\n",
    "#from ..basic_train import Learner, LearnerCallback\n",
    "#import tensorflow as tf\n",
    "\n",
    "__all__ = ['TBLogger']\n",
    "\n",
    "@dataclass\n",
    "class TBLogger(LearnerCallback):\n",
    "    \"A `LearnerCallback` that saves history of metrics while training `learn` into log files for Tensorboard.\"\n",
    "    \n",
    "    log_dir:str = 'logs'\n",
    "    log_name:str = 'data'\n",
    "    log_scalar:bool = True # log scalar values for Tensorboard scalar summary\n",
    "    log_hist:bool = True # log values and gradients of the parameters for Tensorboard histogram summary\n",
    "    log_img:bool = False # log values for Tensorboard image summary\n",
    "\n",
    "    def __post_init__(self): \n",
    "        super().__post_init__()\n",
    "    #def __init__(self):\n",
    "    #    super().__init__()\n",
    "        self.path = self.learn.path\n",
    "        (self.path/self.log_dir).mkdir(parents=True, exist_ok=True) # setup logs directory\n",
    "        self.Log = Logger(str(self.path/self.log_dir/self.log_name))\n",
    "        self.epoch = 0\n",
    "        self.batch = 0\n",
    "        self.log_grads = {}\n",
    "    \n",
    "    def on_backward_end(self, **kwargs:Any):\n",
    "        self.batch = self.batch+1\n",
    "        #print('\\nBatch: ',self.batch)\n",
    "        \n",
    "        if self.log_hist:\n",
    "            for tag, value in learn.model.named_parameters():\n",
    "                tag_grad = tag.replace('.', '/')+'/grad'\n",
    "                \n",
    "                if tag_grad in self.log_grads:\n",
    "                    #self.log_grads[tag_grad] += value.grad.data.cpu().detach().numpy()\n",
    "                    self.log_grads[tag_grad] = self.log_grads[tag_grad] + value.grad.data.cpu().detach().numpy() # gradients are summed up from every batch\n",
    "                    #print('if')\n",
    "                else:\n",
    "                    self.log_grads[tag_grad] = value.grad.data.cpu().detach().numpy()\n",
    "                    #print('else')\n",
    "                \n",
    "                #print(tag_grad, self.log_grads[tag_grad].sum())\n",
    "        return self.log_grads\n",
    "    \n",
    "    #def on_step_end(self, **kwards:Any):\n",
    "        #print('Step end: ', self.log_grads)\n",
    "\n",
    "    def on_epoch_end(self, epoch:int, smooth_loss:Tensor, last_metrics:MetricsList, **kwargs:Any) -> bool:\n",
    "        last_metrics = ifnone(last_metrics, [])\n",
    "        tr_info = {name: stat for name, stat in zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)}\n",
    "        self.epoch = tr_info['epoch']\n",
    "        self.batch = 0 # reset batch count\n",
    "        #print('\\nEpoch: ',self.epoch)\n",
    "        \n",
    "        if self.log_scalar:\n",
    "            for tag, value in tr_info.items():\n",
    "                if tag == 'epoch': continue\n",
    "                self.Log.log_scalar(tag, value, self.epoch+1)\n",
    "                \n",
    "        if self.log_hist:\n",
    "            for tag, value in learn.model.named_parameters():\n",
    "                \n",
    "                tag = tag.replace('.', '/')\n",
    "                self.Log.log_histogram(tag, value.data.cpu().numpy(), self.epoch+1)\n",
    "                \n",
    "                tag_grad = tag.replace('.', '/')+'/grad'\n",
    "                self.Log.log_histogram(tag_grad, self.log_grads[tag_grad], self.epoch+1)\n",
    "                #print(tag_grad, self.log_grads[tag_grad].sum())\n",
    "                \n",
    "        #if self.log_img:\n",
    "        #    for tag, value in learn.model.named_parameters():\n",
    "        #        \n",
    "        #        tag = tag.replace('.', '/')\n",
    "        #        self.Log.log_images(tag, value.data.cpu().numpy(), self.epoch+1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train custom fastai ResNet with fastai "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn_resnet = Learner(data, net_custom_resnet, metrics=accuracy_thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "#[p.shape for p in net_custom_resnet.parameters()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================\n",
      "Layer (type)         Output Shape         Param #    Trainable \n",
      "======================================================================\n",
      "Lambda               [64, 3, 4, 50]       0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 2, 25]      9408       False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 2, 25]      128        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 64, 2, 25]      0          False     \n",
      "______________________________________________________________________\n",
      "MaxPool2d            [64, 64, 1, 13]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 64, 1, 13]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 64, 1, 13]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      73728      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 128, 1, 7]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      147456     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      8192       False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      147456     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 128, 1, 7]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      147456     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      294912     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 256, 1, 4]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      589824     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      32768      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      589824     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 256, 1, 4]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      589824     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      1179648    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 512, 1, 2]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      2359296    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      131072     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      2359296    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 512, 1, 2]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      2359296    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "AdaptiveAvgPool2d    [64, 512, 1, 1]      0          False     \n",
      "______________________________________________________________________\n",
      "AdaptiveMaxPool2d    [64, 512, 1, 1]      0          False     \n",
      "______________________________________________________________________\n",
      "Lambda               [64, 1024]           0          False     \n",
      "______________________________________________________________________\n",
      "BatchNorm1d          [64, 1024]           2048       True      \n",
      "______________________________________________________________________\n",
      "Dropout              [64, 1024]           0          False     \n",
      "______________________________________________________________________\n",
      "Linear               [64, 512]            524800     True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 512]            0          False     \n",
      "______________________________________________________________________\n",
      "BatchNorm1d          [64, 512]            1024       True      \n",
      "______________________________________________________________________\n",
      "Dropout              [64, 512]            0          False     \n",
      "______________________________________________________________________\n",
      "Linear               [64, 2]              1026       True      \n",
      "______________________________________________________________________\n",
      "\n",
      "Total params:  11705410\n",
      "Total trainable params:  538498\n",
      "Total non-trainable params:  11166912\n"
     ]
    }
   ],
   "source": [
    "learn_resnet.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_resnet.lr_find()\n",
    "learn_resnet.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='2' class='' max='10', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      20.00% [2/10 03:32<14:11]\n",
       "    </div>\n",
       "    \n",
       "<table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy_thresh</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>0.777768</th>\n",
       "    <th>0.793338</th>\n",
       "    <th>0.515000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>0.692816</th>\n",
       "    <th>0.765466</th>\n",
       "    <th>0.633000</th>\n",
       "  </tr>\n",
       "</table>\n",
       "\n",
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='progress-bar-interrupted' max='23', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      Interrupted\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/queues.py\", line 240, in _feed\n",
      "    send_bytes(obj)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n",
      "    self._send_bytes(m[offset:offset + size])\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 404, in _send_bytes\n",
      "    self._send(header + buf)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n",
      "    n = write(self._handle, buf)\n",
      "BrokenPipeError: [Errno 32] Broken pipe\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/queues.py\", line 240, in _feed\n",
      "    send_bytes(obj)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n",
      "    self._send_bytes(m[offset:offset + size])\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 404, in _send_bytes\n",
      "    self._send(header + buf)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n",
      "    n = write(self._handle, buf)\n",
      "BrokenPipeError: [Errno 32] Broken pipe\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/queues.py\", line 240, in _feed\n",
      "    send_bytes(obj)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n",
      "    self._send_bytes(m[offset:offset + size])\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 404, in _send_bytes\n",
      "    self._send(header + buf)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n",
      "    n = write(self._handle, buf)\n",
      "BrokenPipeError: [Errno 32] Broken pipe\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/queues.py\", line 240, in _feed\n",
      "    send_bytes(obj)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n",
      "    self._send_bytes(m[offset:offset + size])\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 404, in _send_bytes\n",
      "    self._send(header + buf)\n",
      "  File \"/Users/MMP/anaconda3/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n",
      "    n = write(self._handle, buf)\n",
      "BrokenPipeError: [Errno 32] Broken pipe\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-25-d4b78c83e318>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearn_resnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_one_cycle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Downloads/fastai/fastai/train.py\u001b[0m in \u001b[0;36mfit_one_cycle\u001b[0;34m(learn, cyc_len, max_lr, moms, div_factor, pct_start, wd, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m     19\u001b[0m     callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor,\n\u001b[1;32m     20\u001b[0m                                         pct_start=pct_start, **kwargs))\n\u001b[0;32m---> 21\u001b[0;31m     \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcyc_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_lr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mlr_find\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mLearner\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_lr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-7\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_it\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstop_div\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, epochs, lr, wd, callbacks)\u001b[0m\n\u001b[1;32m    164\u001b[0m         \u001b[0mcallbacks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_fns\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlistify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    165\u001b[0m         fit(epochs, self.model, self.loss_func, opt=self.opt, data=self.data, metrics=self.metrics,\n\u001b[0;32m--> 166\u001b[0;31m             callbacks=self.callbacks+callbacks)\n\u001b[0m\u001b[1;32m    167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    168\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcreate_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, model, loss_func, opt, data, callbacks, metrics)\u001b[0m\n\u001b[1;32m     82\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0myb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprogress_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpbar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     83\u001b[0m                 \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m                 \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     85\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mloss_batch\u001b[0;34m(model, xb, yb, loss_func, opt, cb_handler)\u001b[0m\n\u001b[1;32m     24\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_backward_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m         \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_backward_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m         \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    100\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m         \"\"\"\n\u001b[0;32m--> 102\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     88\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     89\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "learn_resnet.fit_one_cycle(10, max_lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_resnet.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn_resnet.recorder.plot_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}