a b/BioSeqNet/resnest/gluon/model_zoo.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
9
from .resnest import *
10
from .ablation import *
11
12
_all__ = ['get_model', 'get_model_list']
13
14
models = {
15
    'resnest50': resnest50,
16
    'resnest101': resnest101,
17
    'resnest200': resnest200,
18
    'resnest269': resnest269,
19
    'resnest50_fast_1s1x64d': resnest50_fast_1s1x64d,
20
    'resnest50_fast_2s1x64d': resnest50_fast_2s1x64d,
21
    'resnest50_fast_4s1x64d': resnest50_fast_4s1x64d,
22
    'resnest50_fast_1s2x40d': resnest50_fast_1s2x40d,
23
    'resnest50_fast_2s2x40d': resnest50_fast_2s2x40d,
24
    'resnest50_fast_4s2x40d': resnest50_fast_4s2x40d,
25
    'resnest50_fast_1s4x24d': resnest50_fast_1s4x24d,
26
    }
27
28
def get_model(name, **kwargs):
29
    """Returns a pre-defined model by name
30
    Parameters
31
    ----------
32
    name : str
33
        Name of the model.
34
    pretrained : bool
35
        Whether to load the pretrained weights for model.
36
    root : str, default '~/.encoding/models'
37
        Location for keeping the model parameters.
38
    Returns
39
    -------
40
    Module:
41
        The model.
42
    """
43
44
    name = name.lower()
45
    if name in models:
46
        net = models[name](**kwargs)
47
    else:
48
        raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
49
    return net
50
51
def get_model_list():
52
    """Get the entire list of model names in model_zoo.
53
    Returns
54
    -------
55
    list of str
56
        Entire list of model names in model_zoo.
57
    """
58
    return models.keys()
59