[d7cf27]: / src / examples / keras_convnet_example_w_sequence.ipynb

Download this file

826 lines (825 with data), 37.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict regulatory regions from the DNA sequence using keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is similar to the keras tutorial, but instead of using the Janggu Dataset objects directly, we shall use the built-in Sequence class. This approach is also possible with the new keras and tensorflow (v2) interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "from keras import Model\n",
    "from keras import backend as K\n",
    "from keras.layers import Conv2D\n",
    "\n",
    "from keras.layers import Dense\n",
    "from keras.layers import GlobalAveragePooling2D\n",
    "from keras.layers import Input\n",
    "from keras.layers import Maximum\n",
    "\n",
    "from pkg_resources import resource_filename\n",
    "\n",
    "from janggu.data import Bioseq\n",
    "from janggu.data import Cover\n",
    "from janggu.data import ReduceDim\n",
    "from janggu.layers import Reverse, Complement\n",
    "from janggu.data import JangguSequence\n",
    "\n",
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('1.14.0', '2.2.5')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import keras, tensorflow\n",
    "\n",
    "tensorflow.__version__, keras.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need to specify the output directory in which the results are stored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['JANGGU_OUTPUT'] = '/home/wkopp/janggu_examples'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specify the DNA sequence feature order. Order 1, 2 and 3 correspond to mono-, di- and tri-nucleotide based features (see Tutorial)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "order = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the dataset\n",
    "# The pseudo genome represents just a concatenation of all sequences\n",
    "# in sample.fa and sample2.fa. Therefore, the results should be almost\n",
    "# identically to the models obtained from classify_fasta.py.\n",
    "REFGENOME = resource_filename('janggu', 'resources/pseudo_genome.fa')\n",
    "# ROI contains regions spanning positive and negative examples\n",
    "ROI_TRAIN_FILE = resource_filename('janggu', 'resources/roi_train.bed')\n",
    "ROI_TEST_FILE = resource_filename('janggu', 'resources/roi_test.bed')\n",
    "# PEAK_FILE only contains positive examples\n",
    "PEAK_FILE = resource_filename('janggu', 'resources/scores.bed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the datasets for training and testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training input and labels are purely defined genomic coordinates\n",
    "DNA = Bioseq.create_from_refgenome('dna', refgenome=REFGENOME,\n",
    "                                   roi=ROI_TRAIN_FILE,\n",
    "                                   binsize=200,\n",
    "                                   order=order,\n",
    "                                   cache=True)\n",
    "\n",
    "# The ReduceDim wrapper transforms the dataset from a 4D object to a 2D table-like representation\n",
    "LABELS = ReduceDim(Cover.create_from_bed('peaks', roi=ROI_TRAIN_FILE,\n",
    "                               bedfiles=PEAK_FILE,\n",
    "                               binsize=200,\n",
    "                               resolution=200,\n",
    "                               cache=True,\n",
    "                               storage='sparse'))\n",
    "\n",
    "\n",
    "DNA_TEST = Bioseq.create_from_refgenome('dna', refgenome=REFGENOME,\n",
    "                                        roi=ROI_TEST_FILE,\n",
    "                                        binsize=200,\n",
    "                                        order=order)\n",
    "\n",
    "LABELS_TEST = ReduceDim(Cover.create_from_bed('peaks',\n",
    "                                    bedfiles=PEAK_FILE,\n",
    "                                    roi=ROI_TEST_FILE,\n",
    "                                    binsize=200,\n",
    "                                    resolution=200,\n",
    "                                    storage='sparse'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training set: (7797, 198, 1, 64) (7797, 1)\n",
      "test set: (200, 198, 1, 64) (200, 1)\n"
     ]
    }
   ],
   "source": [
    "print('training set:', DNA.shape, LABELS.shape)\n",
    "print('test set:', DNA_TEST.shape, LABELS_TEST.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specify a keras model with compatible dimesions input and output dimensions for the example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:66: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:541: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4432: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/keras/optimizers.py:793: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3657: The name tf.log is deprecated. Please use tf.math.log instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "Model: \"model_1\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, 198, 1, 64)   0                                            \n",
      "__________________________________________________________________________________________________\n",
      "reverse_1 (Reverse)             (None, 198, 1, 64)   0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "complement_1 (Complement)       (None, 198, 1, 64)   0           reverse_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_1 (Conv2D)               (None, 178, 1, 30)   40350       input_1[0][0]                    \n",
      "                                                                 complement_1[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "maximum_1 (Maximum)             (None, 178, 1, 30)   0           conv2d_1[0][0]                   \n",
      "                                                                 conv2d_1[1][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "global_average_pooling2d_1 (Glo (None, 30)           0           maximum_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 1)            31          global_average_pooling2d_1[0][0] \n",
      "==================================================================================================\n",
      "Total params: 40,381\n",
      "Trainable params: 40,381\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "xin = Input((200 - order + 1, 1, pow(4, order)))\n",
    "\n",
    "# use layer to scan both strands of the DNA\n",
    "convlayer = Conv2D(30, (21, 1), activation='relu')\n",
    "\n",
    "# first the forward strand\n",
    "forward = convlayer(xin)\n",
    "# then the reverse complemented sequence\n",
    "reverse = convlayer(Complement()(Reverse()(xin)))\n",
    "\n",
    "#merge the results by using the maximum activation wrt the strands\n",
    "layer = Maximum()([forward, reverse])\n",
    "\n",
    "layer = GlobalAveragePooling2D()(layer)\n",
    "output = Dense(1, activation='sigmoid')(layer)\n",
    "\n",
    "model = Model(xin, output)\n",
    "\n",
    "model.compile(optimizer='adadelta', loss='binary_crossentropy',\n",
    "              metrics=['acc'])\n",
    "model.summary()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainseq = JangguSequence(DNA, LABELS, as_dict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/wkopp/anaconda3/envs/jdev1/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:1033: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n",
      "\n",
      "Epoch 1/100\n",
      "244/244 [==============================] - 10s 40ms/step - loss: 0.6216 - acc: 0.6472\n",
      "Epoch 2/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.5521 - acc: 0.7284\n",
      "Epoch 3/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.5313 - acc: 0.7427\n",
      "Epoch 4/100\n",
      "244/244 [==============================] - 12s 49ms/step - loss: 0.4979 - acc: 0.7717\n",
      "Epoch 5/100\n",
      "244/244 [==============================] - 15s 61ms/step - loss: 0.4840 - acc: 0.7726\n",
      "Epoch 6/100\n",
      "244/244 [==============================] - 15s 61ms/step - loss: 0.4699 - acc: 0.7812\n",
      "Epoch 7/100\n",
      "244/244 [==============================] - 15s 63ms/step - loss: 0.4639 - acc: 0.7855\n",
      "Epoch 8/100\n",
      "244/244 [==============================] - 15s 61ms/step - loss: 0.4352 - acc: 0.8015\n",
      "Epoch 9/100\n",
      "244/244 [==============================] - 15s 60ms/step - loss: 0.4258 - acc: 0.8076\n",
      "Epoch 10/100\n",
      "244/244 [==============================] - 16s 64ms/step - loss: 0.4179 - acc: 0.8139\n",
      "Epoch 11/100\n",
      "244/244 [==============================] - 17s 70ms/step - loss: 0.3992 - acc: 0.8210\n",
      "Epoch 12/100\n",
      "244/244 [==============================] - 18s 73ms/step - loss: 0.3872 - acc: 0.8309\n",
      "Epoch 13/100\n",
      "244/244 [==============================] - 18s 73ms/step - loss: 0.3703 - acc: 0.8411\n",
      "Epoch 14/100\n",
      "244/244 [==============================] - 17s 70ms/step - loss: 0.3627 - acc: 0.8471\n",
      "Epoch 15/100\n",
      "244/244 [==============================] - 16s 65ms/step - loss: 0.3503 - acc: 0.8560\n",
      "Epoch 16/100\n",
      "244/244 [==============================] - 16s 64ms/step - loss: 0.3320 - acc: 0.8633\n",
      "Epoch 17/100\n",
      "244/244 [==============================] - 17s 71ms/step - loss: 0.3180 - acc: 0.8733\n",
      "Epoch 18/100\n",
      "244/244 [==============================] - 18s 75ms/step - loss: 0.3035 - acc: 0.8804\n",
      "Epoch 19/100\n",
      "244/244 [==============================] - 19s 78ms/step - loss: 0.2894 - acc: 0.8854\n",
      "Epoch 20/100\n",
      "244/244 [==============================] - 18s 74ms/step - loss: 0.2742 - acc: 0.8950\n",
      "Epoch 21/100\n",
      "244/244 [==============================] - 11s 45ms/step - loss: 0.2631 - acc: 0.9034\n",
      "Epoch 22/100\n",
      "244/244 [==============================] - 11s 45ms/step - loss: 0.2505 - acc: 0.9073\n",
      "Epoch 23/100\n",
      "244/244 [==============================] - 11s 44ms/step - loss: 0.2384 - acc: 0.9127\n",
      "Epoch 24/100\n",
      "244/244 [==============================] - 13s 54ms/step - loss: 0.2302 - acc: 0.9176\n",
      "Epoch 25/100\n",
      "244/244 [==============================] - 11s 46ms/step - loss: 0.2188 - acc: 0.9224\n",
      "Epoch 26/100\n",
      "244/244 [==============================] - 11s 47ms/step - loss: 0.2073 - acc: 0.9260\n",
      "Epoch 27/100\n",
      "244/244 [==============================] - 11s 45ms/step - loss: 0.2008 - acc: 0.9291\n",
      "Epoch 28/100\n",
      "244/244 [==============================] - 11s 45ms/step - loss: 0.1923 - acc: 0.9327\n",
      "Epoch 29/100\n",
      "244/244 [==============================] - 10s 42ms/step - loss: 0.1830 - acc: 0.9356\n",
      "Epoch 30/100\n",
      "244/244 [==============================] - 10s 41ms/step - loss: 0.1719 - acc: 0.9442\n",
      "Epoch 31/100\n",
      "244/244 [==============================] - 10s 42ms/step - loss: 0.1700 - acc: 0.9423\n",
      "Epoch 32/100\n",
      "244/244 [==============================] - 10s 41ms/step - loss: 0.1630 - acc: 0.9459\n",
      "Epoch 33/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1550 - acc: 0.9486\n",
      "Epoch 34/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1492 - acc: 0.9506\n",
      "Epoch 35/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1481 - acc: 0.9511\n",
      "Epoch 36/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1407 - acc: 0.9538\n",
      "Epoch 37/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1372 - acc: 0.9546\n",
      "Epoch 38/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1320 - acc: 0.9565\n",
      "Epoch 39/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1270 - acc: 0.9577\n",
      "Epoch 40/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1220 - acc: 0.9598\n",
      "Epoch 41/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1199 - acc: 0.9606\n",
      "Epoch 42/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1142 - acc: 0.9633\n",
      "Epoch 43/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1120 - acc: 0.9641\n",
      "Epoch 44/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1073 - acc: 0.9654\n",
      "Epoch 45/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1063 - acc: 0.9679\n",
      "Epoch 46/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.1021 - acc: 0.9693\n",
      "Epoch 47/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.1005 - acc: 0.9675\n",
      "Epoch 48/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.0972 - acc: 0.9683\n",
      "Epoch 49/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0951 - acc: 0.9684\n",
      "Epoch 50/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0923 - acc: 0.9694\n",
      "Epoch 51/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0896 - acc: 0.9740\n",
      "Epoch 52/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0854 - acc: 0.9748\n",
      "Epoch 53/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0847 - acc: 0.9748\n",
      "Epoch 54/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0807 - acc: 0.9749\n",
      "Epoch 55/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0800 - acc: 0.9753\n",
      "Epoch 56/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0769 - acc: 0.9766\n",
      "Epoch 57/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0761 - acc: 0.9767\n",
      "Epoch 58/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0743 - acc: 0.9784\n",
      "Epoch 59/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0718 - acc: 0.9777\n",
      "Epoch 60/100\n",
      "244/244 [==============================] - 9s 39ms/step - loss: 0.0703 - acc: 0.9789\n",
      "Epoch 61/100\n",
      "244/244 [==============================] - 10s 39ms/step - loss: 0.0673 - acc: 0.9791\n",
      "Epoch 62/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.0671 - acc: 0.9789\n",
      "Epoch 63/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0661 - acc: 0.9809\n",
      "Epoch 64/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0637 - acc: 0.9813\n",
      "Epoch 65/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0603 - acc: 0.9834\n",
      "Epoch 66/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0590 - acc: 0.9835\n",
      "Epoch 67/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0600 - acc: 0.9812\n",
      "Epoch 68/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0575 - acc: 0.9848\n",
      "Epoch 69/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0567 - acc: 0.9845\n",
      "Epoch 70/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0560 - acc: 0.9848\n",
      "Epoch 71/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0544 - acc: 0.9855\n",
      "Epoch 72/100\n",
      "244/244 [==============================] - 9s 38ms/step - loss: 0.0531 - acc: 0.9860\n",
      "Epoch 73/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0515 - acc: 0.9859\n",
      "Epoch 74/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0502 - acc: 0.9869\n",
      "Epoch 75/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0500 - acc: 0.9878\n",
      "Epoch 76/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0485 - acc: 0.9873\n",
      "Epoch 77/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0475 - acc: 0.9868\n",
      "Epoch 78/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0462 - acc: 0.9892\n",
      "Epoch 79/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0453 - acc: 0.9885\n",
      "Epoch 80/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0446 - acc: 0.9894\n",
      "Epoch 81/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0433 - acc: 0.9883\n",
      "Epoch 82/100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0430 - acc: 0.9898\n",
      "Epoch 83/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0410 - acc: 0.9899\n",
      "Epoch 84/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0414 - acc: 0.9899\n",
      "Epoch 85/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0396 - acc: 0.9908\n",
      "Epoch 86/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0390 - acc: 0.9913\n",
      "Epoch 87/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0386 - acc: 0.9900\n",
      "Epoch 88/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0383 - acc: 0.9907\n",
      "Epoch 89/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0368 - acc: 0.9912\n",
      "Epoch 90/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0368 - acc: 0.9912\n",
      "Epoch 91/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0355 - acc: 0.9918\n",
      "Epoch 92/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0343 - acc: 0.9924\n",
      "Epoch 93/100\n",
      "244/244 [==============================] - 9s 37ms/step - loss: 0.0337 - acc: 0.9926\n",
      "Epoch 94/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0332 - acc: 0.9935\n",
      "Epoch 95/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0323 - acc: 0.9926\n",
      "Epoch 96/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0321 - acc: 0.9937\n",
      "Epoch 97/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0314 - acc: 0.9932\n",
      "Epoch 98/100\n",
      "244/244 [==============================] - 9s 35ms/step - loss: 0.0307 - acc: 0.9941\n",
      "Epoch 99/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0301 - acc: 0.9933\n",
      "Epoch 100/100\n",
      "244/244 [==============================] - 9s 36ms/step - loss: 0.0291 - acc: 0.9940\n",
      "########################################\n",
      "loss: 0.029113427234549117, acc: 0.9939720405284084\n",
      "########################################\n"
     ]
    }
   ],
   "source": [
    "hist = model.fit(trainseq, epochs=100)\n",
    "\n",
    "print('#' * 40)\n",
    "print('loss: {}, acc: {}'.format(hist.history['loss'][-1],\n",
    "                                 hist.history['acc'][-1]))\n",
    "print('#' * 40)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The predictions may be converted back to Cover object and subsequently exported as bigwig in order to inspect the plausibility of the results in the genome browser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "testseq = JangguSequence(DNA_TEST, as_dict=False)\n",
    "\n",
    "# convert the prediction to a cover object\n",
    "pred = model.predict(testseq)\n",
    "cov_pred = Cover.create_from_array('BindingProba', pred, LABELS_TEST.gindexer)\n",
    "\n",
    "# predictions (or feature activities) can finally be exported to bigwig\n",
    "cov_pred.export_to_bigwig(output_dir=os.environ['JANGGU_OUTPUT'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[9.99151051e-01],\n",
       "       [9.75318134e-01],\n",
       "       [9.96574104e-01],\n",
       "       [9.95460033e-01],\n",
       "       [9.96464014e-01],\n",
       "       [8.64547014e-01],\n",
       "       [9.96426702e-01],\n",
       "       [9.99921978e-01],\n",
       "       [9.89611328e-01],\n",
       "       [9.98907030e-01],\n",
       "       [9.35466111e-01],\n",
       "       [9.86055791e-01],\n",
       "       [9.26077247e-01],\n",
       "       [9.66469288e-01],\n",
       "       [9.70292747e-01],\n",
       "       [9.98694539e-01],\n",
       "       [9.99305248e-01],\n",
       "       [9.97891068e-01],\n",
       "       [9.81406808e-01],\n",
       "       [9.99887824e-01],\n",
       "       [9.92626667e-01],\n",
       "       [9.98514771e-01],\n",
       "       [9.99931455e-01],\n",
       "       [9.99961138e-01],\n",
       "       [9.94316578e-01],\n",
       "       [9.97490048e-01],\n",
       "       [8.59833419e-01],\n",
       "       [9.17471051e-01],\n",
       "       [9.73466277e-01],\n",
       "       [9.76478577e-01],\n",
       "       [9.99038696e-01],\n",
       "       [9.97492373e-01],\n",
       "       [8.55449438e-01],\n",
       "       [9.97820377e-01],\n",
       "       [8.67090762e-01],\n",
       "       [7.52588272e-01],\n",
       "       [9.99957919e-01],\n",
       "       [9.38101947e-01],\n",
       "       [9.83132601e-01],\n",
       "       [9.99170959e-01],\n",
       "       [9.98886287e-01],\n",
       "       [9.99874949e-01],\n",
       "       [9.68174934e-01],\n",
       "       [3.78583193e-01],\n",
       "       [9.82189059e-01],\n",
       "       [7.07389772e-01],\n",
       "       [9.17288303e-01],\n",
       "       [9.28278089e-01],\n",
       "       [9.91107821e-01],\n",
       "       [9.99858558e-01],\n",
       "       [9.93714035e-01],\n",
       "       [9.48678851e-01],\n",
       "       [9.99234915e-01],\n",
       "       [9.93917823e-01],\n",
       "       [9.99332309e-01],\n",
       "       [9.99739230e-01],\n",
       "       [3.62716645e-01],\n",
       "       [9.93830442e-01],\n",
       "       [7.78486133e-01],\n",
       "       [9.32922006e-01],\n",
       "       [9.98838902e-01],\n",
       "       [9.88024175e-01],\n",
       "       [9.94903624e-01],\n",
       "       [9.98093247e-01],\n",
       "       [7.89099932e-01],\n",
       "       [9.39794064e-01],\n",
       "       [9.99178767e-01],\n",
       "       [9.48063254e-01],\n",
       "       [9.98469710e-01],\n",
       "       [9.98108506e-01],\n",
       "       [9.95718956e-01],\n",
       "       [9.93711233e-01],\n",
       "       [9.99210954e-01],\n",
       "       [9.83758092e-01],\n",
       "       [5.87057829e-01],\n",
       "       [9.95499671e-01],\n",
       "       [9.85926390e-01],\n",
       "       [9.98869359e-01],\n",
       "       [9.78308201e-01],\n",
       "       [9.96456385e-01],\n",
       "       [9.99986291e-01],\n",
       "       [9.85248029e-01],\n",
       "       [9.85621572e-01],\n",
       "       [9.29056883e-01],\n",
       "       [9.99831975e-01],\n",
       "       [9.65004444e-01],\n",
       "       [9.84891772e-01],\n",
       "       [9.94043827e-01],\n",
       "       [9.54253972e-01],\n",
       "       [9.60171700e-01],\n",
       "       [9.68681812e-01],\n",
       "       [9.75257874e-01],\n",
       "       [9.59723949e-01],\n",
       "       [9.98453915e-01],\n",
       "       [9.97150302e-01],\n",
       "       [9.99207020e-01],\n",
       "       [9.78554606e-01],\n",
       "       [9.90009785e-01],\n",
       "       [9.95299459e-01],\n",
       "       [9.99656081e-01],\n",
       "       [1.35146677e-02],\n",
       "       [8.34721327e-03],\n",
       "       [3.27438116e-04],\n",
       "       [9.67752934e-03],\n",
       "       [2.69535184e-02],\n",
       "       [9.81986523e-05],\n",
       "       [3.49961698e-01],\n",
       "       [4.88758087e-05],\n",
       "       [6.55651093e-06],\n",
       "       [4.32524681e-02],\n",
       "       [8.71717930e-05],\n",
       "       [2.53319740e-06],\n",
       "       [9.21566784e-02],\n",
       "       [2.68414021e-02],\n",
       "       [6.02006912e-06],\n",
       "       [2.11596489e-06],\n",
       "       [5.16249716e-01],\n",
       "       [4.17162776e-02],\n",
       "       [3.27658653e-03],\n",
       "       [9.49031115e-03],\n",
       "       [2.10910738e-02],\n",
       "       [1.76423788e-03],\n",
       "       [1.61764026e-02],\n",
       "       [1.12227499e-02],\n",
       "       [2.44669288e-01],\n",
       "       [1.06602907e-03],\n",
       "       [5.06916642e-03],\n",
       "       [1.52856112e-04],\n",
       "       [9.04305577e-01],\n",
       "       [1.49846077e-04],\n",
       "       [3.67297202e-01],\n",
       "       [2.70605087e-05],\n",
       "       [6.59823418e-04],\n",
       "       [3.25759649e-02],\n",
       "       [6.39557838e-05],\n",
       "       [1.47734582e-02],\n",
       "       [5.54416478e-02],\n",
       "       [9.39120710e-01],\n",
       "       [5.08606136e-02],\n",
       "       [9.50425863e-04],\n",
       "       [4.90278006e-04],\n",
       "       [4.81268764e-03],\n",
       "       [2.31248140e-03],\n",
       "       [1.67489052e-05],\n",
       "       [1.44871920e-01],\n",
       "       [5.42476773e-02],\n",
       "       [8.04282069e-01],\n",
       "       [2.62717903e-02],\n",
       "       [1.73801184e-03],\n",
       "       [7.49647617e-04],\n",
       "       [3.39749455e-03],\n",
       "       [2.62567401e-03],\n",
       "       [7.17562437e-03],\n",
       "       [6.17504120e-05],\n",
       "       [2.29504704e-03],\n",
       "       [5.80549240e-04],\n",
       "       [1.08796954e-02],\n",
       "       [1.48217678e-02],\n",
       "       [3.58194113e-04],\n",
       "       [8.82148743e-05],\n",
       "       [4.51746583e-03],\n",
       "       [2.05529898e-01],\n",
       "       [3.33523750e-03],\n",
       "       [1.55711174e-03],\n",
       "       [9.24408436e-04],\n",
       "       [1.54112071e-01],\n",
       "       [3.70740891e-04],\n",
       "       [5.39422035e-05],\n",
       "       [1.02728605e-04],\n",
       "       [6.61611557e-06],\n",
       "       [7.57816076e-01],\n",
       "       [1.35952234e-03],\n",
       "       [2.37476826e-03],\n",
       "       [3.35013866e-03],\n",
       "       [8.41557980e-04],\n",
       "       [2.10243464e-03],\n",
       "       [7.53968954e-04],\n",
       "       [4.60126996e-03],\n",
       "       [1.18502975e-03],\n",
       "       [1.31132960e-01],\n",
       "       [7.97893405e-02],\n",
       "       [1.54096812e-01],\n",
       "       [4.36970592e-03],\n",
       "       [6.08861446e-05],\n",
       "       [4.50151533e-01],\n",
       "       [5.43231845e-01],\n",
       "       [3.05288196e-01],\n",
       "       [8.88526440e-04],\n",
       "       [1.91780925e-03],\n",
       "       [2.67830491e-03],\n",
       "       [2.91100144e-03],\n",
       "       [7.61473179e-03],\n",
       "       [2.49683857e-04],\n",
       "       [3.79794955e-01],\n",
       "       [1.20103359e-03],\n",
       "       [5.83350658e-04],\n",
       "       [6.33895397e-05],\n",
       "       [4.75925207e-03],\n",
       "       [1.54508948e-02],\n",
       "       [4.18242007e-01]], dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC: 0.9947\n"
     ]
    }
   ],
   "source": [
    "print(\"AUC:\", roc_auc_score(LABELS_TEST[:], pred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}