[4e96d3]: / tests / test_models / test_backbones / test_mit.py

Download this file

114 lines (97 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import MixVisionTransformer
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
def test_mit():
with pytest.raises(TypeError):
# Pretrained represents pretrain url and must be str or None.
MixVisionTransformer(pretrained=123)
# Test normal input
H, W = (224, 224)
temp = torch.randn((1, 3, H, W))
model = MixVisionTransformer(
embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3))
model.init_weights()
outs = model(temp)
assert outs[0].shape == (1, 32, H // 4, W // 4)
assert outs[1].shape == (1, 64, H // 8, W // 8)
assert outs[2].shape == (1, 160, H // 16, W // 16)
assert outs[3].shape == (1, 256, H // 32, W // 32)
# Test non-squared input
H, W = (224, 256)
temp = torch.randn((1, 3, H, W))
outs = model(temp)
assert outs[0].shape == (1, 32, H // 4, W // 4)
assert outs[1].shape == (1, 64, H // 8, W // 8)
assert outs[2].shape == (1, 160, H // 16, W // 16)
assert outs[3].shape == (1, 256, H // 32, W // 32)
# Test MixFFN
FFN = MixFFN(64, 128)
hw_shape = (32, 32)
token_len = 32 * 32
temp = torch.randn((1, token_len, 64))
# Self identity
out = FFN(temp, hw_shape)
assert out.shape == (1, token_len, 64)
# Out identity
outs = FFN(temp, hw_shape, temp)
assert out.shape == (1, token_len, 64)
# Test EfficientMHA
MHA = EfficientMultiheadAttention(64, 2)
hw_shape = (32, 32)
token_len = 32 * 32
temp = torch.randn((1, token_len, 64))
# Self identity
out = MHA(temp, hw_shape)
assert out.shape == (1, token_len, 64)
# Out identity
outs = MHA(temp, hw_shape, temp)
assert out.shape == (1, token_len, 64)
def test_mit_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = MixVisionTransformer(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()
# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = MixVisionTransformer(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained=None
# init_cfg=123, whose type is unsupported
model = MixVisionTransformer(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = MixVisionTransformer(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
MixVisionTransformer(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
MixVisionTransformer(pretrained=path, init_cfg=123)
# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
MixVisionTransformer(pretrained=123, init_cfg=None)
# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
MixVisionTransformer(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
MixVisionTransformer(pretrained=123, init_cfg=123)