|
a |
|
b/config.py |
|
|
1 |
""" Config class for training """ |
|
|
2 |
import argparse |
|
|
3 |
import os |
|
|
4 |
from functools import partial |
|
|
5 |
import torch |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def get_parser(name): |
|
|
9 |
""" make default formatted parser """ |
|
|
10 |
parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
11 |
# print default value always |
|
|
12 |
parser.add_argument = partial(parser.add_argument, help=' ') |
|
|
13 |
return parser |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def parse_gpus(gpus): |
|
|
17 |
if gpus == 'all': |
|
|
18 |
return list(range(torch.cuda.device_count())) |
|
|
19 |
else: |
|
|
20 |
return [int(s) for s in gpus.split(',')] |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
class BaseConfig(argparse.Namespace): |
|
|
24 |
def print_params(self, prtf=print): |
|
|
25 |
prtf("") |
|
|
26 |
prtf("Parameters:") |
|
|
27 |
for attr, value in sorted(vars(self).items()): |
|
|
28 |
prtf("{}={}".format(attr.upper(), value)) |
|
|
29 |
prtf("") |
|
|
30 |
|
|
|
31 |
def as_markdown(self): |
|
|
32 |
""" Return configs as markdown format """ |
|
|
33 |
text = "|name|value| \n|-|-| \n" |
|
|
34 |
for attr, value in sorted(vars(self).items()): |
|
|
35 |
text += "|{}|{}| \n".format(attr, value) |
|
|
36 |
|
|
|
37 |
return text |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
class TrainConfig(BaseConfig): |
|
|
41 |
def build_parser(self): |
|
|
42 |
parser = get_parser("Train config") |
|
|
43 |
parser.add_argument('--name', default="/home/ubuntu/zhaoqianfei/UNet_Mini/LACB_Net_A/log") |
|
|
44 |
parser.add_argument('--batch_size', type=int, default=1, help='batch size') |
|
|
45 |
parser.add_argument('--input_channels', type=int, default=1, help='input channels') |
|
|
46 |
parser.add_argument('--n_classes', type=int, default=1, help='number classes') |
|
|
47 |
parser.add_argument('--lr', type=float, default=0.0025, help='lr for weights') |
|
|
48 |
parser.add_argument('--momentum', type=float, default=0.9, help='momentum') |
|
|
49 |
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') |
|
|
50 |
parser.add_argument('--grad_clip', type=float, default=5., |
|
|
51 |
help='gradient clipping for weights') |
|
|
52 |
parser.add_argument('--print_freq', type=int, default=20, help='print frequency') |
|
|
53 |
parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. ' |
|
|
54 |
'`all` indicates use all gpus.') |
|
|
55 |
parser.add_argument('--epochs', type=int, default=400, help='# of training epochs') |
|
|
56 |
parser.add_argument('--init_channels', type=int, default=12) |
|
|
57 |
parser.add_argument('--seed', type=int, default=2, help='random seed') |
|
|
58 |
parser.add_argument('--workers', type=int, default=4, help='# of workers') |
|
|
59 |
parser.add_argument('--training_summary_dir', default=".../model/unet") |
|
|
60 |
parser.add_argument('--training_checkpoint_prefix', default=".../model/unet") |
|
|
61 |
parser.add_argument('--testing_checkpoint_name', default=".../model/unet_400.pt") |
|
|
62 |
parser.add_argument('--testing_output_dir', default=".../result") |
|
|
63 |
parser.add_argument('--root_dir', default=".../data_train_test/lalel") |
|
|
64 |
parser.add_argument('--validing_checkpoint_prefix', default=".../model") |
|
|
65 |
|
|
|
66 |
return parser |
|
|
67 |
|
|
|
68 |
def __init__(self): |
|
|
69 |
parser = self.build_parser() |
|
|
70 |
args = parser.parse_args() |
|
|
71 |
super().__init__(**vars(args)) |
|
|
72 |
|
|
|
73 |
self.path = os.path.join('train', self.name) |
|
|
74 |
self.gpus = parse_gpus(self.gpus) |