Diff of /src/model.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/model.py
@@ -0,0 +1,156 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import numpy as np
+import torchvision
+from torchvision import datasets, models, transforms
+import matplotlib.pyplot as plt
+import time
+import os
+import copy
+
+import torch.nn.functional as F
+
+RESNET = ['resnet18','resnet34','resnet50','resnet101','resnet152']
+
+def set_parameter_requires_grad(model, feature_extracting):
+    if feature_extracting:
+        for param in model.parameters():
+            param.requires_grad = False
+
+def initialize_model(conf):
+    model_name = conf['model']['name']
+    feature_extract = conf['model']['feature_extract']
+    use_pretrained = conf['model']['use_pretrained']
+    print_model = conf['model']['print_model']
+    num_classes = len(conf['data']['classes'])
+
+    if model_name in RESNET:
+        model = getattr(models, model_name)(pretrained=use_pretrained)
+        set_parameter_requires_grad(model, feature_extract)
+        num_ftrs = model.fc.in_features
+        model.fc = nn.Linear(num_ftrs, num_classes)
+    elif model_name == 'efficientdet_d0':
+        model = EfficientClassification(num_classes)
+    else:
+        print("Invalid model name, exiting...")
+        exit()
+
+    if print_model: print(model)
+    model.name = model_name
+    return model
+
+class EfficientClassification(nn.Module):
+
+    def __init__(self, num_classes):
+        super(EfficientClassification, self).__init__()
+        from effdet import create_model
+        self.effdet = create_model(model_name='efficientdet_d0')
+        self.effdet.box_net = nn.Identity()
+        self.effdet.class_net = nn.Identity()
+        self.resnet = models.resnet18(pretrained=True)
+
+        num_ftrs = self.resnet.fc.in_features
+        self.resnet.fc = nn.Linear(num_ftrs, num_classes)
+
+        self.deconv0 = nn.ConvTranspose2d(in_channels=64,
+                                          out_channels=16,
+                                          kernel_size=19,
+                                          stride=3,
+                                          padding=1,
+                                          dilation=2)
+
+        self.deconv1 = nn.ConvTranspose2d(in_channels=64,
+                                          out_channels=12,
+                                          kernel_size=9,
+                                          stride=7,
+                                          padding=1,
+                                          dilation=1)
+
+        self.deconv2 = nn.ConvTranspose2d(in_channels=64,
+                                          out_channels=8,
+                                          kernel_size=24,
+                                          stride=9,
+                                          padding=2,
+                                          dilation=4)
+
+        self.deconv3 = nn.ConvTranspose2d(in_channels=64,
+                                          out_channels=4,
+                                          kernel_size=28,
+                                          stride=9,
+                                          padding=1,
+                                          dilation=6)
+
+        self.deconv4 = nn.ConvTranspose2d(in_channels=64,
+                                          out_channels=2,
+                                          kernel_size=30,
+                                          stride=8,
+                                          padding=2,
+                                          dilation=7)
+
+        self.conv0 = nn.Conv2d(in_channels=42,
+                              out_channels=16,
+                              kernel_size=5,
+                              padding=2)
+
+        self.conv1 = nn.Conv2d(in_channels=16,
+                              out_channels=3,
+                              kernel_size=3,
+                              padding=1)
+
+    def forward(self, x):
+        # EffNet + BiFPN
+        fpn_out, _ = self.effdet(x)
+
+        # Convolution Transpose
+        out0 = self.deconv0(fpn_out[0])
+        out1 = self.deconv1(fpn_out[1])
+        out2 = self.deconv2(fpn_out[2])
+        out3 = self.deconv3(fpn_out[3])
+        out4 = self.deconv4(fpn_out[4])
+        deconv_out = torch.cat([out0,out1,out2,out3,out4], dim=1)
+
+        # Convolution
+        conv_out = self.conv1(self.conv0(deconv_out))
+
+        # Resnet18
+        out = self.resnet(conv_out)
+        return out
+
+class EfficientClassification2(nn.Module):
+
+    def __init__(self, num_classes):
+        super(EfficientClassification2, self).__init__()
+        from effdet import create_model
+        self.effdet = create_model(model_name='efficientdet_d0')
+        self.effdet.box_net = nn.Identity()
+        self.effdet.class_net = nn.Identity()
+
+        # In features from FPN
+        fc_in_features = [64 * i*i for i in [64,32,16,8,4]]
+        mid = 64
+        self.fc0 = nn.Linear(fc_in_features[0], mid)
+        self.fc1 = nn.Linear(fc_in_features[1], mid)
+        self.fc2 = nn.Linear(fc_in_features[2], mid)
+        self.fc3 = nn.Linear(fc_in_features[3], mid)
+        self.fc4 = nn.Linear(fc_in_features[4], mid)
+        self.fc_out = nn.Linear(5 * mid, num_classes)
+
+    def forward(self, x):
+        fpn_out, _ = self.effdet(x)
+        fpn_out = list(map(lambda t: torch.flatten(t, start_dim=1), fpn_out))
+        out0 = self.fc0(fpn_out[0])
+        out1 = self.fc1(fpn_out[1])
+        out2 = self.fc2(fpn_out[2])
+        out3 = self.fc3(fpn_out[3])
+        out4 = self.fc4(fpn_out[4])
+        fc_outs = torch.cat([out0,out1,out2,out3,out4], dim=1)
+        out = self.fc_out(fc_outs)
+        return out
+
+if __name__ == '__main__':
+
+    x = torch.randn(20, 3, 512, 512)
+    model = EfficientClassification(num_classes=2)
+    fpn_out = model(x)
+    print('FIN')
\ No newline at end of file