a b/src/examples/classify_fasta.py
1
import argparse
2
import os
3
import numpy as np
4
import pandas as pd
5
import pkg_resources
6
7
from keras import backend as K
8
from keras.layers import Conv2D
9
from keras.layers import GlobalAveragePooling2D
10
from keras.layers import Maximum
11
12
from janggu import Janggu
13
from janggu import Scorer
14
from janggu import inputlayer
15
from janggu import outputdense
16
from janggu.data import Array
17
from janggu.data import Bioseq
18
from janggu.layers import Complement
19
from janggu.layers import DnaConv2D
20
from janggu.layers import Reverse
21
from janggu.utils import ExportClustermap
22
from janggu.utils import ExportTsv
23
24
import matplotlib
25
matplotlib.use('Agg')
26
27
np.random.seed(1234)
28
29
30
# Fetch parser arguments
31
PARSER = argparse.ArgumentParser(description='Command description.')
32
PARSER.add_argument('model', choices=['single', 'double', 'dnaconv'],
33
                    help="Single or double stranded model.")
34
PARSER.add_argument('-path', dest='path',
35
                    default='tf_results',
36
                    help="Output directory for the examples.")
37
PARSER.add_argument('-order', dest='order', type=int,
38
                    default=1,
39
                    help="One-hot order.")
40
41
args = PARSER.parse_args()
42
43
os.environ['JANGGU_OUTPUT'] = args.path
44
45
46
# helper function
47
def nseqs(filename):
48
    """Extract the number of rows in the file.
49
50
    Note however, that this is a simplification
51
    that might not always work. In general, one would
52
    need to parse for '>' occurrences.
53
    """
54
    return sum((1 for line in open(filename) if line[0] == '>'))
55
56
57
# load the dataset
58
DATA_PATH = pkg_resources.resource_filename('janggu', 'resources/')
59
SAMPLE_1 = os.path.join(DATA_PATH, 'sample.fa')
60
SAMPLE_2 = os.path.join(DATA_PATH, 'sample2.fa')
61
62
# DNA sequences in one-hot encoding will be used as input
63
DNA = Bioseq.create_from_seq('dna', fastafile=[SAMPLE_1, SAMPLE_2],
64
                             order=args.order, cache=True)
65
66
# An array of 1/0 will be used as labels for training
67
Y = np.asarray([[1] for line in range(nseqs(SAMPLE_1))] +
68
               [[0] for line in range(nseqs(SAMPLE_2))])
69
LABELS = Array('y', Y, conditions=['TF-binding'])
70
annot = pd.DataFrame(Y[:], columns=LABELS.conditions).applymap(
71
    lambda x: 'Oct4' if x == 1 else 'Mafk').to_dict(orient='list')
72
73
# Define the model templates
74
75
@inputlayer
76
@outputdense('sigmoid')
77
def single_stranded_model(inputs, inp, oup, params):
78
    """ keras model that scans a DNA sequence using
79
    a number of motifs.
80
81
    This model only scans one strand for sequence patterns.
82
    """
83
    with inputs.use('dna') as layer:
84
        # the name in inputs.use() should be the same as the dataset name.
85
        layer = Conv2D(params[0], (params[1], 1), activation=params[2])(layer)
86
    output = GlobalAveragePooling2D(name='motif')(layer)
87
    return inputs, output
88
89
90
@inputlayer
91
@outputdense('sigmoid')
92
def double_stranded_model(inputs, inp, oup, params):
93
    """ keras model for scanning both DNA strands.
94
95
    Sequence patterns may be present on either strand.
96
    By scanning both DNA strands with the same motifs (kernels)
97
    the performance of the model will generally improve.
98
99
    In the model below, this is achieved by reverse complementing
100
    the input tensor and keeping the convolution filters fixed.
101
    """
102
    with inputs.use('dna') as layer:
103
        # the name in inputs.use() should be the same as the dataset name.
104
        forward = layer
105
    convlayer = Conv2D(params[0], (params[1], 1), activation=params[2])
106
    revcomp = Reverse()(forward)
107
    revcomp = Complement()(revcomp)
108
109
    forward = convlayer(forward)
110
    revcomp = convlayer(revcomp)
111
    revcomp = Reverse()(revcomp)
112
    layer = Maximum()([forward, revcomp])
113
    output = GlobalAveragePooling2D(name='motif')(layer)
114
    return inputs, output
115
116
117
@inputlayer
118
@outputdense('sigmoid')
119
def double_stranded_model_dnaconv(inputs, inp, oup, params):
120
    """ keras model for scanning both DNA strands.
121
122
    A more elegant way of scanning both strands for motif occurrences
123
    is achieved by the DnaConv2D layer wrapper, which internally
124
    performs the convolution operation with the normal kernel weights
125
    and the reverse complemented weights.
126
    """
127
    with inputs.use('dna') as layer:
128
        # the name in inputs.use() should be the same as the dataset name.
129
        conv = DnaConv2D(Conv2D(params[0],
130
                                (params[1], 1),
131
                                activation=params[2]), name='conv1')(layer)
132
133
    output = GlobalAveragePooling2D(name='motif')(conv)
134
    return inputs, output
135
136
137
if args.model == 'single':
138
    modeltemplate = single_stranded_model
139
elif args.model == 'double':
140
    modeltemplate = double_stranded_model
141
else:
142
    modeltemplate = double_stranded_model_dnaconv
143
144
K.clear_session()
145
146
# create a new model object
147
model = Janggu.create(template=modeltemplate,
148
                      modelparams=(30, 21, 'relu'),
149
                      inputs=DNA,
150
                      outputs=LABELS,
151
                      name='fasta_seqs_m{}_o{}'.format(args.model, args.order))
152
153
model.compile(optimizer='adadelta', loss='binary_crossentropy',
154
              metrics=['acc'])
155
model.summary()
156
157
# fit the model
158
hist = model.fit(DNA, LABELS, epochs=100)
159
160
print('#' * 40)
161
print('loss: {}, acc: {}'.format(hist.history['loss'][-1],
162
                                 hist.history['acc'][-1]))
163
print('#' * 40)
164
165
# load test data
166
SAMPLE_1 = os.path.join(DATA_PATH, 'sample_test.fa')
167
SAMPLE_2 = os.path.join(DATA_PATH, 'sample2_test.fa')
168
169
DNA_TEST = Bioseq.create_from_seq('dna', fastafile=[SAMPLE_1, SAMPLE_2],
170
                                  order=args.order, cache=True)
171
172
Y = np.asarray([[1] for _ in range(nseqs(SAMPLE_1))] +
173
               [[0] for _ in range(nseqs(SAMPLE_2))])
174
LABELS_TEST = Array('y', Y, conditions=['TF-binding'])
175
annot_test = pd.DataFrame(Y[:], columns=LABELS_TEST.conditions).applymap(
176
    lambda x: 'Oct4' if x == 1 else 'Mafk').to_dict(orient='list')
177
178
# clustering plots based on hidden features
179
heatmap_eval = Scorer('heatmap', exporter=ExportClustermap(annot=annot_test,
180
                                                           z_score=1.))
181
182
# output the predictions as tables or json files
183
pred_tsv = Scorer('pred', exporter=ExportTsv(annot=annot_test,
184
                                             row_names=DNA_TEST.gindexer.chrs))
185
186
# do the evaluation on the independent test data
187
# after the evaluation and prediction has been performed,
188
# the callbacks further process the results allowing
189
# to automatically generate summary statistics or figures
190
# into the JANGGU_OUTPUT directory.
191
model.evaluate(DNA_TEST, LABELS_TEST, datatags=['test'],
192
               callbacks=['auc', 'auprc', 'roc', 'auroc'])
193
194
pred = model.predict(DNA_TEST, datatags=['test'],
195
              callbacks=[pred_tsv, heatmap_eval],
196
              layername='motif')
197
198
pred = model.predict(DNA_TEST)
199
print('Oct4 predictions scores should be greater than Mafk scores:')
200
print('Prediction score examples for Oct4')
201
for i in range(4):
202
    print('{}.: {}'.format(i, pred[i]))
203
print('Prediction score examples for Mafk')
204
for i in range(1, 5):
205
    print('{}.: {}'.format(i, pred[-i]))
206