In [None]:
from mmseg.apis import init_segmentor, inference_segmentor
from mmcv.utils import config

configs = [
    "../../work_dirs/tract/baseline/tract_baseline.py"
]

ckpts = [
    "../../work_dirs/tract/baseline/latest.pth"
]

models = []
for cfg, ckpt in zip(cfgs, ckpts):
    cfg = config.Config.fromfile(cfg)
    cfg.model.test_cfg.logits = True
    cfg.data.test.pipeline[1].transforms.insert(2, dict(type="Normalize", mean=[0,0,0], std=[1,1,1], to_rgb=False))

    model = init_segmentor(cfg, ckpt)
    models.append(model)

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import glob
from tqdm.auto import tqdm

def rle_encode(img):
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

classes = ['large_bowel', 'small_bowel', 'stomach']
data_dir = "."
test_dir = os.path.join(data_dir, "test")
sub = pd.read_csv(os.path.join(data_dir, "sample_submission.csv"))
test_images = glob.glob(os.path.join(test_dir, "**", "*.png"), recursive = True)

if len(test_images) == 0:
    test_dir = os.path.join(data_dir, "train")
    sub = pd.read_csv(os.path.join(data_dir, "train.csv"))[["id", "class"]]
    sub["predicted"] = ""
    test_images = glob.glob(os.path.join(test_dir, "**", "*.png"), recursive = True)
    
id2img = {_.split("/")[3] + "_" + "_".join(_.split("/")[5].split("_")[:2]): _ for _ in test_images}
sub["file_name"] = sub.id.map(id2img)
sub["days"] = sub.id.apply(lambda x: "_".join(x.split("_")[:2]))
fname2index = {f + c: i for f, c, i in zip(sub.file_name, sub["class"], sub.index)}
sub

In [None]:
for day, group in tqdm(sub.groupby("days")):
    imgs = []
    for file_name in group.file_name.unique():
        img = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH)
        old_size = img.shape[:2]
        
        res = [inference_segmentor(model, new_img)[0] for model in models]
        res = sum(res) / len(res)

        res = cv2.resize(res, old_size[::-1], interpolation = cv2.INTER_NEAREST)

        for j in range(3):
            rle = rle_encode(res[...,j])
            index = fname2index[file_name + classes[j]]
            sub.loc[index, "predicted"] = rle

In [None]:
sub = sub[["id", "class", "predicted"]]
sub.to_csv("submission.csv", index = False)
sub