|
a |
|
b/utils.py |
|
|
1 |
import yaml |
|
|
2 |
from argparse import Namespace |
|
|
3 |
|
|
|
4 |
import random |
|
|
5 |
import numpy as np |
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
import torchvision.transforms.functional as TF |
|
|
9 |
from matplotlib.figure import Figure |
|
|
10 |
|
|
|
11 |
def load_yaml(file: str) -> Namespace: |
|
|
12 |
with open(file) as f: |
|
|
13 |
config = yaml.safe_load(f) |
|
|
14 |
return Namespace(**config) |
|
|
15 |
|
|
|
16 |
def set_seed(seed: int): |
|
|
17 |
random.seed(seed) |
|
|
18 |
np.random.seed(seed) |
|
|
19 |
torch.manual_seed(seed) |
|
|
20 |
if torch.cuda.is_available(): |
|
|
21 |
torch.cuda.manual_seed(seed) |
|
|
22 |
torch.cuda.manual_seed_all(seed) |
|
|
23 |
|
|
|
24 |
def get_output_shape(model: nn.Module, input_shape: tuple): |
|
|
25 |
t = torch.rand(input_shape) |
|
|
26 |
return model(t).shape |
|
|
27 |
|
|
|
28 |
def filename_to_width_height(filename: str) -> tuple: |
|
|
29 |
splitted = filename.split('_') |
|
|
30 |
return int(splitted[2]), int(splitted[3]) |
|
|
31 |
|
|
|
32 |
def show_valid_image_during_training(model: nn.Module, image, input_resolution: int, padding_mode: str, device: str) -> Figure: |
|
|
33 |
padding = ((input_resolution - image.sw) // 2, (input_resolution - image.sh) // 2) |
|
|
34 |
image_tensor = TF.pad(image.tensor, padding=padding, padding_mode=padding_mode).unsqueeze(0).unsqueeze(0).to(device) |
|
|
35 |
model = model.to(device) |
|
|
36 |
preds = model(image_tensor) |
|
|
37 |
segmentations = dict() |
|
|
38 |
for i, organ in enumerate(image.organs): |
|
|
39 |
pred = preds[:, i, :, :] |
|
|
40 |
segmentation = TF.center_crop(pred, output_size=(image.sh, image.sw)) |
|
|
41 |
segmentations[organ] = segmentation.squeeze(0).detach().cpu().numpy() |
|
|
42 |
fig = image.show_segmented_images(segmentations=segmentations) |
|
|
43 |
return fig |