|
a |
|
b/mmseg/ops/encoding.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import torch |
|
|
3 |
from torch import nn |
|
|
4 |
from torch.nn import functional as F |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
class Encoding(nn.Module): |
|
|
8 |
"""Encoding Layer: a learnable residual encoder. |
|
|
9 |
|
|
|
10 |
Input is of shape (batch_size, channels, height, width). |
|
|
11 |
Output is of shape (batch_size, num_codes, channels). |
|
|
12 |
|
|
|
13 |
Args: |
|
|
14 |
channels: dimension of the features or feature channels |
|
|
15 |
num_codes: number of code words |
|
|
16 |
""" |
|
|
17 |
|
|
|
18 |
def __init__(self, channels, num_codes): |
|
|
19 |
super(Encoding, self).__init__() |
|
|
20 |
# init codewords and smoothing factor |
|
|
21 |
self.channels, self.num_codes = channels, num_codes |
|
|
22 |
std = 1. / ((num_codes * channels)**0.5) |
|
|
23 |
# [num_codes, channels] |
|
|
24 |
self.codewords = nn.Parameter( |
|
|
25 |
torch.empty(num_codes, channels, |
|
|
26 |
dtype=torch.float).uniform_(-std, std), |
|
|
27 |
requires_grad=True) |
|
|
28 |
# [num_codes] |
|
|
29 |
self.scale = nn.Parameter( |
|
|
30 |
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), |
|
|
31 |
requires_grad=True) |
|
|
32 |
|
|
|
33 |
@staticmethod |
|
|
34 |
def scaled_l2(x, codewords, scale): |
|
|
35 |
num_codes, channels = codewords.size() |
|
|
36 |
batch_size = x.size(0) |
|
|
37 |
reshaped_scale = scale.view((1, 1, num_codes)) |
|
|
38 |
expanded_x = x.unsqueeze(2).expand( |
|
|
39 |
(batch_size, x.size(1), num_codes, channels)) |
|
|
40 |
reshaped_codewords = codewords.view((1, 1, num_codes, channels)) |
|
|
41 |
|
|
|
42 |
scaled_l2_norm = reshaped_scale * ( |
|
|
43 |
expanded_x - reshaped_codewords).pow(2).sum(dim=3) |
|
|
44 |
return scaled_l2_norm |
|
|
45 |
|
|
|
46 |
@staticmethod |
|
|
47 |
def aggregate(assignment_weights, x, codewords): |
|
|
48 |
num_codes, channels = codewords.size() |
|
|
49 |
reshaped_codewords = codewords.view((1, 1, num_codes, channels)) |
|
|
50 |
batch_size = x.size(0) |
|
|
51 |
|
|
|
52 |
expanded_x = x.unsqueeze(2).expand( |
|
|
53 |
(batch_size, x.size(1), num_codes, channels)) |
|
|
54 |
encoded_feat = (assignment_weights.unsqueeze(3) * |
|
|
55 |
(expanded_x - reshaped_codewords)).sum(dim=1) |
|
|
56 |
return encoded_feat |
|
|
57 |
|
|
|
58 |
def forward(self, x): |
|
|
59 |
assert x.dim() == 4 and x.size(1) == self.channels |
|
|
60 |
# [batch_size, channels, height, width] |
|
|
61 |
batch_size = x.size(0) |
|
|
62 |
# [batch_size, height x width, channels] |
|
|
63 |
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() |
|
|
64 |
# assignment_weights: [batch_size, channels, num_codes] |
|
|
65 |
assignment_weights = F.softmax( |
|
|
66 |
self.scaled_l2(x, self.codewords, self.scale), dim=2) |
|
|
67 |
# aggregate |
|
|
68 |
encoded_feat = self.aggregate(assignment_weights, x, self.codewords) |
|
|
69 |
return encoded_feat |
|
|
70 |
|
|
|
71 |
def __repr__(self): |
|
|
72 |
repr_str = self.__class__.__name__ |
|
|
73 |
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ |
|
|
74 |
f'x{self.channels})' |
|
|
75 |
return repr_str |