--- a +++ b/utils.py @@ -0,0 +1,43 @@ +import yaml +from argparse import Namespace + +import random +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms.functional as TF +from matplotlib.figure import Figure + +def load_yaml(file: str) -> Namespace: + with open(file) as f: + config = yaml.safe_load(f) + return Namespace(**config) + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def get_output_shape(model: nn.Module, input_shape: tuple): + t = torch.rand(input_shape) + return model(t).shape + +def filename_to_width_height(filename: str) -> tuple: + splitted = filename.split('_') + return int(splitted[2]), int(splitted[3]) + +def show_valid_image_during_training(model: nn.Module, image, input_resolution: int, padding_mode: str, device: str) -> Figure: + padding = ((input_resolution - image.sw) // 2, (input_resolution - image.sh) // 2) + image_tensor = TF.pad(image.tensor, padding=padding, padding_mode=padding_mode).unsqueeze(0).unsqueeze(0).to(device) + model = model.to(device) + preds = model(image_tensor) + segmentations = dict() + for i, organ in enumerate(image.organs): + pred = preds[:, i, :, :] + segmentation = TF.center_crop(pred, output_size=(image.sh, image.sw)) + segmentations[organ] = segmentation.squeeze(0).detach().cpu().numpy() + fig = image.show_segmented_images(segmentations=segmentations) + return fig