--- a +++ b/src/examples/keras_convnet_example.ipynb @@ -0,0 +1,533 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Predict regulatory regions from the DNA sequence using keras" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook we illustrate several variants how to predict regulatory regions (of a toy example) from the DNA sequence.\n", + "The reference genome is made up of a concatenation of Oct4 and Mafk binding sites and we shall use all regions on chromosome 'pseudo1' as training\n", + "and 'pseudo2' as test chromosomes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/wkopp/anaconda3/envs/jdev/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", + " from ._conv import register_converters as _register_converters\n", + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "from keras import Model\n", + "from keras import backend as K\n", + "from keras.layers import Conv2D\n", + "\n", + "from keras.layers import Dense\n", + "from keras.layers import GlobalAveragePooling2D\n", + "from keras.layers import Input\n", + "\n", + "from pkg_resources import resource_filename\n", + "\n", + "from janggu.data import Bioseq\n", + "from janggu.data import Cover\n", + "from janggu.data import ReduceDim\n", + "from janggu.layers import DnaConv2D\n", + "\n", + "from sklearn.metrics import roc_auc_score" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "\n", + "np.random.seed(1234)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we need to specify the output directory in which the results are stored." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ['JANGGU_OUTPUT'] = '/home/wkopp/janggu_examples'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify the DNA sequence feature order. Order 1, 2 and 3 correspond to mono-, di- and tri-nucleotide based features (see Tutorial)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "order = 3" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# load the dataset\n", + "# The pseudo genome represents just a concatenation of all sequences\n", + "# in sample.fa and sample2.fa. Therefore, the results should be almost\n", + "# identically to the models obtained from classify_fasta.py.\n", + "REFGENOME = resource_filename('janggu', 'resources/pseudo_genome.fa')\n", + "# ROI contains regions spanning positive and negative examples\n", + "ROI_TRAIN_FILE = resource_filename('janggu', 'resources/roi_train.bed')\n", + "ROI_TEST_FILE = resource_filename('janggu', 'resources/roi_test.bed')\n", + "# PEAK_FILE only contains positive examples\n", + "PEAK_FILE = resource_filename('janggu', 'resources/scores.bed')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the datasets for training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading from lazy loader\n", + "reload /home/wkopp/janggu_examples/datasets/dna/09dc9f2602f1bea4bf22fd52b8adfa21c7500a31f70c0ef7a506f79a4f92b43a.npz\n", + "loading from bed lazy loader\n", + "reload /home/wkopp/janggu_examples/datasets/peaks/fd9826cf7fb9cc044a6c1354a14688c1be0f0bd9c593fdba2e9af3284a2be099.npz\n", + "loading from lazy loader\n", + "loading from bed lazy loader\n" + ] + } + ], + "source": [ + "# Training input and labels are purely defined genomic coordinates\n", + "DNA = Bioseq.create_from_refgenome('dna', refgenome=REFGENOME,\n", + " roi=ROI_TRAIN_FILE,\n", + " binsize=200,\n", + " order=order,\n", + " cache=True)\n", + "\n", + "# The ReduceDim wrapper transforms the dataset from a 4D object to a 2D table-like representation\n", + "LABELS = ReduceDim(Cover.create_from_bed('peaks', roi=ROI_TRAIN_FILE,\n", + " bedfiles=PEAK_FILE,\n", + " binsize=200,\n", + " resolution=200,\n", + " cache=True,\n", + " storage='sparse'))\n", + "\n", + "\n", + "DNA_TEST = Bioseq.create_from_refgenome('dna', refgenome=REFGENOME,\n", + " roi=ROI_TEST_FILE,\n", + " binsize=200,\n", + " order=order)\n", + "\n", + "LABELS_TEST = ReduceDim(Cover.create_from_bed('peaks',\n", + " bedfiles=PEAK_FILE,\n", + " roi=ROI_TEST_FILE,\n", + " binsize=200,\n", + " resolution=200,\n", + " storage='sparse'))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training set: (7797, 198, 1, 64) (7797, 1)\n", + "test set: (200, 198, 1, 64) (200, 1)\n" + ] + } + ], + "source": [ + "print('training set:', DNA.shape, LABELS.shape)\n", + "print('test set:', DNA_TEST.shape, LABELS_TEST.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Specify a keras model with compatible dimesions input and output dimensions for the example." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) (None, 198, 1, 64) 0 \n", + "_________________________________________________________________\n", + "dna_conv2d_1 (DnaConv2D) (None, 178, 1, 30) 40350 \n", + "_________________________________________________________________\n", + "global_average_pooling2d_1 ( (None, 30) 0 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 1) 31 \n", + "=================================================================\n", + "Total params: 40,381\n", + "Trainable params: 40,381\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "xin = Input((200 - order + 1, 1, pow(4, order)))\n", + "layer = DnaConv2D(Conv2D(30, (21, 1),\n", + " activation='relu'))(xin)\n", + "layer = GlobalAveragePooling2D()(layer)\n", + "output = Dense(1, activation='sigmoid')(layer)\n", + "\n", + "\n", + "\n", + "model = Model(xin, output)\n", + "\n", + "model.compile(optimizer='adadelta', loss='binary_crossentropy',\n", + " metrics=['acc'])\n", + "model.summary()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/100\n", + "7797/7797 [==============================] - 5s 594us/step - loss: 0.5611 - acc: 0.7354\n", + "Epoch 2/100\n", + "7797/7797 [==============================] - 4s 453us/step - loss: 0.4607 - acc: 0.7980\n", + "Epoch 3/100\n", + "7797/7797 [==============================] - 3s 429us/step - loss: 0.4115 - acc: 0.8222\n", + "Epoch 4/100\n", + "7797/7797 [==============================] - 3s 409us/step - loss: 0.3785 - acc: 0.8424\n", + "Epoch 5/100\n", + "7797/7797 [==============================] - 3s 424us/step - loss: 0.3513 - acc: 0.8551\n", + "Epoch 6/100\n", + "7797/7797 [==============================] - 3s 434us/step - loss: 0.3256 - acc: 0.8657\n", + "Epoch 7/100\n", + "7797/7797 [==============================] - 3s 400us/step - loss: 0.3037 - acc: 0.8821\n", + "Epoch 8/100\n", + "7797/7797 [==============================] - 3s 396us/step - loss: 0.2801 - acc: 0.8916\n", + "Epoch 9/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.2602 - acc: 0.9006\n", + "Epoch 10/100\n", + "7797/7797 [==============================] - 3s 408us/step - loss: 0.2412 - acc: 0.9127\n", + "Epoch 11/100\n", + "7797/7797 [==============================] - 3s 394us/step - loss: 0.2240 - acc: 0.9209\n", + "Epoch 12/100\n", + "7797/7797 [==============================] - 3s 441us/step - loss: 0.2083 - acc: 0.9278\n", + "Epoch 13/100\n", + "7797/7797 [==============================] - 3s 443us/step - loss: 0.1951 - acc: 0.9318\n", + "Epoch 14/100\n", + "7797/7797 [==============================] - 3s 404us/step - loss: 0.1822 - acc: 0.9396\n", + "Epoch 15/100\n", + "7797/7797 [==============================] - 3s 401us/step - loss: 0.1712 - acc: 0.9425\n", + "Epoch 16/100\n", + "7797/7797 [==============================] - 3s 415us/step - loss: 0.1621 - acc: 0.9450\n", + "Epoch 17/100\n", + "7797/7797 [==============================] - 3s 403us/step - loss: 0.1535 - acc: 0.9484\n", + "Epoch 18/100\n", + "7797/7797 [==============================] - 3s 408us/step - loss: 0.1448 - acc: 0.9511\n", + "Epoch 19/100\n", + "7797/7797 [==============================] - 3s 420us/step - loss: 0.1387 - acc: 0.9534\n", + "Epoch 20/100\n", + "7797/7797 [==============================] - 3s 433us/step - loss: 0.1312 - acc: 0.9595\n", + "Epoch 21/100\n", + "7797/7797 [==============================] - 3s 410us/step - loss: 0.1249 - acc: 0.9593\n", + "Epoch 22/100\n", + "7797/7797 [==============================] - 3s 409us/step - loss: 0.1196 - acc: 0.9614\n", + "Epoch 23/100\n", + "7797/7797 [==============================] - 3s 415us/step - loss: 0.1136 - acc: 0.9647\n", + "Epoch 24/100\n", + "7797/7797 [==============================] - 3s 431us/step - loss: 0.1091 - acc: 0.9663\n", + "Epoch 25/100\n", + "7797/7797 [==============================] - 3s 409us/step - loss: 0.1039 - acc: 0.9668\n", + "Epoch 26/100\n", + "7797/7797 [==============================] - 3s 405us/step - loss: 0.0991 - acc: 0.9708\n", + "Epoch 27/100\n", + "7797/7797 [==============================] - 3s 398us/step - loss: 0.0965 - acc: 0.9701\n", + "Epoch 28/100\n", + "7797/7797 [==============================] - 3s 417us/step - loss: 0.0921 - acc: 0.9727\n", + "Epoch 29/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.0887 - acc: 0.9751\n", + "Epoch 30/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.0860 - acc: 0.9760\n", + "Epoch 31/100\n", + "7797/7797 [==============================] - 3s 414us/step - loss: 0.0826 - acc: 0.9764\n", + "Epoch 32/100\n", + "7797/7797 [==============================] - 3s 398us/step - loss: 0.0795 - acc: 0.9777\n", + "Epoch 33/100\n", + "7797/7797 [==============================] - 3s 412us/step - loss: 0.0768 - acc: 0.9808\n", + "Epoch 34/100\n", + "7797/7797 [==============================] - 3s 394us/step - loss: 0.0743 - acc: 0.9797\n", + "Epoch 35/100\n", + "7797/7797 [==============================] - 3s 414us/step - loss: 0.0714 - acc: 0.9810\n", + "Epoch 36/100\n", + "7797/7797 [==============================] - 3s 415us/step - loss: 0.0695 - acc: 0.9822\n", + "Epoch 37/100\n", + "7797/7797 [==============================] - 3s 393us/step - loss: 0.0669 - acc: 0.9833\n", + "Epoch 38/100\n", + "7797/7797 [==============================] - 3s 408us/step - loss: 0.0642 - acc: 0.9844\n", + "Epoch 39/100\n", + "7797/7797 [==============================] - 3s 420us/step - loss: 0.0620 - acc: 0.9849\n", + "Epoch 40/100\n", + "7797/7797 [==============================] - 3s 413us/step - loss: 0.0605 - acc: 0.9858\n", + "Epoch 41/100\n", + "7797/7797 [==============================] - 3s 400us/step - loss: 0.0584 - acc: 0.9860\n", + "Epoch 42/100\n", + "7797/7797 [==============================] - 3s 413us/step - loss: 0.0565 - acc: 0.9863\n", + "Epoch 43/100\n", + "7797/7797 [==============================] - 3s 403us/step - loss: 0.0545 - acc: 0.9878\n", + "Epoch 44/100\n", + "7797/7797 [==============================] - 3s 415us/step - loss: 0.0537 - acc: 0.9885\n", + "Epoch 45/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.0517 - acc: 0.9895\n", + "Epoch 46/100\n", + "7797/7797 [==============================] - 3s 409us/step - loss: 0.0503 - acc: 0.9887\n", + "Epoch 47/100\n", + "7797/7797 [==============================] - 3s 401us/step - loss: 0.0489 - acc: 0.9890\n", + "Epoch 48/100\n", + "7797/7797 [==============================] - 3s 402us/step - loss: 0.0472 - acc: 0.9901\n", + "Epoch 49/100\n", + "7797/7797 [==============================] - 3s 410us/step - loss: 0.0459 - acc: 0.9905\n", + "Epoch 50/100\n", + "7797/7797 [==============================] - 3s 404us/step - loss: 0.0450 - acc: 0.9895\n", + "Epoch 51/100\n", + "7797/7797 [==============================] - 3s 415us/step - loss: 0.0432 - acc: 0.9897\n", + "Epoch 52/100\n", + "7797/7797 [==============================] - 3s 399us/step - loss: 0.0420 - acc: 0.9915\n", + "Epoch 53/100\n", + "7797/7797 [==============================] - 3s 416us/step - loss: 0.0402 - acc: 0.9928\n", + "Epoch 54/100\n", + "7797/7797 [==============================] - 3s 408us/step - loss: 0.0394 - acc: 0.9919\n", + "Epoch 55/100\n", + "7797/7797 [==============================] - 3s 428us/step - loss: 0.0383 - acc: 0.9920\n", + "Epoch 56/100\n", + "7797/7797 [==============================] - 3s 404us/step - loss: 0.0368 - acc: 0.9932\n", + "Epoch 57/100\n", + "7797/7797 [==============================] - 3s 407us/step - loss: 0.0361 - acc: 0.9932\n", + "Epoch 58/100\n", + "7797/7797 [==============================] - 3s 403us/step - loss: 0.0353 - acc: 0.9931\n", + "Epoch 59/100\n", + "7797/7797 [==============================] - 3s 412us/step - loss: 0.0343 - acc: 0.9937\n", + "Epoch 60/100\n", + "7797/7797 [==============================] - 3s 419us/step - loss: 0.0332 - acc: 0.9938\n", + "Epoch 61/100\n", + "7797/7797 [==============================] - 3s 389us/step - loss: 0.0322 - acc: 0.9942\n", + "Epoch 62/100\n", + "7797/7797 [==============================] - 3s 421us/step - loss: 0.0315 - acc: 0.9946\n", + "Epoch 63/100\n", + "7797/7797 [==============================] - 3s 415us/step - loss: 0.0306 - acc: 0.9944\n", + "Epoch 64/100\n", + "7797/7797 [==============================] - 3s 410us/step - loss: 0.0297 - acc: 0.9956\n", + "Epoch 65/100\n", + "7797/7797 [==============================] - 3s 423us/step - loss: 0.0291 - acc: 0.9959\n", + "Epoch 66/100\n", + "7797/7797 [==============================] - 3s 440us/step - loss: 0.0285 - acc: 0.9955\n", + "Epoch 67/100\n", + "7797/7797 [==============================] - 3s 423us/step - loss: 0.0277 - acc: 0.9947\n", + "Epoch 68/100\n", + "7797/7797 [==============================] - 3s 439us/step - loss: 0.0265 - acc: 0.9955\n", + "Epoch 69/100\n", + "7797/7797 [==============================] - 3s 417us/step - loss: 0.0260 - acc: 0.9963\n", + "Epoch 70/100\n", + "7797/7797 [==============================] - 3s 435us/step - loss: 0.0253 - acc: 0.9968\n", + "Epoch 71/100\n", + "7797/7797 [==============================] - 4s 462us/step - loss: 0.0246 - acc: 0.9967\n", + "Epoch 72/100\n", + "7797/7797 [==============================] - 4s 454us/step - loss: 0.0239 - acc: 0.9964\n", + "Epoch 73/100\n", + "7797/7797 [==============================] - 4s 481us/step - loss: 0.0231 - acc: 0.9977\n", + "Epoch 74/100\n", + "7797/7797 [==============================] - 3s 427us/step - loss: 0.0229 - acc: 0.9967\n", + "Epoch 75/100\n", + "7797/7797 [==============================] - 3s 441us/step - loss: 0.0219 - acc: 0.9977\n", + "Epoch 76/100\n", + "7797/7797 [==============================] - 3s 414us/step - loss: 0.0216 - acc: 0.9974\n", + "Epoch 77/100\n", + "7797/7797 [==============================] - 3s 409us/step - loss: 0.0209 - acc: 0.9976\n", + "Epoch 78/100\n", + "7797/7797 [==============================] - 3s 422us/step - loss: 0.0202 - acc: 0.9982\n", + "Epoch 79/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.0197 - acc: 0.9979\n", + "Epoch 80/100\n", + "7797/7797 [==============================] - 3s 419us/step - loss: 0.0192 - acc: 0.9986\n", + "Epoch 81/100\n", + "7797/7797 [==============================] - 3s 399us/step - loss: 0.0187 - acc: 0.9981\n", + "Epoch 82/100\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7797/7797 [==============================] - 3s 408us/step - loss: 0.0184 - acc: 0.9985\n", + "Epoch 83/100\n", + "7797/7797 [==============================] - 3s 387us/step - loss: 0.0177 - acc: 0.9986\n", + "Epoch 84/100\n", + "7797/7797 [==============================] - 3s 412us/step - loss: 0.0172 - acc: 0.9987\n", + "Epoch 85/100\n", + "7797/7797 [==============================] - 3s 412us/step - loss: 0.0168 - acc: 0.9987\n", + "Epoch 86/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.0165 - acc: 0.9988\n", + "Epoch 87/100\n", + "7797/7797 [==============================] - 3s 413us/step - loss: 0.0160 - acc: 0.9992\n", + "Epoch 88/100\n", + "7797/7797 [==============================] - 3s 412us/step - loss: 0.0156 - acc: 0.9991\n", + "Epoch 89/100\n", + "7797/7797 [==============================] - 3s 425us/step - loss: 0.0153 - acc: 0.9990\n", + "Epoch 90/100\n", + "7797/7797 [==============================] - 4s 449us/step - loss: 0.0149 - acc: 0.9988\n", + "Epoch 91/100\n", + "7797/7797 [==============================] - 3s 407us/step - loss: 0.0145 - acc: 0.9992\n", + "Epoch 92/100\n", + "7797/7797 [==============================] - 3s 447us/step - loss: 0.0144 - acc: 0.9991\n", + "Epoch 93/100\n", + "7797/7797 [==============================] - 4s 469us/step - loss: 0.0138 - acc: 0.9991\n", + "Epoch 94/100\n", + "7797/7797 [==============================] - 3s 422us/step - loss: 0.0133 - acc: 0.9995\n", + "Epoch 95/100\n", + "7797/7797 [==============================] - 3s 403us/step - loss: 0.0130 - acc: 0.9992\n", + "Epoch 96/100\n", + "7797/7797 [==============================] - 3s 436us/step - loss: 0.0125 - acc: 0.9995\n", + "Epoch 97/100\n", + "7797/7797 [==============================] - 3s 433us/step - loss: 0.0125 - acc: 0.9988\n", + "Epoch 98/100\n", + "7797/7797 [==============================] - 3s 407us/step - loss: 0.0123 - acc: 0.9994\n", + "Epoch 99/100\n", + "7797/7797 [==============================] - 3s 420us/step - loss: 0.0118 - acc: 0.9992\n", + "Epoch 100/100\n", + "7797/7797 [==============================] - 3s 406us/step - loss: 0.0117 - acc: 0.9995\n", + "########################################\n", + "loss: 0.011705336400447587, acc: 0.9994869821726305\n", + "########################################\n" + ] + } + ], + "source": [ + "hist = model.fit(DNA, LABELS, epochs=100)\n", + "\n", + "print('#' * 40)\n", + "print('loss: {}, acc: {}'.format(hist.history['loss'][-1],\n", + " hist.history['acc'][-1]))\n", + "print('#' * 40)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The predictions may be converted back to Cover object and subsequently exported as bigwig in order to inspect the plausibility of the results in the genome browser." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# convert the prediction to a cover object\n", + "pred = model.predict(DNA_TEST)\n", + "cov_pred = Cover.create_from_array('BindingProba', pred, LABELS_TEST.gindexer)\n", + "\n", + "# predictions (or feature activities) can finally be exported to bigwig\n", + "cov_pred.export_to_bigwig(output_dir=os.environ['JANGGU_OUTPUT'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "AUC: 0.9955999999999999\n" + ] + } + ], + "source": [ + "print(\"AUC:\", roc_auc_score(LABELS_TEST[:], pred))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}