Switch to unified view

a b/BioSeqNet/resnest/gluon/dropblock.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
import mxnet as mx
9
from functools import partial
10
from mxnet.gluon.nn import MaxPool1D, Block, HybridBlock
11
12
__all__ = ['DropBlock', 'set_drop_prob', 'DropBlockScheduler']
13
14
class DropBlock(HybridBlock):
15
    def __init__(self, drop_prob, block_size, c, h, w):
16
        super().__init__()
17
        self.drop_prob = drop_prob
18
        self.block_size = block_size
19
        self.c, self.h, self.w = c, h, w
20
        self.numel = c * h * w
21
        pad_h = max((block_size - 1), 0)
22
        pad_w = max((block_size - 1), 0)
23
        self.padding = (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)
24
        self.dtype = 'float32'
25
26
    def hybrid_forward(self, F, x):
27
        if not mx.autograd.is_training() or self.drop_prob <= 0:
28
            return x
29
        gamma = self.drop_prob * (self.h * self.w) / (self.block_size ** 2) / \
30
            ((self.w - self.block_size + 1) * (self.h - self.block_size + 1))
31
        # generate mask
32
        mask = F.random.uniform(0, 1, shape=(1, self.c, self.h, self.w), dtype=self.dtype) < gamma
33
        mask = F.Pooling(mask, pool_type='max',
34
                         kernel=(self.block_size, self.block_size), pad=self.padding)
35
        mask = 1 - mask
36
        y = F.broadcast_mul(F.broadcast_mul(x, mask),
37
                            (1.0 * self.numel / mask.sum(axis=0, exclude=True).expand_dims(1).expand_dims(1).expand_dims(1)))
38
        return y
39
40
    def cast(self, dtype):
41
        super(DropBlock, self).cast(dtype)
42
        self.dtype = dtype
43
44
    def __repr__(self):
45
        reprstr = self.__class__.__name__ + '(' + \
46
            'drop_prob: {}, block_size{}'.format(self.drop_prob, self.block_size) +')'
47
        return reprstr
48
49
def set_drop_prob(drop_prob, module):
50
    """
51
    Example:
52
        from functools import partial
53
        apply_drop_prob = partial(set_drop_prob, 0.1)
54
        net.apply(apply_drop_prob)
55
    """
56
    if isinstance(module, DropBlock):
57
        module.drop_prob = drop_prob
58
59
60
class DropBlockScheduler(object):
61
    def __init__(self, net, start_prob, end_prob, num_epochs):
62
        self.net = net
63
        self.start_prob = start_prob
64
        self.end_prob = end_prob
65
        self.num_epochs = num_epochs
66
67
    def __call__(self, epoch):
68
        ratio = self.start_prob + 1.0 * (self.end_prob - self.start_prob) * (epoch + 1) / self.num_epochs
69
        assert (ratio >= 0 and ratio <= 1)
70
        apply_drop_prob = partial(set_drop_prob, ratio)
71
        self.net.apply(apply_drop_prob)
72
        self.net.hybridize()
73