Diff of /utils.py [000000] .. [70e190]

Switch to unified view

a b/utils.py
1
import argparse
2
import numpy as np
3
import nibabel as nib
4
from skimage.transform import resize as skires
5
import csv
6
import yaml
7
import numpy as np
8
9
def str2bool(v):
10
    if v.lower() in ['true', 1]:
11
        return True
12
    elif v.lower() in ['false', 0]:
13
        return False
14
    else:
15
        raise argparse.ArgumentTypeError('Boolean value expected.')
16
    
17
class DotDict:
18
    def __init__(self, dictionary):
19
        self._dict = dictionary
20
21
    def __getattr__(self, attr):
22
        value = self._dict[attr]
23
        if isinstance(value, dict):
24
            return DotDict(value)
25
        return value
26
    
27
def parse_args():
28
    parser = argparse.ArgumentParser()
29
    parser.add_argument('--network', default=None, help='architecture name')
30
    parser.add_argument('--name', default=None, help='model name')
31
    parser = parser.parse_args()
32
    with open(f'outputs/{parser.network}/{parser.name}/config.yml', 'r') as f:
33
        config = yaml.load(f, Loader=yaml.FullLoader)
34
    config = DotDict(config)
35
    return config
36
        
37
def save_vol(vol, type, path):
38
    vol = np.transpose(vol, (2, 1, 0))
39
    affine = np.eye(4)
40
    nifti_file = nib.Nifti1Image(vol.astype(np.int8), affine) if type=='labels' else nib.Nifti1Image(vol, affine)
41
    nib.save(nifti_file, path)    
42
    
43
def resize_vol(vol, new_size):
44
    return skires(vol, new_size, order=1, preserve_range=True, anti_aliasing=False)
45
46
47
def write_csv(path, data):
48
    with open(path, mode='a', newline='') as file:
49
        iteration = csv.writer(file)
50
        iteration.writerow(data)
51
    file.close()
52
    
53