a b/config.py
1
""" Config class for training """
2
import argparse
3
import os
4
from functools import partial
5
import torch
6
7
8
def get_parser(name):
9
    """ make default formatted parser """
10
    parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
11
    # print default value always
12
    parser.add_argument = partial(parser.add_argument, help=' ')
13
    return parser
14
15
16
def parse_gpus(gpus):
17
    if gpus == 'all':
18
        return list(range(torch.cuda.device_count()))
19
    else:
20
        return [int(s) for s in gpus.split(',')]
21
22
23
class BaseConfig(argparse.Namespace):
24
    def print_params(self, prtf=print):
25
        prtf("")
26
        prtf("Parameters:")
27
        for attr, value in sorted(vars(self).items()):
28
            prtf("{}={}".format(attr.upper(), value))
29
        prtf("")
30
31
    def as_markdown(self):
32
        """ Return configs as markdown format """
33
        text = "|name|value|  \n|-|-|  \n"
34
        for attr, value in sorted(vars(self).items()):
35
            text += "|{}|{}|  \n".format(attr, value)
36
37
        return text
38
39
40
class TrainConfig(BaseConfig):
41
    def build_parser(self):
42
        parser = get_parser("Train config")
43
        parser.add_argument('--name', default="/home/ubuntu/zhaoqianfei/UNet_Mini/LACB_Net_A/log")
44
        parser.add_argument('--batch_size', type=int, default=1, help='batch size')
45
        parser.add_argument('--input_channels', type=int, default=1, help='input channels')
46
        parser.add_argument('--n_classes', type=int, default=1, help='number classes')
47
        parser.add_argument('--lr', type=float, default=0.0025, help='lr for weights')
48
        parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
49
        parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
50
        parser.add_argument('--grad_clip', type=float, default=5.,
51
                            help='gradient clipping for weights')
52
        parser.add_argument('--print_freq', type=int, default=20, help='print frequency')
53
        parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
54
                                                        '`all` indicates use all gpus.')
55
        parser.add_argument('--epochs', type=int, default=400, help='# of training epochs')
56
        parser.add_argument('--init_channels', type=int, default=12)
57
        parser.add_argument('--seed', type=int, default=2, help='random seed')
58
        parser.add_argument('--workers', type=int, default=4, help='# of workers')
59
        parser.add_argument('--training_summary_dir', default=".../model/unet")
60
        parser.add_argument('--training_checkpoint_prefix', default=".../model/unet")
61
        parser.add_argument('--testing_checkpoint_name', default=".../model/unet_400.pt")
62
        parser.add_argument('--testing_output_dir', default=".../result")
63
        parser.add_argument('--root_dir', default=".../data_train_test/lalel")
64
        parser.add_argument('--validing_checkpoint_prefix', default=".../model")
65
66
        return parser
67
68
    def __init__(self):
69
        parser = self.build_parser()
70
        args = parser.parse_args()
71
        super().__init__(**vars(args))
72
73
        self.path = os.path.join('train', self.name)
74
        self.gpus = parse_gpus(self.gpus)