128 lines (127 with data), 4.0 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmseg.apis import init_segmentor, inference_segmentor\n",
"from mmcv.utils import config\n",
"\n",
"configs = [\n",
" \"../../work_dirs/tract/baseline/tract_baseline.py\"\n",
"]\n",
"\n",
"ckpts = [\n",
" \"../../work_dirs/tract/baseline/latest.pth\"\n",
"]\n",
"\n",
"models = []\n",
"for cfg, ckpt in zip(cfgs, ckpts):\n",
" cfg = config.Config.fromfile(cfg)\n",
" cfg.model.test_cfg.logits = True\n",
" cfg.data.test.pipeline[1].transforms.insert(2, dict(type=\"Normalize\", mean=[0,0,0], std=[1,1,1], to_rgb=False))\n",
"\n",
" model = init_segmentor(cfg, ckpt)\n",
" models.append(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import numpy as np\n",
"import pandas as pd\n",
"import glob\n",
"from tqdm.auto import tqdm\n",
"\n",
"def rle_encode(img):\n",
" pixels = img.flatten()\n",
" pixels = np.concatenate([[0], pixels, [0]])\n",
" runs = np.where(pixels[1:] != pixels[:-1])[0] + 1\n",
" runs[1::2] -= runs[::2]\n",
" return ' '.join(str(x) for x in runs)\n",
"\n",
"classes = ['large_bowel', 'small_bowel', 'stomach']\n",
"data_dir = \".\"\n",
"test_dir = os.path.join(data_dir, \"test\")\n",
"sub = pd.read_csv(os.path.join(data_dir, \"sample_submission.csv\"))\n",
"test_images = glob.glob(os.path.join(test_dir, \"**\", \"*.png\"), recursive = True)\n",
"\n",
"if len(test_images) == 0:\n",
" test_dir = os.path.join(data_dir, \"train\")\n",
" sub = pd.read_csv(os.path.join(data_dir, \"train.csv\"))[[\"id\", \"class\"]]\n",
" sub[\"predicted\"] = \"\"\n",
" test_images = glob.glob(os.path.join(test_dir, \"**\", \"*.png\"), recursive = True)\n",
" \n",
"id2img = {_.split(\"/\")[3] + \"_\" + \"_\".join(_.split(\"/\")[5].split(\"_\")[:2]): _ for _ in test_images}\n",
"sub[\"file_name\"] = sub.id.map(id2img)\n",
"sub[\"days\"] = sub.id.apply(lambda x: \"_\".join(x.split(\"_\")[:2]))\n",
"fname2index = {f + c: i for f, c, i in zip(sub.file_name, sub[\"class\"], sub.index)}\n",
"sub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for day, group in tqdm(sub.groupby(\"days\")):\n",
" imgs = []\n",
" for file_name in group.file_name.unique():\n",
" img = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH)\n",
" old_size = img.shape[:2]\n",
" \n",
" res = [inference_segmentor(model, new_img)[0] for model in models]\n",
" res = sum(res) / len(res)\n",
"\n",
" res = cv2.resize(res, old_size[::-1], interpolation = cv2.INTER_NEAREST)\n",
"\n",
" for j in range(3):\n",
" rle = rle_encode(res[...,j])\n",
" index = fname2index[file_name + classes[j]]\n",
" sub.loc[index, \"predicted\"] = rle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sub = sub[[\"id\", \"class\", \"predicted\"]]\n",
"sub.to_csv(\"submission.csv\", index = False)\n",
"sub"
]
}
],
"metadata": {
"interpreter": {
"hash": "798318ca82443b72eb6fef8740009541d10bc0f18c7a2ab2923389b003353be2"
},
"kernelspec": {
"display_name": "Python 3.7.10 64-bit ('mmcv': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}