Switch to side-by-side view

--- a
+++ b/src/examples/classify_fasta.py
@@ -0,0 +1,206 @@
+import argparse
+import os
+import numpy as np
+import pandas as pd
+import pkg_resources
+
+from keras import backend as K
+from keras.layers import Conv2D
+from keras.layers import GlobalAveragePooling2D
+from keras.layers import Maximum
+
+from janggu import Janggu
+from janggu import Scorer
+from janggu import inputlayer
+from janggu import outputdense
+from janggu.data import Array
+from janggu.data import Bioseq
+from janggu.layers import Complement
+from janggu.layers import DnaConv2D
+from janggu.layers import Reverse
+from janggu.utils import ExportClustermap
+from janggu.utils import ExportTsv
+
+import matplotlib
+matplotlib.use('Agg')
+
+np.random.seed(1234)
+
+
+# Fetch parser arguments
+PARSER = argparse.ArgumentParser(description='Command description.')
+PARSER.add_argument('model', choices=['single', 'double', 'dnaconv'],
+                    help="Single or double stranded model.")
+PARSER.add_argument('-path', dest='path',
+                    default='tf_results',
+                    help="Output directory for the examples.")
+PARSER.add_argument('-order', dest='order', type=int,
+                    default=1,
+                    help="One-hot order.")
+
+args = PARSER.parse_args()
+
+os.environ['JANGGU_OUTPUT'] = args.path
+
+
+# helper function
+def nseqs(filename):
+    """Extract the number of rows in the file.
+
+    Note however, that this is a simplification
+    that might not always work. In general, one would
+    need to parse for '>' occurrences.
+    """
+    return sum((1 for line in open(filename) if line[0] == '>'))
+
+
+# load the dataset
+DATA_PATH = pkg_resources.resource_filename('janggu', 'resources/')
+SAMPLE_1 = os.path.join(DATA_PATH, 'sample.fa')
+SAMPLE_2 = os.path.join(DATA_PATH, 'sample2.fa')
+
+# DNA sequences in one-hot encoding will be used as input
+DNA = Bioseq.create_from_seq('dna', fastafile=[SAMPLE_1, SAMPLE_2],
+                             order=args.order, cache=True)
+
+# An array of 1/0 will be used as labels for training
+Y = np.asarray([[1] for line in range(nseqs(SAMPLE_1))] +
+               [[0] for line in range(nseqs(SAMPLE_2))])
+LABELS = Array('y', Y, conditions=['TF-binding'])
+annot = pd.DataFrame(Y[:], columns=LABELS.conditions).applymap(
+    lambda x: 'Oct4' if x == 1 else 'Mafk').to_dict(orient='list')
+
+# Define the model templates
+
+@inputlayer
+@outputdense('sigmoid')
+def single_stranded_model(inputs, inp, oup, params):
+    """ keras model that scans a DNA sequence using
+    a number of motifs.
+
+    This model only scans one strand for sequence patterns.
+    """
+    with inputs.use('dna') as layer:
+        # the name in inputs.use() should be the same as the dataset name.
+        layer = Conv2D(params[0], (params[1], 1), activation=params[2])(layer)
+    output = GlobalAveragePooling2D(name='motif')(layer)
+    return inputs, output
+
+
+@inputlayer
+@outputdense('sigmoid')
+def double_stranded_model(inputs, inp, oup, params):
+    """ keras model for scanning both DNA strands.
+
+    Sequence patterns may be present on either strand.
+    By scanning both DNA strands with the same motifs (kernels)
+    the performance of the model will generally improve.
+
+    In the model below, this is achieved by reverse complementing
+    the input tensor and keeping the convolution filters fixed.
+    """
+    with inputs.use('dna') as layer:
+        # the name in inputs.use() should be the same as the dataset name.
+        forward = layer
+    convlayer = Conv2D(params[0], (params[1], 1), activation=params[2])
+    revcomp = Reverse()(forward)
+    revcomp = Complement()(revcomp)
+
+    forward = convlayer(forward)
+    revcomp = convlayer(revcomp)
+    revcomp = Reverse()(revcomp)
+    layer = Maximum()([forward, revcomp])
+    output = GlobalAveragePooling2D(name='motif')(layer)
+    return inputs, output
+
+
+@inputlayer
+@outputdense('sigmoid')
+def double_stranded_model_dnaconv(inputs, inp, oup, params):
+    """ keras model for scanning both DNA strands.
+
+    A more elegant way of scanning both strands for motif occurrences
+    is achieved by the DnaConv2D layer wrapper, which internally
+    performs the convolution operation with the normal kernel weights
+    and the reverse complemented weights.
+    """
+    with inputs.use('dna') as layer:
+        # the name in inputs.use() should be the same as the dataset name.
+        conv = DnaConv2D(Conv2D(params[0],
+                                (params[1], 1),
+                                activation=params[2]), name='conv1')(layer)
+
+    output = GlobalAveragePooling2D(name='motif')(conv)
+    return inputs, output
+
+
+if args.model == 'single':
+    modeltemplate = single_stranded_model
+elif args.model == 'double':
+    modeltemplate = double_stranded_model
+else:
+    modeltemplate = double_stranded_model_dnaconv
+
+K.clear_session()
+
+# create a new model object
+model = Janggu.create(template=modeltemplate,
+                      modelparams=(30, 21, 'relu'),
+                      inputs=DNA,
+                      outputs=LABELS,
+                      name='fasta_seqs_m{}_o{}'.format(args.model, args.order))
+
+model.compile(optimizer='adadelta', loss='binary_crossentropy',
+              metrics=['acc'])
+model.summary()
+
+# fit the model
+hist = model.fit(DNA, LABELS, epochs=100)
+
+print('#' * 40)
+print('loss: {}, acc: {}'.format(hist.history['loss'][-1],
+                                 hist.history['acc'][-1]))
+print('#' * 40)
+
+# load test data
+SAMPLE_1 = os.path.join(DATA_PATH, 'sample_test.fa')
+SAMPLE_2 = os.path.join(DATA_PATH, 'sample2_test.fa')
+
+DNA_TEST = Bioseq.create_from_seq('dna', fastafile=[SAMPLE_1, SAMPLE_2],
+                                  order=args.order, cache=True)
+
+Y = np.asarray([[1] for _ in range(nseqs(SAMPLE_1))] +
+               [[0] for _ in range(nseqs(SAMPLE_2))])
+LABELS_TEST = Array('y', Y, conditions=['TF-binding'])
+annot_test = pd.DataFrame(Y[:], columns=LABELS_TEST.conditions).applymap(
+    lambda x: 'Oct4' if x == 1 else 'Mafk').to_dict(orient='list')
+
+# clustering plots based on hidden features
+heatmap_eval = Scorer('heatmap', exporter=ExportClustermap(annot=annot_test,
+                                                           z_score=1.))
+
+# output the predictions as tables or json files
+pred_tsv = Scorer('pred', exporter=ExportTsv(annot=annot_test,
+                                             row_names=DNA_TEST.gindexer.chrs))
+
+# do the evaluation on the independent test data
+# after the evaluation and prediction has been performed,
+# the callbacks further process the results allowing
+# to automatically generate summary statistics or figures
+# into the JANGGU_OUTPUT directory.
+model.evaluate(DNA_TEST, LABELS_TEST, datatags=['test'],
+               callbacks=['auc', 'auprc', 'roc', 'auroc'])
+
+pred = model.predict(DNA_TEST, datatags=['test'],
+              callbacks=[pred_tsv, heatmap_eval],
+              layername='motif')
+
+pred = model.predict(DNA_TEST)
+print('Oct4 predictions scores should be greater than Mafk scores:')
+print('Prediction score examples for Oct4')
+for i in range(4):
+    print('{}.: {}'.format(i, pred[i]))
+print('Prediction score examples for Mafk')
+for i in range(1, 5):
+    print('{}.: {}'.format(i, pred[-i]))
+