a b/select_model.py
1
from Segmentation.model.unet import UNet, R2_UNet, Nested_UNet, Nested_UNet_v2
2
from Segmentation.model.segnet import SegNet
3
from Segmentation.model.deeplabv3 import Deeplabv3, Deeplabv3_plus
4
from Segmentation.model.vnet import VNet
5
from Segmentation.model.Hundred_Layer_Tiramisu import Hundred_Layer_Tiramisu
6
7
from absl import logging
8
9
def select_model(FLAGS, num_classes):
10
11
    if FLAGS.model_architecture == 'unet':
12
        model_args = [FLAGS.num_filters,
13
                      num_classes,
14
                      FLAGS.use_2d,
15
                      FLAGS.backbone_architecture,
16
                      FLAGS.num_conv,
17
                      FLAGS.kernel_size,
18
                      FLAGS.activation,
19
                      FLAGS.use_attention,
20
                      FLAGS.use_batchnorm,
21
                      FLAGS.use_bias,
22
                      FLAGS.use_dropout,
23
                      FLAGS.dropout_rate,
24
                      FLAGS.use_spatial,
25
                      FLAGS.channel_order]
26
27
        model_fn = UNet
28
    elif FLAGS.model_architecture == 'vnet':
29
        model_args = [FLAGS.num_filters,
30
                      num_classes,
31
                      FLAGS.use_2d,
32
                      FLAGS.num_conv,
33
                      FLAGS.kernel_size,
34
                      FLAGS.activation,
35
                      FLAGS.use_batchnorm,
36
                      FLAGS.dropout_rate,
37
                      FLAGS.use_spatial,
38
                      FLAGS.channel_order]
39
        model_fn = VNet
40
    elif FLAGS.model_architecture == 'r2unet':
41
        model_args = [FLAGS.num_filters,
42
                      num_classes,
43
                      FLAGS.use_2d,
44
                      FLAGS.num_conv,
45
                      FLAGS.kernel_size,
46
                      FLAGS.activation,
47
                      2,
48
                      FLAGS.use_attention,
49
                      FLAGS.use_batchnorm,
50
                      FLAGS.use_bias,
51
                      FLAGS.channel_order]
52
        model_fn = R2_UNet
53
54
    elif FLAGS.model_architecture == 'segnet':
55
        model_args = [FLAGS.num_filters,
56
                      num_classes,
57
                      FLAGS.backbone_architecture,
58
                      FLAGS.kernel_size,
59
                      (2, 2),
60
                      FLAGS.activation,
61
                      FLAGS.use_batchnorm,
62
                      FLAGS.use_bias,
63
                      FLAGS.use_transpose,
64
                      FLAGS.use_dropout,
65
                      FLAGS.dropout_rate,
66
                      FLAGS.use_spatial,
67
                      FLAGS.channel_order]
68
69
        model_fn = SegNet
70
71
    elif FLAGS.model_architecture == 'unet++':
72
        model_args = [FLAGS.num_filters,
73
                      num_classes,
74
                      FLAGS.num_conv,
75
                      FLAGS.kernel_size,
76
                      FLAGS.activation,
77
                      FLAGS.use_batchnorm,
78
                      FLAGS.use_bias,
79
                      FLAGS.channel_order]
80
        model_fn = Nested_UNet
81
82
    elif FLAGS.model_architecture == '100-Layer-Tiramisu':
83
        model_args = [FLAGS.growth_rate,
84
                      FLAGS.layers_per_block,
85
                      FLAGS.init_num_channels,
86
                      num_classes,
87
                      FLAGS.kernel_size,
88
                      FLAGS.pool_size,
89
                      FLAGS.activation,
90
                      FLAGS.dropout_rate,
91
                      FLAGS.strides,
92
                      FLAGS.padding]
93
94
        model_fn = Hundred_Layer_Tiramisu
95
96
    elif FLAGS.model_architecture == 'deeplabv3':
97
        model_args = [num_classes,
98
                      FLAGS.kernel_size_initial_conv,
99
                      FLAGS.num_filters_atrous,
100
                      FLAGS.num_filters_DCNN,
101
                      FLAGS.num_filters_ASPP,
102
                      FLAGS.kernel_size_atrous,
103
                      FLAGS.kernel_size_DCNN,
104
                      FLAGS.kernel_size_ASPP,
105
                      'same',
106
                      FLAGS.activation,
107
                      FLAGS.use_batchnorm,
108
                      FLAGS.use_bias,
109
                      FLAGS.channel_order,
110
                      FLAGS.MultiGrid,
111
                      FLAGS.rate_ASPP,
112
                      FLAGS.output_stride]
113
114
        model_fn = Deeplabv3
115
116
    elif FLAGS.model_architecture == 'deeplabv3_plus':
117
        model_args = [num_classes,
118
                      FLAGS.kernel_size_initial_conv,
119
                      FLAGS.num_filters_atrous,
120
                      FLAGS.num_filters_DCNN,
121
                      FLAGS.num_filters_ASPP,
122
                      FLAGS.kernel_size_atrous,
123
                      FLAGS.kernel_size_DCNN,
124
                      FLAGS.kernel_size_ASPP,
125
                      FLAGS.num_filters_final_encoder,
126
                      FLAGS.num_filters_from_backbone,
127
                      FLAGS.num_channels_UpConv,
128
                      FLAGS.kernel_size_UpConv,
129
                      (2, 2),
130
                      False,
131
                      FLAGS.use_transpose,
132
                      'same',
133
                      FLAGS.activation,
134
                      FLAGS.use_batchnorm,
135
                      FLAGS.use_bias,
136
                      FLAGS.channel_order,
137
                      FLAGS.MultiGrid,
138
                      FLAGS.rate_ASPP,
139
                      FLAGS.output_stride]
140
141
        model_fn = Deeplabv3_plus
142
143
    else:
144
        logging.error('The model architecture {} is not supported!'.format(FLAGS.model_architecture))