Switch to unified view

a b/tools/ensemble_inferencer.py
1
import argparse
2
import os
3
import shutil
4
import warnings
5
import numpy as np
6
from PIL import Image
7
8
import mmcv
9
import torch
10
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
11
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
12
                         wrap_fp16_model)
13
from mmcv.utils import DictAction
14
15
from mmseg.apis import multi_gpu_test, single_gpu_test
16
from mmseg.datasets import build_dataloader, build_dataset
17
from mmseg.models import build_segmentor
18
19
configs = [
20
    "../work_dirs/remote/swb384_22k_1x_16bs_all/remote_swb.py",
21
    "../work_dirs/remote/dl3pr101_1x_16bs_all/remote_dl3pr101.py"
22
]
23
ckpts = [
24
    "../work_dirs/remote/swb384_22k_1x_16bs_all/latest.pth",
25
    "../work_dirs/remote/dl3pr101_1x_16bs_all/latest.pth"
26
]
27
28
cfg = mmcv.Config.fromfile(configs[0])
29
torch.backends.cudnn.benchmark = True
30
cfg.model.pretrained = None
31
cfg.data.test.test_mode = True
32
dataset = build_dataset(cfg.data.test)
33
data_loader = build_dataloader(
34
    dataset,
35
    samples_per_gpu=64,
36
    workers_per_gpu=4, #cfg.data.workers_per_gpu,
37
    dist=False,
38
    shuffle=False)
39
40
models = []
41
for config, ckpt in zip(configs, ckpts):
42
    cfg = mmcv.Config.fromfile(config)
43
    torch.backends.cudnn.benchmark = True
44
    cfg.model.pretrained = None
45
    cfg.data.test.test_mode = True
46
    cfg.model.train_cfg = None
47
    model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
48
    fp16_cfg = cfg.get('fp16', None)
49
    wrap_fp16_model(model)
50
    checkpoint = load_checkpoint(model, ckpt, map_location='cpu')
51
    model.CLASSES = dataset.CLASSES
52
    model.PALETTE = dataset.PALETTE
53
    torch.cuda.empty_cache()
54
    eval_kwargs = {"imgfile_prefix": "../work_imgs"}
55
    tmpdir = eval_kwargs['imgfile_prefix']
56
    mmcv.mkdir_or_exist(tmpdir)
57
    model = MMDataParallel(model, device_ids=[0])
58
    model.eval()
59
    models.append(model)
60
61
results = []
62
dataset = data_loader.dataset
63
prog_bar = mmcv.ProgressBar(len(dataset))
64
loader_indices = data_loader.batch_sampler
65
66
for batch_indices, data in zip(loader_indices, data_loader):
67
    result = []
68
    for model in models:
69
        with torch.no_grad():
70
            result.append(model(return_loss=False, **data))
71
    result = [np.stack(_, 0).sum(0).argmax(0) for _ in zip(*result)]
72
    for res, batch_index in zip(result, batch_indices):
73
        img_info = dataset.img_infos[batch_index]
74
        file_name = os.path.join(tmpdir, img_info['ann']['seg_map'])
75
        Image.fromarray(res.astype(np.uint8)).save(file_name)
76
        prog_bar.update()