Diff of /mmseg/ops/encoding.py [000000] .. [4e96d3]

Switch to side-by-side view

--- a
+++ b/mmseg/ops/encoding.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Encoding(nn.Module):
+    """Encoding Layer: a learnable residual encoder.
+
+    Input is of shape  (batch_size, channels, height, width).
+    Output is of shape (batch_size, num_codes, channels).
+
+    Args:
+        channels: dimension of the features or feature channels
+        num_codes: number of code words
+    """
+
+    def __init__(self, channels, num_codes):
+        super(Encoding, self).__init__()
+        # init codewords and smoothing factor
+        self.channels, self.num_codes = channels, num_codes
+        std = 1. / ((num_codes * channels)**0.5)
+        # [num_codes, channels]
+        self.codewords = nn.Parameter(
+            torch.empty(num_codes, channels,
+                        dtype=torch.float).uniform_(-std, std),
+            requires_grad=True)
+        # [num_codes]
+        self.scale = nn.Parameter(
+            torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
+            requires_grad=True)
+
+    @staticmethod
+    def scaled_l2(x, codewords, scale):
+        num_codes, channels = codewords.size()
+        batch_size = x.size(0)
+        reshaped_scale = scale.view((1, 1, num_codes))
+        expanded_x = x.unsqueeze(2).expand(
+            (batch_size, x.size(1), num_codes, channels))
+        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+
+        scaled_l2_norm = reshaped_scale * (
+            expanded_x - reshaped_codewords).pow(2).sum(dim=3)
+        return scaled_l2_norm
+
+    @staticmethod
+    def aggregate(assignment_weights, x, codewords):
+        num_codes, channels = codewords.size()
+        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+        batch_size = x.size(0)
+
+        expanded_x = x.unsqueeze(2).expand(
+            (batch_size, x.size(1), num_codes, channels))
+        encoded_feat = (assignment_weights.unsqueeze(3) *
+                        (expanded_x - reshaped_codewords)).sum(dim=1)
+        return encoded_feat
+
+    def forward(self, x):
+        assert x.dim() == 4 and x.size(1) == self.channels
+        # [batch_size, channels, height, width]
+        batch_size = x.size(0)
+        # [batch_size, height x width, channels]
+        x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
+        # assignment_weights: [batch_size, channels, num_codes]
+        assignment_weights = F.softmax(
+            self.scaled_l2(x, self.codewords, self.scale), dim=2)
+        # aggregate
+        encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
+        return encoded_feat
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
+                    f'x{self.channels})'
+        return repr_str