|
a |
|
b/model/network/basic_blocks.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
class BasicConv2d(nn.Module): |
|
|
6 |
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): |
|
|
7 |
super(BasicConv2d, self).__init__() |
|
|
8 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) |
|
|
9 |
|
|
|
10 |
def forward(self, x): |
|
|
11 |
x = self.conv(x) |
|
|
12 |
return F.leaky_relu(x, inplace=True) |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class SetBlock(nn.Module): |
|
|
16 |
def __init__(self, forward_block, pooling=False): |
|
|
17 |
super(SetBlock, self).__init__() |
|
|
18 |
self.forward_block = forward_block |
|
|
19 |
self.pooling = pooling |
|
|
20 |
if pooling: |
|
|
21 |
self.pool2d = nn.MaxPool2d(2) |
|
|
22 |
def forward(self, x): |
|
|
23 |
n, s, c, h, w = x.size() |
|
|
24 |
x = self.forward_block(x.view(-1,c,h,w)) |
|
|
25 |
if self.pooling: |
|
|
26 |
x = self.pool2d(x) |
|
|
27 |
_, c, h, w = x.size() |
|
|
28 |
return x.view(n, s, c, h ,w) |