a b/algorithms/arch/resnet.py
1
import os, sys
2
import torch
3
import torchvision
4
from torch import nn
5
6
7
8
def loadResnetBackbone(arch, torch_pretrain= None, freeze= False):
9
10
    ## pretrain setting
11
    if torch_pretrain in ["DEFAULT", "IMAGENET-1K"]:
12
        torch_pretrain = "DEFAULT" #"IMAGENET1K_V2"
13
    elif torch_pretrain in [None, "NONE", "none"]:
14
        torch_pretrain = None
15
    else:
16
        raise ValueError("Unknown pretrain weight type requested ", torch_pretrain )
17
    print("Torch Pretrain Set to ...", torch_pretrain)
18
19
    ## Model loading
20
    if arch == 'resnet18':
21
        backbone = torchvision.models.resnet18(zero_init_residual=True,
22
                                weights=torch_pretrain)
23
        outfeat_size = 512
24
    elif arch == 'resnet34':
25
        backbone = torchvision.models.resnet34(zero_init_residual=True,
26
                            weights=torch_pretrain)
27
        outfeat_size = 512
28
    elif arch == 'resnet50':
29
        backbone = torchvision.models.resnet50(zero_init_residual=True,
30
                            weights=torch_pretrain)
31
        outfeat_size = 2048
32
33
    elif arch == 'resnet101':
34
        backbone = torchvision.models.resnet101(zero_init_residual=True,
35
                            weights=torch_pretrain)
36
        outfeat_size = 2048
37
38
    elif arch == 'resnet152':
39
        backbone = torchvision.models.resnet152(zero_init_residual=True,
40
                            weights=torch_pretrain)
41
        outfeat_size = 2048
42
43
    else:
44
        raise ValueError(f"Unknown Model Implementation called in {os.path.basename(__file__)}")
45
    backbone.fc = nn.Identity() #remove fc of default arch
46
47
    # freeze model
48
    if freeze:
49
        print("Freezing Resnet weights ...")
50
        for param in backbone.parameters():
51
            param.requires_grad = False
52
53
    return backbone, outfeat_size