[de45a9]: / deprecated / DL_Genomics_v6.ipynb

Download this file

3364 lines (3363 with data), 178.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep learning in genomics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is based on the [jupyter notebook](https://github.com/abidlabs/deep-learning-genomics-primer/blob/master/A_Primer_on_Deep_Learning_in_Genomics_Public.ipynb) from the publication [\"A primer on deep learning in genomics\"](https://www.nature.com/articles/s41588-018-0295-5) but uses the [fastai](https://www.fast.ai) library based on [PyTorch](https://pytorch.org)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.35.dev0'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# fastai version\n",
    "__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Load data from the web, generate dataframe, and save to disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "URL_seq = 'https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/sequences.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# get data from URL\n",
    "seq_raw = requests.get(URL_seq).text.split('\\n')\n",
    "seq_raw = list(filter(None, seq_raw)) # Removes empty lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check length\n",
    "len(seq_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup df from list\n",
    "seq_df = pd.DataFrame(seq_raw, columns=['Sequences'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# show head of dataframe\n",
    "#seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "URL_labels = 'https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/labels.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_labels = requests.get(URL_labels).text.split('\\n')\n",
    "seq_labels = list(filter(None, seq_labels)) # Removes empty entries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_label_series = pd.Series(seq_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_df['Target'] = seq_label_series.astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Sequences  Target\n",
       "0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...       0\n",
       "1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...       0\n",
       "2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...       0\n",
       "3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...       1\n",
       "4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...       1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_df.to_csv('seq_df.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_df = pd.read_csv('seq_df.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#seq_df.drop('Unnamed: 0', axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Sequences  Target\n",
       "0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...       0\n",
       "1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...       0\n",
       "2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...       0\n",
       "3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...       1\n",
       "4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...       1"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess y values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_df['Target'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 1, ..., 1, 0, 1, 1])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targA = seq_df['Target'].values; targA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 1, 0, ..., 0, 1, 0, 0])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targB = np.logical_not(seq_df['Target'].values).astype(int); targB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([[0, 1], [0, 1], [0, 1], [1, 0]], '...', [[1, 0], [0, 1], [1, 0], [1, 0]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targ = [[a,b] for a, b in zip(targA, targB)]; targ[:4], '...', targ[-4:]#, len(targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_df['MultiTarget'] = targ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def joinarray(x):\n",
    "    y = ';'.join((str(x[0]),str(x[1])))\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_df['MultiTarget'] = seq_df['MultiTarget'].apply(joinarray)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "      <th>MultiTarget</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "      <td>0;1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "      <td>0;1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "      <td>0;1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "      <td>1;0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "      <td>1;0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Sequences  Target MultiTarget\n",
       "0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...       0         0;1\n",
       "1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...       0         0;1\n",
       "2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...       0         0;1\n",
       "3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...       1         1;0\n",
       "4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...       1         1;0"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Data encoding test (incorporated into \"open_seq_image\" function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup class instance to encode the four different bases to integer values (1D)\n",
    "int_enc = LabelEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "one_hot_enc = OneHotEncoder(categories=[range(4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_enc = []\n",
    "\n",
    "for s in seq:\n",
    "    enc = int_enc.fit_transform(list(s)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    seq_enc.append(enc.toarray()) # export sparse matrix to np array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_enc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 462,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., 1., ..., 0., 1., 0., 0.],\n",
       "        [1., 1., 0., 0., ..., 1., 0., 1., 1.],\n",
       "        [0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., ..., 0., 0., 0., 0.]]), (4, 50))"
      ]
     },
     "execution_count": 462,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_enc[0].T, seq_enc[0].T.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAMAAAD7wpwzAAADAFBMVEUAAAABAQECAgIDAwMEBAQFBQUGBgYHBwcICAgJCQkKCgoLCwsMDAwNDQ0ODg4PDw8QEBARERESEhITExMUFBQVFRUWFhYXFxcYGBgZGRkaGhobGxscHBwdHR0eHh4fHx8gICAhISEiIiIjIyMkJCQlJSUmJiYnJycoKCgpKSkqKiorKyssLCwtLS0uLi4vLy8wMDAxMTEyMjIzMzM0NDQ1NTU2NjY3Nzc4ODg5OTk6Ojo7Ozs8PDw9PT0+Pj4/Pz9AQEBBQUFCQkJDQ0NERERFRUVGRkZHR0dISEhJSUlKSkpLS0tMTExNTU1OTk5PT09QUFBRUVFSUlJTU1NUVFRVVVVWVlZXV1dYWFhZWVlaWlpbW1tcXFxdXV1eXl5fX19gYGBhYWFiYmJjY2NkZGRlZWVmZmZnZ2doaGhpaWlqampra2tsbGxtbW1ubm5vb29wcHBxcXFycnJzc3N0dHR1dXV2dnZ3d3d4eHh5eXl6enp7e3t8fHx9fX1+fn5/f3+AgICBgYGCgoKDg4OEhISFhYWGhoaHh4eIiIiJiYmKioqLi4uMjIyNjY2Ojo6Pj4+QkJCRkZGSkpKTk5OUlJSVlZWWlpaXl5eYmJiZmZmampqbm5ucnJydnZ2enp6fn5+goKChoaGioqKjo6OkpKSlpaWmpqanp6eoqKipqamqqqqrq6usrKytra2urq6vr6+wsLCxsbGysrKzs7O0tLS1tbW2tra3t7e4uLi5ubm6urq7u7u8vLy9vb2+vr6/v7/AwMDBwcHCwsLDw8PExMTFxcXGxsbHx8fIyMjJycnKysrLy8vMzMzNzc3Ozs7Pz8/Q0NDR0dHS0tLT09PU1NTV1dXW1tbX19fY2NjZ2dna2trb29vc3Nzd3d3e3t7f39/g4ODh4eHi4uLj4+Pk5OTl5eXm5ubn5+fo6Ojp6enq6urr6+vs7Ozt7e3u7u7v7+/w8PDx8fHy8vLz8/P09PT19fX29vb39/f4+Pj5+fn6+vr7+/v8/Pz9/f3+/v7////isF19AAAAO0lEQVR4nGWNSQoAMAgD8/9PT6loIq0HMauShOR9DwphIkO5eALbEI1xp5DUoxYnNtf3t1rci9mGy30A2+0xz9q2+b0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=P size=50x4 at 0x1A20462668>"
      ]
     },
     "execution_count": 311,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PIL.Image.fromarray(seq_enc[0].T.astype('uint8')*255).convert('P')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## fastai data setup for NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# open sequence image function\n",
    "def open_seq_image(seq:str, cls:type=Image)->Image:\n",
    "    \"Return `Image` object created from sequence string `seq`.\"\n",
    "    \n",
    "    int_enc = LabelEncoder() # setup class instance to encode the four different bases to integer values (1D)\n",
    "    one_hot_enc = OneHotEncoder(categories=[range(4)]) # setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "    \n",
    "    enc = int_enc.fit_transform(list(seq)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    enc = enc.toarray().T # export sparse matrix to np array\n",
    "    #print('enc', enc, enc.shape)\n",
    "    \n",
    "    # https://stackoverflow.com/questions/22902040/convert-black-and-white-array-into-an-image-in-python\n",
    "    x = PIL.Image.fromarray(enc.astype('uint8')).convert('P')\n",
    "    x = pil2tensor(x,np.float32)\n",
    "    #x = x.view(4,-1) # remove first dimension\n",
    "    #x = x.expand(3, 4, 50) # expand to 3 channel image\n",
    "    #print('x', x, x.shape)\n",
    "    return cls(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD80/8AhrrUr/8A0rxT8B/h1q95B/pGltNotzaWGm6iflfUYdIsrmDSvPkjjsopUeze3nTTrfzYZGadpvuj4Jf8EwP2bf2gf2wL39kD4j6n4uuYLDwj4g1u38e/8JEz68IdG8WXnhax0rMqvZ/YlsbG3faLUTCVcJMkAWBSivm/HbGYzgrIq2IyKrLDTjhcTUThKSanTdBQkuZztbnleKXLK95xqOMXD9sPnT9kmOD9rf8A4Wbb6rcav4Ls/gx8Ita+Inwp0/wV4s1VYfCuqaf5LmKyF/dXRhguriVbqcg+eZreIxTQqHR+R+CXxp8V/tC+Mr34X/FXTdI1fwnpnhHxBrWlaFqWlx3s2nWmjaPearZaJaapd+bqtnpYkso4TbwXkbeTLOFkWSZ5SUV9lXwmGhmfENJRVsJDCOj3pSq0HOpKEvjUpz9+cudylJuTbbk2Hrn7DnwE+DH7cXg3xd4t1X4aaR8PovBvi7wX4fsNJ8DWguIbqPxRrA0fUbm4l1v+0Ll50tdptsTLHazJ58UayvI7/On/AA2R480r/ibfDvwF4R8F+IW+SXxP4N0uXT7mWE/vXiaCOb7H/wAfuL+OUW4mtLmK3+ySW0VpaQwFFHCuFoZnx5n2W4xOrQw06KpQnKUoxU6VKUk+ab51KUm2qjqpt+SSA/4ba+Mn/QmfCL/xH/wf/wDKqiiiv0n/AFT4Y/6A6f3S/wDlgH//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHNJREFUKJFjdGEM+c8ABTufXWBwlzJgQAe4xNHVMDAwoKiD6cMmRwlANhdmJuPf5yr/cVlCjAdIdQAMELIP3bHobkEPJEZYjKBLYLMMV6gji+GykJDHcHkOm9uwqWNETlqEDEF2LD7LiUlK+JIxTC8pKQIA7jhk1J8wofAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test open sequence image function\n",
    "open_seq_image('CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGACACC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "class SeqItemList(ImageItemList):\n",
    "    _bunch,_square_show = ImageDataBunch,True\n",
    "    def __post_init__(self):\n",
    "        super().__post_init__()\n",
    "        self.sizes={}\n",
    "    \n",
    "    def open(self, seq): return open_seq_image(seq)\n",
    "    \n",
    "    def get(self, i):\n",
    "        seq = self.items[i][0]\n",
    "        res = self.open(seq)\n",
    "        return res\n",
    "    \n",
    "    @classmethod\n",
    "    def import_from_df(cls, df:DataFrame, cols:IntsOrStrs=0, **kwargs)->'ItemList':\n",
    "        \"Get the sequences in `col` of `df` and will had `path/folder` in front of them, `suffix` at the end.\"\n",
    "        return cls(items=df[cols].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "bs = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "data = (SeqItemList.import_from_df(seq_df, ['Sequences'])\n",
    "        .random_split_by_pct(valid_pct=0.25)\n",
    "        #.split_by_idxs(range(1500), range(1500,2000))\n",
    "        #.label_from_list(seq_df['Target'].values) # --> Two categories! AND WRONG LOSS FUNCTION!?\n",
    "        #.label_from_list(targ) # --> MultiCategory\n",
    "        #.label_from_list(seq_df['MultiTarget'], sep=';')\n",
    "        .label_from_list(seq_df['Target'].values, sep=';')\n",
    "        .databunch(bs=bs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Verify data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "### Check data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[[[1., 0., 0.,  ..., 0., 1., 0.],\n",
       "           [0., 1., 1.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 1., 0., 1.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       " \n",
       " \n",
       "         [[[1., 0., 0.,  ..., 0., 0., 1.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 1., 1.,  ..., 1., 1., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       " \n",
       " \n",
       "         [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [1., 0., 0.,  ..., 1., 1., 1.],\n",
       "           [0., 0., 1.,  ..., 0., 0., 0.],\n",
       "           [0., 1., 0.,  ..., 0., 0., 0.]]],\n",
       " \n",
       " \n",
       "         ...,\n",
       " \n",
       " \n",
       "         [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 1.,  ..., 0., 1., 0.],\n",
       "           [1., 1., 0.,  ..., 1., 0., 1.]]],\n",
       " \n",
       " \n",
       "         [[[0., 0., 0.,  ..., 1., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [1., 1., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 1.,  ..., 0., 1., 1.]]],\n",
       " \n",
       " \n",
       "         [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 1.,  ..., 1., 0., 0.],\n",
       "           [1., 1., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 1., 1.]]]]), tensor([[1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.]])]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "\n",
       "Train: LabelList\n",
       "y: MultiCategoryList (2000 items)\n",
       "[MultiCategory 0, MultiCategory 0, MultiCategory 0, MultiCategory 1, MultiCategory 1]...\n",
       "Path: .\n",
       "x: SeqItemList (1500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "\n",
       "Valid: LabelList\n",
       "y: MultiCategoryList (2000 items)\n",
       "[MultiCategory 0, MultiCategory 0, MultiCategory 0, MultiCategory 1, MultiCategory 1]...\n",
       "Path: .\n",
       "x: SeqItemList (500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "\n",
       "Test: None"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, ['0', '1'])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check classes\n",
    "data.c, data.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_dl.batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "### Check data points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 217,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4w8Ny6l4h/wCEMjg8RavpkXi34ReJviPYjS9ZuY5vDeteHv7ej0f7DcmRrmSC0tfD1taQR3kt0beG7vfIaKWcSp5F/wANpfGDwt/xKvgLN/wq/wAPP++uPB/g3XtWl0qe/wD4dTaDUr27xeptgMdwpV4JLS3li8uWJZAUV9NwtlOWZnisTDF0Y1Iw5bRklKOtXGxu4tOLfLRp2lKMpJx5oyjKU5VP2wP+G0vjB4k/dfGub/hY8EnyX8HjDXtWX+0oT8zxXj2N7bSXm+WLTpDLMzzD+xtNiWRYLZYSf8NmfESD/QtL+HXw6s9Ktfm0bw/D8PtPawspF+SGeaGSNhq88Nu1xbxTar9ukjS8uJFYTyeeCivsf9U+GtlhIJfypOMV5qEZRhF9Lxgnb3bqOgHrvhC+8V/Ef9lTXP2vPiV471fXvEei/wBqxNaXc0cMN9aWV94XsZLOWe3SO9SC9i8X6018IbmJr2aZZ5mebfJJyOi/GDUvEHwG8SftK6t4U0ibxZoPi7RPDmoyuLk2HiGyv7XVbq3iv9OM32JoNOl0TTTZ2cMMNkqWyRz29xGqopRX5tl+EwtetjoygrU8wp0IJLlUaMnHmppR5VyO+1t7NNNRcQ/RX9n/AP4N/P2OP2kfgN4J/aJ8d/Ej4i2euePvCOm+I9Zs/D91o9jYQXd9ax3M0dtbR6bst4FeVgkS/KiBVHAooor/ADY4j8ePGTAcRY3C4bPcRCnTrVoQipQtGMK1WEYr909IxjGK1eiWr3Yf/9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHBJREFUKJFjdGEM+b/z2QUGdykDBgYGBgZcbBgfBtylDLDykWlsAJuZ2PQQcge6+YwujCH/8SnABYgxHFkO2cPIYuieIMdcdykDBsa/z1X+oytENxjdEcR4EJ8+YgMMmztweRgeI6SEBK0AOUkTJgYA3Nhz1DR798MAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 217,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 2\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory 0"
      ]
     },
     "execution_count": 218,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 219,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4C+Nvxp8V/s9eMrL4bfCrTdI0vw5P4R8P+JtP0y00uOCbTdU1PR7PWIrqK/h2X8k+n3d7J9huJ7maeCENA0ksM91HOfBL45fET46eMr34fa/caRpeh2/hHxB4pu/D+ieGNPj0e+1rRtHvNUtL9tJkgk06CdjYWlnM9tbQme0SWJ9xuLl5iivpv7Jyz/UX+0XRi6/subnavLm57czbbvK32mua/vcylq/2w679kmSf9t7/AIWbY/tCW+kX2lfCr4Ra18QfD2i6J4T0rSIZLvTfJ8vTGksbWKe30uU3E8k1paS26vNPLcApPI8zeRf8NrfGeT/iY3K6QuuTfNqfi3TbA6ZrGqSR/vbOe7vLB4Jbie3vcagJnYyXF3HBJeNdi2tkhKKMlynLMZxhnOArUYyo4f6p7KFvdh7XD1KlXlV7fvJxjOd+ZSlFNq6uB1/xJ+NPiv4I+Dfh/wCJPgbpukeFtN+IfhGbxHfeFk0uPVbDSbtdY1LSLiOwbVvtdzbwXlrpdut5EZmW8Rnhn8y28u3TkP8AhtH4rP8A6be+HPCOpah/Dd+IvC8Gr20Gfkf7Ppt+J9Nst0EOnWo+zWsXlW2kWUMXlIJhMUV7OQcPZHjssVbEYeM5udVOTTcmo168I80uZSlaMIxvKTbUVzOTu5B1/wC0B4y+BP7PXx58bfALRv2Kvh1r1n4H8Xal4ftdc8Qa14o+36jHZXUlstzc/ZdZgg8+QRh38mGKPezbI0XCgooo4cyHLsfw7gsViPaSqVKNGcn7fEaynRpTk9MQlrKUnoktdElZIP/Z\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHRJREFUKJFjdGEM+c+ABHY+u8DgLmWAwYbxGRgYUMSw6SMkDhPDZj42u3GpRwaMMI8gOxLdEHTHIxsMA8h8bA7HZjY2gK4Wn3nIcox/n6v8JyYkCTkAl6MIhTi+WMZmFq4YZnRhDPlPDUeiyxHjOGLNI0YfAGvwc9QnBcXaAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 219,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 3\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 220,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory 1"
      ]
     },
     "execution_count": 220,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 221,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAETCAYAAAA79nyeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAACIFJREFUeJzt3TGrHFUYx+H3XFMY1BQ2SrCNjYIprOxDPoGNWNoofoAUKf0K9oKFgqCVQshXsEgKm6SyMIiNiIooMWNhBLl3Vsa7Z/bM/vM8VVju7pzduzn5MfvupE3TVAAASU5GLwAAoDeBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBw2paa8+31r5orf3aWvu2tfbW6DUBWewz7HJh9AKI9mFV/VFVL1TV1ar6srV2d5qmb8YuCwhin2FWcyVj1tBae6aqfqyqV6dpuvf4to+r6rtpmm4MXRwQwT7Df/ERFWt5uar+/GfTeexuVb0yaD1AHvsMOwkc1vJsVf106rafquq5AWsBMtln2EngsJZfqurSqdsuVdXPA9YCZLLPsJPAYS33qupCa+3Kv257raoM/gG92GfYyZAxq2mtfVpVU1W9U39/u+GrqnrDtxuAXuwz7OIMDmt6r6ouVtUPVfVJVb1r0wE6s88wyxkcACCOMzgAQByBAwDEETgAQByBAwDEETgAQJxh/5v4tZM3N/v1rVsP7py57frlqwNWMm/E+vY55tx958w93j73Xdsh1rb2MXq/l24/+qyd+84r2cpes6W/Q/scd0t74RKH2H96Pn7q72bEXuMMDgAQR+AAAHEEDgAQZ9iVjOc+F+/9Gd1WZlXm9F7HVj6P3co65ixd25afw1KjPts/lhmcQ7w+p39uS++1Q+y1Sx7fDMrxvSeWvK97m1vbyYv3zeAAAE8egQMAxBE4AEAcgQMAxBk2ZPzo+ytnDtz7QktbHjgbofeQ25aG/NZey5ae65wRg39zjmXIeM5WfsfHNmS66+cO/fjJtvzajboIpQv9AQBPJIEDAMQROABAHIEDAMTZ1JWMD6HnFTeX3ve86+h9jK0Mpe2y9tVVRz3/rQzKH+J/s15yddFDW/qFhjlbHu6cs5V19HaI38OT9NrN2cr7eukxDRkDAE8kgQMAxBE4AEAcgQMAxBk2ZAwAsBZncACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcACAOAIHAIgjcFhFa+391trXrbXfW2sfjV4PkMlewy4XRi+AWA+q6oOqul5VFwevBchlr2GWwGEV0zR9XlXVWnu9ql4avBwglL2GXXxEBQDEETgAQByBAwDEETgAQBxDxqyitXah/n5/PVVVT7XWnq6qh9M0PRy7MiCJvYZdnMFhLTer6requlFVbz/+882hKwIS2WuY1aZpGr0GAICunMEBAOIIHAAgjsABAOIIHAAgzrCviV87efPMdPOtB3fO/Nz1y1dXXcfSY8793Jze993H6eMuffwRv4d9LVlz6vPa0jpuP/qsHWI9/8ej76+c2Wv2eQ23sk8tdWzP9RB6P68le+3axzzUcbdiyV7jDA4AEEfgAABxBA4AEEfgAABxhl3JuPfgX6qeA2LHOGy2z5rPO2S9zzoO8RovGTjdZ6B8n8fb4pDx3BcaettnCPi0Q/yd7P2liSVrNmS7bVt5nXp+ocEZHAAgjsABAOIIHAAgjsABAOIMGzJeOvg34oqT++g5FPt/7ntevQf1Rl0Zuufg9dJ19HZsz2vuGCcv3j/aIeOtDFkm6L0PzBl1teCla+l5zN7/hhzbgPZ59xpncACAOAIHAIgjcACAOAIHAIiz+SHjOb0H05Y8/paGDbcyGHtsg2q9HWKg+tgcy5WMt/zeHXXF332s/eWNQxhx1fgtvw936XnV7jk9r5ruDA4AEEfgAABxBA4AEEfgAABxhg0ZAwCsxRkcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACCOwAEA4ggcACDOXxoNqLkuqPO9AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 576x576 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Setup fastai ResNet34 architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "test_learn = create_cnn(data, models.resnet18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BCEWithLogitsLoss()"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_learn.loss_func.func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#test_learn.model[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): Lambda()\n",
       "  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25)\n",
       "  (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (5): ReLU(inplace)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_learn.model[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def ExpandInput(): return Lambda(lambda x: x.expand(-1, 3, 4, 50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "EI = ExpandInput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# Test ExpandInput layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1, 4, 50])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt = torch.rand((64,1,4,50)); tt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 4, 50])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt.expand(-1, 3, 4, 50).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDzK9/bD+K3iz9rzwX8GPHmi+Etd0fVIvtGpm98I2cMt7a32lWy3thcG2jiW4hntWs7JpJVe4FrpdpEk6A3JuO81678XNomreKfFHj7U/EFtoPwv+JEejeHdbgtZNLgu9H0Sy1SLUXtkhRbq7uLy6uJrqS581Z5JWkKiTDgooA0fh5qcXhX4R/tDeJ00HSb7xJ8N/jVPZ6N4p1DSoXvbySz8SrJa3l2FVYZ7mI2+I5fLXylnlWMICoX0D9ohfBPhL9mNdT0n4U+HRYW3xTs/B0eh+TOlvHpc15b2TRRSJKtzbJ9lub+AW8E0dt5OpXcXkmOUoCigBPjQfiV8NP+CgHgn9jDRvjx4vudK+LmqeHdV8TeML65tW1+C81SDVvPmtrhbdYkaKTSbCW3LxObZraNYikSJGvjf7POn2HxR+Cul/tb/Euzj17WrHQ59Vj8P6ku7R2uYNB8HTuhtxhlsppr5XbTEddPjWytoYLaCASwylFAHs3gHxvrlt4E0W3jubgLHpFsqiLU7qJQBEo4SOVUQf7KqFHQADiiiigD/9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAsVJREFUKJEFwV8oAwEcB/Dv7U+0W1v+LPJAkUeipMyfyLU8SK3wsNWSrtTUrTzyciip5WWd2nlktFopYixKvJiNpfyZmbiIsYd1U+uk9PP5YH19ndLpNFmtVvr8/CSHw0E8z9P9/T2ZzWZKpVLU3NxM39/fdHZ2RqIo0urqKvE8TxzHUSQSoYmJCWptbaXOzk6KxWI0Pz9PLMtSNpuliooK2tjYoKurK3I6nRSNRqm2tpbe39+pvr6eVFWlh4cH6u7upng8Tjs7O/T09ERer5cKhQINDw9TLpej6elpkmWZLBYLLSwsEM/z5HA4yO120/7+PhlqamrAMAwKhQKKxSIMBgP6+vqQzWYhyzJmZ2dxeXmJk5MTRCIRqKoKnU6HyclJGI1G5PN5vL6+QhRF/P39IRwOQxAEMAwDVVURDoexubmJm5sbLC0toVgsYnt7G0NDQ+jv70c6nYbL5cLMzAx+f3/x/PwMTdMwMjICj8eD3d1dmM1mWK1WeDwe9Pb2IpVKIZPJ4O7uDn6/H9fX19C3t7eLo6OjsNls8Hq9ODo6gtvtxs/PDwRBQDAYBBFhZWUFFxcX+Pr6QiAQgNFoxOLiIs7Pz8EwDFiWRVdXF1iWRU9PD9bW1pDL5dDU1ASXywW9Xg+TyYTl5WUMDg6iqqoKPM+D53nY7Xbs7e2hVCqB4zgEg0FomobDw0PMzc1BVVUkEgmIoojx8XHk83nIsozb21scHBzAbrfDkEwmIUkSQqEQqqur0djYiLq6OpSXlyMUCmFsbAyKoqCtrQ1bW1uYmppCNBqFz+dDJpPB6ekpBgYGYDAYoNPpwHEckskkTCYTFEWBpmkIBAKorKxEQ0MDfD4fYrEYJElCqVSC3+9HS0sLXl5e4HQ6UVZWhkQigY6ODjw+PsJms0GSJLy9vUFRFAiCAEEQcHx8jHg8DovFgo+PD/wDZ+hmByIU2BsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (3, 4, 50)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image(EI(tt)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "new_model = nn.Sequential(ExpandInput(), test_learn.model) # insert ExpandInput layer at the beginning of the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 4, 50])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model[0](tt).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "new_learn = Learner(data, new_model)#, metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#[p.shape for p in new_model.parameters()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      0.00% [0/5 00:00<00:00]\n",
       "    </div>\n",
       "    \n",
       "<table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "</table>\n",
       "\n",
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='progress-bar-interrupted' max='8', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      Interrupted\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "RuntimeError",
     "evalue": "Expected object of scalar type Long but got scalar type Float for argument #2 'other'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-99-c937c2aa1812>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnew_learn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_one_cycle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Downloads/fastai/fastai/train.py\u001b[0m in \u001b[0;36mfit_one_cycle\u001b[0;34m(learn, cyc_len, max_lr, moms, div_factor, pct_start, wd, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m     19\u001b[0m     callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor,\n\u001b[1;32m     20\u001b[0m                                         pct_start=pct_start, **kwargs))\n\u001b[0;32m---> 21\u001b[0;31m     \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcyc_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_lr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mlr_find\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mLearner\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_lr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-7\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_it\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstop_div\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, epochs, lr, wd, callbacks)\u001b[0m\n\u001b[1;32m    164\u001b[0m         \u001b[0mcallbacks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_fns\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlistify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    165\u001b[0m         fit(epochs, self.model, self.loss_func, opt=self.opt, data=self.data, metrics=self.metrics,\n\u001b[0;32m--> 166\u001b[0;31m             callbacks=self.callbacks+callbacks)\n\u001b[0m\u001b[1;32m    167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    168\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mcreate_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, model, loss_func, opt, data, callbacks, metrics)\u001b[0m\n\u001b[1;32m     92\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     93\u001b[0m         \u001b[0mexception\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     95\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexception\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, model, loss_func, opt, data, callbacks, metrics)\u001b[0m\n\u001b[1;32m     87\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'valid_dl'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalid_dl\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalid_ds\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     88\u001b[0m                 val_loss = validate(model, data.valid_dl, loss_func=loss_func,\n\u001b[0;32m---> 89\u001b[0;31m                                        cb_handler=cb_handler, pbar=pbar)\n\u001b[0m\u001b[1;32m     90\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mval_loss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mvalidate\u001b[0;34m(model, dl, loss_func, cb_handler, pbar, average, n_batch)\u001b[0m\n\u001b[1;32m     52\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     53\u001b[0m             \u001b[0mnums\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m             \u001b[0;32mif\u001b[0m \u001b[0mcb_handler\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     55\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mn_batch\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m>=\u001b[0m\u001b[0mn_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     56\u001b[0m         \u001b[0mnums\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnums\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/callback.py\u001b[0m in \u001b[0;36mon_batch_end\u001b[0;34m(self, loss)\u001b[0m\n\u001b[1;32m    237\u001b[0m         \u001b[0;34m\"Handle end of processing one batch with `loss`.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    238\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'last_loss'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 239\u001b[0;31m         \u001b[0mstop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'batch_end'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    240\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    241\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'iteration'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/callback.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, cb_name, call_mets, **kwargs)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcall_mets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m         \u001b[0;34m\"Call through to all of the `CallbakHandler` functions.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mcall_mets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'on_{cb_name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmet\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'on_{cb_name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/callback.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcall_mets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m         \u001b[0;34m\"Call through to all of the `CallbakHandler` functions.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mcall_mets\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'on_{cb_name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmet\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'on_{cb_name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/callback.py\u001b[0m in \u001b[0;36mon_batch_end\u001b[0;34m(self, last_output, last_target, **kwargs)\u001b[0m\n\u001b[1;32m    272\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast_target\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlast_target\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlast_target\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    273\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcount\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlast_target\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlast_target\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mlast_target\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    276\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mon_epoch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Downloads/fastai/fastai/metrics.py\u001b[0m in \u001b[0;36maccuracy\u001b[0;34m(input, targs)\u001b[0m\n\u001b[1;32m     37\u001b[0m     \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     38\u001b[0m     \u001b[0mtargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mtargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     41\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror_rate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0mRank0Tensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Expected object of scalar type Long but got scalar type Float for argument #2 'other'"
     ]
    }
   ],
   "source": [
    "new_learn.fit_one_cycle(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "new_learn.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "test_learn.get_predicts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "del test_learn\n",
    "del new_model\n",
    "del new_learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup custom model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Flatten(): return Lambda(lambda x: x.view((x.size(0), -1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ResizeInput(): return Lambda(lambda x: x.view((-1,)+x.size()[-2:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "#def ResizeOutput(): return Lambda(lambda x: x.view(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "drop_p = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inplace=True seems to generate problems?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(ResizeInput(),\n",
    "                    nn.Conv1d(in_channels=4, out_channels=32, kernel_size=12),\n",
    "                    nn.MaxPool1d(kernel_size=4),\n",
    "                    Flatten(),\n",
    "                    nn.Dropout(drop_p),\n",
    "                    nn.Linear(in_features=288, out_features=16),\n",
    "                    nn.ReLU(),\n",
    "                    nn.Dropout(drop_p),\n",
    "                    nn.Linear(in_features=16, out_features=2),\n",
    "                    #nn.Dropout(drop_p), # not at the end?\n",
    "                    #Debugger()\n",
    "                   )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Lambda()\n",
       "  (1): Conv1d(4, 32, kernel_size=(12,), stride=(1,))\n",
       "  (2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)\n",
       "  (3): Lambda()\n",
       "  (4): Dropout(p=0.2)\n",
       "  (5): Linear(in_features=288, out_features=16, bias=True)\n",
       "  (6): ReLU()\n",
       "  (7): Dropout(p=0.2)\n",
       "  (8): Linear(in_features=16, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_init(net, nn.init.kaiming_normal_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run sequence through network for testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'data' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-25-37a889cb3037>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_ds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_ds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
     ]
    }
   ],
   "source": [
    "data.train_ds[0], data.train_ds[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net(data.train_ds[0][0].data), net(data.train_ds[3][0].data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learner setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## TBLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# From https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\n",
    "\"\"\"Simple example on how to log scalars and images to tensorboard without tensor ops.\n",
    "License: Copyleft\n",
    "\"\"\"\n",
    "#__author__ = \"Michael Gygli\"\n",
    "\n",
    "#import tensorflow as tf\n",
    "#from StringIO import StringIO\n",
    "#import matplotlib.pyplot as plt\n",
    "#import numpy as np\n",
    "\n",
    "class Logger(object):\n",
    "    \"\"\"Logging in tensorboard without tensorflow ops.\"\"\"\n",
    "\n",
    "    def __init__(self, log_dir):\n",
    "        \"\"\"Creates a summary writer logging to log_dir.\"\"\"\n",
    "        self.writer = tf.summary.FileWriter(log_dir)\n",
    "\n",
    "    def log_scalar(self, tag, value, step):\n",
    "        \"\"\"Log a scalar variable.\n",
    "        Parameter\n",
    "        ----------\n",
    "        tag : basestring\n",
    "            Name of the scalar\n",
    "        value\n",
    "        step : int\n",
    "            training iteration\n",
    "        \"\"\"\n",
    "        summary = tf.Summary(value=[tf.Summary.Value(tag=tag,\n",
    "                                                     simple_value=value)])\n",
    "        self.writer.add_summary(summary, step)\n",
    "\n",
    "    def log_images(self, tag, images, step):\n",
    "        \"\"\"Logs a list of images.\"\"\"\n",
    "\n",
    "        im_summaries = []\n",
    "        for nr, img in enumerate(images):\n",
    "            # Write the image to a string\n",
    "            s = StringIO()\n",
    "            plt.imsave(s, img, format='png')\n",
    "\n",
    "            # Create an Image object\n",
    "            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),\n",
    "                                       height=img.shape[0],\n",
    "                                       width=img.shape[1])\n",
    "            # Create a Summary value\n",
    "            im_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, nr),\n",
    "                                                 image=img_sum))\n",
    "\n",
    "        # Create and write Summary\n",
    "        summary = tf.Summary(value=im_summaries)\n",
    "        self.writer.add_summary(summary, step)\n",
    "        \n",
    "\n",
    "    def log_histogram(self, tag, values, step, bins=1000):\n",
    "        \"\"\"Logs the histogram of a list/vector of values.\"\"\"\n",
    "        # Convert to a numpy array\n",
    "        values = np.array(values)\n",
    "        \n",
    "        # Create histogram using numpy        \n",
    "        counts, bin_edges = np.histogram(values, bins=bins)\n",
    "\n",
    "        # Fill fields of histogram proto\n",
    "        hist = tf.HistogramProto()\n",
    "        hist.min = float(np.min(values))\n",
    "        hist.max = float(np.max(values))\n",
    "        hist.num = int(np.prod(values.shape))\n",
    "        hist.sum = float(np.sum(values))\n",
    "        hist.sum_squares = float(np.sum(values**2))\n",
    "\n",
    "        # Requires equal number as bins, where the first goes from -DBL_MAX to bin_edges[1]\n",
    "        # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto#L30\n",
    "        # Thus, we drop the start of the first bin\n",
    "        bin_edges = bin_edges[1:]\n",
    "\n",
    "        # Add bin edges and counts\n",
    "        for edge in bin_edges:\n",
    "            hist.bucket_limit.append(edge)\n",
    "        for c in counts:\n",
    "            hist.bucket.append(c)\n",
    "\n",
    "        # Create and write Summary\n",
    "        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])\n",
    "        self.writer.add_summary(summary, step)\n",
    "        self.writer.flush()\n",
    "        \n",
    "\"A `Callback` that saves tracked metrics into a log file for Tensorboard.\"\n",
    "# Based on https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\n",
    "# and devforfu: https://nbviewer.jupyter.org/gist/devforfu/ea0b3fcfe194dad323c3762492b05cae\n",
    "# Contribution from MicPie\n",
    "\n",
    "#from ..torch_core import *\n",
    "#from ..basic_data import DataBunch\n",
    "#from ..callback import *\n",
    "#from ..basic_train import Learner, LearnerCallback\n",
    "#import tensorflow as tf\n",
    "\n",
    "__all__ = ['TBLogger']\n",
    "\n",
    "@dataclass\n",
    "class TBLogger(LearnerCallback):\n",
    "    \"A `LearnerCallback` that saves history of metrics while training `learn` into log files for Tensorboard.\"\n",
    "    \n",
    "    log_dir:str = 'logs'\n",
    "    log_name:str = 'data'\n",
    "    log_scalar:bool = True # log scalar values for Tensorboard scalar summary\n",
    "    log_hist:bool = True # log values and gradients of the parameters for Tensorboard histogram summary\n",
    "    log_img:bool = False # log values for Tensorboard image summary\n",
    "\n",
    "    def __post_init__(self): \n",
    "        super().__post_init__()\n",
    "    #def __init__(self):\n",
    "    #    super().__init__()\n",
    "        self.path = self.learn.path\n",
    "        (self.path/self.log_dir).mkdir(parents=True, exist_ok=True) # setup logs directory\n",
    "        self.Log = Logger(str(self.path/self.log_dir/self.log_name))\n",
    "        self.epoch = 0\n",
    "        self.batch = 0\n",
    "        self.log_grads = {}\n",
    "    \n",
    "    def on_backward_end(self, **kwargs:Any):\n",
    "        self.batch = self.batch+1\n",
    "        #print('\\nBatch: ',self.batch)\n",
    "        \n",
    "        if self.log_hist:\n",
    "            for tag, value in learn.model.named_parameters():\n",
    "                tag_grad = tag.replace('.', '/')+'/grad'\n",
    "                \n",
    "                if tag_grad in self.log_grads:\n",
    "                    #self.log_grads[tag_grad] += value.grad.data.cpu().detach().numpy()\n",
    "                    self.log_grads[tag_grad] = self.log_grads[tag_grad] + value.grad.data.cpu().detach().numpy() # gradients are summed up from every batch\n",
    "                    #print('if')\n",
    "                else:\n",
    "                    self.log_grads[tag_grad] = value.grad.data.cpu().detach().numpy()\n",
    "                    #print('else')\n",
    "                \n",
    "                #print(tag_grad, self.log_grads[tag_grad].sum())\n",
    "        return self.log_grads\n",
    "    \n",
    "    #def on_step_end(self, **kwards:Any):\n",
    "        #print('Step end: ', self.log_grads)\n",
    "\n",
    "    def on_epoch_end(self, epoch:int, smooth_loss:Tensor, last_metrics:MetricsList, **kwargs:Any) -> bool:\n",
    "        last_metrics = ifnone(last_metrics, [])\n",
    "        tr_info = {name: stat for name, stat in zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)}\n",
    "        self.epoch = tr_info['epoch']\n",
    "        self.batch = 0 # reset batch count\n",
    "        #print('\\nEpoch: ',self.epoch)\n",
    "        \n",
    "        if self.log_scalar:\n",
    "            for tag, value in tr_info.items():\n",
    "                if tag == 'epoch': continue\n",
    "                self.Log.log_scalar(tag, value, self.epoch+1)\n",
    "                \n",
    "        if self.log_hist:\n",
    "            for tag, value in learn.model.named_parameters():\n",
    "                \n",
    "                tag = tag.replace('.', '/')\n",
    "                self.Log.log_histogram(tag, value.data.cpu().numpy(), self.epoch+1)\n",
    "                \n",
    "                tag_grad = tag.replace('.', '/')+'/grad'\n",
    "                self.Log.log_histogram(tag_grad, self.log_grads[tag_grad], self.epoch+1)\n",
    "                #print(tag_grad, self.log_grads[tag_grad].sum())\n",
    "                \n",
    "        #if self.log_img:\n",
    "        #    for tag, value in learn.model.named_parameters():\n",
    "        #        \n",
    "        #        tag = tag.replace('.', '/')\n",
    "        #        self.Log.log_images(tag, value.data.cpu().numpy(), self.epoch+1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## fastai learner setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 208,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#learn = Learner(data, net, loss_func=nn.functional.cross_entropy, metrics=accuracy)#, callback_fns=[TBLogger])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def accuracy_float(input:Tensor, targs:LongTensor)->Rank0Tensor:\n",
    "    \"Compute accuracy with `targs` when `input` is bs * n_classes.\"\n",
    "    n = targs.shape[0]\n",
    "    input = input.argmax(dim=-1).view(n,-1)\n",
    "    targs = targs.type(torch.long) # convert to torch.long = int64 --> difference to accuracy metric!\n",
    "    targs = targs.view(n,-1)\n",
    "    return (input==targs).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "learn = Learner(data, net, loss_func=BCEWithLogitsFlat(), metrics=accuracy_float)#, callback_fns=[TBLogger])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================\n",
      "Layer (type)         Output Shape         Param #    Trainable \n",
      "======================================================================\n",
      "Lambda               [64, 4, 50]          0          False     \n",
      "______________________________________________________________________\n",
      "Conv1d               [64, 32, 39]         1568       True      \n",
      "______________________________________________________________________\n",
      "MaxPool1d            [64, 32, 9]          0          False     \n",
      "______________________________________________________________________\n",
      "Lambda               [64, 288]            0          False     \n",
      "______________________________________________________________________\n",
      "Dropout              [64, 288]            0          False     \n",
      "______________________________________________________________________\n",
      "Linear               [64, 16]             4624       True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 16]             0          False     \n",
      "______________________________________________________________________\n",
      "Dropout              [64, 16]             0          False     \n",
      "______________________________________________________________________\n",
      "Linear               [64, 2]              34         True      \n",
      "______________________________________________________________________\n",
      "\n",
      "Total params:  6226\n",
      "Total trainable params:  6226\n",
      "Total non-trainable params:  0\n"
     ]
    }
   ],
   "source": [
    "learn.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## fastai training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.lr_find()\n",
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:14 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy_float</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>0.766329</th>\n",
       "    <th>0.692531</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>0.721194</th>\n",
       "    <th>0.693430</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>0.707417</th>\n",
       "    <th>0.693202</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>0.701203</th>\n",
       "    <th>0.692725</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>0.697864</th>\n",
       "    <th>0.692572</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>0.695997</th>\n",
       "    <th>0.692562</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>0.694890</th>\n",
       "    <th>0.692712</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>0.694252</th>\n",
       "    <th>0.692631</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>0.693816</th>\n",
       "    <th>0.692724</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.693551</th>\n",
       "    <th>0.692713</th>\n",
       "    <th>0.500000</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, max_lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADmlJREFUeJzt3H+s3fVdx/Hna+26OGAZGxcDpXg70+jYRmAcK4mGoAmsM7HVsGk1cXSK5Y81uIV/mBpRiH84f5AYcUmJ3ZiZlolOL8hWgUhcVLCnhgGlMirCelcCd1RgcRnY8faPe4qH09uec2/vvafd5/lIbnq/n+/ne8/nfHN43u/93nNJVSFJasObxr0ASdLyMfqS1BCjL0kNMfqS1BCjL0kNMfqS1BCjL0kNMfqS1BCjL0kNWTnuBQw666yzanJyctzLkKRTyp49e75ZVRPD5p100Z+cnKTb7Y57GZJ0SknyzCjzvL0jSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUkJGin2RDkieS7E9ywxz7tySZSfJw7+Oa3vhFSf41yd4kjyT5+cV+ApKk0a0cNiHJCuBW4ApgGtidZKqqHh+YekdVbRsY+zbwkap6Msm5wJ4ku6rqxcVYvCRpfka50l8P7K+qp6rqVWAnsGmUL15VX6uqJ3ufHwSeByYWulhJ0okZJfqrgQN929O9sUFX9W7h3JlkzeDOJOuBVcB/LmilkqQTNkr0M8dYDWzfBUxW1YXAfcDtb/gCyTnAnwMfrarXjnqAZGuSbpLuzMzMaCuXJM3bKNGfBvqv3M8DDvZPqKoXquqV3uZtwCVH9iV5G/D3wG9W1YNzPUBVba+qTlV1Jia8+yNJS2WU6O8G1iVZm2QVsBmY6p/Qu5I/YiOwrze+Cvgi8Lmq+qvFWbIkaaGGvnunqg4n2QbsAlYAO6pqb5KbgG5VTQHXJdkIHAYOAVt6h/8ccBnwziRHxrZU1cOL+zQkSaNI1eDt+fHqdDrV7XbHvQxJOqUk2VNVnWHz/ItcSWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWrISNFPsiHJE0n2J7lhjv1bkswkebj3cU3fvi8neTHJ3Yu5cEnS/K0cNiHJCuBW4ApgGtidZKqqHh+YekdVbZvjS/w+8Fbg2hNdrCTpxIxypb8e2F9VT1XVq8BOYNOoD1BV9wPfWuD6JEmLaJTorwYO9G1P98YGXZXkkSR3JlmzKKuTJC2qUaKfOcZqYPsuYLKqLgTuA26fzyKSbE3STdKdmZmZz6GSpHkYJfrTQP+V+3nAwf4JVfVCVb3S27wNuGQ+i6iq7VXVqarOxMTEfA6VJM3DKNHfDaxLsjbJKmAzMNU/Ick5fZsbgX2Lt0RJ0mIZ+u6dqjqcZBuwC1gB7KiqvUluArpVNQVcl2QjcBg4BGw5cnySrwA/DJyeZBr4laratfhPRZI0TKoGb8+PV6fTqW63O+5lSNIpJcmequoMm+df5EpSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDXE6EtSQ4y+JDVkpOgn2ZDkiST7k9wwx/4tSWaSPNz7uKZv39VJnux9XL2Yi5ckzc/KYROSrABuBa4ApoHdSaaq6vGBqXdU1baBY98B3Ah0gAL29I7970VZvSRpXka50l8P7K+qp6rqVWAnsGnEr/8B4N6qOtQL/b3AhoUtVZJ0ooZe6QOrgQN929PAj84x76oklwFfAz5RVQeOcezqBa51qN+5ay+PH3x5qb68JC2pC859Gzf+9HuW9DFGudLPHGM1sH0XMFlVFwL3AbfP41iSbE3STdKdmZkZYUmSpIUY5Up/GljTt30ecLB/QlW90Ld5G/B7fcdePnDsA4MPUFXbge0AnU7nqG8Ko1rq75CSdKob5Up/N7Auydokq4DNwFT/hCTn9G1uBPb1Pt8FXJnkzCRnAlf2xiRJYzD0Sr+qDifZxmysVwA7qmpvkpuAblVNAdcl2QgcBg4BW3rHHkpyM7PfOABuqqpDS/A8JEkjSNWC76YsiU6nU91ud9zLkKRTSpI9VdUZNs+/yJWkhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0ZekhowU/SQbkjyRZH+SG44z70NJKkmnt70qyWeSPJrkq0kuX6R1S5IWYOWwCUlWALcCVwDTwO4kU1X1+MC8M4DrgIf6hn8VoKrel+Rs4EtJfqSqXlusJyBJGt0oV/rrgf1V9VRVvQrsBDbNMe9m4FPAd/rGLgDuB6iq54EXgc4JrViStGCjRH81cKBve7o39rokFwNrqurugWO/CmxKsjLJWuASYM0JrFeSdAKG3t4BMsdYvb4zeRNwC7Bljnk7gHcDXeAZ4F+Aw0c9QLIV2Apw/vnnj7AkSdJCjHKlP80br87PAw72bZ8BvBd4IMnTwKXAVJJOVR2uqk9U1UVVtQl4O/Dk4ANU1faq6lRVZ2JiYqHPRZI0xCjR3w2sS7I2ySpgMzB1ZGdVvVRVZ1XVZFVNAg8CG6uqm+StSU4DSHIFcHjwF8CSpOUz9PZOVR1Osg3YBawAdlTV3iQ3Ad2qmjrO4WcDu5K8BnwD+KXFWLQkaWFGuadPVd0D3DMw9lvHmHt53+dPAz+08OVJkhaTf5ErSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUEKMvSQ0x+pLUkJGin2RDkieS7E9yw3HmfShJJen0tt+c5PYkjybZl+STi7VwSdL8DY1+khXArcAHgQuAX0hywRzzzgCuAx7qG/4w8Jaqeh9wCXBtkskTX7YkaSFGudJfD+yvqqeq6lVgJ7Bpjnk3A58CvtM3VsBpSVYC3we8Crx8YkuWJC3UKNFfDRzo257ujb0uycXAmqq6e+DYO4H/AZ4Fvg78QVUdWvhyJUknYpToZ46xen1n8ibgFuD6OeatB74LnAusBa5P8q6jHiDZmqSbpDszMzPSwiVJ8zdK9KeBNX3b5wEH+7bPAN4LPJDkaeBSYKr3y9xfBL5cVf9bVc8D/wx0Bh+gqrZXVaeqOhMTEwt7JpKkoUaJ/m5gXZK1SVYBm4GpIzur6qWqOquqJqtqEngQ2FhVXWZv6fxkZp3G7DeE/1j0ZyFJGsnQ6FfVYWAbsAvYB3yhqvYmuSnJxiGH3wqcDjzG7DePz1TVIye4ZknSAqWqhs9aRp1Op7rd7riXIUmnlCR7quqo2+eD/ItcSWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWqI0Zekhhh9SWpIqmrca3iDJDPAM+NexxI5C/jmuBdxCvA8Dec5Gk1L5+kHqmpi2KSTLvrfy5J0q6oz7nWc7DxPw3mORuN5Opq3dySpIUZfkhpi9JfX9nEv4BTheRrOczQaz9MA7+lLUkO80pekhhj9JZTk6SSPJnk4Sbc39o4k9yZ5svfvmeNe53JLsiPJ80ke6xub87xk1h8n2Z/kkSTvH9/Kl88xztFvJ/lG7/X0cJKf6tv3yd45eiLJB8az6uWVZE2Sf0yyL8neJL/WG/e1dBxGf+n9RFVd1Pe2sRuA+6tqHXB/b7s1nwU2DIwd67x8EFjX+9gKfHqZ1jhun+XocwRwS+/1dFFV3QOQ5AJgM/Ce3jF/mmTFsq10fA4D11fVu4FLgY/1zoWvpeMw+stvE3B77/PbgZ8Z41rGoqr+CTg0MHys87IJ+FzNehB4e5Jzlmel43OMc3Qsm4CdVfVKVf0XsB9Yv2SLO0lU1bNV9e+9z78F7ANW42vpuIz+0irgH5LsSbK1N/b9VfUszL5ogbPHtrqTy7HOy2rgQN+86d5Yq7b1bk3s6Ls12Pw5SjIJXAw8hK+l4zL6S+vHqur9zP5Y+bEkl417QaegzDHW6lvOPg38IHAR8Czwh73xps9RktOBvwY+XlUvH2/qHGPNnKcjjP4SqqqDvX+fB77I7I/czx35kbL37/PjW+FJ5VjnZRpY0zfvPODgMq/tpFBVz1XVd6vqNeA2/v8WTrPnKMmbmQ3+56vqb3rDvpaOw+gvkSSnJTnjyOfAlcBjwBRwdW/a1cDfjWeFJ51jnZcp4CO9d15cCrx05Ef31gzcf/5ZZl9PMHuONid5S5K1zP6i8t+We33LLUmAPwP2VdUf9e3ytXQc/nHWEknyLmav7gFWAn9RVb+b5J3AF4Dzga8DH66qUX9h9z0hyV8ClzP7f0B8DrgR+FvmOC+9/7D/hNl3pXwb+GhVdcex7uV0jHN0ObO3dgp4Grj2SLSS/Abwy8y+o+XjVfWlZV/0Mkvy48BXgEeB13rDv87sfX1fS8dg9CWpId7ekaSGGH1JaojRl6SGGH1JaojRl6SGGH1JaojRl6SGGH1Jasj/AQ7Y/tMWNxljAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[0.5050, 0.4949],\n",
       "         [0.5050, 0.4949],\n",
       "         [0.5050, 0.4949],\n",
       "         ...,\n",
       "         [0.5050, 0.4949],\n",
       "         [0.5050, 0.4949],\n",
       "         [0.5050, 0.4949]]), tensor([[0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         ...,\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.]])]"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.get_preds(DatasetType.Train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(MultiCategory 0, MultiCategory 1)"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0][1], data.train_ds[3][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#net(data.train_ds[0][0].data), net(data.train_ds[3][0].data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.0202, -0.0202]], grad_fn=<ThAddmmBackward>),\n",
       " tensor([[ 0.0202, -0.0202]], grad_fn=<ThAddmmBackward>))"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(data.train_ds[0][0].data), net(data.train_ds[3][0].data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "preds, targs = learn.get_preds()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#targs.view(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0.4612, 0.5234]), tensor([1., 0.]))"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 0\n",
    "preds[i], targs[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((Image (1, 4, 50), MultiCategory 0), (Image (1, 4, 50), MultiCategory 1))"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0], data.train_ds[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((MultiCategory 0, tensor([1., 0.]), tensor([0.5242, 0.4852])),\n",
       " (MultiCategory 0, tensor([1., 0.]), tensor([0.5242, 0.4852])))"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict(data.train_ds[0][0]), learn.predict(data.train_ds[3][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#learn.save('v6')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#learn.load('v6')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "interpret = ClassificationInterpretation.from_learner(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Expected object of scalar type Float but got scalar type Long for argument #2 'other'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-72-7cdf6701acab>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0minterpret\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfusion_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Downloads/fastai/fastai/vision/learner.py\u001b[0m in \u001b[0;36mconfusion_matrix\u001b[0;34m(self, slice_size)\u001b[0m\n\u001b[1;32m    107\u001b[0m         \u001b[0;34m\"Confusion matrix as an `np.ndarray`.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    108\u001b[0m         \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mslice_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpred_class\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m&\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    110\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    111\u001b[0m             \u001b[0mcm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Expected object of scalar type Float but got scalar type Long for argument #2 'other'"
     ]
    }
   ],
   "source": [
    "interpret.confusion_matrix()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#interpret.plot_confusion_matrix()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seq2array(seq:str)->List:\n",
    "    \"Return `List` object with np.array created from sequence string `seq`.\"\n",
    "    \n",
    "    int_enc = LabelEncoder() # setup class instance to encode the four different bases to integer values (1D)\n",
    "    one_hot_enc = OneHotEncoder(categories=[range(4)]) # setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "    \n",
    "    enc = int_enc.fit_transform(list(seq)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    enc = enc.toarray().T # export sparse matrix to np array\n",
    "    \n",
    "    return enc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., 1., ..., 0., 1., 0., 0.],\n",
       "        [1., 1., 0., 0., ..., 1., 0., 1., 1.],\n",
       "        [0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., ..., 0., 0., 0., 0.]]), (4, 50))"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test open sequence image function\n",
    "test_arr = seq2array('CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGACACC'); test_arr, test_arr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAA9CAYAAABWdClAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAAbhJREFUeJzt3cFNwzAYhmESMQTizp0lEBMwJRMgluDeO2KKhgmcguXEzpfnubZq3Ch5Val/3WlZljsAssy9FwBAe+IOEEjcAQKJO0AgcQcIJO4AgcQdIJC4AwQSd4BA970O/DK/Nf1p7Mf3V/Gx18fnloeqMvr6trD2nktan4ua816z7rXXI0PtPdz63v+8vk9/eZ5P7gCBxB0gkLgDBBJ3gEDiDhBI3AECTb3+rOP68/TvA48yanbGsUZuqx2hLBnhWtriWi+9Zutxwj3v0y3GJEvmh4tRSICzEneAQOIOEEjcAQKJO0CgbtMyNRuH7fmN9Jo9pxhG33yr9TpGmHAYYUrlltbXdMmRz0XrJoxyLmwcBnBi4g4QSNwBAok7QCBxBwgk7gCBDjUKeQR7bYpUs4Zao4yAHXl8rcboG9S1HpEd/f2OwigkwImJO0AgcQcIJO4AgcQdIJC4AwTqNgoJwHZ8cgcIJO4AgcQdIJC4AwQSd4BA4g4QSNwBAok7QCBxBwgk7gCBxB0gkLgDBBJ3gEDiDhBI3AECiTtAIHEHCCTuAIHEHSCQuAMEEneAQOIOEEjcAQL9Al/6fGlF2Jm2AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(test_arr)\n",
    "plt.axis('off');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.],\n",
       "         [1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0.,\n",
       "          0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1.],\n",
       "         [0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1.,\n",
       "          1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1.,\n",
       "          1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0.,\n",
       "          0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "          0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_ten = tensor(test_arr).view(1,4,50).type(torch.float); test_ten"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9148, -0.4416]], grad_fn=<ThAddmmBackward>)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(test_ten)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "#x = [seq2array(s) for s in seq_df['Sequences'].values]\n",
    "x = [tensor(seq2array(s)).view(1,4,50).type(torch.float) for s in seq_df['Sequences'].values]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.stack(x) # convert list to tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = tensor(targ).type(torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2000, 2000)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x), len(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1.,\n",
       "           1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,\n",
       "           1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,\n",
       "           0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1.,\n",
       "           0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0.],\n",
       "          [1., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.,\n",
       "           0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
       "           0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0.],\n",
       "          [0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "           0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]),\n",
       " tensor([0., 1.]))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 2\n",
    "x[i], y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_ds = TensorDataset(*map(tensor, (x[:1500],y[:1500])))\n",
    "#valid_ds = TensorDataset(*map(tensor, (x[-500:],y[-500:])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = TensorDataset(x[:1500],y[:1500])\n",
    "valid_ds = TensorDataset(x[-500:],y[-500:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1500, 500)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_ds), len(valid_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs*2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From https://github.com/fastai/fastai_docs/blob/master/dev_nb/mnist_sample.py\n",
    "\n",
    "def simple_loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)\n",
    "\n",
    "    if opt is not None:\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "    return loss.item(), len(xb)\n",
    "\n",
    "def simple_fit(epochs, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        for xb,yb in train_dl: simple_loss_batch(model, loss_func, xb, yb, opt)\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses,nums = zip(*[simple_loss_batch(model, loss_func, xb, yb)\n",
    "                                for xb,yb in valid_dl])\n",
    "            \n",
    "            # Accuracy metric:\n",
    "            thr = 0.5\n",
    "            comparison = [np.array(torch.sigmoid(net(xb)).detach().numpy()>thr).astype(int)==yb.detach().numpy().astype(int)\n",
    "                          for xb,yb in valid_dl]\n",
    "            accuracy = np.concatenate(comparison).mean()\n",
    "            \n",
    "        val_loss = np.sum(np.multiply(losses,nums)) / np.sum(nums)\n",
    "\n",
    "        #if epoch % 10 == 0:\n",
    "        print(epoch, val_loss, accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.68692982006073 0.547\n",
      "1 0.6797177696228027 0.563\n",
      "2 0.6759244503974915 0.571\n",
      "3 0.6710091686248779 0.608\n",
      "4 0.6701397824287415 0.606\n",
      "5 0.6652094225883484 0.624\n",
      "6 0.6608625617027283 0.626\n",
      "7 0.6558821315765381 0.634\n",
      "8 0.6506783146858215 0.645\n",
      "9 0.6421114640235901 0.657\n",
      "10 0.6383359127044678 0.653\n",
      "11 0.6326445651054382 0.66\n",
      "12 0.622108877658844 0.673\n",
      "13 0.6152496490478515 0.684\n",
      "14 0.609443440914154 0.683\n",
      "15 0.6009174609184265 0.698\n",
      "16 0.5920117449760437 0.703\n",
      "17 0.5872178387641906 0.712\n",
      "18 0.5786009964942932 0.717\n",
      "19 0.5699475564956665 0.719\n",
      "20 0.5632805652618408 0.724\n",
      "21 0.5551681814193725 0.732\n",
      "22 0.5455421586036682 0.734\n",
      "23 0.5378458318710327 0.738\n",
      "24 0.5294094376564026 0.742\n",
      "25 0.5234851956367492 0.745\n",
      "26 0.5149962892532348 0.751\n",
      "27 0.5067254395484925 0.755\n",
      "28 0.5028578705787659 0.75\n",
      "29 0.4942864866256714 0.768\n",
      "30 0.4868652968406677 0.767\n",
      "31 0.4809345436096191 0.766\n",
      "32 0.47337263011932373 0.778\n",
      "33 0.46588316559791565 0.782\n",
      "34 0.45987958240509036 0.786\n",
      "35 0.453939590215683 0.793\n",
      "36 0.44865147876739503 0.793\n",
      "37 0.4399886100292206 0.796\n",
      "38 0.4356886956691742 0.806\n",
      "39 0.4297665309906006 0.81\n",
      "40 0.4235528767108917 0.812\n",
      "41 0.4227601671218872 0.811\n",
      "42 0.4155415644645691 0.818\n",
      "43 0.41043953800201416 0.82\n",
      "44 0.4036995759010315 0.821\n",
      "45 0.39897701191902163 0.826\n",
      "46 0.3943213264942169 0.831\n",
      "47 0.3920146129131317 0.827\n",
      "48 0.3889892637729645 0.834\n",
      "49 0.38026661682128904 0.836\n"
     ]
    }
   ],
   "source": [
    "simple_fit(50, net, F.binary_cross_entropy_with_logits, opt, train_dl, valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_dl = DataLoader(valid_ds, batch_size=bs*2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([128, 1, 4, 50]), torch.Size([128, 2]))"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xbatch, ybatch = next(iter(valid_dl)); xbatch.shape, ybatch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 1],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [0, 1],\n",
       "       [1, 0]])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ybatch[:10].detach().numpy().astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 1],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [1, 0],\n",
       "       [0, 1],\n",
       "       [1, 0],\n",
       "       [0, 1],\n",
       "       [1, 0]])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(torch.sigmoid(net(xbatch[:10])).detach().numpy()>0.5).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "#np.array(torch.sigmoid(net(xbatch)).detach().numpy()>0.5).astype(int)==ybatch.detach().numpy().astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8359375"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(np.array(torch.sigmoid(net(xbatch)).detach().numpy()>0.5).astype(int)==ybatch.detach().numpy().astype(int)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8359375"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.concatenate(np.array(torch.sigmoid(net(xbatch)).detach().numpy()>0.5).astype(int)==ybatch.detach().numpy().astype(int)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.836"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.concatenate([np.array(torch.sigmoid(net(xb)).detach().numpy()>0.5).astype(int)==yb.detach().numpy().astype(int) for xb,yb in valid_dl]).mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.8359375, 0.86328125, 0.83203125, 0.8103448275862069]"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[np.concatenate(np.array(torch.sigmoid(net(xb)).detach().numpy()>0.5).astype(int)==yb.detach().numpy().astype(int)).mean() for xb,yb in valid_dl]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'data' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-58-6ff02be695bd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_ds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_ds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
     ]
    }
   ],
   "source": [
    "data.train_ds[0][1], data.train_ds[3][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 780,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-1.4258,  1.5087]], grad_fn=<ThAddmmBackward>),\n",
       " tensor([[ 1.2086, -1.2096]], grad_fn=<ThAddmmBackward>))"
      ]
     },
     "execution_count": 780,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(data.train_ds[0][0].data), net(data.train_ds[3][0].data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check fastai dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 777,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory 0"
      ]
     },
     "execution_count": 777,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds.y[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 778,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[[[0., 1., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [1., 0., 0.,  ..., 1., 0., 0.],\n",
       "           [0., 0., 1.,  ..., 0., 1., 1.]]],\n",
       " \n",
       " \n",
       "         [[[0., 0., 0.,  ..., 1., 0., 0.],\n",
       "           [1., 0., 0.,  ..., 0., 1., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 1., 1.,  ..., 0., 0., 1.]]],\n",
       " \n",
       " \n",
       "         [[[0., 1., 1.,  ..., 0., 0., 1.],\n",
       "           [0., 0., 0.,  ..., 1., 0., 0.],\n",
       "           [1., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 1., 0.]]],\n",
       " \n",
       " \n",
       "         ...,\n",
       " \n",
       " \n",
       "         [[[0., 0., 1.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 1., 0.],\n",
       "           [1., 1., 0.,  ..., 1., 0., 1.]]],\n",
       " \n",
       " \n",
       "         [[[0., 0., 0.,  ..., 1., 1., 1.],\n",
       "           [0., 1., 1.,  ..., 0., 0., 0.],\n",
       "           [1., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       " \n",
       " \n",
       "         [[[1., 1., 1.,  ..., 0., 0., 1.],\n",
       "           [0., 0., 0.,  ..., 1., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "           [0., 0., 0.,  ..., 0., 1., 0.]]]]), tensor([[1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [1., 0.],\n",
       "         [0., 1.],\n",
       "         [0., 1.],\n",
       "         [0., 1.]])]"
      ]
     },
     "execution_count": 778,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}