Switch to unified view

a b/BioSeqNet/resnest/gluon/resnest.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
"""ResNeSt implemented in Gluon."""
9
10
__all__ = ['resnest50', 'resnest101',
11
           'resnest200', 'resnest269']
12
13
from .resnet import ResNet, Bottleneck
14
from mxnet import cpu
15
16
def resnest50(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
17
    model = ResNet(Bottleneck, [3, 4, 6, 3],
18
                      radix=2, cardinality=1, bottleneck_width=64,
19
                      deep_stem=True, avg_down=True,
20
                      avd=True, avd_first=False,
21
                      use_splat=True, dropblock_prob=0.1,
22
                      name_prefix='resnest_', **kwargs)
23
    if pretrained:
24
        from .model_store import get_model_file
25
        model.load_parameters(get_model_file('resnest50', root=root), ctx=ctx)
26
    return model
27
28
def resnest101(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
29
    model = ResNet(Bottleneck, [3, 4, 23, 3],
30
                      radix=2, cardinality=1, bottleneck_width=64,
31
                      deep_stem=True, avg_down=True, stem_width=64,
32
                      avd=True, avd_first=False, use_splat=True, dropblock_prob=0.1,
33
                      name_prefix='resnest_', **kwargs)
34
    if pretrained:
35
        from .model_store import get_model_file
36
        model.load_parameters(get_model_file('resnest101', root=root), ctx=ctx)
37
    return model
38
39
def resnest200(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
40
    model = ResNet(Bottleneck, [3, 24, 36, 3], deep_stem=True, avg_down=True, stem_width=64,
41
                      avd=True, use_splat=True, dropblock_prob=0.1, final_drop=0.2,
42
                      name_prefix='resnest_', **kwargs)
43
    if pretrained:
44
        from .model_store import get_model_file
45
        model.load_parameters(get_model_file('resnest200', root=root), ctx=ctx)
46
    return model
47
48
def resnest269(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
49
    model = ResNet(Bottleneck, [3, 30, 48, 8], deep_stem=True, avg_down=True, stem_width=64,
50
                      avd=True, use_splat=True, dropblock_prob=0.1, final_drop=0.2,
51
                      name_prefix='resnest_', **kwargs)
52
    if pretrained:
53
        from .model_store import get_model_file
54
        model.load_parameters(get_model_file('resnest269', root=root), ctx=ctx)
55
    return model