|
a |
|
b/algorithms/arch/resnet.py |
|
|
1 |
import os, sys |
|
|
2 |
import torch |
|
|
3 |
import torchvision |
|
|
4 |
from torch import nn |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def loadResnetBackbone(arch, torch_pretrain= None, freeze= False): |
|
|
9 |
|
|
|
10 |
## pretrain setting |
|
|
11 |
if torch_pretrain in ["DEFAULT", "IMAGENET-1K"]: |
|
|
12 |
torch_pretrain = "DEFAULT" #"IMAGENET1K_V2" |
|
|
13 |
elif torch_pretrain in [None, "NONE", "none"]: |
|
|
14 |
torch_pretrain = None |
|
|
15 |
else: |
|
|
16 |
raise ValueError("Unknown pretrain weight type requested ", torch_pretrain ) |
|
|
17 |
print("Torch Pretrain Set to ...", torch_pretrain) |
|
|
18 |
|
|
|
19 |
## Model loading |
|
|
20 |
if arch == 'resnet18': |
|
|
21 |
backbone = torchvision.models.resnet18(zero_init_residual=True, |
|
|
22 |
weights=torch_pretrain) |
|
|
23 |
outfeat_size = 512 |
|
|
24 |
elif arch == 'resnet34': |
|
|
25 |
backbone = torchvision.models.resnet34(zero_init_residual=True, |
|
|
26 |
weights=torch_pretrain) |
|
|
27 |
outfeat_size = 512 |
|
|
28 |
elif arch == 'resnet50': |
|
|
29 |
backbone = torchvision.models.resnet50(zero_init_residual=True, |
|
|
30 |
weights=torch_pretrain) |
|
|
31 |
outfeat_size = 2048 |
|
|
32 |
|
|
|
33 |
elif arch == 'resnet101': |
|
|
34 |
backbone = torchvision.models.resnet101(zero_init_residual=True, |
|
|
35 |
weights=torch_pretrain) |
|
|
36 |
outfeat_size = 2048 |
|
|
37 |
|
|
|
38 |
elif arch == 'resnet152': |
|
|
39 |
backbone = torchvision.models.resnet152(zero_init_residual=True, |
|
|
40 |
weights=torch_pretrain) |
|
|
41 |
outfeat_size = 2048 |
|
|
42 |
|
|
|
43 |
else: |
|
|
44 |
raise ValueError(f"Unknown Model Implementation called in {os.path.basename(__file__)}") |
|
|
45 |
backbone.fc = nn.Identity() #remove fc of default arch |
|
|
46 |
|
|
|
47 |
# freeze model |
|
|
48 |
if freeze: |
|
|
49 |
print("Freezing Resnet weights ...") |
|
|
50 |
for param in backbone.parameters(): |
|
|
51 |
param.requires_grad = False |
|
|
52 |
|
|
|
53 |
return backbone, outfeat_size |