[de45a9]: / deprecated / DL_Genomics_v1.ipynb

Download this file

810 lines (809 with data), 37.1 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep learning in genomics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is based on the [jupyter notebook](https://github.com/abidlabs/deep-learning-genomics-primer/blob/master/A_Primer_on_Deep_Learning_in_Genomics_Public.ipynb) from the publication [\"A primer on deep learning in genomics\"](https://www.nature.com/articles/s41588-018-0295-5) but uses the [fastai](https://www.fast.ai) library based on [PyTorch](https://pytorch.org)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.31.dev0'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# fastai version?\n",
    "__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "URL_seq = 'https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/sequences.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get data from URL\n",
    "seq_raw = requests.get(URL_seq).text.split('\\n')\n",
    "seq_raw = list(filter(None, seq)) # Removes empty lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check length\n",
    "len(seq_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup df from list\n",
    "seq_df = pd.DataFrame(seq_raw, columns=['Sequences'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 434,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show head of dataframe\n",
    "#seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "URL_labels = 'https://raw.githubusercontent.com/abidlabs/deep-learning-genomics-primer/master/labels.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_labels = requests.get(URL_labels).text.split('\\n')\n",
    "seq_labels = list(filter(None, seq_labels)) # Removes empty entries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_df['Target'] = seq_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Sequences</th>\n",
       "      <th>Target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Sequences Target\n",
       "0  CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGA...      0\n",
       "1  GAGTTTATATGGCGCGAGCCTAGTGGTTTTTGTACTTGTTTGTCGC...      0\n",
       "2  GATCAGTAGGGAAACAAACAGAGGGCCCAGCCACATCTAGCAGGTA...      0\n",
       "3  GTCCACGACCGAACTCCCACCTTGACCGCAGAGGTACCACCAGAGC...      1\n",
       "4  GGCGACCGAACTCCAACTAGAACCTGCATAACTGGCCTGGGAGATA...      1"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_labels = np.array(seq_labels) # transform to np.array"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data encoding test (incorporated into \"open_seq_image\" function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup class instance to encode the four different bases to integer values (1D)\n",
    "int_enc = LabelEncoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "one_hot_enc = OneHotEncoder(categories=[range(4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_enc = []\n",
    "\n",
    "for s in seq:\n",
    "    enc = int_enc.fit_transform(list(s)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    seq_enc.append(enc.toarray()) # export sparse matrix to np array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_enc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 462,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., 1., ..., 0., 1., 0., 0.],\n",
       "        [1., 1., 0., 0., ..., 1., 0., 1., 1.],\n",
       "        [0., 0., 1., 0., ..., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., ..., 0., 0., 0., 0.]]), (4, 50))"
      ]
     },
     "execution_count": 462,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_enc[0].T, seq_enc[0].T.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAMAAAD7wpwzAAADAFBMVEUAAAABAQECAgIDAwMEBAQFBQUGBgYHBwcICAgJCQkKCgoLCwsMDAwNDQ0ODg4PDw8QEBARERESEhITExMUFBQVFRUWFhYXFxcYGBgZGRkaGhobGxscHBwdHR0eHh4fHx8gICAhISEiIiIjIyMkJCQlJSUmJiYnJycoKCgpKSkqKiorKyssLCwtLS0uLi4vLy8wMDAxMTEyMjIzMzM0NDQ1NTU2NjY3Nzc4ODg5OTk6Ojo7Ozs8PDw9PT0+Pj4/Pz9AQEBBQUFCQkJDQ0NERERFRUVGRkZHR0dISEhJSUlKSkpLS0tMTExNTU1OTk5PT09QUFBRUVFSUlJTU1NUVFRVVVVWVlZXV1dYWFhZWVlaWlpbW1tcXFxdXV1eXl5fX19gYGBhYWFiYmJjY2NkZGRlZWVmZmZnZ2doaGhpaWlqampra2tsbGxtbW1ubm5vb29wcHBxcXFycnJzc3N0dHR1dXV2dnZ3d3d4eHh5eXl6enp7e3t8fHx9fX1+fn5/f3+AgICBgYGCgoKDg4OEhISFhYWGhoaHh4eIiIiJiYmKioqLi4uMjIyNjY2Ojo6Pj4+QkJCRkZGSkpKTk5OUlJSVlZWWlpaXl5eYmJiZmZmampqbm5ucnJydnZ2enp6fn5+goKChoaGioqKjo6OkpKSlpaWmpqanp6eoqKipqamqqqqrq6usrKytra2urq6vr6+wsLCxsbGysrKzs7O0tLS1tbW2tra3t7e4uLi5ubm6urq7u7u8vLy9vb2+vr6/v7/AwMDBwcHCwsLDw8PExMTFxcXGxsbHx8fIyMjJycnKysrLy8vMzMzNzc3Ozs7Pz8/Q0NDR0dHS0tLT09PU1NTV1dXW1tbX19fY2NjZ2dna2trb29vc3Nzd3d3e3t7f39/g4ODh4eHi4uLj4+Pk5OTl5eXm5ubn5+fo6Ojp6enq6urr6+vs7Ozt7e3u7u7v7+/w8PDx8fHy8vLz8/P09PT19fX29vb39/f4+Pj5+fn6+vr7+/v8/Pz9/f3+/v7////isF19AAAAO0lEQVR4nGWNSQoAMAgD8/9PT6loIq0HMauShOR9DwphIkO5eALbEI1xp5DUoxYnNtf3t1rci9mGy30A2+0xz9q2+b0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=P size=50x4 at 0x1A20462668>"
      ]
     },
     "execution_count": 311,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PIL.Image.fromarray(seq_enc[0].T.astype('uint8')*255).convert('P')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data setup for NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 347,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _maybe_squeeze(arr): return (arr if is1d(arr) else np.squeeze(arr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 435,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open sequence image function\n",
    "def open_seq_image(seq:str, div:bool=True, cls:type=Image)->Image:\n",
    "    \"Return `Image` object created from sequence string `seq`.\"\n",
    "    \n",
    "    int_enc = LabelEncoder() # setup class instance to encode the four different bases to integer values (1D)\n",
    "    one_hot_enc = OneHotEncoder(categories=[range(4)]) # setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "    \n",
    "    enc = int_enc.fit_transform(list(seq)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    enc = enc.toarray().T # export sparse matrix to np array\n",
    "    \n",
    "    # https://stackoverflow.com/questions/22902040/convert-black-and-white-array-into-an-image-in-python\n",
    "    x = PIL.Image.fromarray(enc.astype('uint8')*255).convert('P')\n",
    "        \n",
    "    x = pil2tensor(x,np.float32)\n",
    "    if div: x.div_(255)\n",
    "    return cls(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 436,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD80/8AhrrUr/8A0rxT8B/h1q95B/pGltNotzaWGm6iflfUYdIsrmDSvPkjjsopUeze3nTTrfzYZGadpvuj4Jf8EwP2bf2gf2wL39kD4j6n4uuYLDwj4g1u38e/8JEz68IdG8WXnhax0rMqvZ/YlsbG3faLUTCVcJMkAWBSivm/HbGYzgrIq2IyKrLDTjhcTUThKSanTdBQkuZztbnleKXLK95xqOMXD9sPnT9kmOD9rf8A4Wbb6rcav4Ls/gx8Ita+Inwp0/wV4s1VYfCuqaf5LmKyF/dXRhguriVbqcg+eZreIxTQqHR+R+CXxp8V/tC+Mr34X/FXTdI1fwnpnhHxBrWlaFqWlx3s2nWmjaPearZaJaapd+bqtnpYkso4TbwXkbeTLOFkWSZ5SUV9lXwmGhmfENJRVsJDCOj3pSq0HOpKEvjUpz9+cudylJuTbbk2Hrn7DnwE+DH7cXg3xd4t1X4aaR8PovBvi7wX4fsNJ8DWguIbqPxRrA0fUbm4l1v+0Ll50tdptsTLHazJ58UayvI7/On/AA2R480r/ibfDvwF4R8F+IW+SXxP4N0uXT7mWE/vXiaCOb7H/wAfuL+OUW4mtLmK3+ySW0VpaQwFFHCuFoZnx5n2W4xOrQw06KpQnKUoxU6VKUk+ab51KUm2qjqpt+SSA/4ba+Mn/QmfCL/xH/wf/wDKqiiiv0n/AFT4Y/6A6f3S/wDlgH//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAHNJREFUKJFjdGEM+c8ABTufXWBwlzJgQAe4xNHVMDAwoKiD6cMmRwlANhdmJuPf5yr/cVlCjAdIdQAMELIP3bHobkEPJEZYjKBLYLMMV6gji+GykJDHcHkOm9uwqWNETlqEDEF2LD7LiUlK+JIxTC8pKQIA7jhk1J8wofAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 436,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test open sequence image function\n",
    "open_seq_image('CCGAGGGCTATGGTTTGGAAGTTAGAACCCTGGGGCTTCTCGCGGACACC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 437,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SeqItemList(ImageItemList):\n",
    "    _bunch,_square_show = ImageDataBunch,True\n",
    "    def __post_init__(self):\n",
    "        super().__post_init__()\n",
    "        self.sizes={}\n",
    "    \n",
    "    def open(self, seq): return open_seq_image(seq)\n",
    "    \n",
    "    def get(self, i):\n",
    "        seq = self.items[i][0]\n",
    "        res = self.open(seq)\n",
    "        return res\n",
    "    \n",
    "    @classmethod\n",
    "    def import_from_df(cls, df:DataFrame, cols:IntsOrStrs=0, **kwargs)->'ItemList':\n",
    "        \"Get the sequences in `col` of `df` and will had `path/folder` in front of them, `suffix` at the end.\"\n",
    "        return _maybe_squeeze(cls(items=df[cols].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 438,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (SeqItemList.import_from_df(seq_df, ['Sequences'])\n",
    "        .split_by_idxs(range(1800), range(1800,2000))\n",
    "        .label_from_list(seq_df['Target'].values)\n",
    "        .databunch())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 439,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "Train: LabelList\n",
       "y: CategoryList (2000 items)\n",
       "[Category 0, Category 0, Category 0, Category 1, Category 1]...\n",
       "Path: .\n",
       "x: SeqItemList (1800 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "Valid: LabelList\n",
       "y: CategoryList (2000 items)\n",
       "[Category 0, Category 0, Category 0, Category 1, Category 1]...\n",
       "Path: .\n",
       "x: SeqItemList (200 items)\n",
       "[Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50), Image (1, 4, 50)]...\n",
       "Path: .;\n",
       "Test: None"
      ]
     },
     "execution_count": 439,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# why is y data not split 1800/200?\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 440,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAEADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4C+JPxp8a/Dzwb8P/AIm3msav4j8WfFTwjNq/xD8V+IPGWtNf6/pyaxqWkP4fuZIL2IS6XLb6Val0ZTcF9wW4WMJGnrn7P/wI+FP7QvwG8E/HrXfCv9kz+MP2i9N+DGn+GrG+nvdK8K6Df2sd1dXOlxam93JFetLcXjrJNJPFFJeyzRQxzrDNEUVyca3yfgqhj8C3SrSxM6blF8v7tPNFGFlaKjFYahGKUY8sKapx5acqlOp+2Hkf7P8A8afFf7Rvx58E/s+eK9N0jRNK8eeLtN0LVb3wVpcejTWTahdR2F7fWUVpstrOe406T7BPFFCtrdQxxPcW808STqfG34paN8BvGVl8KPCXwY8I3fgfVfCPh/xLrfgzVYbyS21G/wBV0ez1UPJcrcrfL9je8a2tDHco8VsjozSNe6g94UV9jXyrLocf0MshTth5YapUcE5KLqRqRjGek0+eKbUZc147xal7wHrnwS+CPgP4+eDb3xF4z0/7HpEP7OniD41an4N8OxRadpWr+JNJ1i80mOO4S3RZhZSwWjubaKWIW0l/e/YjZxzmMfOn/DaPxWs/+RY8OeEfD3k/6Tp3/CM+F4NO+w6qeusJ9nCeZep5l39nebzI9P8At832COz2weSUV4/A2Ew2dZ7m+Gx8fa06E6KpxldxiprF81lezcvZU03Pnk1BR5uW8ZAf8NNfDS9/0zxH+wx8ItS1CX57/UfO8SWX2qY8vL9nsdZgtbfc2W8q3hihTO2OONAqgoor9J/1WyTpGa9K+JS+SWJsl2S0S0WiQH//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAAECAYAAADMHGwBAAAABHNCSVQICAgIfAhkiAAAAG5JREFUKJFj/Ptc5T8DAwODu5QBAwzsfHaBAZsYMh8bQFaDSz0hc0iRR2YzujCG/MfrOgIGonsan0OQ9WBTj00vIXtgNIpHsCmGAUKxQ8hx6HqQ7SIUe9j0ovPhHiE26WALFXzq8TkUm1ps9hADAFNpfNSob50MAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (1, 4, 50)"
      ]
     },
     "execution_count": 440,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.x[1799]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 441,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Category 0"
      ]
     },
     "execution_count": 441,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[1799]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 442,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((Image (1, 4, 50), Category 0), ['0', '1'])"
      ]
     },
     "execution_count": 442,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0], data.train_ds.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 443,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHJCAYAAAAVcogaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEIFJREFUeJzt3T/LXOW+x+HfL3mKSDSFjSK7jY0bTG8tKS02NuKG3Sq+gBQ5jeQt2AspLAQ9bLAQ211amMImVhY7iHDYBA2iYNYpjOc8Gck4M1n/Zr7XVSWTZ9a6MzP3zIe17llPD8NQAADkuLD0AAAAmJcABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCMAT093Pd/en3f2gu7/t7reWHhMcK/MJxtPd73X3l939c3d/uPR40p0tPQBG90FV/VJVL1TVtar6rLvvDMPw9bLDgqNkPsF47lXVraq6XlXPLDyWeO03gZyO7r5cVf+pqr8Ow3D30W23q+rfwzDcWHRwcGTMJ5hGd9+qqr8Mw/CPpceSzCng0/JyVf36+4fVI3eq6pWFxgPHzHwCTpYAPC3PVtX9jdvuV9VzC4wFjp35BJwsAXhafqyqKxu3XamqHxYYCxw78wk4WQLwtNytqrPuvnrutleryoJ12J/5BJwsAXhChmF4UFWfVNX73X25u1+rqjeq6vayI4PjYz7BuLr7rLsvVdXFqrrY3Ze629VIFiIAT8+79dvX67+vqo+q6h2XrICDmU8wnptV9VNV3aiqtx/9+eaiIwrmMjAAAGEcAQQACCMAAQDCCEAAgDACEAAgjAAEAAiz2PV3Hn539bGvH19/6dpO9/v83leP/X3b/c7/7K4/t89Y9rHrWHbdxtNsZ5997Lq/KZ6XQ83xOH3x8OMefaNP6dA5tc0Uj+XS2xzrNb7tfofuY2prfq9b25x6/cKbT7xExrbndx9TvxYOfb6Xfv2ONZ933cahxnpP2Ob8dvZ5XHaZT44AAgCEEYAAAGEWuxD0tsPrczj0NMWaTivvY4pD+lM/FnOfVt7nEP2FF79Z1emqqsPn1BzPz67bWXqeLG2Kx2LuuXHo62dtc2qOz6gx3kPnWP5w6Fi2OZalUGt+T9r2eO8ynxwBBAAIIwABAMKs5hTw0t86epJDD3cfelp5m6f5VtHc38qd+hT7HGPZZm3fWKw6/FuL215Xc889p5X/37EsndhnLOdt7m9tc2qKb9VvM9bztqb33l33t+lYrw4xxf6m/Fa9I4AAAGEEIABAGAEIABBmNWsAt9l1vcOmMb4qfwprEcZyrL/BYextVK1vvVLV9nW1c6ynnOKK/GNsc4nfEjL3byLZdRvbLL0Oe21zao75tM2xrGOdY5xTv5dtmmKt/9ysAQQA4A8EIABAGAEIABBmsTWAAAAswxFAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwghAAIAwAhAAIIwABAAIIwABAMIIQACAMAIQACCMAAQACCMAAQDCCEAAgDACEAAgjAAEAAgjAAEAwgjAE9Pdz3f3p939oLu/7e63lh4THKvufq+7v+zun7v7w6XHA8fOZ9R6nC09AEb3QVX9UlUvVNW1qvqsu+8Mw/D1ssOCo3Svqm5V1fWqembhscAp8Bm1Ej0Mw9JjYCTdfbmq/lNVfx2G4e6j225X1b+HYbix6ODgiHX3rar6yzAM/1h6LHCsfEati1PAp+Xlqvr194n1yJ2qemWh8QDA73xGrYgAPC3PVtX9jdvuV9VzC4wFAM7zGbUiAvC0/FhVVzZuu1JVPywwFgA4z2fUigjA03K3qs66++q5216tKotrAViaz6gVEYAnZBiGB1X1SVW9392Xu/u1qnqjqm4vOzI4Tt191t2XqupiVV3s7kvd7eoJcACfUesiAE/Pu/Xb5Sq+r6qPquodX6+Hg92sqp+q6kZVvf3ozzcXHREcN59RK+EyMAAAYRwBBAAIIwABAMIIQACAMAIQACCMAAQACLPY9axev/DmY18//vzeV//35+svXRt9f+e3v88+tt3v0G0ubXPc5+3zf9i2nUO3eej+pn5eNrd54cVv+qk3OrLNOXXePq/jJ/3c5s9u+7c/287Y9xvrdTvHe8/U+1varq+RzX/74uHHq5pT2+bTPtb0epti31N8di/92brWHtnHLvPJEUAAgDACEAAgzGIXgn743dXHdjzGqZ6lDxuPYZ//w9yn4/bZzhhjGcuhp6S2Wdvpqqr9TlmNMafGeq1OMc5dx7XP/qZ4Hc/xGO6z/yeNJXFZxaFLKvaxpsd/2/aXfO1NtZ0x9r/P5+Wu2x/ruXYKGACAPxCAAABhFjsFPMU3rKY4tbTP/aY+VZ3wjeRjGfcaTwFvLqs4b02P6xSv/ynut4+5T2dNcepp21jWcspqTlMsqdjHFEtXxnDosoVNcyybmMKaHvttdllS4QggAEAYAQgAEEYAAgCEWeUawLkvPTHHOoVt21zrWo8lTP1bIcb6jRhrW69UNd66Wn6z9FqjKUxx+ZKx1geubU4duqZ27tfNHJeIWXq97ZK/9WfuNcRzXqrMEUAAgDACEAAgjAAEAAiz2BpAAACW4QggAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQTgCenu97r7y+7+ubs/XHo8cOzMKRhXdz/f3Z9294Pu/ra731p6TKnOlh4Ao7pXVbeq6npVPbPwWOAUmFMwrg+q6peqeqGqrlXVZ919ZxiGr5cdVh5HAE/IMAyfDMPw31X1P0uPBU6BOQXj6e7LVfW3qvqvYRh+HIbhX1X1z6r6+7IjyyQAAYA5vFxVvw7DcPfcbXeq6pWFxhNNAAIAc3i2qu5v3Ha/qp5bYCzxBCAAMIcfq+rKxm1XquqHBcYSTwACAHO4W1Vn3X313G2vVpUvgCxAAJ6Q7j7r7ktVdbGqLnb3pe72TW84kDkF4xmG4UFVfVJV73f35e5+rareqKrby44skwA8LTer6qequlFVbz/6881FRwTHzZyCcb1bv11S6fuq+qiq3nEJmGX0MAxLjwEAgBk5AggAEEYAAgCEEYAAAGEEIABAGAEIABBmsetZvX7hTV8/3tPn97567O/XX7p28H0P3c4h+zt0nJv3O3SbT9rGn21n2/6+ePhxHzSACT387uoT59Su/8997rePXZ+7fcYyxuthm6eZb7tuZ4p97PP47rq/OV4ja5tTm/PpFN4nd30e535dbrvfWGMb6/90qG2P/Xljvc/tMp8cAQQACCMAAQDCLHYh6G2nqzbNfeh9itMpUxx6H+u0zNSnDOY4BbbNFKeV13a6qmqaZRVTnJKd+1TMWKf9x3odT72UYe73kH3Gss2FF79Z1ZzanE9TnK5NtvQp2W0O/dx70s/NYXMsu8wnRwABAMIIQACAMIudAt7ndNWhh1mX/NbopjFOO431Lapjsab/wyGH1+e2bU5N8Viu6fnZxxyn5KZ47zlvjlOQc5+6XNuyikPn01j/ts2avpF8yPb33cexvIbHcOhr5JArVTgCCAAQRgACAIQRgAAAYY5iDeA2U1+mZNv+jmVNwT7meCymXh+1aYorxa9tvVLV4etql17vM8blY8ZYf/tnP7vtvnP/xpQ1O5VLK61pPu267839T/XbbHY19zq/ffY/9XcLll4jbQ0gAAB/IAABAMIIQACAMIutAQQAYBmOAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAGAEIABBGAAIAhBGAAABhBCAAQBgBCAAQRgACAIQRgAAAYQQgAEAYAQgAEEYAAgCEEYAAAGEEIABAmP8FUCudvf/ncusAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 648x720 with 9 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 444,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Flatten(): return Lambda(lambda x: x.view((x.size(0), -1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 458,
   "metadata": {},
   "outputs": [],
   "source": [
    "basic_model = nn.Sequential(nn.Conv1d(in_channels=4, out_channels=32, kernel_size=12),\n",
    "                            nn.MaxPool1d(kernel_size=4),\n",
    "                            Flatten(),\n",
    "                            nn.Linear(in_features=288, out_features=16),\n",
    "                            nn.ReLU(),\n",
    "                            nn.Linear(in_features=16, out_features=1)\n",
    "                           )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 459,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv1d(4, 32, kernel_size=(12,), stride=(1,))\n",
       "  (1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)\n",
       "  (2): Lambda()\n",
       "  (3): Linear(in_features=288, out_features=16, bias=True)\n",
       "  (4): ReLU()\n",
       "  (5): Linear(in_features=16, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 459,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "basic_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 464,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Image (1, 4, 50), Category 0)"
      ]
     },
     "execution_count": 464,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 460,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4, 50])"
      ]
     },
     "execution_count": 460,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds[0][0].data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 461,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0896]], grad_fn=<ThAddmmBackward>)"
      ]
     },
     "execution_count": 461,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "basic_model(data.train_ds[0][0].data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learner setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 465,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "conv1d(): argument 'input' (position 1) must be Tensor, not bool",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-465-9fd4f8d6552c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlearn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_cnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbasic_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpretrained\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Downloads/fastai/fastai/vision/learner.py\u001b[0m in \u001b[0;36mcreate_cnn\u001b[0;34m(data, arch, cut, pretrained, lin_ftrs, ps, custom_head, split_on, bn_final, **kwargs)\u001b[0m\n\u001b[1;32m     52\u001b[0m     \u001b[0;34m\"Build convnet style learners.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     53\u001b[0m     \u001b[0mmeta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcnn_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0march\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m     \u001b[0mbody\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_body\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0march\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpretrained\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mifnone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcut\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmeta\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'cut'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     55\u001b[0m     \u001b[0mnf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnum_features_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     56\u001b[0m     \u001b[0mhead\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcustom_head\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mcreate_head\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlin_ftrs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbn_final\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbn_final\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    475\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    476\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    478\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    479\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     90\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     93\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    475\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    476\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    478\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    479\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    179\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    180\u001b[0m         return F.conv1d(input, self.weight, self.bias, self.stride,\n\u001b[0;32m--> 181\u001b[0;31m                         self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m    182\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: conv1d(): argument 'input' (position 1) must be Tensor, not bool"
     ]
    }
   ],
   "source": [
    "learn = create_cnn(data, basic_model, pretrained=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}