Download this file

141 lines (109 with data), 4.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch import nn
from mmseg.models import BACKBONES, HEADS
from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
segs = rng.randint(
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
'flip_direction': 'horizontal'
} for _ in range(N)]
mm_inputs = {
'imgs': torch.FloatTensor(imgs),
'img_metas': img_metas,
'gt_semantic_seg': torch.LongTensor(segs)
}
return mm_inputs
@BACKBONES.register_module()
class ExampleBackbone(nn.Module):
def __init__(self):
super(ExampleBackbone, self).__init__()
self.conv = nn.Conv2d(3, 3, 3)
def init_weights(self, pretrained=None):
pass
def forward(self, x):
return [self.conv(x)]
@HEADS.register_module()
class ExampleDecodeHead(BaseDecodeHead):
def __init__(self):
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19)
def forward(self, inputs):
return self.cls_seg(inputs[0])
@HEADS.register_module()
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
def __init__(self):
super(ExampleCascadeDecodeHead, self).__init__(3, 3, num_classes=19)
def forward(self, inputs, prev_out):
return self.cls_seg(inputs[0])
def _segmentor_forward_train_test(segmentor):
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
# batch_size=2 for BatchNorm
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_semantic_seg = mm_inputs['gt_semantic_seg']
# convert to cuda Tensor if applicable
if torch.cuda.is_available():
segmentor = segmentor.cuda()
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
# Test forward train
losses = segmentor.forward(
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
assert isinstance(losses, dict)
# Test train_step
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.train_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test val_step
with torch.no_grad():
segmentor.eval()
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.val_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test forward simple test
with torch.no_grad():
segmentor.eval()
# pack into lists
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
segmentor.forward(img_list, img_meta_list, return_loss=False)
# Test forward aug test
with torch.no_grad():
segmentor.eval()
# pack into lists
img_list = [img[None, :] for img in imgs]
img_list = img_list + img_list
img_meta_list = [[img_meta] for img_meta in img_metas]
img_meta_list = img_meta_list + img_meta_list
segmentor.forward(img_list, img_meta_list, return_loss=False)