a b/mmseg/ops/wrappers.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3
4
import torch.nn as nn
5
import torch.nn.functional as F
6
7
8
def resize(input,
9
           size=None,
10
           scale_factor=None,
11
           mode='nearest',
12
           align_corners=None,
13
           warning=True):
14
    if warning:
15
        if size is not None and align_corners:
16
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
17
            output_h, output_w = tuple(int(x) for x in size)
18
            if output_h > input_h or output_w > output_h:
19
                if ((output_h > 1 and output_w > 1 and input_h > 1
20
                     and input_w > 1) and (output_h - 1) % (input_h - 1)
21
                        and (output_w - 1) % (input_w - 1)):
22
                    warnings.warn(
23
                        f'When align_corners={align_corners}, '
24
                        'the output would more aligned if '
25
                        f'input size {(input_h, input_w)} is `x+1` and '
26
                        f'out size {(output_h, output_w)} is `nx+1`')
27
    return F.interpolate(input, size, scale_factor, mode, align_corners)
28
29
30
class Upsample(nn.Module):
31
32
    def __init__(self,
33
                 size=None,
34
                 scale_factor=None,
35
                 mode='nearest',
36
                 align_corners=None):
37
        super(Upsample, self).__init__()
38
        self.size = size
39
        if isinstance(scale_factor, tuple):
40
            self.scale_factor = tuple(float(factor) for factor in scale_factor)
41
        else:
42
            self.scale_factor = float(scale_factor) if scale_factor else None
43
        self.mode = mode
44
        self.align_corners = align_corners
45
46
    def forward(self, x):
47
        if not self.size:
48
            size = [int(t * self.scale_factor) for t in x.shape[-2:]]
49
        else:
50
            size = self.size
51
        return resize(x, size, None, self.mode, self.align_corners)