Diff of /OurModel.py [000000] .. [2507a0]

Switch to side-by-side view

--- a
+++ b/OurModel.py
@@ -0,0 +1,85 @@
+import torch
+import torchvision
+import torch.nn.functional as F
+
+import pandas as pd
+import numpy as np
+
+
+class Block(torch.nn.Module):
+    def __init__(self, in_channels, mid_channel, out_channels, max_pool_kernel_size, batch_norm=False):
+        super().__init__()
+
+
+        self.max_pool_kernel_size=max_pool_kernel_size
+        self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=mid_channel, kernel_size=3, padding=1)
+        self.conv2 = torch.nn.Conv2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, padding=1)
+
+        self.batch_norm = batch_norm
+        if batch_norm:
+            self.bn1 = torch.nn.BatchNorm2d(mid_channel)
+            self.bn2 = torch.nn.BatchNorm2d(out_channels)
+            self.bn3 = torch.nn.BatchNorm2d(out_channels)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        if self.batch_norm:
+            x = self.bn1(x)
+        x = torch.nn.functional.relu(x, inplace=True)
+
+        x = self.conv2(x)
+        if self.batch_norm:
+            x = self.bn2(x)
+        x = torch.nn.functional.relu(x, inplace=True)
+
+        if self.max_pool_kernel_size!=1:
+            x = torch.nn.functional.max_pool2d(x, kernel_size=self.max_pool_kernel_size)
+        if self.batch_norm:
+            x = self.bn3(x)
+        out = x
+        return out
+
+ 
+
+
+class CxlNet(torch.nn.Module):
+    def up(self, x, size):
+        return torch.nn.functional.interpolate(x, size=size, mode=self.upscale_mode)
+
+    def down(self, x):
+        return torch.nn.functional.max_pool2d(x, kernel_size=2)
+
+
+    def __init__(self, in_channels, out_channels, batch_norm=False, upscale_mode="nearest",image_size=512):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.batch_norm = batch_norm
+        self.upscale_mode = upscale_mode
+        self.image_size=image_size
+        self.enc1 = Block(in_channels, 32, 64,2, batch_norm)
+        self.enc2 = Block(64, 64, 64, 2, batch_norm)
+        self.enc3 = Block(64, 128, 128, 2, batch_norm)
+        self.enc4 = Block(128, 256, 256, 2, batch_norm)
+        #self.enc3 = Block(256, 128, 128, 2, batch_norm)
+        #self.enc4 = Block(128, 64, 64, 2, batch_norm)
+
+        self.dec3 = Block(512, 256, 256, 1, batch_norm)
+        self.dec2 = Block(256, 128, 128, 1, batch_norm)
+        self.dec1 = Block(128, 64, 64, 1, batch_norm)
+        self.dec0 = Block(64, 32, out_channels, 1, batch_norm)
+    def forward(self, x):
+        enc1 = self.enc1(x)
+        enc2 = self.enc2(enc1)
+        enc3 = self.enc3(enc2)
+        enc4 = self.enc4(enc3)
+        outOfDec3 = self.dec3(torch.cat([enc1,
+                                    self.up(enc2, enc1.size()[-2:]),
+                                    self.up(enc3, enc1.size()[-2:]),
+                                    self.up(enc4, enc1.size()[-2:]),
+        ], 1))
+
+        outOfDec2 = self.dec2(self.up(outOfDec3, (self.image_size,self.image_size)))
+        outOfDec1 = self.dec1(outOfDec2)
+        outOfDec0 = self.dec0(outOfDec1)
+        return outOfDec0
\ No newline at end of file