--- a +++ b/data.py @@ -0,0 +1,180 @@ +import os +import math +import random +import torch +import torch.nn as nn +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import torchvision.transforms.functional as TF +from matplotlib.figure import Figure +from pathlib import Path +from torch.utils.data import Dataset, DataLoader +from utils import get_output_shape + +class GIImage(object): + organs = ["stomach", "small_bowel", "large_bowel"] + + def __init__(self, fpath: Path, label_df: pd.DataFrame = None): + case_day = str(fpath).split('/')[-3] + fname = fpath.name # file name + metadata = fname.rstrip(".png").split('_') + slice_no = metadata[1] + numbers = metadata[2:] + # metadata: image id + self.id = f"{case_day}_slice_{slice_no}" + # metadata: slice width/height and pixel width/height + self.sw = int(numbers[0]) + self.sh = int(numbers[1]) + self.pw = float(numbers[2]) + self.ph = float(numbers[3]) + # data: 2D array + self.data = plt.imread(fpath) + self.labels = None + if label_df is not None: + self.labels = self.get_labels(label_df) + + @property + def tensor(self) -> torch.Tensor: + return torch.from_numpy(self.data) + + @property + def label_tensors(self) -> dict: + if self.labels: + return {organ: torch.from_numpy(self.labels[organ]) for organ in self.organs} + else: + return None + + def get_labels(self, label_df: pd.DataFrame) -> dict: + labels = label_df.loc[label_df.id == self.id] + organ2label = dict() + for _, row in labels.iterrows(): + organ2label[row["class"]] = self.seg_to_label(row["segmentation"]) + return organ2label + + # converting run-length encoding to pixel-wise labels + def seg_to_label(self, seg: str): + label = np.zeros(shape=(self.sh, self.sw)) + if type(seg) == str: + numbers = seg.split(' ') + assert len(numbers) % 2 == 0 + for i in range(0, len(numbers), 2): + start_id = int(numbers[i]) + length = int(numbers[i + 1]) + for j in range(length): + pixel = start_id + j + px = math.ceil(pixel / self.sw) + py = ((pixel - 1) % self.sw) + 1 + label[px, py] = 1 + return label + + def label_to_seg(self, label): + raise NotImplementedError + + def print_image_info(self) -> None: + print(f"Image ID: {self.id}; slice width/height = ({self.sw}, {self.sh}); data shape = {self.data.shape}") + + def show_segmented_images(self, segmentations: dict[str, np.ndarray] = None) -> Figure: + if segmentations is None: + segmentations = self.labels + + fig, axs = plt.subplots(ncols=len(self.organs), squeeze=False, figsize=(15, 5)) + for i, organ in enumerate(self.organs): + axs[0, i].imshow(self.data, cmap="gray") + if self.labels: + axs[0, i].imshow(segmentations[organ], cmap="gray", alpha=0.4) + axs[0, i].set_title(organ) + return fig + +class GIImageDataset(Dataset): + + def __init__( + self, + image_path: Path, + label_path: Path = None, + cases: set[str] = None + ): + if cases: + self._image_paths = self.get_image_files_by_cases(image_path, cases) + else: + self._image_paths = [fpath for fpath in self.image_files_walker(image_path)] + self._label_df = None + if label_path: + self._label_df = pd.read_csv(label_path) + + def __len__(self): + return len(self._image_paths) + + def __getitem__(self, idx: int): + return GIImage(fpath=self._image_paths[idx], label_df=self._label_df) + + @staticmethod + def get_image_files_by_cases(image_path: str, cases: set[str]) -> list: + image_path = Path(image_path) + image_paths = [] + for case in os.listdir(image_path): + if case in cases: + case_path = image_path / case + for day in os.listdir(case_path): + day_path = case_path / day / "scans" + for file in os.listdir(day_path): + fpath = day_path / file + image_paths.append(fpath) + return image_paths + + @staticmethod + def image_files_walker(image_path: str): + for dirname, _, filenames in os.walk(image_path): + dirpath = Path(dirname) + for filename in filenames: + yield dirpath / filename + +def train_valid_split_cases(image_path: str, valid_size: float = 0.2) -> set: + cases = os.listdir(image_path) + valid_cases = set(random.sample(cases, math.ceil(len(cases) * valid_size))) + train_cases = set(cases) - valid_cases + assert ((valid_cases & train_cases) == set()) and ((valid_cases | train_cases) == set(cases)) + return train_cases, valid_cases + +class GIImageDataLoader(object): + + def __init__( + self, + model: nn.Module, + dataset: GIImageDataset, + batch_size: int, + shuffle: bool = True, + input_resolution: int = 572, + padding_mode: str = "reflect" + ): + self._model = model + self._dataset = dataset + self._batch_size = batch_size + self._shuffle = shuffle + self._input_resolution = input_resolution + self._output_resolution = get_output_shape(model, input_shape=(1, 1, input_resolution, input_resolution))[-1] + self._padding_mode = padding_mode + + def collate_fn(self, images: list[GIImage]): + inputs = list() + labels = list() + for image in images: + padding = ((self._input_resolution - image.sw) // 2, (self._input_resolution - image.sh) // 2) + inputs.append(TF.pad(image.tensor, padding=padding, padding_mode=self._padding_mode).unsqueeze(0)) + organs = list() + for organ in GIImage.organs: + organ_label = image.label_tensors[organ] + organ_label = TF.pad(organ_label, padding=padding, padding_mode=self._padding_mode) + organ_label = TF.center_crop(organ_label, output_size=self._output_resolution) + organs.append(organ_label) + labels.append(torch.stack(organs)) + return torch.stack(inputs), torch.stack(labels) + + def get_data_loader(self) -> DataLoader: + return DataLoader( + dataset=self._dataset, + batch_size=self._batch_size, + shuffle=self._shuffle, + collate_fn=self.collate_fn, + pin_memory=True + )