[de45a9]: / DL_Genomics_v10_resnet-fastai-fastq.ipynb

Download this file

1805 lines (1804 with data), 137.7 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep learning in genomics - Basic model with PyTorch and fastai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is based on the [jupyter notebook](https://nbviewer.jupyter.org/github/abidlabs/deep-learning-genomics-primer/blob/master/A_Primer_on_Deep_Learning_in_Genomics_Public.ipynb) from the publication [\"A primer on deep learning in genomics\"](https://www.nature.com/articles/s41588-018-0295-5) but uses the [fastai](https://www.fast.ai) library based on [PyTorch](https://pytorch.org)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.39.dev0'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# fastai version\n",
    "__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From https://forum.pyro.ai/t/a-clever-trick-to-debug-tensor-memory/556\n",
    "def debug_memory():\n",
    "    import collections, gc, torch\n",
    "    tensors = collections.Counter((str(o.device), o.dtype, tuple(o.shape))\n",
    "                                  for o in gc.get_objects()\n",
    "                                  if torch.is_tensor(o))\n",
    "    for line in sorted(tensors.items()):\n",
    "        print('{}\\t{}'.format(*line))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## Data frame setup from raw data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PosixPath('/Volumes/HDD08/ugenomfit/data')"
      ]
     },
     "execution_count": 274,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PATH_raw_data = Path('/Volumes/HDD08/ugenomfit/data/'); PATH_raw_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 275,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "from Bio import SeqIO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 276,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def get_seqs(path, seq_per_file=500):\n",
    "    label, seq, score = [], [], []\n",
    "    for n, fn in enumerate((path).glob('*.fastq')):\n",
    "        \n",
    "        # print file number n\n",
    "        #print(n)\n",
    "        \n",
    "        # print file number n and filename fn\n",
    "        print(n, fn)\n",
    "        \n",
    "        # append sequences\n",
    "        for m, record in enumerate(SeqIO.parse(fn, 'fastq')):\n",
    "            \n",
    "            #print(m)\n",
    "            if m == seq_per_file: break # only read in the first m sequences per fastq file\n",
    "            \n",
    "            # append sequences\n",
    "            seq.append(str(record.seq))\n",
    "            score.append(record.letter_annotations['phred_quality'])\n",
    "            \n",
    "            # print genome name including contig number and sequence length\n",
    "            #print(\"%s %i\" % (record.id, len(record)))\n",
    "        \n",
    "            # append genome labels\n",
    "            label.append(str(record.id))\n",
    "            \n",
    "            #pdb.set_trace() # start debugger\n",
    "                    \n",
    "        # break loop\n",
    "        #if n == 1: break\n",
    "    return label, seq, score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 277,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 /Volumes/HDD08/ugenomfit/data/SRR1993099.sra_1.fastq\n",
      "1 /Volumes/HDD08/ugenomfit/data/SRR1993099.sra_2.fastq\n",
      "2 /Volumes/HDD08/ugenomfit/data/SRR5665975.sra_1.fastq\n",
      "3 /Volumes/HDD08/ugenomfit/data/SRR5665975.sra_2.fastq\n"
     ]
    }
   ],
   "source": [
    "label, seq, score = get_seqs(PATH_raw_data, seq_per_file=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#label[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#seq[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#score[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 281,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20000, 20000, 20000)"
      ]
     },
     "execution_count": 281,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(label), len(seq), len(score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 282,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "col_names = ['label', 'sequence', 'score']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 283,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "seq_score = [(a,b) for a,b in zip(seq, score)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "fastq_df = pd.DataFrame({'label': label, 'seq_score': seq_score})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>seq_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>SRR1993099.sra.1</td>\n",
       "      <td>(ATCTACAAGAAGGGTGAAGTGCTTTTCGAATTTTGCCACTGCAAG...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>SRR1993099.sra.2</td>\n",
       "      <td>(TGTGCTCCATGTCGATATTTCGTGGAGCAAACCAAAAAAGATGCG...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>SRR1993099.sra.3</td>\n",
       "      <td>(NACCCTTTATAAAAGCCTAGATGTAGCAGTGCGAAGCGAACTCGA...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SRR1993099.sra.4</td>\n",
       "      <td>(CGAATGGGACCTTGAATGGATTAACGAGATTCCCACTGTCCCTAT...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SRR1993099.sra.5</td>\n",
       "      <td>(CACACCGCAGTAGATGGAAAACTCGAGTTTTACATGGAACAGGTA...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              label                                          seq_score\n",
       "0  SRR1993099.sra.1  (ATCTACAAGAAGGGTGAAGTGCTTTTCGAATTTTGCCACTGCAAG...\n",
       "1  SRR1993099.sra.2  (TGTGCTCCATGTCGATATTTCGTGGAGCAAACCAAAAAAGATGCG...\n",
       "2  SRR1993099.sra.3  (NACCCTTTATAAAAGCCTAGATGTAGCAGTGCGAAGCGAACTCGA...\n",
       "3  SRR1993099.sra.4  (CGAATGGGACCTTGAATGGATTAACGAGATTCCCACTGTCCCTAT...\n",
       "4  SRR1993099.sra.5  (CACACCGCAGTAGATGGAAAACTCGAGTTTTACATGGAACAGGTA..."
      ]
     },
     "execution_count": 285,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fastq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def recode_label(label, inv=False):\n",
    "\n",
    "    if label.find('SRR1993099') != -1:\n",
    "        return 0 if inv == False else 1\n",
    "       \n",
    "    elif label.find('SRR5665975') != -1:\n",
    "        return 1 if inv == False else 0\n",
    "    \n",
    "    else:\n",
    "        return np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 287,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "fastq_df['target'] = fastq_df['label'].apply(recode_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "fastq_df['nottarget'] = fastq_df['label'].apply(partial(recode_label,inv=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>seq_score</th>\n",
       "      <th>target</th>\n",
       "      <th>nottarget</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>SRR1993099.sra.1</td>\n",
       "      <td>(ATCTACAAGAAGGGTGAAGTGCTTTTCGAATTTTGCCACTGCAAG...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>SRR1993099.sra.2</td>\n",
       "      <td>(TGTGCTCCATGTCGATATTTCGTGGAGCAAACCAAAAAAGATGCG...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>SRR1993099.sra.3</td>\n",
       "      <td>(NACCCTTTATAAAAGCCTAGATGTAGCAGTGCGAAGCGAACTCGA...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SRR1993099.sra.4</td>\n",
       "      <td>(CGAATGGGACCTTGAATGGATTAACGAGATTCCCACTGTCCCTAT...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SRR1993099.sra.5</td>\n",
       "      <td>(CACACCGCAGTAGATGGAAAACTCGAGTTTTACATGGAACAGGTA...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              label                                          seq_score  \\\n",
       "0  SRR1993099.sra.1  (ATCTACAAGAAGGGTGAAGTGCTTTTCGAATTTTGCCACTGCAAG...   \n",
       "1  SRR1993099.sra.2  (TGTGCTCCATGTCGATATTTCGTGGAGCAAACCAAAAAAGATGCG...   \n",
       "2  SRR1993099.sra.3  (NACCCTTTATAAAAGCCTAGATGTAGCAGTGCGAAGCGAACTCGA...   \n",
       "3  SRR1993099.sra.4  (CGAATGGGACCTTGAATGGATTAACGAGATTCCCACTGTCCCTAT...   \n",
       "4  SRR1993099.sra.5  (CACACCGCAGTAGATGGAAAACTCGAGTTTTACATGGAACAGGTA...   \n",
       "\n",
       "   target  nottarget  \n",
       "0       0          1  \n",
       "1       0          1  \n",
       "2       0          1  \n",
       "3       0          1  \n",
       "4       0          1  "
      ]
     },
     "execution_count": 289,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fastq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 290,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>seq_score</th>\n",
       "      <th>target</th>\n",
       "      <th>nottarget</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>19995</th>\n",
       "      <td>SRR5665975.sra.4996</td>\n",
       "      <td>(TTTCTACCATCTCCAGCGGAGCAGAATCGGTGGCGGTTTCCGCTT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19996</th>\n",
       "      <td>SRR5665975.sra.4997</td>\n",
       "      <td>(CCCGTGCGCGCGGCTATACGCCGGGACGTTTCAGCTTTAACGTTC...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19997</th>\n",
       "      <td>SRR5665975.sra.4998</td>\n",
       "      <td>(ATTTTCCGCAAACCAGGCTTTAACCGCGCCAGCGAAGGCTATTTT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19998</th>\n",
       "      <td>SRR5665975.sra.4999</td>\n",
       "      <td>(TGCAAAATGGTGCTAATTTTGGCCCCTGGCGATTACGCAACTATT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19999</th>\n",
       "      <td>SRR5665975.sra.5000</td>\n",
       "      <td>(AACGGCAAAATTATCGCCGTTGCCAGCAATATCCCTTCTGACATT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     label                                          seq_score  \\\n",
       "19995  SRR5665975.sra.4996  (TTTCTACCATCTCCAGCGGAGCAGAATCGGTGGCGGTTTCCGCTT...   \n",
       "19996  SRR5665975.sra.4997  (CCCGTGCGCGCGGCTATACGCCGGGACGTTTCAGCTTTAACGTTC...   \n",
       "19997  SRR5665975.sra.4998  (ATTTTCCGCAAACCAGGCTTTAACCGCGCCAGCGAAGGCTATTTT...   \n",
       "19998  SRR5665975.sra.4999  (TGCAAAATGGTGCTAATTTTGGCCCCTGGCGATTACGCAACTATT...   \n",
       "19999  SRR5665975.sra.5000  (AACGGCAAAATTATCGCCGTTGCCAGCAATATCCCTTCTGACATT...   \n",
       "\n",
       "       target  nottarget  \n",
       "19995       1          0  \n",
       "19996       1          0  \n",
       "19997       1          0  \n",
       "19998       1          0  \n",
       "19999       1          0  "
      ]
     },
     "execution_count": 290,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fastq_df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 20000 entries, 0 to 19999\n",
      "Data columns (total 4 columns):\n",
      "label        20000 non-null object\n",
      "seq_score    20000 non-null object\n",
      "target       20000 non-null int64\n",
      "nottarget    20000 non-null int64\n",
      "dtypes: int64(2), object(2)\n",
      "memory usage: 625.1+ KB\n"
     ]
    }
   ],
   "source": [
    "fastq_df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 292,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# pickle data frame to disk\n",
    "fastq_df.to_pickle('fastq_df.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load pickled data frame\n",
    "fastq_df = pd.read_pickle('fastq_df.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>seq_score</th>\n",
       "      <th>target</th>\n",
       "      <th>nottarget</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>SRR1993099.sra.1</td>\n",
       "      <td>(ATCTACAAGAAGGGTGAAGTGCTTTTCGAATTTTGCCACTGCAAG...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>SRR1993099.sra.2</td>\n",
       "      <td>(TGTGCTCCATGTCGATATTTCGTGGAGCAAACCAAAAAAGATGCG...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>SRR1993099.sra.3</td>\n",
       "      <td>(NACCCTTTATAAAAGCCTAGATGTAGCAGTGCGAAGCGAACTCGA...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SRR1993099.sra.4</td>\n",
       "      <td>(CGAATGGGACCTTGAATGGATTAACGAGATTCCCACTGTCCCTAT...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SRR1993099.sra.5</td>\n",
       "      <td>(CACACCGCAGTAGATGGAAAACTCGAGTTTTACATGGAACAGGTA...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              label                                          seq_score  \\\n",
       "0  SRR1993099.sra.1  (ATCTACAAGAAGGGTGAAGTGCTTTTCGAATTTTGCCACTGCAAG...   \n",
       "1  SRR1993099.sra.2  (TGTGCTCCATGTCGATATTTCGTGGAGCAAACCAAAAAAGATGCG...   \n",
       "2  SRR1993099.sra.3  (NACCCTTTATAAAAGCCTAGATGTAGCAGTGCGAAGCGAACTCGA...   \n",
       "3  SRR1993099.sra.4  (CGAATGGGACCTTGAATGGATTAACGAGATTCCCACTGTCCCTAT...   \n",
       "4  SRR1993099.sra.5  (CACACCGCAGTAGATGGAAAACTCGAGTTTTACATGGAACAGGTA...   \n",
       "\n",
       "   target  nottarget  \n",
       "0       0          1  \n",
       "1       0          1  \n",
       "2       0          1  \n",
       "3       0          1  \n",
       "4       0          1  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show pandas data frame head\n",
    "fastq_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>seq_score</th>\n",
       "      <th>target</th>\n",
       "      <th>nottarget</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>19995</th>\n",
       "      <td>SRR5665975.sra.4996</td>\n",
       "      <td>(TTTCTACCATCTCCAGCGGAGCAGAATCGGTGGCGGTTTCCGCTT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19996</th>\n",
       "      <td>SRR5665975.sra.4997</td>\n",
       "      <td>(CCCGTGCGCGCGGCTATACGCCGGGACGTTTCAGCTTTAACGTTC...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19997</th>\n",
       "      <td>SRR5665975.sra.4998</td>\n",
       "      <td>(ATTTTCCGCAAACCAGGCTTTAACCGCGCCAGCGAAGGCTATTTT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19998</th>\n",
       "      <td>SRR5665975.sra.4999</td>\n",
       "      <td>(TGCAAAATGGTGCTAATTTTGGCCCCTGGCGATTACGCAACTATT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19999</th>\n",
       "      <td>SRR5665975.sra.5000</td>\n",
       "      <td>(AACGGCAAAATTATCGCCGTTGCCAGCAATATCCCTTCTGACATT...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     label                                          seq_score  \\\n",
       "19995  SRR5665975.sra.4996  (TTTCTACCATCTCCAGCGGAGCAGAATCGGTGGCGGTTTCCGCTT...   \n",
       "19996  SRR5665975.sra.4997  (CCCGTGCGCGCGGCTATACGCCGGGACGTTTCAGCTTTAACGTTC...   \n",
       "19997  SRR5665975.sra.4998  (ATTTTCCGCAAACCAGGCTTTAACCGCGCCAGCGAAGGCTATTTT...   \n",
       "19998  SRR5665975.sra.4999  (TGCAAAATGGTGCTAATTTTGGCCCCTGGCGATTACGCAACTATT...   \n",
       "19999  SRR5665975.sra.5000  (AACGGCAAAATTATCGCCGTTGCCAGCAATATCCCTTCTGACATT...   \n",
       "\n",
       "       target  nottarget  \n",
       "19995       1          0  \n",
       "19996       1          0  \n",
       "19997       1          0  \n",
       "19998       1          0  \n",
       "19999       1          0  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fastq_df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 20000 entries, 0 to 19999\n",
      "Data columns (total 4 columns):\n",
      "label        20000 non-null object\n",
      "seq_score    20000 non-null object\n",
      "target       20000 non-null int64\n",
      "nottarget    20000 non-null int64\n",
      "dtypes: int64(2), object(2)\n",
      "memory usage: 625.1+ KB\n"
     ]
    }
   ],
   "source": [
    "fastq_df.info()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fastai data object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup custom fastai data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open sequence image function\n",
    "def open_fastq_image(seq_score:list, seq_len:int=50, cls:type=Image)->Image:\n",
    "    \"Return `Image` object created from sequence string `seq`.\"\n",
    "    \n",
    "    # unpack tuple of lists to separate tuples\n",
    "    if isinstance(seq_score,(np.ndarray)): # fastai dataset loads tuple in np.array???\n",
    "        seq, score = seq_score[0]\n",
    "    else:\n",
    "        seq, score = seq_score\n",
    "    \n",
    "    if seq_len != None:\n",
    "        seq = seq[:seq_len]\n",
    "        score = score[:seq_len]\n",
    "    \n",
    "    int_enc = LabelEncoder() # setup class instance to encode the four different bases to integer values (1D)\n",
    "    #one_hot_enc = OneHotEncoder(categories=[range(4)]) # setup one hot encoder to encode integer encoded classes (1D) to one hot encoded array (4D)\n",
    "    \n",
    "    enc = int_enc.fit_transform(list(seq)) # bases (ACGT) to int (0,1,2,3)\n",
    "    enc = np.array(enc).reshape(-1,1) # reshape to get rank 2 array (from rank 1 array)\n",
    "    #enc = one_hot_enc.fit_transform(enc) # encoded integer encoded bases to sparse matrix (sparse matrix dtype)\n",
    "    #enc = enc.toarray().T#.reshape(1,-1,4) # export sparse matrix to np array\n",
    "    \n",
    "    # one-hot-encode base position\n",
    "    adenosine = enc == 0\n",
    "    thymine = enc == 3\n",
    "    cytosine = enc == 1\n",
    "    \n",
    "    # base pair encoding: AT = False, GC = True\n",
    "    # https://stackoverflow.com/questions/32192163/python-and-operator-on-two-boolean-lists-how\n",
    "    basepair_enc = np.array([not(a or t) for a, t in zip(adenosine, thymine)]).astype('int')*255\n",
    "    \n",
    "    # strandpostion encodes the position of A/C or T/G on the the strand: A or C = False, T or G = True\n",
    "    strandposition_enc = np.array([not(a or c) for a, c in zip(adenosine, cytosine)]).astype('int')*255\n",
    "    \n",
    "    score_enc = np.array(score)\n",
    "    \n",
    "    enc = np.stack((basepair_enc, strandposition_enc, score_enc), axis=0)\n",
    "    \n",
    "    enc = enc.reshape(1,-1,3)\n",
    "    \n",
    "    # https://stackoverflow.com/questions/22902040/convert-black-and-white-array-into-an-image-in-python\n",
    "    x = PIL.Image.fromarray(enc.astype('uint8')).convert('RGB')\n",
    "    x = pil2tensor(x,np.float32)\n",
    "    x.div_(255)\n",
    "    \n",
    "    #pdb.set_trace()\n",
    "    \n",
    "    # optional functions not needed\n",
    "    #x = x.view(4,-1) # remove first dimension\n",
    "    #x = x.expand(3, 4, 50) # expand to 3 channel image\n",
    "    \n",
    "    return cls(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAABADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD4q/YJ/wCTS/i7/wBj74T/APTP4spNG/5RFfGr/r/H/pP4Goor+vuLP+UjOJP+wqr/AOqTBn3fGP8Ayizw9/2G4j/1IifV3/BE7/kzfxb/ANlT+D3/AKdPh/W9+17/AMnk6D/2kH8df+mrwrRRX4rw58Oaf1/zM6p8pwN8XHX/AGFZl/6rpn5ZePv+Q3af9i5pH/putq564/1H/bvH/MUUV+f8Sf8AI/x//X6v/wCpOINOI/8AkbV/8X/tlMZB/qE/3B/Kiiivianxv1f5s8mXxM//2Q==\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAABCAYAAACc0f2yAAAABHNCSVQICAgIfAhkiAAAAGlJREFUCJlNzEEKwkAQRNHXyaggwRzB3P9yQjAEJUO7yETtRUHXL36Q6e8yg4BECClFq1IGsaMjfos8PiKaS9KU2cDXmTvL+2QqRUGphZ61qy7xYBl1w9kb61YN+dKdmLen3ui6VNtt8QEKyixbTwHqlAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "Image (3, 1, 50)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test open sequence image function\n",
    "open_fastq_image(fastq_df.seq_score[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SeqItemList(ImageItemList):\n",
    "    \"Sequence Item List\"\n",
    "    _bunch,_square_show = ImageDataBunch,True\n",
    "    def __post_init__(self):\n",
    "        super().__post_init__()\n",
    "        self.sizes={}\n",
    "    \n",
    "    def open(self, seq_score): return open_fastq_image(seq_score)\n",
    "    \n",
    "    def get(self, i):\n",
    "        seq = self.items[i][0]\n",
    "        res = self.open(seq)\n",
    "        return res\n",
    "    \n",
    "    @classmethod\n",
    "    def import_from_df(cls, df:DataFrame, seq_score:IntsOrStrs=0, **kwargs)->'ItemList':\n",
    "        \"Get the sequences in `col` of `df` and return cls and df.\"\n",
    "        return cls(items=zip(df[seq_score].values), xtra=df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (SeqItemList.import_from_df(fastq_df, ['seq_score'], seq_len=50) #seq_len is not set in SeqItemList yet, only in open_fastq_image !!!\n",
    "        .random_split_by_pct(valid_pct=0.25)\n",
    "        #.split_by_idxs(range(1500), range(1500,2000))\n",
    "        .label_from_df(['target', 'nottarget']) #one_hot=True !\n",
    "        .databunch(bs=bs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verify data object"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check data object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "\n",
       "Train: LabelList\n",
       "y: MultiCategoryList (15000 items)\n",
       "[MultiCategory nottarget, MultiCategory nottarget, MultiCategory nottarget, MultiCategory nottarget, MultiCategory nottarget]...\n",
       "Path: .\n",
       "x: SeqItemList (15000 items)\n",
       "[Image (3, 1, 50), Image (3, 1, 50), Image (3, 1, 50), Image (3, 1, 50), Image (3, 1, 50)]...\n",
       "Path: .;\n",
       "\n",
       "Valid: LabelList\n",
       "y: MultiCategoryList (5000 items)\n",
       "[MultiCategory target, MultiCategory target, MultiCategory nottarget, MultiCategory target, MultiCategory target]...\n",
       "Path: .\n",
       "x: SeqItemList (5000 items)\n",
       "[Image (3, 1, 50), Image (3, 1, 50), Image (3, 1, 50), Image (3, 1, 50), Image (3, 1, 50)]...\n",
       "Path: .;\n",
       "\n",
       "Test: None"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, ['target', 'nottarget'])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check classes\n",
    "data.c, data.classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check batch size\n",
    "data.train_dl.batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check data points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAABADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDA/wCC2X/KOrRP+y8/GH/07+Jq7X9uH/k3v9pT/sMWH/q47eiivxTgX/kTVf8ArzmP/q4y8+v4d/3bB/8AY6w//peWnafsHf8ANB/+yZ6x/wC9Or4c/Zb/AOPzwR/2aX8O/wD1ammUUV/QHAX/ACa7Mv8ArzD/ANImeV9KX/k5HEH/AHM/+k5mfLvx6/5Kxf8A/XrY/wDpJBXKXH+rP+6P60UV8rkf/Iow3/XuH/pKPofFP/k5ed/9heJ/9PTFXoPpRRRXqnwB/9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAABCAYAAACc0f2yAAAABHNCSVQICAgIfAhkiAAAAGBJREFUCJlNzTEOwlAQA9HnjRTEKcL9L0hB+KZIClyNrJGdVpNqAwKpiqC9yrQa0qhKov7SktxeaO+t2+/FXF9JaTQ4Dq99t8MaM7X6lu1pcloexsdaY2az1hecZUtk4geo+jB3hBcTdAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "Image (3, 1, 50)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 2\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory nottarget"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAABADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDP/bT/AOSb6N/2bx4B/wDVnaRX5w6d/wAi/wDtH/8Acc/9L7uiivs/Cf8A5I7xJ/7GWJ/9Oo/srhz4eHf+vOI/9VWfifsN/wDJhH7SH/ZNIf8A06LX6eftVf8AIreA/wDsk3gv/wBRz43UUV+g+LP+54f/ALAMp/8AVfjD+QMX/wAi3KP+wqX55efl7+0B/wAnI/EX/soGt/8ApfPXH9n/AN5f6UUV/OuTf8iPC/8AXml/6apHyvCX/JK5f/14of8ApmkPj/1a/wC6KKKK9Y9s/9k=\n",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAADIAAAABCAYAAACc0f2yAAAABHNCSVQICAgIfAhkiAAAAF9JREFUCJlFyzEOwkAAA8HxXaDiFeH/v6MBEpkiEFxY8mqdalMqxJGSUAWRA/p6etJ/enZK83v12EiirSS0h1PcV+ucJsYYGMaA3d5h5mV/LnKtWETtW43Lw3u7WWZ8ALB/KHh7WYNvAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "Image (3, 1, 50)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = 3\n",
    "data.x[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiCategory nottarget"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.y[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x864 with 9 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom fastai ResNet model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = create_cnn(data, models.resnet18, metrics=accuracy_thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BCEWithLogitsLoss()"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.loss_func.func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "#learn.model[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): Lambda()\n",
       "  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25)\n",
       "  (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (5): ReLU(inplace)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.model[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train basic model with fastai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 315,
   "metadata": {},
   "outputs": [],
   "source": [
    "#learn = Learner(data, net_basic_fastai, loss_func=F.binary_cross_entropy_with_logits, metrics=accuracy_thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "======================================================================\n",
      "Layer (type)         Output Shape         Param #    Trainable \n",
      "======================================================================\n",
      "Conv2d               [64, 64, 1, 25]      9408       False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 25]      128        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 64, 1, 25]      0          False     \n",
      "______________________________________________________________________\n",
      "MaxPool2d            [64, 64, 1, 13]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 64, 1, 13]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 64, 1, 13]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 64, 1, 13]      36864      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 64, 1, 13]      128        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      73728      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 128, 1, 7]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      147456     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      8192       False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      147456     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 128, 1, 7]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 128, 1, 7]      147456     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 128, 1, 7]      256        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      294912     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 256, 1, 4]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      589824     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      32768      False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      589824     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 256, 1, 4]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 256, 1, 4]      589824     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 256, 1, 4]      512        True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      1179648    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 512, 1, 2]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      2359296    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      131072     False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      2359296    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 512, 1, 2]      0          False     \n",
      "______________________________________________________________________\n",
      "Conv2d               [64, 512, 1, 2]      2359296    False     \n",
      "______________________________________________________________________\n",
      "BatchNorm2d          [64, 512, 1, 2]      1024       True      \n",
      "______________________________________________________________________\n",
      "AdaptiveAvgPool2d    [64, 512, 1, 1]      0          False     \n",
      "______________________________________________________________________\n",
      "AdaptiveMaxPool2d    [64, 512, 1, 1]      0          False     \n",
      "______________________________________________________________________\n",
      "Lambda               [64, 1024]           0          False     \n",
      "______________________________________________________________________\n",
      "BatchNorm1d          [64, 1024]           2048       True      \n",
      "______________________________________________________________________\n",
      "Dropout              [64, 1024]           0          False     \n",
      "______________________________________________________________________\n",
      "Linear               [64, 512]            524800     True      \n",
      "______________________________________________________________________\n",
      "ReLU                 [64, 512]            0          False     \n",
      "______________________________________________________________________\n",
      "BatchNorm1d          [64, 512]            1024       True      \n",
      "______________________________________________________________________\n",
      "Dropout              [64, 512]            0          False     \n",
      "______________________________________________________________________\n",
      "Linear               [64, 2]              1026       True      \n",
      "______________________________________________________________________\n",
      "\n",
      "Total params:  11705410\n",
      "Total trainable params:  538498\n",
      "Total non-trainable params:  11166912\n"
     ]
    }
   ],
   "source": [
    "learn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.lr_find()\n",
    "learn.recorder.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 50:06 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy_thresh</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>0.432172</th>\n",
       "    <th>0.960091</th>\n",
       "    <th>0.520400</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>0.187206</th>\n",
       "    <th>3.096432</th>\n",
       "    <th>0.525900</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>0.144700</th>\n",
       "    <th>2.654047</th>\n",
       "    <th>0.530000</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>0.120802</th>\n",
       "    <th>0.110985</th>\n",
       "    <th>0.948200</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(4, max_lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 49:36 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy_thresh</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>0.144766</th>\n",
       "    <th>1.269982</th>\n",
       "    <th>0.513300</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>0.149426</th>\n",
       "    <th>1.271635</th>\n",
       "    <th>0.512700</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>0.123069</th>\n",
       "    <th>0.155888</th>\n",
       "    <th>0.938300</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>0.107704</th>\n",
       "    <th>0.106997</th>\n",
       "    <th>0.950500</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(4, max_lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_losses()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.recorder.plot_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('resnet18_epoch-8')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interpretation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpret = learn.interpret()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "#interpret.confusion_matrix() # does not work yet!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "#debug_memory()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}