Diff of /utils.py [000000] .. [748954]

Switch to unified view

a b/utils.py
1
import yaml
2
from argparse import Namespace
3
4
import random
5
import numpy as np
6
import torch
7
import torch.nn as nn
8
import torchvision.transforms.functional as TF
9
from matplotlib.figure import Figure
10
11
def load_yaml(file: str) -> Namespace:
12
    with open(file) as f:
13
        config = yaml.safe_load(f)
14
    return Namespace(**config)
15
16
def set_seed(seed: int):
17
    random.seed(seed)
18
    np.random.seed(seed)
19
    torch.manual_seed(seed)
20
    if torch.cuda.is_available():
21
        torch.cuda.manual_seed(seed)
22
        torch.cuda.manual_seed_all(seed)
23
24
def get_output_shape(model: nn.Module, input_shape: tuple):
25
    t = torch.rand(input_shape)
26
    return model(t).shape
27
28
def filename_to_width_height(filename: str) -> tuple:
29
    splitted = filename.split('_')
30
    return int(splitted[2]), int(splitted[3])
31
32
def show_valid_image_during_training(model: nn.Module, image, input_resolution: int, padding_mode: str, device: str) -> Figure:
33
    padding = ((input_resolution - image.sw) // 2, (input_resolution - image.sh) // 2)
34
    image_tensor = TF.pad(image.tensor, padding=padding, padding_mode=padding_mode).unsqueeze(0).unsqueeze(0).to(device)
35
    model = model.to(device)
36
    preds = model(image_tensor)
37
    segmentations = dict()
38
    for i, organ in enumerate(image.organs):
39
        pred = preds[:, i, :, :]
40
        segmentation = TF.center_crop(pred, output_size=(image.sh, image.sw))
41
        segmentations[organ] = segmentation.squeeze(0).detach().cpu().numpy()
42
    fig = image.show_segmented_images(segmentations=segmentations)
43
    return fig