Diff of /bpnet/heads.py [000000] .. [d45a3a]

Switch to unified view

a b/bpnet/heads.py
1
"""Head modules
2
"""
3
import numpy as np
4
from bpnet.utils import dict_prefix_key
5
from bpnet.metrics import ClassificationMetrics, RegressionMetrics
6
import keras.backend as K
7
import tensorflow as tf
8
import keras.layers as kl
9
import gin
10
import os
11
import abc
12
13
14
class BaseHead:
15
16
    # loss
17
    # weight -> loss weight (1 by default)
18
    # kwargs -> kwargs for the model
19
    # name -> name of the module
20
    # _model -> gets setup in the init
21
22
    @abc.abstractmethod
23
    def get_target(self, task):
24
        pass
25
26
    @abc.abstractmethod
27
    def __call__(self, inp, task):
28
        """Useful for writing together the model
29
        Returns the output tensor
30
        """
31
        raise NotImplementedError
32
33
    @abc.abstractmethod
34
    def get_preact_tensor(self, graph=None):
35
        """Return the single pre-activation tensors
36
        """
37
        pass
38
39
    @abc.abstractmethod
40
    def intp_tensors(self, preact_only=False, graph=None):
41
        """Dictionary of all available interpretation tensors
42
        for `get_interpretation_node`
43
        """
44
        raise NotImplementedError
45
46
    # @abc.abstractmethod
47
    # def get_intp_tensor(self, which=None):
48
    #     """Returns a target tensor which is a scalar
49
    #     w.r.t. to which to compute the outputs
50
51
    #     Args:
52
    #       which [string]: If None, use the default
53
    #       **kwargs: optional kwargs for the interpretation method
54
55
    #     Returns:
56
    #       scalar tensor
57
    #     """
58
    #     raise NotImplementedError
59
60
    def copy(self):
61
        from copy import deepcopy
62
        return deepcopy(self)
63
64
65
class BaseHeadWBias(BaseHead):
66
67
    @abc.abstractmethod
68
    def get_bias_input(self, task):
69
        pass
70
71
    @abc.abstractmethod
72
    def neutral_bias_input(self, task, length, seqlen):
73
        pass
74
75
76
def id_fn(x):
77
    return x
78
79
80
def named_tensor(x, name):
81
    return kl.Lambda(id_fn, name=name)(x)
82
83
84
# --------------------------------------------
85
# Head implementations
86
87
@gin.configurable
88
class ScalarHead(BaseHeadWBias):
89
90
    def __init__(self, target_name,  # "{task}/scalar"
91
                 net,  # function that takes a keras tensor and returns a keras tensor
92
                 activation=None,
93
                 loss='mse',
94
                 loss_weight=1,
95
                 metric=RegressionMetrics(),
96
                 postproc_fn=None,  # post-processing to apply so that we are in the right scale
97
                 # bias input
98
                 use_bias=False,
99
                 bias_net=None,
100
                 bias_input='bias/{task}/scalar',
101
                 bias_shape=(1,),
102
                 ):
103
        self.net = net
104
        self.loss = loss
105
        self.loss_weight = loss_weight
106
        self.metric = metric
107
        self.postproc_fn = postproc_fn
108
        self.target_name = target_name
109
        self.activation = activation
110
        self.bias_input = bias_input
111
        self.bias_net = bias_net
112
        self.use_bias = use_bias
113
        self.bias_shape = bias_shape
114
115
    def get_target(self, task):
116
        return self.target_name.format(task=task)
117
118
    def __call__(self, inp, task):
119
        o = self.net(inp)
120
121
        # remember the tensors useful for interpretation (referred by name)
122
        self.pre_act = o.name
123
124
        # add the target bias
125
        if self.use_bias:
126
            binp = kl.Input(self.bias_shape, name=self.get_bias_input(task))
127
            bias_inputs = [binp]
128
129
            # add the bias term
130
            if self.bias_net is not None:
131
                bias_x = self.bias_net(binp)
132
                # This allows to normalize the bias data first
133
                # (e.g. when we have profile counts to aggregate it first)
134
            else:
135
                # Don't use the nn 'bias' so that when the measurement bias = 0,
136
                # this term vanishes
137
                bias_x = kl.Dense(1, use_bias=False)(binp)
138
            o = kl.add([o, bias_x])
139
        else:
140
            bias_inputs = []
141
142
        if self.activation is not None:
143
            if isinstance(self.activation, str):
144
                o = kl.Activation(self.activation)(o)
145
            else:
146
                o = self.activation(o)
147
148
        self.post_act = o.name
149
150
        # label the target op so that we can use a dictionary of targets
151
        # to train the model
152
        return named_tensor(o, name=self.get_target(task)), bias_inputs
153
154
    def get_preact_tensor(self, graph=None):
155
        if graph is None:
156
            graph = tf.get_default_graph()
157
        return graph.get_tensor_by_name(self.pre_act)
158
159
    def intp_tensors(self, preact_only=False, graph=None):
160
        """Return the required interpretation tensors
161
        """
162
        if graph is None:
163
            graph = tf.get_default_graph()
164
165
        if self.activation is None:
166
            # the post-activation doesn't
167
            # have any specific meaning when
168
            # we don't use any activation function
169
            return {"pre-act": graph.get_tensor_by_name(self.pre_act)}
170
171
        if preact_only:
172
            return {"pre-act": graph.get_tensor_by_name(self.pre_act)}
173
        else:
174
            return {"pre-act": graph.get_tensor_by_name(self.pre_act),
175
                    "output": graph.get_tensor_by_name(self.post_act)}
176
177
    # def get_intp_tensor(self, which='pre-act'):
178
    #     return self.intp_tensors()[which]
179
180
    def get_bias_input(self, task):
181
        return self.bias_input.format(task=task)
182
183
    def neutral_bias_input(self, task, length, seqlen):
184
        """Create dummy bias input
185
186
        Return: (k, v) tuple
187
        """
188
        shape = tuple([x if x is not None else seqlen
189
                       for x in self.bias_shape])
190
        return (self.get_bias_input(task), np.zeros((length, ) + shape))
191
192
193
@gin.configurable
194
class BinaryClassificationHead(ScalarHead):
195
196
    def __init__(self, target_name,  # "{task}/scalar"
197
                 net,  # function that takes a keras tensor and returns a keras tensor
198
                 activation='sigmoid',
199
                 loss='binary_crossentropy',
200
                 loss_weight=1,
201
                 metric=ClassificationMetrics(),
202
                 postproc_fn=None,
203
                 # bias input
204
                 use_bias=False,
205
                 bias_net=None,
206
                 bias_input='bias/{task}/scalar',
207
                 bias_shape=(1,),
208
                 ):
209
        # override the default
210
        super().__init__(target_name,
211
                         net,
212
                         activation=activation,
213
                         loss=loss,
214
                         metric=metric,
215
                         postproc_fn=postproc_fn,
216
                         use_bias=use_bias,
217
                         bias_net=bias_net,
218
                         bias_input=bias_input,
219
                         bias_shape=bias_shape)
220
221
        # TODO - mabye override the way we call outputs?
222
223
224
@gin.configurable
225
class ProfileHead(BaseHeadWBias):
226
    """Deals with the case where the output are multiple tracks of
227
    total shape (L, C) (L = sequence length, C = number of channels)
228
229
    Note: Since the contribution score will be a single scalar, the
230
    interpretation method will have to aggregate both across channels
231
    as well as positions
232
    """
233
234
    def __init__(self, target_name,  # "{task}/profile"
235
                 net,  # function that takes a keras tensor and returns a keras tensor
236
                 activation=None,
237
                 loss='mse',
238
                 loss_weight=1,
239
                 metric=RegressionMetrics(),
240
                 postproc_fn=None,
241
                 # bias input
242
                 use_bias=False,
243
                 bias_net=None,
244
                 bias_input='bias/{task}/profile',
245
                 bias_shape=(None, 1),
246
                 ):
247
        self.net = net
248
        self.loss = loss
249
        self.loss_weight = loss_weight
250
        self.metric = metric
251
        self.postproc_fn = postproc_fn
252
        self.target_name = target_name
253
        self.activation = activation
254
        self.bias_input = bias_input
255
        self.bias_net = bias_net
256
        self.use_bias = use_bias
257
        self.bias_shape = bias_shape
258
259
    def get_target(self, task):
260
        return self.target_name.format(task=task)
261
262
    def __call__(self, inp, task):
263
        o = self.net(inp)
264
265
        # remember the tensors useful for interpretation (referred by name)
266
        self.pre_act = o.name
267
268
        # add the target bias
269
        if self.use_bias:
270
            binp = kl.Input(self.bias_shape, name=self.get_bias_input(task))
271
            bias_inputs = [binp]
272
273
            # add the bias term
274
            if self.bias_net is not None:
275
                bias_x = self.bias_net(binp)
276
                # This allows to normalize the bias data first
277
                # (e.g. when we have profile counts to aggregate it first)
278
            else:
279
                # Don't use the nn 'bias' so that when the measurement bias = 0,
280
                # this term vanishes
281
                bias_x = kl.Conv1D(1, kernel_size=1, use_bias=False)(binp)
282
            o = kl.add([o, bias_x])
283
        else:
284
            bias_inputs = []
285
286
        if self.activation is not None:
287
            if isinstance(self.activation, str):
288
                o = kl.Activation(self.activation)(o)
289
            else:
290
                o = self.activation(o)
291
292
        self.post_act = o.name
293
294
        # label the target op so that we can use a dictionary of targets
295
        # to train the model
296
        return named_tensor(o, name=self.get_target(task)), bias_inputs
297
298
    def get_preact_tensor(self, graph=None):
299
        if graph is None:
300
            graph = tf.get_default_graph()
301
        return graph.get_tensor_by_name(self.pre_act)
302
303
    @staticmethod
304
    def profile_contrib(p):
305
        """Summarizing the profile for the contribution scores
306
307
        wn: Normalized contribution (weighted sum of the contribution scores)
308
          where the weighted sum uses softmax(p) to weight it
309
        w2: Simple sum (p**2)
310
        w1: sum(p)
311
        winf: max(p)
312
        """
313
        # Note: unfortunately we have to use the kl.Lambda boiler-plate
314
        # to be able to do Model(inp, outputs) in deep-explain code
315
316
        # Normalized contribution  - # TODO - update with tensorflow
317
        wn = kl.Lambda(lambda p:
318
                       K.mean(K.sum(K.stop_gradient(tf.nn.softmax(p, dim=-2)) * p, axis=-2), axis=-1)
319
                       )(p)
320
321
        # Squared weight
322
        w2 = kl.Lambda(lambda p:
323
                       K.mean(K.sum(p * p, axis=-2), axis=-1)
324
                       )(p)
325
326
        # W1 weight
327
        w1 = kl.Lambda(lambda preact_m:
328
                       K.mean(K.sum(preact_m, axis=-2), axis=-1)
329
                       )(p)
330
331
        # Winf
332
        # 1. max across the positional axis, average the strands
333
        winf = kl.Lambda(lambda p:
334
                         K.mean(K.max(p, axis=-2), axis=-1)
335
                         )(p)
336
337
        return {"wn": wn,
338
                "w1": w1,
339
                "w2": w2,
340
                "winf": winf
341
                }
342
343
    def intp_tensors(self, preact_only=False, graph=None):
344
        """Return the required interpretation tensors (scalars)
345
346
        Note: Since we are predicting a track,
347
            we should return a single scalar here
348
        """
349
        if graph is None:
350
            graph = tf.get_default_graph()
351
352
        preact = graph.get_tensor_by_name(self.pre_act)
353
        postact = graph.get_tensor_by_name(self.post_act)
354
355
        # Contruct the profile summary ops
356
        preact_tensors = self.profile_contrib(preact)
357
        postact_tensors = dict_prefix_key(self.profile_contrib(postact), 'output_')
358
359
        if self.activation is None:
360
            # the post-activation doesn't
361
            # have any specific meaning when
362
            # we don't use any activation function
363
            return preact_tensors
364
365
        if preact_only:
366
            return preact_tensors
367
        else:
368
            return {**preact_tensors, **postact_tensors}
369
370
    # def get_intp_tensor(self, which='wn'):
371
    #     return self.intp_tensors()[which]
372
373
    def get_bias_input(self, task):
374
        return self.bias_input.format(task=task)
375
376
    def neutral_bias_input(self, task, length, seqlen):
377
        """Create dummy bias input
378
379
        Return: (k, v) tuple
380
        """
381
        shape = tuple([x if x is not None else seqlen
382
                       for x in self.bias_shape])
383
        return (self.get_bias_input(task), np.zeros((length, ) + shape))