|
a |
|
b/OurModel.py |
|
|
1 |
import torch |
|
|
2 |
import torchvision |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
import pandas as pd |
|
|
6 |
import numpy as np |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class Block(torch.nn.Module): |
|
|
10 |
def __init__(self, in_channels, mid_channel, out_channels, max_pool_kernel_size, batch_norm=False): |
|
|
11 |
super().__init__() |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
self.max_pool_kernel_size=max_pool_kernel_size |
|
|
15 |
self.conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=mid_channel, kernel_size=3, padding=1) |
|
|
16 |
self.conv2 = torch.nn.Conv2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, padding=1) |
|
|
17 |
|
|
|
18 |
self.batch_norm = batch_norm |
|
|
19 |
if batch_norm: |
|
|
20 |
self.bn1 = torch.nn.BatchNorm2d(mid_channel) |
|
|
21 |
self.bn2 = torch.nn.BatchNorm2d(out_channels) |
|
|
22 |
self.bn3 = torch.nn.BatchNorm2d(out_channels) |
|
|
23 |
|
|
|
24 |
def forward(self, x): |
|
|
25 |
x = self.conv1(x) |
|
|
26 |
if self.batch_norm: |
|
|
27 |
x = self.bn1(x) |
|
|
28 |
x = torch.nn.functional.relu(x, inplace=True) |
|
|
29 |
|
|
|
30 |
x = self.conv2(x) |
|
|
31 |
if self.batch_norm: |
|
|
32 |
x = self.bn2(x) |
|
|
33 |
x = torch.nn.functional.relu(x, inplace=True) |
|
|
34 |
|
|
|
35 |
if self.max_pool_kernel_size!=1: |
|
|
36 |
x = torch.nn.functional.max_pool2d(x, kernel_size=self.max_pool_kernel_size) |
|
|
37 |
if self.batch_norm: |
|
|
38 |
x = self.bn3(x) |
|
|
39 |
out = x |
|
|
40 |
return out |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
|
|
|
44 |
|
|
|
45 |
class CxlNet(torch.nn.Module): |
|
|
46 |
def up(self, x, size): |
|
|
47 |
return torch.nn.functional.interpolate(x, size=size, mode=self.upscale_mode) |
|
|
48 |
|
|
|
49 |
def down(self, x): |
|
|
50 |
return torch.nn.functional.max_pool2d(x, kernel_size=2) |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def __init__(self, in_channels, out_channels, batch_norm=False, upscale_mode="nearest",image_size=512): |
|
|
54 |
super().__init__() |
|
|
55 |
self.in_channels = in_channels |
|
|
56 |
self.out_channels = out_channels |
|
|
57 |
self.batch_norm = batch_norm |
|
|
58 |
self.upscale_mode = upscale_mode |
|
|
59 |
self.image_size=image_size |
|
|
60 |
self.enc1 = Block(in_channels, 32, 64,2, batch_norm) |
|
|
61 |
self.enc2 = Block(64, 64, 64, 2, batch_norm) |
|
|
62 |
self.enc3 = Block(64, 128, 128, 2, batch_norm) |
|
|
63 |
self.enc4 = Block(128, 256, 256, 2, batch_norm) |
|
|
64 |
#self.enc3 = Block(256, 128, 128, 2, batch_norm) |
|
|
65 |
#self.enc4 = Block(128, 64, 64, 2, batch_norm) |
|
|
66 |
|
|
|
67 |
self.dec3 = Block(512, 256, 256, 1, batch_norm) |
|
|
68 |
self.dec2 = Block(256, 128, 128, 1, batch_norm) |
|
|
69 |
self.dec1 = Block(128, 64, 64, 1, batch_norm) |
|
|
70 |
self.dec0 = Block(64, 32, out_channels, 1, batch_norm) |
|
|
71 |
def forward(self, x): |
|
|
72 |
enc1 = self.enc1(x) |
|
|
73 |
enc2 = self.enc2(enc1) |
|
|
74 |
enc3 = self.enc3(enc2) |
|
|
75 |
enc4 = self.enc4(enc3) |
|
|
76 |
outOfDec3 = self.dec3(torch.cat([enc1, |
|
|
77 |
self.up(enc2, enc1.size()[-2:]), |
|
|
78 |
self.up(enc3, enc1.size()[-2:]), |
|
|
79 |
self.up(enc4, enc1.size()[-2:]), |
|
|
80 |
], 1)) |
|
|
81 |
|
|
|
82 |
outOfDec2 = self.dec2(self.up(outOfDec3, (self.image_size,self.image_size))) |
|
|
83 |
outOfDec1 = self.dec1(outOfDec2) |
|
|
84 |
outOfDec0 = self.dec0(outOfDec1) |
|
|
85 |
return outOfDec0 |