[f1e01c]: / data / tract / submission.ipynb

Download this file

128 lines (127 with data), 4.0 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mmseg.apis import init_segmentor, inference_segmentor\n",
    "from mmcv.utils import config\n",
    "\n",
    "configs = [\n",
    "    \"../../work_dirs/tract/baseline/tract_baseline.py\"\n",
    "]\n",
    "\n",
    "ckpts = [\n",
    "    \"../../work_dirs/tract/baseline/latest.pth\"\n",
    "]\n",
    "\n",
    "models = []\n",
    "for cfg, ckpt in zip(cfgs, ckpts):\n",
    "    cfg = config.Config.fromfile(cfg)\n",
    "    cfg.model.test_cfg.logits = True\n",
    "    cfg.data.test.pipeline[1].transforms.insert(2, dict(type=\"Normalize\", mean=[0,0,0], std=[1,1,1], to_rgb=False))\n",
    "\n",
    "    model = init_segmentor(cfg, ckpt)\n",
    "    models.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import glob\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "def rle_encode(img):\n",
    "    pixels = img.flatten()\n",
    "    pixels = np.concatenate([[0], pixels, [0]])\n",
    "    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1\n",
    "    runs[1::2] -= runs[::2]\n",
    "    return ' '.join(str(x) for x in runs)\n",
    "\n",
    "classes = ['large_bowel', 'small_bowel', 'stomach']\n",
    "data_dir = \".\"\n",
    "test_dir = os.path.join(data_dir, \"test\")\n",
    "sub = pd.read_csv(os.path.join(data_dir, \"sample_submission.csv\"))\n",
    "test_images = glob.glob(os.path.join(test_dir, \"**\", \"*.png\"), recursive = True)\n",
    "\n",
    "if len(test_images) == 0:\n",
    "    test_dir = os.path.join(data_dir, \"train\")\n",
    "    sub = pd.read_csv(os.path.join(data_dir, \"train.csv\"))[[\"id\", \"class\"]]\n",
    "    sub[\"predicted\"] = \"\"\n",
    "    test_images = glob.glob(os.path.join(test_dir, \"**\", \"*.png\"), recursive = True)\n",
    "    \n",
    "id2img = {_.split(\"/\")[3] + \"_\" + \"_\".join(_.split(\"/\")[5].split(\"_\")[:2]): _ for _ in test_images}\n",
    "sub[\"file_name\"] = sub.id.map(id2img)\n",
    "sub[\"days\"] = sub.id.apply(lambda x: \"_\".join(x.split(\"_\")[:2]))\n",
    "fname2index = {f + c: i for f, c, i in zip(sub.file_name, sub[\"class\"], sub.index)}\n",
    "sub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for day, group in tqdm(sub.groupby(\"days\")):\n",
    "    imgs = []\n",
    "    for file_name in group.file_name.unique():\n",
    "        img = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH)\n",
    "        old_size = img.shape[:2]\n",
    "        \n",
    "        res = [inference_segmentor(model, new_img)[0] for model in models]\n",
    "        res = sum(res) / len(res)\n",
    "\n",
    "        res = cv2.resize(res, old_size[::-1], interpolation = cv2.INTER_NEAREST)\n",
    "\n",
    "        for j in range(3):\n",
    "            rle = rle_encode(res[...,j])\n",
    "            index = fname2index[file_name + classes[j]]\n",
    "            sub.loc[index, \"predicted\"] = rle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub = sub[[\"id\", \"class\", \"predicted\"]]\n",
    "sub.to_csv(\"submission.csv\", index = False)\n",
    "sub"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "798318ca82443b72eb6fef8740009541d10bc0f18c7a2ab2923389b003353be2"
  },
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('mmcv': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}