[de45a9]: / deprecated / DL_Genomics_v4.ipynb

Download this file

1944 lines (1943 with data), 132.6 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep learning in genomics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is based on the [jupyter notebook](https://github.com/abidlabs/deep-learning-genomics-primer/blob/master/A_Primer_on_Deep_Learning_in_Genomics_Public.ipynb) from the publication [\"A primer on deep learning in genomics\"](https://www.nature.com/articles/s41588-018-0295-5) but uses the [fastai](https://www.fast.ai) library based on [PyTorch](https://pytorch.org)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.32.dev0'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# fastai version\n",
    "__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Data preprocessing (not needed, go to \"Data loading\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "URL_seq = 'https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/sequences.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# get data from URL\n",
    "seq_raw = requests.get(URL_seq).text.split('\\n')\n",
    "seq_raw = list(filter(None, seq_raw)) # Removes empty lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check length\n",
    "len(seq_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup df from list\n",
    "seq_df = pd.DataFrame(seq_raw, columns=['Sequences'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# show head of dataframe\n",
    "#seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "URL_labels = 'https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/labels.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_labels = requests.get(URL_labels).text.split('\\n')\n",
    "seq_labels = list(filter(None, seq_labels)) # Removes empty entries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_label_series = pd.Series(seq_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_df['Target'] = seq_label_series.astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Sequences  Target\n",
       "0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...       0\n",
       "1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...       0\n",
       "2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...       0\n",
       "3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...       1\n",
       "4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...       1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_df.to_csv('seq_df.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_df = pd.read_csv('seq_df.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0                                          Sequences  Target\n",
       "0           0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...       0\n",
       "1           1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...       0\n",
       "2           2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...       0\n",
       "3           3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...       1\n",
       "4           4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...       1"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 1, ..., 1, 0, 1, 1])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df['Target'].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Data encoding test (incorporated into \"open_seq_image\" function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup class instance to encode the four different bases to integer values (1D)\n",
    "int_enc = LabelEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "one_hot_enc = OneHotEncoder(categories=[range(4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_enc = []\n",
    "\n",
    "for s in seq:\n",
    "    enc = int_enc.fit_transform(list(s)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    seq_enc.append(enc.toarray()) # export sparse matrix to np array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_enc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 462,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., 1., ..., 0., 1., 0., 0.],\n",
       "        [1., 1., 0., 0., ..., 1., 0., 1., 1.],\n",
       "        [0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., ..., 0., 0., 0., 0.]]), (4, 50))"
      ]
     },
     "execution_count": 462,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_enc[0].T, seq_enc[0].T.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAMAAAD7wpwzAAADAFBMVEUAAAABAQECAgIDAwMEBAQFBQUGBgYHBwcICAgJCQkKCgoLCwsMDAwNDQ0ODg4PDw8QEBARERESEhITExMUFBQVFRUWFhYXFxcYGBgZGRkaGhobGxscHBwdHR0eHh4fHx8gICAhISEiIiIjIyMkJCQlJSUmJiYnJycoKCgpKSkqKiorKyssLCwtLS0uLi4vLy8wMDAxMTEyMjIzMzM0NDQ1NTU2NjY3Nzc4ODg5OTk6Ojo7Ozs8PDw9PT0+Pj4/Pz9AQEBBQUFCQkJDQ0NERERFRUVGRkZHR0dISEhJSUlKSkpLS0tMTExNTU1OTk5PT09QUFBRUVFSUlJTU1NUVFRVVVVWVlZXV1dYWFhZWVlaWlpbW1tcXFxdXV1eXl5fX19gYGBhYWFiYmJjY2NkZGRlZWVmZmZnZ2doaGhpaWlqampra2tsbGxtbW1ubm5vb29wcHBxcXFycnJzc3N0dHR1dXV2dnZ3d3d4eHh5eXl6enp7e3t8fHx9fX1+fn5/f3+AgICBgYGCgoKDg4OEhISFhYWGhoaHh4eIiIiJiYmKioqLi4uMjIyNjY2Ojo6Pj4+QkJCRkZGSkpKTk5OUlJSVlZWWlpaXl5eYmJiZmZmampqbm5ucnJydnZ2enp6fn5+goKChoaGioqKjo6OkpKSlpaWmpqanp6eoqKipqamqqqqrq6usrKytra2urq6vr6+wsLCxsbGysrKzs7O0tLS1tbW2tra3t7e4uLi5ubm6urq7u7u8vLy9vb2+vr6/v7/AwMDBwcHCwsLDw8PExMTFxcXGxsbHx8fIyMjJycnKysrLy8vMzMzNzc3Ozs7Pz8/Q0NDR0dHS0tLT09PU1NTV1dXW1tbX19fY2NjZ2dna2trb29vc3Nzd3d3e3t7f39/g4ODh4eHi4uLj4+Pk5OTl5eXm5ubn5+fo6Ojp6enq6urr6+vs7Ozt7e3u7u7v7+/w8PDx8fHy8vLz8/P09PT19fX29vb39/f4+Pj5+fn6+vr7+/v8/Pz9/f3+/v7////isF19AAAAO0lEQVR4nGWNSQoAMAgD8/9PT6loIq0HMauShOR9DwphIkO5eALbEI1xp5DUoxYnNtf3t1rci9mGy30A2+0xz9q2+b0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=P size=50x4 at 0x1A20462668>"
      ]
     },
     "execution_count": 311,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PIL.Image.fromarray(seq_enc[0].T.astype('uint8')*255).convert('P')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data setup for NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open sequence image function\n",
    "def open_seq_image(seq:str, cls:type=Image)->Image:\n",
    "    \"Return `Image` object created from sequence string `seq`.\"\n",
    "    \n",
    "    int_enc = LabelEncoder() # setup class instance to encode the four different bases to integer values (1D)\n",
    "    one_hot_enc = OneHotEncoder(categories=[range(4)]) # setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "    \n",
    "    enc = int_enc.fit_transform(list(seq)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    enc = enc.toarray().T # export sparse matrix to np array\n",
    "    #print('enc', enc, enc.shape)\n",
    "    \n",
    "    # https://stackoverflow.com/questions/22902040/convert-black-and-white-array-into-an-image-in-python\n",
    "    x = PIL.Image.fromarray(enc.astype('uint8')).convert('P')\n",
    "    x = pil2tensor(x,np.float32)\n",
    "    #x = x.view(4,-1) # remove first dimension\n",
    "    #x = x.expand(3, 4, 50) # expand to 3 channel image\n",
    "    #print('x', x, x.shape)\n",
    "    return cls(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD80/8AhrrUr/8A0rxT8B/h1q95B/pGltNotzaWGm6iflfUYdIsrmDSvPkjjsopUeze3nTTrfzYZGadpvuj4Jf8EwP2bf2gf2wL39kD4j6n4uuYLDwj4g1u38e/8JEz68IdG8WXnhax0rMqvZ/YlsbG3faLUTCVcJMkAWBSivm/HbGYzgrIq2IyKrLDTjhcTUThKSanTdBQkuZztbnleKXLK95xqOMXD9sPnT9kmOD9rf8A4Wbb6rcav4Ls/gx8Ita+Inwp0/wV4s1VYfCuqaf5LmKyF/dXRhguriVbqcg+eZreIxTQqHR+R+CXxp8V/tC+Mr34X/FXTdI1fwnpnhHxBrWlaFqWlx3s2nWmjaPearZaJaapd+bqtnpYkso4TbwXkbeTLOFkWSZ5SUV9lXwmGhmfENJRVsJDCOj3pSq0HOpKEvjUpz9+cudylJuTbbk2Hrn7DnwE+DH7cXg3xd4t1X4aaR8PovBvi7wX4fsNJ8DWguIbqPxRrA0fUbm4l1v+0Ll50tdptsTLHazJ58UayvI7/On/AA2R480r/ibfDvwF4R8F+IW+SXxP4N0uXT7mWE/vXiaCOb7H/wAfuL+OUW4mtLmK3+ySW0VpaQwFFHCuFoZnx5n2W4xOrQw06KpQnKUoxU6VKUk+ab51KUm2qjqpt+SSA/4ba+Mn/QmfCL/xH/wf/wDKqiiiv0n/AFT4Y/6A6f3S/wDlgH//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHNJREFUKJFjdGEM+c8ABTufXWBwlzJgQAe4xNHVMDAwoKiD6cMmRwlANhdmJuPf5yr/cVlCjAdIdQAMELIP3bHobkEPJEZYjKBLYLMMV6gji+GykJDHcHkOm9uwqWNETlqEDEF2LD7LiUlK+JIxTC8pKQIA7jhk1J8wofAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test open sequence image function\n",
    "open_seq_image('CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGACACC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SeqItemList(ImageItemList):\n",
    "    _bunch,_square_show = ImageDataBunch,True\n",
    "    def __post_init__(self):\n",
    "        super().__post_init__()\n",
    "        self.sizes={}\n",
    "    \n",
    "    def open(self, seq): return open_seq_image(seq)\n",
    "    \n",
    "    def get(self, i):\n",
    "        seq = self.items[i][0]\n",
    "        res = self.open(seq)\n",
    "        return res\n",
    "    \n",
    "    @classmethod\n",
    "    def import_from_df(cls, df:DataFrame, cols:IntsOrStrs=0, **kwargs)->'ItemList':\n",
    "        \"Get the sequences in `col` of `df` and will had `path/folder` in front of them, `suffix` at the end.\"\n",
    "        return cls(items=df[cols].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (SeqItemList.import_from_df(seq_df, ['Sequences'])\n",
    "        .random_split_by_pct(valid_pct=0.25)\n",
    "        #.split_by_idxs(range(1500), range(1500,2000))\n",
    "        .label_from_list(seq_df['Target'].values)\n",
    "        .databunch())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "Train: LabelList\n",
       "y: CategoryList (2000 items)\n",
       "[Category 0, Category 0, Category 0, Category 1, Category 1]...\n",
       "Path: .\n",
       "x: SeqItemList (1500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "Valid: LabelList\n",
       "y: CategoryList (2000 items)\n",
       "[Category 0, Category 0, Category 0, Category 1, Category 1]...\n",
       "Path: .\n",
       "x: SeqItemList (500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "Test: None"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4w8Ny6l4h/wCEMjg8RavpkXi34ReJviPYjS9ZuY5vDeteHv7ej0f7DcmRrmSC0tfD1taQR3kt0beG7vfIaKWcSp5F/wANpfGDwt/xKvgLN/wq/wAPP++uPB/g3XtWl0qe/wD4dTaDUr27xeptgMdwpV4JLS3li8uWJZAUV9NwtlOWZnisTDF0Y1Iw5bRklKOtXGxu4tOLfLRp2lKMpJx5oyjKU5VP2wP+G0vjB4k/dfGub/hY8EnyX8HjDXtWX+0oT8zxXj2N7bSXm+WLTpDLMzzD+xtNiWRYLZYSf8NmfESD/QtL+HXw6s9Ktfm0bw/D8PtPawspF+SGeaGSNhq88Nu1xbxTar9ukjS8uJFYTyeeCivsf9U+GtlhIJfypOMV5qEZRhF9Lxgnb3bqOgHrvhC+8V/Ef9lTXP2vPiV471fXvEei/wBqxNaXc0cMN9aWV94XsZLOWe3SO9SC9i8X6018IbmJr2aZZ5mebfJJyOi/GDUvEHwG8SftK6t4U0ibxZoPi7RPDmoyuLk2HiGyv7XVbq3iv9OM32JoNOl0TTTZ2cMMNkqWyRz29xGqopRX5tl+EwtetjoygrU8wp0IJLlUaMnHmppR5VyO+1t7NNNRcQ/RX9n/AP4N/P2OP2kfgN4J/aJ8d/Ej4i2euePvCOm+I9Zs/D91o9jYQXd9ax3M0dtbR6bst4FeVgkS/KiBVHAooor/ADY4j8ePGTAcRY3C4bPcRCnTrVoQipQtGMK1WEYr909IxjGK1eiWr3Yf/9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHBJREFUKJFjdGEM+b/z2QUGdykDBgYGBgZcbBgfBtylDLDykWlsAJuZ2PQQcge6+YwujCH/8SnABYgxHFkO2cPIYuieIMdcdykDBsa/z1X+oytENxjdEcR4EJ8+YgMMmztweRgeI6SEBK0AOUkTJgYA3Nhz1DR798MAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 2\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Category 0"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4C+Nvxp8V/s9eMrL4bfCrTdI0vw5P4R8P+JtP0y00uOCbTdU1PR7PWIrqK/h2X8k+n3d7J9huJ7maeCENA0ksM91HOfBL45fET46eMr34fa/caRpeh2/hHxB4pu/D+ieGNPj0e+1rRtHvNUtL9tJkgk06CdjYWlnM9tbQme0SWJ9xuLl5iivpv7Jyz/UX+0XRi6/subnavLm57czbbvK32mua/vcylq/2w679kmSf9t7/AIWbY/tCW+kX2lfCr4Ra18QfD2i6J4T0rSIZLvTfJ8vTGksbWKe30uU3E8k1paS26vNPLcApPI8zeRf8NrfGeT/iY3K6QuuTfNqfi3TbA6ZrGqSR/vbOe7vLB4Jbie3vcagJnYyXF3HBJeNdi2tkhKKMlynLMZxhnOArUYyo4f6p7KFvdh7XD1KlXlV7fvJxjOd+ZSlFNq6uB1/xJ+NPiv4I+Dfh/wCJPgbpukeFtN+IfhGbxHfeFk0uPVbDSbtdY1LSLiOwbVvtdzbwXlrpdut5EZmW8Rnhn8y28u3TkP8AhtH4rP8A6be+HPCOpah/Dd+IvC8Gr20Gfkf7Ppt+J9Nst0EOnWo+zWsXlW2kWUMXlIJhMUV7OQcPZHjssVbEYeM5udVOTTcmo168I80uZSlaMIxvKTbUVzOTu5B1/wC0B4y+BP7PXx58bfALRv2Kvh1r1n4H8Xal4ftdc8Qa14o+36jHZXUlstzc/ZdZgg8+QRh38mGKPezbI0XCgooo4cyHLsfw7gsViPaSqVKNGcn7fEaynRpTk9MQlrKUnoktdElZIP/Z\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHRJREFUKJFjdGEM+c+ABHY+u8DgLmWAwYbxGRgYUMSw6SMkDhPDZj42u3GpRwaMMI8gOxLdEHTHIxsMA8h8bA7HZjY2gK4Wn3nIcox/n6v8JyYkCTkAl6MIhTi+WMZmFq4YZnRhDPlPDUeiyxHjOGLNI0YfAGvwc9QnBcXaAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 3\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Category 1"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Image (1, 4, 50), Category 0)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0]#, data.train_ds.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAETCAYAAAA79nyeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAACLtJREFUeJzt3TGPXFcZx+H3tbdwFEhBQxSlNQ1IuKCjjlxSpUEpaEH5AC5cIX8F+kgpKCKFiiLKV0iRFDRJRYEV0aAIoigo+FLESMh7Vxrv3DvnzH+ep7JGu3PPnZm9/mn2nbO9LEsBACS5M3oBAABbEzgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETjsorvf7e5Puvvb7n5v9HqATK413ORq9AKI9bSqnlTVw6p6ZfBagFyuNawSOOxiWZYPq6q6+xdV9ebg5QChXGu4iV9RAQBxBA4AEEfgAABxBA4AEMeQMbvo7qv6/vV1t6rudve9qvpuWZbvxq4MSOJaw028g8NeHlfVN1X1qKreef7vx0NXBCRyrWFVL8syeg0AAJvyDg4AEEfgAABxBA4AEEfgAABxBA4AEGfYPjjPvrx/7eNbD994cO3rPnr66bXb1r5uzW2/99Dv2/rrjlnLqe/rZY6x5tzO/xSP3TFeXN8pnte1Y3z87IPe9MAbOMW1ZhZbn8PWP8/nZubzP8Xa9v4/9JjvPeRa4x0cACCOwAEA4ggcACDOsJ2M37rz9q0PvPXv9257X4fe/8xzLqPmDmaeh5lpBmtLh75u1hy63hlncI651qw55nnf0pbXvJe5v0OOcczjsfV5HXqMWa4rHMYMDgBwkQQOABBH4AAAcQQOABBnqiHjmYewRg0P7z34d8w6Znq+9h4eP2YdIzZ/HPXcXMKQ8ZoRj/cpNvWbeRg+4VzXjNpccMTGuGsO/d47r39hyBgAuDwCBwCII3AAgDgCBwCIM9WQ8TFmHhpbY71jjjHDMW8y818uXnPbwb9TO+YDDTO9Ps7JpT1ut/2Qw7ntfL92f6OeVzsZAwAXSeAAAHEEDgAQR+AAAHGGDRkDAOzFOzgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETgAQByBAwDEETjsorvf7e5Puvvb7n5v9HqATN39o+7+U3d/3d1/7e5fj14Tc7gavQBiPa2qJ1X1sKpeGbwWINcfqurfVfXjqnpQVX/u7s+WZfnL2GUxWi/LMnoNBOvuJ1X15rIsvxm9FiBLd79aVf+oqp8ty/L589ver6q/LcvyaOjiGM6vqAA4Vz+pqv/8L26e+6yqfjpoPUxE4ABwrn5QVV+9cNtXVfXDAWthMgIHgHP1r6p67YXbXquqfw5YC5MROACcq8+r6qq77//fbT+vKgPGCBz20d1X3X2vqu5W1d3uvtfdPrUHbGZZlq+r6sOq+n13v9rdv6yqX1XV+2NXxgwEDnt5XFXfVNWjqnrn+b8fD10RkOh39f1WFH+vqj9W1W99RJwqHxMHAAJ5BwcAiCNwAIA4AgcAiCNwAIA4wz62++zL+wdNNz9848G12z56+ummX3dbW9//2v2t2fIctnbMY7L387r36+FlnOJxOvV9VVV9/OyDvvU37+StO29fu9ac4md3lmvNLNfLUzjFOWx5/Vlziv9DDj3u3teaQ9ex5pBrjXdwAIA4AgcAiCNwAIA4AgcAiDNsJ+O1wb9ZbD3Qt+YUA42HrOXchghfxovnP/u5bvnaGTWwfi5DxmtmHhQ+xf0dc9wR67gkMw9P731fNzFkDABcJIEDAMQROABAHIEDAMSZfsh4xG7Bp9jlc5ZdQ2ffIfXcdmFN2I16zaHndef1L6YbMl7bNX3EkP+InapHHeOYHXXXXNI1eXZ7D5lvea3xDg4AEEfgAABxBA4AEEfgAABxph8yPtSlD4jN/GftDz3GzEPBhxr1mtty1+Zjnptz3sl4zaXvbjzzz+kpdpcfsWv8uX2wYhQ7GQMAF0ngAABxBA4AEEfgAABxhg0ZAwDsxTs4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AEAcgQMAxBE4AECc/wJLUH4q+JoU2gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 576x576 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Lookup ResNet34 architecture (not needed, just for reference)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "test_learn = create_cnn(data, models.resnet18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='elementwise_mean')>"
      ]
     },
     "execution_count": 144,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_learn.loss_func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#test_learn.model[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): Lambda()\n",
       "  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25)\n",
       "  (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (5): ReLU(inplace)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 149,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_learn.model[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def ExpandInput(): return Lambda(lambda x: x.expand(-1, 3, 4, 50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "EI = ExpandInput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# Test ExpandInput layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1, 4, 50])"
      ]
     },
     "execution_count": 166,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt = torch.rand((64,1,4,50)); tt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 4, 50])"
      ]
     },
     "execution_count": 167,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt.expand(-1, 3, 4, 50).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD2D9jrwz4l/aM8P+IvBfif4x+N9CtPAWo+EjY/8Ib4jk0oatN4j0GDV7y51OOACLUblLq5cpcTo8sqZS6e6Es3m+IfEb4vax8Rb/xJ8evEWiWkGt6/4y8CwSQ6HfX2lWtqDDoesiRIbK4iEk63GuXkazzebMsCQwh9qvvKKAOH1vx1490T4laH8X/BHi0+Hn8afAW+8f8AijQNK0ewNhfX+hSeIr6xtMT28k0VkH0iCJoYpUPkXF1EjRpOwr2j9m39rr42/Gvw1+1X+0N8QdctL3Xfht8NPBniLw3BJYI9p9r1fSrhdQWaJwwmilgt4Ldo2ODFDGDlo0ZSigDY/Yk8G+Gv2t/jj4k8E/GrR7a+uvEvwH034v6p4ntrdItUuNXv4o4VsZpguLywtI5JVtY7pZpoWk8zzmljikj88+LNjdfEXWPh9rHjTxJreq6Qv7L2j/GufwfrWuXOo6Ve60G1LUBp92l48st7p3nWEqmO4kknaPUrpGnYJafZSigD0q0/4J6fCY2sR0n4q/FTTLUxr9m03TfiVqEdvaR4+WGJPMO2NRhVXJwABRRRQB//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAs5JREFUKJEFwU0o8wEAB+Dfei01lwlbm7GVjwNKajVMPg/aZeWwuAhlTXYQuzisWD6nrbXisuSwFNFqReKw2Z80X5G01bSY5bC1A5FR6Pc+jyiZTLKrqwuzs7Po6+uDVCrF5eUllEolbm5uUF1djbOzM1xcXODp6Qk+nw8HBwcIh8OYmZnB+vo6tFotDg8PsbOzg9PTUzw+PiKbzUIul8NoNEIikWBsbAwulwsrKyuoqanBwMAAnE4npqamIAgCLBYLSMLj8SCXyyEcDkOpVGJ6ehpXV1eoqqrC7e0t6uvrsb+/j2AwiN/fX0ilUgwODgLNzc1MpVL8/Pyk2+1mWVkZTSYTS0tL6fF4+PPzw1AoxEwmQ7FYTK1WS6fTycLCQo6OjvLh4YFWq5XDw8M0m810OBz8+vpiW1sbY7EYE4kEI5EIvV4v7XY7I5EI5+fnGYvFmM1m6ff7abfb6XK5GI/HGQgEWFxczEgkwsnJSep0Ora2tvL6+podHR20WCzUaDSUy+V8fX2lXq9nZWUl/5lMplmv1wuxWAy1Wg2FQoHn52csLy9jc3MTBoMBx8fH2N7eRnl5OXw+H1QqFYqKitDd3Y1sNguJRIK/vz9sbW0hFArB4/Egk8mgsbER6XQaLy8viMfjODo6gl6vR3t7O3Z3d5FOp6FWq2Gz2fD29oaRkRHc3d3BbrdDEASIxWI4HA7U1dXBYrFgbm4ONpsNZrMZ0WgU4+PjUCqVEIlEgNFoZCKRYElJCScmJpjP5xmNRrm3t8dEIsHe3l4KgkCZTMZAIMClpSUuLi6yoaGBCwsL7OzspNVqZSaTYVNTEw0GA3t6ephKpahSqajT6bi2tsba2loWFBQwl8vx/f2dGxsb1Ol01Gg0PD8/5+rqKhUKBfv7+/n9/c2Pjw+enJywoqKCyWSS+XyeMpmMQ0NDdLvdvL+/Z0tLC/1+P4PBIP8D9TaVeZbwiwgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (3, 4, 50)"
      ]
     },
     "execution_count": 168,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image(EI(tt)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "new_model = nn.Sequential(ExpandInput(), test_learn.model) # insert ExpandInput layer at the beginning of the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 3, 4, 50])"
      ]
     },
     "execution_count": 156,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model[0](tt).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "new_learn = Learner(data, new_model, metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#[p.shape for p in new_model.parameters()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 03:42\n",
      "epoch  train_loss  valid_loss  accuracy\n",
      "1      0.664468    0.764640    0.532000  (00:43)\n",
      "2      0.659713    0.819829    0.480000  (00:43)\n",
      "3      0.635178    0.888000    0.482000  (00:49)\n",
      "4      0.593556    0.875104    0.502000  (00:43)\n",
      "5      0.561881    0.893715    0.472000  (00:42)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "new_learn.fit_one_cycle(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "new_learn.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "del test_learn\n",
    "del new_model\n",
    "del new_learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Flatten(): return Lambda(lambda x: x.view((x.size(0), -1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ResizeInput(): return Lambda(lambda x: x.view((-1,)+x.size()[-2:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "#def ResizeOutput(): return Lambda(lambda x: x.view(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(ResizeInput(),\n",
    "                    nn.Conv1d(in_channels=4, out_channels=32, kernel_size=12),\n",
    "                    nn.MaxPool1d(kernel_size=4),\n",
    "                    Flatten(),\n",
    "                    nn.Linear(in_features=288, out_features=16),\n",
    "                    nn.ReLU(),\n",
    "                    nn.Linear(in_features=16, out_features=2),\n",
    "                    #Debugger()\n",
    "                   )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Lambda()\n",
       "  (1): Conv1d(4, 32, kernel_size=(12,), stride=(1,))\n",
       "  (2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)\n",
       "  (3): Lambda()\n",
       "  (4): Linear(in_features=288, out_features=16, bias=True)\n",
       "  (5): ReLU()\n",
       "  (6): Linear(in_features=16, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_init(net, nn.init.kaiming_normal_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run sequence through network for testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((Image (1, 4, 50), Category 0), (Image (1, 4, 50), Category 1))"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0], data.train_ds[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.2031, 0.2126]], grad_fn=<ThAddmmBackward>),\n",
       " tensor([[-0.2022,  1.0126]], grad_fn=<ThAddmmBackward>))"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(data.train_ds[0][0].data), net(data.train_ds[3][0].data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learner setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## TBLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# From https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\n",
    "\"\"\"Simple example on how to log scalars and images to tensorboard without tensor ops.\n",
    "License: Copyleft\n",
    "\"\"\"\n",
    "#__author__ = \"Michael Gygli\"\n",
    "\n",
    "#import tensorflow as tf\n",
    "#from StringIO import StringIO\n",
    "#import matplotlib.pyplot as plt\n",
    "#import numpy as np\n",
    "\n",
    "class Logger(object):\n",
    "    \"\"\"Logging in tensorboard without tensorflow ops.\"\"\"\n",
    "\n",
    "    def __init__(self, log_dir):\n",
    "        \"\"\"Creates a summary writer logging to log_dir.\"\"\"\n",
    "        self.writer = tf.summary.FileWriter(log_dir)\n",
    "\n",
    "    def log_scalar(self, tag, value, step):\n",
    "        \"\"\"Log a scalar variable.\n",
    "        Parameter\n",
    "        ----------\n",
    "        tag : basestring\n",
    "            Name of the scalar\n",
    "        value\n",
    "        step : int\n",
    "            training iteration\n",
    "        \"\"\"\n",
    "        summary = tf.Summary(value=[tf.Summary.Value(tag=tag,\n",
    "                                                     simple_value=value)])\n",
    "        self.writer.add_summary(summary, step)\n",
    "\n",
    "    def log_images(self, tag, images, step):\n",
    "        \"\"\"Logs a list of images.\"\"\"\n",
    "\n",
    "        im_summaries = []\n",
    "        for nr, img in enumerate(images):\n",
    "            # Write the image to a string\n",
    "            s = StringIO()\n",
    "            plt.imsave(s, img, format='png')\n",
    "\n",
    "            # Create an Image object\n",
    "            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),\n",
    "                                       height=img.shape[0],\n",
    "                                       width=img.shape[1])\n",
    "            # Create a Summary value\n",
    "            im_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, nr),\n",
    "                                                 image=img_sum))\n",
    "\n",
    "        # Create and write Summary\n",
    "        summary = tf.Summary(value=im_summaries)\n",
    "        self.writer.add_summary(summary, step)\n",
    "        \n",
    "\n",
    "    def log_histogram(self, tag, values, step, bins=1000):\n",
    "        \"\"\"Logs the histogram of a list/vector of values.\"\"\"\n",
    "        # Convert to a numpy array\n",
    "        values = np.array(values)\n",
    "        \n",
    "        # Create histogram using numpy        \n",
    "        counts, bin_edges = np.histogram(values, bins=bins)\n",
    "\n",
    "        # Fill fields of histogram proto\n",
    "        hist = tf.HistogramProto()\n",
    "        hist.min = float(np.min(values))\n",
    "        hist.max = float(np.max(values))\n",
    "        hist.num = int(np.prod(values.shape))\n",
    "        hist.sum = float(np.sum(values))\n",
    "        hist.sum_squares = float(np.sum(values**2))\n",
    "\n",
    "        # Requires equal number as bins, where the first goes from -DBL_MAX to bin_edges[1]\n",
    "        # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto#L30\n",
    "        # Thus, we drop the start of the first bin\n",
    "        bin_edges = bin_edges[1:]\n",
    "\n",
    "        # Add bin edges and counts\n",
    "        for edge in bin_edges:\n",
    "            hist.bucket_limit.append(edge)\n",
    "        for c in counts:\n",
    "            hist.bucket.append(c)\n",
    "\n",
    "        # Create and write Summary\n",
    "        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])\n",
    "        self.writer.add_summary(summary, step)\n",
    "        self.writer.flush()\n",
    "        \n",
    "\"A `Callback` that saves tracked metrics into a log file for Tensorboard.\"\n",
    "# Based on https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514\n",
    "# and devforfu: https://nbviewer.jupyter.org/gist/devforfu/ea0b3fcfe194dad323c3762492b05cae\n",
    "# Contribution from MicPie\n",
    "\n",
    "#from ..torch_core import *\n",
    "#from ..basic_data import DataBunch\n",
    "#from ..callback import *\n",
    "#from ..basic_train import Learner, LearnerCallback\n",
    "#import tensorflow as tf\n",
    "\n",
    "__all__ = ['TBLogger']\n",
    "\n",
    "@dataclass\n",
    "class TBLogger(LearnerCallback):\n",
    "    \"A `LearnerCallback` that saves history of metrics while training `learn` into log files for Tensorboard.\"\n",
    "    \n",
    "    log_dir:str = 'logs'\n",
    "    log_name:str = 'data'\n",
    "    log_scalar:bool = True # log scalar values for Tensorboard scalar summary\n",
    "    log_hist:bool = True # log values and gradients of the parameters for Tensorboard histogram summary\n",
    "    log_img:bool = False # log values for Tensorboard image summary\n",
    "\n",
    "    def __post_init__(self): \n",
    "        super().__post_init__()\n",
    "    #def __init__(self):\n",
    "    #    super().__init__()\n",
    "        self.path = self.learn.path\n",
    "        (self.path/self.log_dir).mkdir(parents=True, exist_ok=True) # setup logs directory\n",
    "        self.Log = Logger(str(self.path/self.log_dir/self.log_name))\n",
    "        self.epoch = 0\n",
    "        self.batch = 0\n",
    "        self.log_grads = {}\n",
    "    \n",
    "    def on_backward_end(self, **kwargs:Any):\n",
    "        self.batch = self.batch+1\n",
    "        #print('\\nBatch: ',self.batch)\n",
    "        \n",
    "        if self.log_hist:\n",
    "            for tag, value in learn.model.named_parameters():\n",
    "                tag_grad = tag.replace('.', '/')+'/grad'\n",
    "                \n",
    "                if tag_grad in self.log_grads:\n",
    "                    #self.log_grads[tag_grad] += value.grad.data.cpu().detach().numpy()\n",
    "                    self.log_grads[tag_grad] = self.log_grads[tag_grad] + value.grad.data.cpu().detach().numpy() # gradients are summed up from every batch\n",
    "                    #print('if')\n",
    "                else:\n",
    "                    self.log_grads[tag_grad] = value.grad.data.cpu().detach().numpy()\n",
    "                    #print('else')\n",
    "                \n",
    "                #print(tag_grad, self.log_grads[tag_grad].sum())\n",
    "        return self.log_grads\n",
    "    \n",
    "    #def on_step_end(self, **kwards:Any):\n",
    "        #print('Step end: ', self.log_grads)\n",
    "\n",
    "    def on_epoch_end(self, epoch:int, smooth_loss:Tensor, last_metrics:MetricsList, **kwargs:Any) -> bool:\n",
    "        last_metrics = ifnone(last_metrics, [])\n",
    "        tr_info = {name: stat for name, stat in zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics)}\n",
    "        self.epoch = tr_info['epoch']\n",
    "        self.batch = 0 # reset batch count\n",
    "        #print('\\nEpoch: ',self.epoch)\n",
    "        \n",
    "        if self.log_scalar:\n",
    "            for tag, value in tr_info.items():\n",
    "                if tag == 'epoch': continue\n",
    "                self.Log.log_scalar(tag, value, self.epoch+1)\n",
    "                \n",
    "        if self.log_hist:\n",
    "            for tag, value in learn.model.named_parameters():\n",
    "                \n",
    "                tag = tag.replace('.', '/')\n",
    "                self.Log.log_histogram(tag, value.data.cpu().numpy(), self.epoch+1)\n",
    "                \n",
    "                tag_grad = tag.replace('.', '/')+'/grad'\n",
    "                self.Log.log_histogram(tag_grad, self.log_grads[tag_grad], self.epoch+1)\n",
    "                #print(tag_grad, self.log_grads[tag_grad].sum())\n",
    "                \n",
    "        #if self.log_img:\n",
    "        #    for tag, value in learn.model.named_parameters():\n",
    "        #        \n",
    "        #        tag = tag.replace('.', '/')\n",
    "        #        self.Log.log_images(tag, value.data.cpu().numpy(), self.epoch+1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learner setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, net, loss_func=nn.functional.cross_entropy, metrics=accuracy, callback_fns=[TBLogger])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Layer (type)               Output Shape         Param #   \n",
      "================================================================================\n",
      "Lambda                    [64, 4, 50]          0                   \n",
      "________________________________________________________________________________\n",
      "Conv1d                    [64, 32, 39]         1568                \n",
      "________________________________________________________________________________\n",
      "MaxPool1d                 [64, 32, 9]          0                   \n",
      "________________________________________________________________________________\n",
      "Lambda                    [64, 288]            0                   \n",
      "________________________________________________________________________________\n",
      "Linear                    [64, 16]             4624                \n",
      "________________________________________________________________________________\n",
      "ReLU                      [64, 16]             0                   \n",
      "________________________________________________________________________________\n",
      "Linear                    [64, 2]              34                  \n",
      "________________________________________________________________________________\n",
      "Total params:  6226\n"
     ]
    }
   ],
   "source": [
    "learn.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.lr_find()\n",
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 00:55\n",
      "epoch  train_loss  valid_loss  accuracy\n",
      "1      0.694043    0.694633    0.476000  (00:01)\n",
      "2      0.693227    0.693842    0.510000  (00:01)\n",
      "3      0.692341    0.693214    0.520000  (00:02)\n",
      "4      0.691532    0.693981    0.516000  (00:01)\n",
      "5      0.690240    0.693385    0.536000  (00:01)\n",
      "6      0.688261    0.694868    0.502000  (00:01)\n",
      "7      0.684323    0.696151    0.514000  (00:01)\n",
      "8      0.679574    0.703019    0.466000  (00:01)\n",
      "9      0.670428    0.706717    0.494000  (00:01)\n",
      "10     0.659853    0.717627    0.482000  (00:01)\n",
      "11     0.647399    0.731758    0.482000  (00:01)\n",
      "12     0.630531    0.734997    0.490000  (00:01)\n",
      "13     0.612728    0.781713    0.494000  (00:01)\n",
      "14     0.595601    0.755801    0.476000  (00:01)\n",
      "15     0.576510    0.771777    0.482000  (00:01)\n",
      "16     0.556412    0.786781    0.478000  (00:01)\n",
      "17     0.537086    0.798632    0.478000  (00:02)\n",
      "18     0.518724    0.815592    0.484000  (00:02)\n",
      "19     0.501062    0.828895    0.476000  (00:02)\n",
      "20     0.485658    0.859679    0.466000  (00:03)\n",
      "21     0.470815    0.864150    0.478000  (00:03)\n",
      "22     0.457828    0.868159    0.476000  (00:02)\n",
      "23     0.449460    0.869975    0.458000  (00:02)\n",
      "24     0.438471    0.874046    0.464000  (00:01)\n",
      "25     0.429300    0.882240    0.454000  (00:01)\n",
      "26     0.423787    0.885527    0.454000  (00:01)\n",
      "27     0.418790    0.883194    0.470000  (00:01)\n",
      "28     0.416018    0.886720    0.458000  (00:01)\n",
      "29     0.412148    0.888030    0.452000  (00:01)\n",
      "30     0.411357    0.888070    0.452000  (00:01)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn.fit_one_cycle(30, max_lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 2.2482, -2.3293]], grad_fn=<ThAddmmBackward>),\n",
       " tensor([[-2.8670,  2.5273]], grad_fn=<ThAddmmBackward>))"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(data.train_ds[0][0].data), net(data.train_ds[3][0].data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds, targs = learn.get_preds()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [],
   "source": [
    "#targs.view(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {},
   "outputs": [],
   "source": [
    "#preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0.0346, 0.9654]), tensor(0))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 0\n",
    "preds[i], targs[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((Image (1, 4, 50), Category 0), (Image (1, 4, 50), Category 1))"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0], data.train_ds[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((Category 0, tensor(0), tensor([0.9898, 0.0102])),\n",
       " (Category 1, tensor(1), tensor([0.0045, 0.9955])))"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict(data.train_ds[0][0]), learn.predict(data.train_ds[3][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('epoch15')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Learner(data=ImageDataBunch;\n",
       "Train: LabelList\n",
       "y: CategoryList (2000 items)\n",
       "[Category 0, Category 0, Category 0, Category 1, Category 1]...\n",
       "Path: .\n",
       "x: SeqItemList (1500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "Valid: LabelList\n",
       "y: CategoryList (2000 items)\n",
       "[Category 0, Category 0, Category 0, Category 1, Category 1]...\n",
       "Path: .\n",
       "x: SeqItemList (500 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "Test: None, model=Sequential(\n",
       "  (0): Lambda()\n",
       "  (1): Conv1d(4, 32, kernel_size=(12,), stride=(1,))\n",
       "  (2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)\n",
       "  (3): Lambda()\n",
       "  (4): Linear(in_features=288, out_features=16, bias=True)\n",
       "  (5): ReLU()\n",
       "  (6): Linear(in_features=16, out_features=2, bias=True)\n",
       "), opt_func=functools.partial(<class 'torch.optim.adam.Adam'>, betas=(0.9, 0.99)), loss_func=<function cross_entropy at 0x1a1881ff28>, metrics=[<function accuracy at 0x1a18a44a60>], true_wd=True, bn_wd=True, wd=0.01, train_bn=True, path=PosixPath('.'), model_dir='models', callback_fns=[<class 'fastai.basic_train.Recorder'>, <class '__main__.TBLogger'>], callbacks=[], layer_groups=[Sequential(\n",
       "  (0): Lambda()\n",
       "  (1): Conv1d(4, 32, kernel_size=(12,), stride=(1,))\n",
       "  (2): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)\n",
       "  (3): Lambda()\n",
       "  (4): Linear(in_features=288, out_features=16, bias=True)\n",
       "  (5): ReLU()\n",
       "  (6): Linear(in_features=16, out_features=2, bias=True)\n",
       ")])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.load('epoch15')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpret = ClassificationInterpretation.from_learner(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE3tJREFUeJzt3XmUnXWd5/H3t7KQQEgqCxBJIOxEsO0ICijSIuLWItjN6VFEWhinGXXsaZrWFpczgtojgzqncVxaON2I0CqtfbRBnEZNiyIG2aRZmi2IGTAgISEh+1L1nT/uU6GyVFJZvnVreb/OuYf7/J7ffZ7vpao+9fv9nqduIjORpEod7S5A0vBn0EgqZ9BIKmfQSCpn0EgqZ9BIKmfQjFARMT4iboiIZRHx7V04ztkR8cPdWVu7RMRJEfFwu+sYjsL7aAa3iHgncCEwG1gO3AP8TWb+fBePew7w58CrMnPDLhc6yEVEAodn5vx21zISOaIZxCLiQuBvgf8J7AccCHwZOGM3HH4W8MhICJn+iIjR7a5hWMtMH4PwAUwCVgB/so0+e9AKooXN42+BPZp9JwNPAn8FPAM8BZzX7LsEWAesb87xHuBi4Npexz4ISGB0s30u8Gtao6rHgbN7tf+81+teBdwBLGv++6pe+24GPgXc2hznh8C0Pt5bT/1/3av+twF/CDwCLAE+2qv/ccA8YGnT94vA2Gbfz5r3srJ5v2/vdfwPA08D1/S0Na85tDnHMc32/sCzwMnt/t4Yio+2F+Cjjy8MvAnY0POD3kefTwK3AfsC+wC/AD7V7Du5ef0ngTHND+gqYHKzf/Ng6TNogL2A54Ejm30vAo5unm8MGmAK8BxwTvO6s5rtqc3+m4HHgCOA8c32pX28t576/0dT/58Bi4BvAHsDRwNrgEOa/scCJzTnPQh4ELig1/ESOGwrx/9ftAJ7fO+gafr8WXOcPYGbgM+1+/tiqD6cOg1eU4Fnc9tTm7OBT2bmM5m5iNZI5Zxe+9c3+9dn5g9o/TY/cifr6QZeEhHjM/OpzHxgK33eAjyamddk5obM/CbwEPDWXn2uysxHMnM18E/AnG2ccz2t9aj1wLeAacDlmbm8Of8DwEsBMvOuzLytOe9vgK8Cr+nHe/pEZq5t6tlEZl4JPAr8kla4fmw7x1MfDJrBazEwbTtrB/sDC3ptL2jaNh5js6BaBUzY0UIycyWt6cZ7gaci4saImN2PenpqmtFr++kdqGdxZnY1z3uC4He99q/ueX1EHBER34+IpyPieVrrWtO2cWyARZm5Zjt9rgReAvyfzFy7nb7qg0EzeM2jNTV42zb6LKS1qNvjwKZtZ6ykNUXoMb33zsy8KTNfT+s3+0O0fgC3V09PTb/dyZp2xFdo1XV4Zk4EPgrEdl6zzUuuETGB1rrX3wMXR8SU3VHoSGTQDFKZuYzW+sSXIuJtEbFnRIyJiDdHxGVNt28CH4+IfSJiWtP/2p085T3AH0TEgRExCfhIz46I2C8iTo+IvYC1tKZgXVs5xg+AIyLinRExOiLeDhwFfH8na9oRe9NaR1rRjLbet9n+3wGH7OAxLwfuysz/AtwI/N0uVzlCGTSDWGb+b1r30Hyc1kLoE8AHgO81XT4N3AncC9wH3N207cy5fgRc1xzrLjYNhw5aV68W0roS8xrg/Vs5xmLgtKbvYlpXjE7LzGd3pqYd9EHgnbSuZl1J6730djFwdUQsjYj/tL2DRcQZtBbk39s0XQgcExFn77aKRxBv2JNUzhGNpHIGjaRyBo2kcgaNpHKD6g/JJk2emtNnHNDuMlTg0V/v7O09Gsxy3XJyw+rt3a80uIJm+owD+PJ3ftzuMlTgtLMubncJKrD24X/qVz+nTpLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0ksoZNJLKGTSSyhk0BV48fQInHTaF4w/u3Ni2795jOf7gTk45cip7jxu9sT2Ao140geMP6uSEgzuZNWV8GypWf/3dJ85mwdzPcOe3P7rFvgvOeR2rf/VFpnbutUn7sUcdyIo7v8AfnTpnoMocdAyaAk8tW8M9TyzbpG3F2i7u++1ylq7esEn7vhP3oCOCX/5mKbf/ZikzJo9j3Bi/LIPVNTfcxhn/7UtbtM/cr5NTTpjN/3tqySbtHR3Bp//iDH4078GBKnFQ8ju6wNLVG1jfnZu0rVrXxap1XVt2zqSjIwigI4JM2NCVW/bToHDr3Y+xZNmqLdov++CZfOzy75G56dfu/e94Dd+b++8sWrJ8oEoclAyaNntm+Tq6u5NXHzaFVx82hQWLV7Gh26AZSt7ymt9j4TNLue+R327Svv8+kzj9lN/nyu/c0qbKBo/SoImIN0XEwxExPyIuqjzXUDVx/Ggyk5/PX8Ktjy3hwCnjnToNIePHjeHD73kjn/zKjVvs++yHzuTjl/8L3f7iYPT2u+yciBgFfAl4PfAkcEdEXJ+Z/1F1zqFo+sQ9WLxyPQms70qWrd7AxHGjWbN+XbtLUz8cMnMfZs2Yyu3XfQSAGft2Mu8bH+akcz7LMUcdyNcvPQ+AqZ0TeOOrj2bDhm5uuPnedpbcFmVBAxwHzM/MXwNExLeAMwCDppc167uZvOcYnn5+LR0Bk8aP5onnVre7LPXTA/MXMut1H9m4/dCNl3Di2ZexeOlKXnzaxRvbr7jkXfzfW+4fkSEDtVOnGcATvbafbNqGvaP335uXz+pkz7GjOPHQybxo0h7sM2EsJx46mUnjRjNn5kTmzJwIwJPPrWZUR3D8wZ0cd1AnC5etZcXarSwaa1C4+jPncvPVf8URs/Zj/r9+ine/7ZXtLmlIqBzRxFbatpisRsT5wPkA++4/s7CcgfPAwq1fYVi0YsvpUFfC/X301+Dz7o98bZv7Z7/lE1ttP/8T1xZUM3RUjmieBA7otT0TWLh5p8y8IjNfnpkv75w8tbAcSe1SGTR3AIdHxMERMRZ4B3B94fkkDVJlU6fM3BARHwBuAkYB/5CZD1SdT9LgVblGQ2b+APhB5TkkDX7eGSapnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6mcQSOpnEEjqZxBI6nc6L52RMQNQPa1PzNPL6lI0rDTZ9AAnxuwKiQNa30GTWb+dCALkTR8bWtEA0BEHA58BjgKGNfTnpmHFNYlaRjpz2LwVcBXgA3Aa4GvA9dUFiVpeOlP0IzPzLlAZOaCzLwYOKW2LEnDyXanTsCaiOgAHo2IDwC/BfatLUvScNKfEc0FwJ7AfweOBc4B3l1ZlKThZbsjmsy8o3m6AjivthxJw1F/rjr9hK3cuJeZrtNI6pf+rNF8sNfzccCZtK5ASVK/9GfqdNdmTbdGRMnNfE8/v5bL5s6vOLTa7Lk7vtjuElTgxONv61e//kydpvTa7KC1IDx958qSNBL1Z+p0F601mqA1ZXoceE9lUZKGl/4EzYszc03vhojYo6geScNQf+6j+cVW2ubt7kIkDV/b+jya6cAMYHxEvIzW1AlgIq0b+CSpX7Y1dXojcC4wE/g8LwTN88BHa8uSNJxs6/NorgaujogzM/OfB7AmScNMf9Zojo2Izp6NiJgcEZ8urEnSMNOfoHlzZi7t2cjM54A/rCtJ0nDTn6AZ1ftydkSMB7y8Lanf+nMfzbXA3Ii4qtk+D7i6riRJw01//tbpsoi4FziV1pWnfwVmVRcmafjo7z8g9zTQTesvt18HPFhWkaRhZ1s37B0BvAM4C1gMXEfrc4NfO0C1SRomtjV1egi4BXhrZs4HiIi/HJCqJA0r25o6nUlryvSTiLgyIl7HC3cHS1K/9Rk0mfndzHw7MBu4GfhLYL+I+EpEvGGA6pM0DGx3MTgzV2bmP2bmabT+7uke4KLyyiQNG/296gRAZi7JzK/6weSSdsQOBY0k7QyDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSudHtLmC4uvCUQzh+1mSWrl7Pf/3WvQD86XEzeeXBk0lg6ar1fG7uYyxZtR6A9500i+NmTWbN+i4+P/cx5j+7qo3Vqy+jO2BUQALrul5oH9MBEZAJ67u3bAfo6oauHNByBw1HNEV++OAiPnbDg5u0fedXT/G+6+7j/dfdxy8XLOVdr5gJwCtmdTJj0njOu/YeLr/5cf785EPaUbL6oat704CBVvh0Z6u9O1vbsGkgret6oX0kGsFvvdb9Ty1n+dpNvyNXrX9he9zoDnp+ub3y4Mn8+OFFADz0uxXsNXYUU/YcM1ClagdsbUDSES+MVLqytb256OO1I4VTpwF27vEHcOqR01i5rou//t5/ADBtr7EsWrFuY59nV65j6l5jN06rNLhtnis9210JYwL2GNXa7j2lGmnKRjQR8Q8R8UxE3F91jqHoa798gnd9/Vf82yPPcvpLp/fZbyT/9hsuOpqp01qnTqVTp68Bbyo8/pD2k0ef5dWHTAFaI5h9JozduG/aXmNZsnJdXy/VILP5L4We7VHRWtPpacvccvQzUpQFTWb+DFhSdfyhaP9J4zY+P+GgyTzx3GoAbnv8OU49ch8AZu83gVXrupw2DSHd2QoVaP23u0maBEb1+gnrGeGMRG1fo4mI84HzAcZN7nsqMdRc9PrDeOmMiUwaN5pr3/0yrrn9SY6b1cnMzvF0Z/LM8nV84ae/BuD2BUt5xaxOrnrXHNZu6Obzcx9rc/Xqy5iOFxZ79xgFG7pbjzEdrVDpfXm7p32sazTtD5rMvAK4AmDSgS8eNoF/6Y/mb9F204OL+uz/pZ/9prAa7S59hcWOto80I3h5StJAMWgklau8vP1NYB5wZEQ8GRHvqTqXpMGtbI0mM8+qOrakocWpk6RyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcgaNpHIGjaRyBo2kcpGZ7a5ho4hYBCxodx0DZBrwbLuL0G430r6uszJzn+11GlRBM5JExJ2Z+fJ216Hdy6/r1jl1klTOoJFUzqBpnyvaXYBK+HXdCtdoJJVzRCOpnEEjqZxBI6nc6HYXMBJExGzgDGAGkMBC4PrMfLCthUkDxBFNsYj4MPAtIIDbgTua59+MiIvaWZs0ULzqVCwiHgGOzsz1m7WPBR7IzMPbU5kqRcR5mXlVu+sYLBzR1OsG9t9K+4uafRqeLml3AYOJazT1LgDmRsSjwBNN24HAYcAH2laVdllE3NvXLmC/gaxlsHPqNAAiogM4jtZicABPAndkZldbC9MuiYjfAW8Entt8F/CLzNzaSHZEckQzADKzG7it3XVot/s+MCEz79l8R0TcPPDlDF6OaCSVczFYUjmDRlI5g0YARERXRNwTEfdHxLcjYs9dONbJEfH95vnp27oxMSI6I+L9O3GOiyPigztbowaWQaMeqzNzTma+BFgHvLf3zmjZ4e+XzLw+My/dRpdOYIeDRkOLQaOtuQU4LCIOiogHI+LLwN3AARHxhoiYFxF3NyOfCQAR8aaIeCgifg78cc+BIuLciPhi83y/iPhuRPx783gVcClwaDOa+mzT70MRcUdE3BsRl/Q61sci4uGI+DFw5ID939AuM2i0iYgYDbwZuK9pOhL4ema+DFgJfBw4NTOPAe4ELoyIccCVwFuBk4DpfRz+C8BPM/P3gWOAB4CLgMea0dSHIuINwOG07juaAxwbEX8QEccC7wBeRivIXrGb37oKeR+NeoyPiJ77QW4B/p7Wn04syMyee4BOAI4Cbo0IgLHAPGA28HhmPgoQEdcC52/lHKcAfwrQ3Ky4LCImb9bnDc3jV832BFrBszfw3cxc1Zzj+l16txpQBo16rM7MOb0bmjBZ2bsJ+FFmnrVZvzm0Pv5idwjgM5n51c3OccFuPIcGmFMn7YjbgBMj4jCAiNgzIo4AHgIOjohDm35n9fH6ucD7mteOioiJwHJao5UeNwH/udfaz4yI2Bf4GfBHETE+IvamNU3TEGHQqN8ycxFwLq3P0rmXVvDMzsw1tKZKNzaLwX39a6N/Abw2Iu4D7qL18RmLaU3F7o+Iz2bmD4FvAPOaft8B9s7Mu4HrgHuAf6Y1vdMQ4Z8gSCrniEZSOYNGUjmDRlI5g0ZSOYNGUjmDRlI5g0ZSuf8P4bKxNeJExWUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "interpret.plot_confusion_matrix()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}