a b/tests/test_data/test_tta.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import os.path as osp
3
4
import mmcv
5
import pytest
6
from mmcv.utils import build_from_cfg
7
8
from mmseg.datasets.builder import PIPELINES
9
10
11
def test_multi_scale_flip_aug():
12
    # test assertion if img_scale=None, img_ratios=1 (not float).
13
    with pytest.raises(AssertionError):
14
        tta_transform = dict(
15
            type='MultiScaleFlipAug',
16
            img_scale=None,
17
            img_ratios=1,
18
            transforms=[dict(type='Resize', keep_ratio=False)],
19
        )
20
        build_from_cfg(tta_transform, PIPELINES)
21
22
    # test assertion if img_scale=None, img_ratios=None.
23
    with pytest.raises(AssertionError):
24
        tta_transform = dict(
25
            type='MultiScaleFlipAug',
26
            img_scale=None,
27
            img_ratios=None,
28
            transforms=[dict(type='Resize', keep_ratio=False)],
29
        )
30
        build_from_cfg(tta_transform, PIPELINES)
31
32
    # test assertion if img_scale=(512, 512), img_ratios=1 (not float).
33
    with pytest.raises(AssertionError):
34
        tta_transform = dict(
35
            type='MultiScaleFlipAug',
36
            img_scale=(512, 512),
37
            img_ratios=1,
38
            transforms=[dict(type='Resize', keep_ratio=False)],
39
        )
40
        build_from_cfg(tta_transform, PIPELINES)
41
42
    tta_transform = dict(
43
        type='MultiScaleFlipAug',
44
        img_scale=(512, 512),
45
        img_ratios=[0.5, 1.0, 2.0],
46
        flip=False,
47
        transforms=[dict(type='Resize', keep_ratio=False)],
48
    )
49
    tta_module = build_from_cfg(tta_transform, PIPELINES)
50
51
    results = dict()
52
    # (288, 512, 3)
53
    img = mmcv.imread(
54
        osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
55
    results['img'] = img
56
    results['img_shape'] = img.shape
57
    results['ori_shape'] = img.shape
58
    # Set initial values for default meta_keys
59
    results['pad_shape'] = img.shape
60
    results['scale_factor'] = 1.0
61
62
    tta_results = tta_module(results.copy())
63
    assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
64
    assert tta_results['flip'] == [False, False, False]
65
66
    tta_transform = dict(
67
        type='MultiScaleFlipAug',
68
        img_scale=(512, 512),
69
        img_ratios=[0.5, 1.0, 2.0],
70
        flip=True,
71
        transforms=[dict(type='Resize', keep_ratio=False)],
72
    )
73
    tta_module = build_from_cfg(tta_transform, PIPELINES)
74
    tta_results = tta_module(results.copy())
75
    assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
76
                                    (512, 512), (1024, 1024), (1024, 1024)]
77
    assert tta_results['flip'] == [False, True, False, True, False, True]
78
79
    tta_transform = dict(
80
        type='MultiScaleFlipAug',
81
        img_scale=(512, 512),
82
        img_ratios=1.0,
83
        flip=False,
84
        transforms=[dict(type='Resize', keep_ratio=False)],
85
    )
86
    tta_module = build_from_cfg(tta_transform, PIPELINES)
87
    tta_results = tta_module(results.copy())
88
    assert tta_results['scale'] == [(512, 512)]
89
    assert tta_results['flip'] == [False]
90
91
    tta_transform = dict(
92
        type='MultiScaleFlipAug',
93
        img_scale=(512, 512),
94
        img_ratios=1.0,
95
        flip=True,
96
        transforms=[dict(type='Resize', keep_ratio=False)],
97
    )
98
    tta_module = build_from_cfg(tta_transform, PIPELINES)
99
    tta_results = tta_module(results.copy())
100
    assert tta_results['scale'] == [(512, 512), (512, 512)]
101
    assert tta_results['flip'] == [False, True]
102
103
    tta_transform = dict(
104
        type='MultiScaleFlipAug',
105
        img_scale=None,
106
        img_ratios=[0.5, 1.0, 2.0],
107
        flip=False,
108
        transforms=[dict(type='Resize', keep_ratio=False)],
109
    )
110
    tta_module = build_from_cfg(tta_transform, PIPELINES)
111
    tta_results = tta_module(results.copy())
112
    assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
113
    assert tta_results['flip'] == [False, False, False]
114
115
    tta_transform = dict(
116
        type='MultiScaleFlipAug',
117
        img_scale=None,
118
        img_ratios=[0.5, 1.0, 2.0],
119
        flip=True,
120
        transforms=[dict(type='Resize', keep_ratio=False)],
121
    )
122
    tta_module = build_from_cfg(tta_transform, PIPELINES)
123
    tta_results = tta_module(results.copy())
124
    assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
125
                                    (512, 288), (1024, 576), (1024, 576)]
126
    assert tta_results['flip'] == [False, True, False, True, False, True]
127
128
    tta_transform = dict(
129
        type='MultiScaleFlipAug',
130
        img_scale=[(256, 256), (512, 512), (1024, 1024)],
131
        img_ratios=None,
132
        flip=False,
133
        transforms=[dict(type='Resize', keep_ratio=False)],
134
    )
135
    tta_module = build_from_cfg(tta_transform, PIPELINES)
136
    tta_results = tta_module(results.copy())
137
    assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
138
    assert tta_results['flip'] == [False, False, False]
139
140
    tta_transform = dict(
141
        type='MultiScaleFlipAug',
142
        img_scale=[(256, 256), (512, 512), (1024, 1024)],
143
        img_ratios=None,
144
        flip=True,
145
        transforms=[dict(type='Resize', keep_ratio=False)],
146
    )
147
    tta_module = build_from_cfg(tta_transform, PIPELINES)
148
    tta_results = tta_module(results.copy())
149
    assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
150
                                    (512, 512), (1024, 1024), (1024, 1024)]
151
    assert tta_results['flip'] == [False, True, False, True, False, True]