|
a |
|
b/BioSeqNet/resnest/gluon/model_zoo.py |
|
|
1 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
2 |
## Created by: Hang Zhang |
|
|
3 |
## Email: zhanghang0704@gmail.com |
|
|
4 |
## Copyright (c) 2020 |
|
|
5 |
## |
|
|
6 |
## LICENSE file in the root directory of this source tree |
|
|
7 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
8 |
|
|
|
9 |
from .resnest import * |
|
|
10 |
from .ablation import * |
|
|
11 |
|
|
|
12 |
_all__ = ['get_model', 'get_model_list'] |
|
|
13 |
|
|
|
14 |
models = { |
|
|
15 |
'resnest50': resnest50, |
|
|
16 |
'resnest101': resnest101, |
|
|
17 |
'resnest200': resnest200, |
|
|
18 |
'resnest269': resnest269, |
|
|
19 |
'resnest50_fast_1s1x64d': resnest50_fast_1s1x64d, |
|
|
20 |
'resnest50_fast_2s1x64d': resnest50_fast_2s1x64d, |
|
|
21 |
'resnest50_fast_4s1x64d': resnest50_fast_4s1x64d, |
|
|
22 |
'resnest50_fast_1s2x40d': resnest50_fast_1s2x40d, |
|
|
23 |
'resnest50_fast_2s2x40d': resnest50_fast_2s2x40d, |
|
|
24 |
'resnest50_fast_4s2x40d': resnest50_fast_4s2x40d, |
|
|
25 |
'resnest50_fast_1s4x24d': resnest50_fast_1s4x24d, |
|
|
26 |
} |
|
|
27 |
|
|
|
28 |
def get_model(name, **kwargs): |
|
|
29 |
"""Returns a pre-defined model by name |
|
|
30 |
Parameters |
|
|
31 |
---------- |
|
|
32 |
name : str |
|
|
33 |
Name of the model. |
|
|
34 |
pretrained : bool |
|
|
35 |
Whether to load the pretrained weights for model. |
|
|
36 |
root : str, default '~/.encoding/models' |
|
|
37 |
Location for keeping the model parameters. |
|
|
38 |
Returns |
|
|
39 |
------- |
|
|
40 |
Module: |
|
|
41 |
The model. |
|
|
42 |
""" |
|
|
43 |
|
|
|
44 |
name = name.lower() |
|
|
45 |
if name in models: |
|
|
46 |
net = models[name](**kwargs) |
|
|
47 |
else: |
|
|
48 |
raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys())))) |
|
|
49 |
return net |
|
|
50 |
|
|
|
51 |
def get_model_list(): |
|
|
52 |
"""Get the entire list of model names in model_zoo. |
|
|
53 |
Returns |
|
|
54 |
------- |
|
|
55 |
list of str |
|
|
56 |
Entire list of model names in model_zoo. |
|
|
57 |
""" |
|
|
58 |
return models.keys() |
|
|
59 |
|