[1fc74a]: / BioSeqNet / resnest / gluon / dropblock.py

Download this file

74 lines (63 with data), 2.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import mxnet as mx
from functools import partial
from mxnet.gluon.nn import MaxPool1D, Block, HybridBlock
__all__ = ['DropBlock', 'set_drop_prob', 'DropBlockScheduler']
class DropBlock(HybridBlock):
def __init__(self, drop_prob, block_size, c, h, w):
super().__init__()
self.drop_prob = drop_prob
self.block_size = block_size
self.c, self.h, self.w = c, h, w
self.numel = c * h * w
pad_h = max((block_size - 1), 0)
pad_w = max((block_size - 1), 0)
self.padding = (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)
self.dtype = 'float32'
def hybrid_forward(self, F, x):
if not mx.autograd.is_training() or self.drop_prob <= 0:
return x
gamma = self.drop_prob * (self.h * self.w) / (self.block_size ** 2) / \
((self.w - self.block_size + 1) * (self.h - self.block_size + 1))
# generate mask
mask = F.random.uniform(0, 1, shape=(1, self.c, self.h, self.w), dtype=self.dtype) < gamma
mask = F.Pooling(mask, pool_type='max',
kernel=(self.block_size, self.block_size), pad=self.padding)
mask = 1 - mask
y = F.broadcast_mul(F.broadcast_mul(x, mask),
(1.0 * self.numel / mask.sum(axis=0, exclude=True).expand_dims(1).expand_dims(1).expand_dims(1)))
return y
def cast(self, dtype):
super(DropBlock, self).cast(dtype)
self.dtype = dtype
def __repr__(self):
reprstr = self.__class__.__name__ + '(' + \
'drop_prob: {}, block_size{}'.format(self.drop_prob, self.block_size) +')'
return reprstr
def set_drop_prob(drop_prob, module):
"""
Example:
from functools import partial
apply_drop_prob = partial(set_drop_prob, 0.1)
net.apply(apply_drop_prob)
"""
if isinstance(module, DropBlock):
module.drop_prob = drop_prob
class DropBlockScheduler(object):
def __init__(self, net, start_prob, end_prob, num_epochs):
self.net = net
self.start_prob = start_prob
self.end_prob = end_prob
self.num_epochs = num_epochs
def __call__(self, epoch):
ratio = self.start_prob + 1.0 * (self.end_prob - self.start_prob) * (epoch + 1) / self.num_epochs
assert (ratio >= 0 and ratio <= 1)
apply_drop_prob = partial(set_drop_prob, ratio)
self.net.apply(apply_drop_prob)
self.net.hybridize()