a b/BioSeqNet/resnest/torch/resnest.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
"""ResNeSt models"""
9
10
import torch
11
from .resnet import ResNet, Bottleneck
12
13
__all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269', 'resnest14', 'resnest26']
14
15
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
16
17
_model_sha256 = {name: checksum for checksum, name in [
18
    ('528c19ca', 'resnest50'),
19
    ('22405ba7', 'resnest101'),
20
    ('75117900', 'resnest200'),
21
    ('0cc87c48', 'resnest269'),
22
    ]}
23
24
def short_hash(name):
25
    if name not in _model_sha256:
26
        raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
27
    return _model_sha256[name][:8]
28
29
resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
30
    name in _model_sha256.keys()
31
}
32
33
def resnest14(pretrained=False, root='~/.encoding/models', **kwargs):
34
    model = ResNet(Bottleneck, [1, 1, 1, 1],
35
                   radix=2, groups=1, bottleneck_width=64,
36
                   deep_stem=True, stem_width=32, avg_down=True,
37
                   avd=True, avd_first=False, **kwargs)
38
    if pretrained:
39
        model.load_state_dict(torch.hub.load_state_dict_from_url(
40
            resnest_model_urls['resnest14'], progress=True, check_hash=True))
41
    return model
42
43
def resnest26(pretrained=False, root='~/.encoding/models', **kwargs):
44
    model = ResNet(Bottleneck, [2, 2, 2, 2],
45
                   radix=2, groups=1, bottleneck_width=64,
46
                   deep_stem=True, stem_width=32, avg_down=True,
47
                   avd=True, avd_first=False, **kwargs)
48
    if pretrained:
49
        model.load_state_dict(torch.hub.load_state_dict_from_url(
50
            resnest_model_urls['resnest26'], progress=True, check_hash=True))
51
    return model
52
53
54
55
def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
56
    model = ResNet(Bottleneck, [3, 4, 6, 3],
57
                   radix=2, groups=1, bottleneck_width=64,
58
                   deep_stem=True, stem_width=32, avg_down=True,
59
                   avd=True, avd_first=False, **kwargs)
60
    if pretrained:
61
        model.load_state_dict(torch.hub.load_state_dict_from_url(
62
            resnest_model_urls['resnest50'], progress=True, check_hash=True))
63
    return model
64
65
def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
66
    model = ResNet(Bottleneck, [3, 4, 23, 3],
67
                   radix=2, groups=1, bottleneck_width=64,
68
                   deep_stem=True, stem_width=64, avg_down=True,
69
                   avd=True, avd_first=False, **kwargs)
70
    if pretrained:
71
        model.load_state_dict(torch.hub.load_state_dict_from_url(
72
            resnest_model_urls['resnest101'], progress=True, check_hash=True))
73
    return model
74
75
def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
76
    model = ResNet(Bottleneck, [3, 24, 36, 3],
77
                   radix=2, groups=1, bottleneck_width=64,
78
                   deep_stem=True, stem_width=64, avg_down=True,
79
                   avd=True, avd_first=False, **kwargs)
80
    if pretrained:
81
        model.load_state_dict(torch.hub.load_state_dict_from_url(
82
            resnest_model_urls['resnest200'], progress=True, check_hash=True))
83
    return model
84
85
def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
86
    model = ResNet(Bottleneck, [3, 30, 48, 8],
87
                   radix=2, groups=1, bottleneck_width=64,
88
                   deep_stem=True, stem_width=64, avg_down=True,
89
                   avd=True, avd_first=False, **kwargs)
90
    if pretrained:
91
        model.load_state_dict(torch.hub.load_state_dict_from_url(
92
            resnest_model_urls['resnest269'], progress=True, check_hash=True))
93
    return model