Diff of /algorithms/arch/resnet.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/algorithms/arch/resnet.py
@@ -0,0 +1,53 @@
+import os, sys
+import torch
+import torchvision
+from torch import nn
+
+
+
+def loadResnetBackbone(arch, torch_pretrain= None, freeze= False):
+
+    ## pretrain setting
+    if torch_pretrain in ["DEFAULT", "IMAGENET-1K"]:
+        torch_pretrain = "DEFAULT" #"IMAGENET1K_V2"
+    elif torch_pretrain in [None, "NONE", "none"]:
+        torch_pretrain = None
+    else:
+        raise ValueError("Unknown pretrain weight type requested ", torch_pretrain )
+    print("Torch Pretrain Set to ...", torch_pretrain)
+
+    ## Model loading
+    if arch == 'resnet18':
+        backbone = torchvision.models.resnet18(zero_init_residual=True,
+                                weights=torch_pretrain)
+        outfeat_size = 512
+    elif arch == 'resnet34':
+        backbone = torchvision.models.resnet34(zero_init_residual=True,
+                            weights=torch_pretrain)
+        outfeat_size = 512
+    elif arch == 'resnet50':
+        backbone = torchvision.models.resnet50(zero_init_residual=True,
+                            weights=torch_pretrain)
+        outfeat_size = 2048
+
+    elif arch == 'resnet101':
+        backbone = torchvision.models.resnet101(zero_init_residual=True,
+                            weights=torch_pretrain)
+        outfeat_size = 2048
+
+    elif arch == 'resnet152':
+        backbone = torchvision.models.resnet152(zero_init_residual=True,
+                            weights=torch_pretrain)
+        outfeat_size = 2048
+
+    else:
+        raise ValueError(f"Unknown Model Implementation called in {os.path.basename(__file__)}")
+    backbone.fc = nn.Identity() #remove fc of default arch
+
+    # freeze model
+    if freeze:
+        print("Freezing Resnet weights ...")
+        for param in backbone.parameters():
+            param.requires_grad = False
+
+    return backbone, outfeat_size
\ No newline at end of file