[a18f15]: / algorithms / arch / resnet.py

Download this file

53 lines (43 with data), 1.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os, sys
import torch
import torchvision
from torch import nn
def loadResnetBackbone(arch, torch_pretrain= None, freeze= False):
## pretrain setting
if torch_pretrain in ["DEFAULT", "IMAGENET-1K"]:
torch_pretrain = "DEFAULT" #"IMAGENET1K_V2"
elif torch_pretrain in [None, "NONE", "none"]:
torch_pretrain = None
else:
raise ValueError("Unknown pretrain weight type requested ", torch_pretrain )
print("Torch Pretrain Set to ...", torch_pretrain)
## Model loading
if arch == 'resnet18':
backbone = torchvision.models.resnet18(zero_init_residual=True,
weights=torch_pretrain)
outfeat_size = 512
elif arch == 'resnet34':
backbone = torchvision.models.resnet34(zero_init_residual=True,
weights=torch_pretrain)
outfeat_size = 512
elif arch == 'resnet50':
backbone = torchvision.models.resnet50(zero_init_residual=True,
weights=torch_pretrain)
outfeat_size = 2048
elif arch == 'resnet101':
backbone = torchvision.models.resnet101(zero_init_residual=True,
weights=torch_pretrain)
outfeat_size = 2048
elif arch == 'resnet152':
backbone = torchvision.models.resnet152(zero_init_residual=True,
weights=torch_pretrain)
outfeat_size = 2048
else:
raise ValueError(f"Unknown Model Implementation called in {os.path.basename(__file__)}")
backbone.fc = nn.Identity() #remove fc of default arch
# freeze model
if freeze:
print("Freezing Resnet weights ...")
for param in backbone.parameters():
param.requires_grad = False
return backbone, outfeat_size