In [None]:
import os
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import glob
from rich import print
from ipywidgets import interact
from tqdm.auto import tqdm
import albumentations as A

In [None]:
df_train = pd.read_csv("./train.csv")
df_train = df_train.sort_values(["id", "class"]).reset_index(drop = True)
df_train["patient"] = df_train.id.apply(lambda x: x.split("_")[0])
df_train["days"] = df_train.id.apply(lambda x: "_".join(x.split("_")[:2]))
num_slices = len(np.unique(df_train.id))
num_empty_slices = df_train.groupby("id").apply(lambda x: x.segmentation.isna().all()).sum()
num_patients = len(np.unique(df_train.patient))
num_days = len(np.unique(df_train.days))
print({
    "#slices:": num_slices,
    "#empty slices:": num_empty_slices,
    "#patients": num_patients,
    "#days": num_days
})

In [None]:
all_image_files = sorted(glob.glob("./train/*/*/scans/*.png"), key = lambda x: x.split("/")[3] + "_" + x.split("/")[5])
size_x = [int(os.path.basename(_)[:-4].split("_")[-4]) for _ in all_image_files]
size_y = [int(os.path.basename(_)[:-4].split("_")[-3]) for _ in all_image_files]
spacing_x = [float(os.path.basename(_)[:-4].split("_")[-2]) for _ in all_image_files]
spacing_y = [float(os.path.basename(_)[:-4].split("_")[-1]) for _ in all_image_files]
df_train["image_files"] = np.repeat(all_image_files, 3)
df_train["spacing_x"] = np.repeat(spacing_x, 3)
df_train["spacing_y"] = np.repeat(spacing_y, 3)
df_train["size_x"] = np.repeat(size_x, 3)
df_train["size_y"] = np.repeat(size_y, 3)
df_train["slice"] = np.repeat([int(os.path.basename(_)[:-4].split("_")[-5]) for _ in all_image_files], 3)
df_train

In [None]:
print(df_train[["size_x", "size_y", "spacing_x", "spacing_y"]].value_counts())

norm = lambda x: ((x - x.min()) / np.ptp(x) * 255).astype(np.uint8)
colors = {"large_bowel": [127, 0, 0], "small_bowel": [0, 127, 0], "stomach": [0, 0, 127]}

def rle_decode(mask_rle, shape):
    s = np.array(mask_rle.split(), dtype=int)
    starts, lengths = s[0::2] - 1, s[1::2]
    ends = starts + lengths
    h, w = shape
    img = np.zeros((h * w,), dtype = np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = 1
    return img.reshape(shape)

def rle_encode(img):
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def display_pre(file_name):
    img = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = norm(img)

    img_ = img.copy()
    segms = df_train.loc[df_train.image_files == file_name]
    for segm, label in zip(segms.segmentation, segms["class"]):
        if not pd.isna(segm):
            mask = rle_decode(segm, img.shape[:2])
            img_[mask == 1] = img_[mask == 1] // 2 + colors[label]
    img = np.concatenate([img, np.ones((img.shape[0], 10, 3), dtype = np.uint8) * 255, img_], 1)
    return img

for info, group in df_train.groupby(["size_x", "size_y", "spacing_x", "spacing_y"]):
    print(info)
    file_name = np.random.choice(group.loc[group.slice == 70, "image_files"])
    img = display_pre(file_name)
    display(Image.fromarray(img))

In [None]:
def show(idx):
    file_name = df_train.loc[idx, "image_files"]
    img = display_pre(file_name)
    display(Image.fromarray(img))

@interact
def f(idx = (0, len(df_train) - 1, 3)):
    show(idx)

In [None]:
for day, group in tqdm(df_train.groupby("days")):
    patient = group.patient.iloc[0]
    imgs = []
    msks = []
    for file_name in tqdm(group.image_files.unique(), leave = False):
        img = cv2.imread(file_name, cv2.IMREAD_ANYDEPTH)
        segms = group.loc[group.image_files == file_name]
        masks = {}
        for segm, label in zip(segms.segmentation, segms["class"]):
            if not pd.isna(segm):
                mask = rle_decode(segm, img.shape[:2])
                masks[label] = mask
            else:
                masks[label] = np.zeros(img.shape[:2], dtype = np.uint8)
        masks = np.stack([masks[k] for k in sorted(masks)], -1)
        imgs.append(img)
        msks.append(masks)
        
    imgs = np.stack(imgs, 0)
    msks = np.stack(msks, 0)
    for i in range(msks.shape[0]):
        img = imgs[i]
        msk = msks[i]
        new_image_name = f"{day}_{i}.png"
        cv2.imwrite(f"./mmseg_train/images/{new_image_name}", img)
        cv2.imwrite(f"./mmseg_train/labels/{new_image_name}", msk)

In [None]:
all_image_files = glob.glob("./mmseg_train/images/*")
patients = [os.path.basename(_).split("_")[0] for _ in all_image_files]


from sklearn.model_selection import GroupKFold

split = list(GroupKFold(5).split(patients, groups = patients))

for fold, (train_idx, valid_idx) in enumerate(split):
    with open(f"./mmseg_train/splits/fold_{fold}.txt", "w") as f:
        for idx in train_idx:
            f.write(os.path.basename(all_image_files[idx])[:-4] + "\n")
    with open(f"./mmseg_train/splits/holdout_{fold}.txt", "w") as f:
        for idx in valid_idx:
            f.write(os.path.basename(all_image_files[idx])[:-4] + "\n")