a b/model/network/basic_blocks.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
class BasicConv2d(nn.Module):
6
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
7
        super(BasicConv2d, self).__init__()
8
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
9
10
    def forward(self, x):
11
        x = self.conv(x)
12
        return F.leaky_relu(x, inplace=True)
13
14
15
class SetBlock(nn.Module):
16
    def __init__(self, forward_block, pooling=False):
17
        super(SetBlock, self).__init__()
18
        self.forward_block = forward_block
19
        self.pooling = pooling
20
        if pooling:
21
            self.pool2d = nn.MaxPool2d(2)
22
    def forward(self, x):
23
        n, s, c, h, w = x.size()
24
        x = self.forward_block(x.view(-1,c,h,w))
25
        if self.pooling:
26
            x = self.pool2d(x)
27
        _, c, h, w = x.size()
28
        return x.view(n, s, c, h ,w)