|
a |
|
b/Projects/NCS1/Classes/inception_utils.py |
|
|
1 |
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
|
|
2 |
# |
|
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
4 |
# you may not use this file except in compliance with the License. |
|
|
5 |
# You may obtain a copy of the License at |
|
|
6 |
# |
|
|
7 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
8 |
# |
|
|
9 |
# Unless required by applicable law or agreed to in writing, software |
|
|
10 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
12 |
# See the License for the specific language governing permissions and |
|
|
13 |
# limitations under the License. |
|
|
14 |
# ============================================================================== |
|
|
15 |
"""Contains common code shared by all inception models. |
|
|
16 |
|
|
|
17 |
Usage of arg scope: |
|
|
18 |
with slim.arg_scope(inception_arg_scope()): |
|
|
19 |
logits, end_points = inception.inception_v3(images, num_classes, |
|
|
20 |
is_training=is_training) |
|
|
21 |
|
|
|
22 |
""" |
|
|
23 |
from __future__ import absolute_import |
|
|
24 |
from __future__ import division |
|
|
25 |
from __future__ import print_function |
|
|
26 |
|
|
|
27 |
import tensorflow as tf |
|
|
28 |
|
|
|
29 |
slim = tf.contrib.slim |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def inception_arg_scope(weight_decay=0.00004, |
|
|
33 |
use_batch_norm=True, |
|
|
34 |
batch_norm_decay=0.9997, |
|
|
35 |
batch_norm_epsilon=0.001, |
|
|
36 |
activation_fn=tf.nn.relu): |
|
|
37 |
"""Defines the default arg scope for inception models. |
|
|
38 |
|
|
|
39 |
Args: |
|
|
40 |
weight_decay: The weight decay to use for regularizing the model. |
|
|
41 |
use_batch_norm: "If `True`, batch_norm is applied after each convolution. |
|
|
42 |
batch_norm_decay: Decay for batch norm moving average. |
|
|
43 |
batch_norm_epsilon: Small float added to variance to avoid dividing by zero |
|
|
44 |
in batch norm. |
|
|
45 |
activation_fn: Activation function for conv2d. |
|
|
46 |
|
|
|
47 |
Returns: |
|
|
48 |
An `arg_scope` to use for the inception models. |
|
|
49 |
""" |
|
|
50 |
batch_norm_params = { |
|
|
51 |
# Decay for the moving averages. |
|
|
52 |
'decay': batch_norm_decay, |
|
|
53 |
# epsilon to prevent 0s in variance. |
|
|
54 |
'epsilon': batch_norm_epsilon, |
|
|
55 |
# collection containing update_ops. |
|
|
56 |
'updates_collections': tf.GraphKeys.UPDATE_OPS, |
|
|
57 |
# use fused batch norm if possible. |
|
|
58 |
'fused': None, |
|
|
59 |
} |
|
|
60 |
if use_batch_norm: |
|
|
61 |
normalizer_fn = slim.batch_norm |
|
|
62 |
normalizer_params = batch_norm_params |
|
|
63 |
else: |
|
|
64 |
normalizer_fn = None |
|
|
65 |
normalizer_params = {} |
|
|
66 |
# Set weight_decay for weights in Conv and FC layers. |
|
|
67 |
with slim.arg_scope([slim.conv2d, slim.fully_connected], |
|
|
68 |
weights_regularizer=slim.l2_regularizer(weight_decay)): |
|
|
69 |
with slim.arg_scope( |
|
|
70 |
[slim.conv2d], |
|
|
71 |
weights_initializer=slim.variance_scaling_initializer(), |
|
|
72 |
activation_fn=activation_fn, |
|
|
73 |
normalizer_fn=normalizer_fn, |
|
|
74 |
normalizer_params=normalizer_params) as sc: |
|
|
75 |
return sc |