Diff of /mmseg/ops/encoding.py [000000] .. [4e96d3]

Switch to unified view

a b/mmseg/ops/encoding.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import torch
3
from torch import nn
4
from torch.nn import functional as F
5
6
7
class Encoding(nn.Module):
8
    """Encoding Layer: a learnable residual encoder.
9
10
    Input is of shape  (batch_size, channels, height, width).
11
    Output is of shape (batch_size, num_codes, channels).
12
13
    Args:
14
        channels: dimension of the features or feature channels
15
        num_codes: number of code words
16
    """
17
18
    def __init__(self, channels, num_codes):
19
        super(Encoding, self).__init__()
20
        # init codewords and smoothing factor
21
        self.channels, self.num_codes = channels, num_codes
22
        std = 1. / ((num_codes * channels)**0.5)
23
        # [num_codes, channels]
24
        self.codewords = nn.Parameter(
25
            torch.empty(num_codes, channels,
26
                        dtype=torch.float).uniform_(-std, std),
27
            requires_grad=True)
28
        # [num_codes]
29
        self.scale = nn.Parameter(
30
            torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
31
            requires_grad=True)
32
33
    @staticmethod
34
    def scaled_l2(x, codewords, scale):
35
        num_codes, channels = codewords.size()
36
        batch_size = x.size(0)
37
        reshaped_scale = scale.view((1, 1, num_codes))
38
        expanded_x = x.unsqueeze(2).expand(
39
            (batch_size, x.size(1), num_codes, channels))
40
        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
41
42
        scaled_l2_norm = reshaped_scale * (
43
            expanded_x - reshaped_codewords).pow(2).sum(dim=3)
44
        return scaled_l2_norm
45
46
    @staticmethod
47
    def aggregate(assignment_weights, x, codewords):
48
        num_codes, channels = codewords.size()
49
        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
50
        batch_size = x.size(0)
51
52
        expanded_x = x.unsqueeze(2).expand(
53
            (batch_size, x.size(1), num_codes, channels))
54
        encoded_feat = (assignment_weights.unsqueeze(3) *
55
                        (expanded_x - reshaped_codewords)).sum(dim=1)
56
        return encoded_feat
57
58
    def forward(self, x):
59
        assert x.dim() == 4 and x.size(1) == self.channels
60
        # [batch_size, channels, height, width]
61
        batch_size = x.size(0)
62
        # [batch_size, height x width, channels]
63
        x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
64
        # assignment_weights: [batch_size, channels, num_codes]
65
        assignment_weights = F.softmax(
66
            self.scaled_l2(x, self.codewords, self.scale), dim=2)
67
        # aggregate
68
        encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
69
        return encoded_feat
70
71
    def __repr__(self):
72
        repr_str = self.__class__.__name__
73
        repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
74
                    f'x{self.channels})'
75
        return repr_str