[40f229]: / model / network / basic_blocks.py

Download this file

29 lines (24 with data), 896 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
def forward(self, x):
x = self.conv(x)
return F.leaky_relu(x, inplace=True)
class SetBlock(nn.Module):
def __init__(self, forward_block, pooling=False):
super(SetBlock, self).__init__()
self.forward_block = forward_block
self.pooling = pooling
if pooling:
self.pool2d = nn.MaxPool2d(2)
def forward(self, x):
n, s, c, h, w = x.size()
x = self.forward_block(x.view(-1,c,h,w))
if self.pooling:
x = self.pool2d(x)
_, c, h, w = x.size()
return x.view(n, s, c, h ,w)