|
a |
|
b/BioSeqNet/resnest/gluon/dropblock.py |
|
|
1 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
2 |
## Created by: Hang Zhang |
|
|
3 |
## Email: zhanghang0704@gmail.com |
|
|
4 |
## Copyright (c) 2020 |
|
|
5 |
## |
|
|
6 |
## LICENSE file in the root directory of this source tree |
|
|
7 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
8 |
import mxnet as mx |
|
|
9 |
from functools import partial |
|
|
10 |
from mxnet.gluon.nn import MaxPool1D, Block, HybridBlock |
|
|
11 |
|
|
|
12 |
__all__ = ['DropBlock', 'set_drop_prob', 'DropBlockScheduler'] |
|
|
13 |
|
|
|
14 |
class DropBlock(HybridBlock): |
|
|
15 |
def __init__(self, drop_prob, block_size, c, h, w): |
|
|
16 |
super().__init__() |
|
|
17 |
self.drop_prob = drop_prob |
|
|
18 |
self.block_size = block_size |
|
|
19 |
self.c, self.h, self.w = c, h, w |
|
|
20 |
self.numel = c * h * w |
|
|
21 |
pad_h = max((block_size - 1), 0) |
|
|
22 |
pad_w = max((block_size - 1), 0) |
|
|
23 |
self.padding = (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2) |
|
|
24 |
self.dtype = 'float32' |
|
|
25 |
|
|
|
26 |
def hybrid_forward(self, F, x): |
|
|
27 |
if not mx.autograd.is_training() or self.drop_prob <= 0: |
|
|
28 |
return x |
|
|
29 |
gamma = self.drop_prob * (self.h * self.w) / (self.block_size ** 2) / \ |
|
|
30 |
((self.w - self.block_size + 1) * (self.h - self.block_size + 1)) |
|
|
31 |
# generate mask |
|
|
32 |
mask = F.random.uniform(0, 1, shape=(1, self.c, self.h, self.w), dtype=self.dtype) < gamma |
|
|
33 |
mask = F.Pooling(mask, pool_type='max', |
|
|
34 |
kernel=(self.block_size, self.block_size), pad=self.padding) |
|
|
35 |
mask = 1 - mask |
|
|
36 |
y = F.broadcast_mul(F.broadcast_mul(x, mask), |
|
|
37 |
(1.0 * self.numel / mask.sum(axis=0, exclude=True).expand_dims(1).expand_dims(1).expand_dims(1))) |
|
|
38 |
return y |
|
|
39 |
|
|
|
40 |
def cast(self, dtype): |
|
|
41 |
super(DropBlock, self).cast(dtype) |
|
|
42 |
self.dtype = dtype |
|
|
43 |
|
|
|
44 |
def __repr__(self): |
|
|
45 |
reprstr = self.__class__.__name__ + '(' + \ |
|
|
46 |
'drop_prob: {}, block_size{}'.format(self.drop_prob, self.block_size) +')' |
|
|
47 |
return reprstr |
|
|
48 |
|
|
|
49 |
def set_drop_prob(drop_prob, module): |
|
|
50 |
""" |
|
|
51 |
Example: |
|
|
52 |
from functools import partial |
|
|
53 |
apply_drop_prob = partial(set_drop_prob, 0.1) |
|
|
54 |
net.apply(apply_drop_prob) |
|
|
55 |
""" |
|
|
56 |
if isinstance(module, DropBlock): |
|
|
57 |
module.drop_prob = drop_prob |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
class DropBlockScheduler(object): |
|
|
61 |
def __init__(self, net, start_prob, end_prob, num_epochs): |
|
|
62 |
self.net = net |
|
|
63 |
self.start_prob = start_prob |
|
|
64 |
self.end_prob = end_prob |
|
|
65 |
self.num_epochs = num_epochs |
|
|
66 |
|
|
|
67 |
def __call__(self, epoch): |
|
|
68 |
ratio = self.start_prob + 1.0 * (self.end_prob - self.start_prob) * (epoch + 1) / self.num_epochs |
|
|
69 |
assert (ratio >= 0 and ratio <= 1) |
|
|
70 |
apply_drop_prob = partial(set_drop_prob, ratio) |
|
|
71 |
self.net.apply(apply_drop_prob) |
|
|
72 |
self.net.hybridize() |
|
|
73 |
|