[974c13]: / FastRCNN / train.py

Download this file

157 lines (136 with data), 6.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import numpy as np
import chainer
from chainer.datasets import TransformDataset
from chainer import training
from chainer.training import extensions
from chainer.training.triggers import ManualScheduleTrigger
from chainercv.extensions import DetectionVOCEvaluator
from chainercv.links import FasterRCNNVGG16
from chainercv.links.model.faster_rcnn import FasterRCNNTrainChain
from chainercv import transforms
from utils import DetectionDataset
from utils import rotate_bbox , random_resize , random_distort , random_crop_with_bbox_constraints
import matplotlib.pyplot as plt
plt.switch_backend ( 'agg' )
class Transform ( object ) :
# initial faster_rcnn
def __init__ ( self , faster_rcnn ) :
self.faster_rcnn = faster_rcnn
# Initial datasets, H, W stores the hight and width of the image
def __call__ ( self , in_data ) :
img , bbox , label = in_data
_ , H , W = img.shape
# random brightness and contrast
img = random_distort ( img )
# rotate image
# return a tuple whose elements are rotated image, param.
# k (int in param)represents the number of times the image is rotated by 90 degrees.
img , params = transforms.random_rotate ( img , return_param = True )
# restore the new hight and width
_ , t_H , t_W = img.shape
# rotate bbox based on renewed parameters
bbox = rotate_bbox ( bbox , (H , W) , params [ 'k' ] )
# # Random expansion:This method randomly place the input image on
# # a larger canvas. The size of the canvas is (rH,rW), r is a random ratio drawn from [1,max_ratio].
# # The canvas is filled by a value fill except for the region where the original image is placed.
if np.random.randint ( 2 ) :
fill_value = img.mean ( axis = 1 ).mean ( axis = 1 ).reshape ( -1 , 1 , 1 )
img , param = transforms.random_expand ( img , max_ratio = 2 , fill = fill_value , return_param = True )
bbox = transforms.translate_bbox ( bbox , y_offset = param [ 'y_offset' ] ,
x_offset = param [ 'x_offset' ] )
# # Random crop
# # crops the image with bounding box constraints
img , param = random_crop_with_bbox_constraints ( img , bbox , min_scale = 0.75 , max_aspect_ratio = 1.25 ,
return_param = True )
# # this translates bounding boxes to fit within the cropped area of an image, bounding boxes whose
# centers are outside of the cropped area are removed.
bbox , param = transforms.crop_bbox ( bbox , y_slice = param [ 'y_slice' ] , x_slice = param [ 'x_slice' ] ,
allow_outside_center = False , return_param = True )
# #assigning new labels to the bounding boxes after cropping
label = label [ param [ 'index' ] ]
# # if the bounding boxes are all removed,
if bbox.shape [ 0 ] == 0 :
img , bbox , label = in_data
# # update the height and width of the image
_ , t_H , t_W = img.shape
img = self.faster_rcnn.prepare ( img )
# prepares the image to match the size of the image to be input into the RCNN
_ , o_H , o_W = img.shape
# resize the bounding box according to the image resize
bbox = transforms.resize_bbox ( bbox , (t_H , t_W) , (o_H , o_W) )
# horizontally & vertical flip
# simutaneously flip horizontally and vertically of the image
img , params = transforms.random_flip (
img , x_random = True , y_random = True , return_param = True )
# flip the bounding box with respect to the parameter
bbox = transforms.flip_bbox (
bbox , (o_H , o_W) , x_flip = params [ 'x_flip' ] , y_flip = params [ 'y_flip' ] )
scale = o_H / t_H
return img , bbox , label , scale
def main ( ) :
bbox_label_names = ('bsite')
n_itrs = 100000
n_step = 50000
np.random.seed ( 0 )
train_data = DetectionDataset ( split = 'train' )
test_data = DetectionDataset ( split = 'test' )
proposal_params = { 'min_size' : 8 }
faster_rcnn = FasterRCNNVGG16 ( n_fg_class = 1 , pretrained_model = 'imagenet' ,
ratios = [ 0.5 , 1 , 1.5 , 2 , 2.5 , 3 , 3.5 , 4 ] ,
anchor_scales = [ 1 , 4 , 8 , 16 ] , min_size = 1024 , max_size = 1024 ,
proposal_creator_params = proposal_params )
faster_rcnn.use_preset ( 'evaluate' )
model = FasterRCNNTrainChain ( faster_rcnn )
chainer.cuda.get_device_from_id ( 0 ).use ( )
model.to_gpu ( )
optimizer = chainer.optimizers.MomentumSGD ( lr = 5e-4 , momentum = 0.9 ) # reduce lr from 1e-3
optimizer.setup ( model )
optimizer.add_hook ( chainer.optimizer.WeightDecay ( rate = 0.0005 ) )
train_data = TransformDataset ( train_data , Transform ( faster_rcnn ) )
train_iter = chainer.iterators.MultiprocessIterator (
train_data , batch_size = 1 , n_processes = None , shared_mem = 100000000 )
test_iter = chainer.iterators.SerialIterator (
test_data , batch_size = 1 , repeat = False , shuffle = False )
updater = chainer.training.updater.StandardUpdater (
train_iter , optimizer , device = 0 )
trainer = training.Trainer (
updater , (n_itrs , 'iteration') , out = 'result' )
trainer.extend (
extensions.snapshot_object ( model.faster_rcnn , 'snapshot_model_{.updater.iteration}.npz' ) ,
trigger = (n_itrs / 5 , 'iteration') )
trainer.extend ( extensions.ExponentialShift ( 'lr' , 0.1 ) ,
trigger = (n_step , 'iteration') )
log_interval = 50 , 'iteration'
plot_interval = 100 , 'iteration'
print_interval = 20 , 'iteration'
trainer.extend ( chainer.training.extensions.observe_lr ( ) ,
trigger = log_interval )
trainer.extend ( extensions.LogReport ( trigger = log_interval ) )
trainer.extend ( extensions.PrintReport (
[ 'iteration' , 'epoch' , 'elapsed_time' , 'lr' ,
'main/loss' ,
'main/roi_loc_loss' ,
'main/roi_cls_loss' ,
'main/rpn_loc_loss' ,
'main/rpn_cls_loss' ,
'validation/main/map' ,
] ) , trigger = print_interval )
trainer.extend ( extensions.ProgressBar ( update_interval = 5 ) )
if extensions.PlotReport.available ( ) :
trainer.extend (
extensions.PlotReport (
[ 'main/loss' ] ,
file_name = 'loss.png' , trigger = plot_interval
) ,
trigger = plot_interval
)
trainer.extend (
DetectionVOCEvaluator (
test_iter , model.faster_rcnn , use_07_metric = True ,
label_names = bbox_label_names ) ,
trigger = ManualScheduleTrigger (
[ 100 , 500 , 1000 , 5000 , 10000 , 20000 , 40000 , 60000 , n_step , n_itrs ] , 'iteration' ) )
trainer.extend ( extensions.dump_graph ( 'main/loss' ) )
trainer.run ( )
if __name__ == '__main__' :
main ( )