@@ -28,29 +28,35 @@ def parse_args():
28
28
help = "Base network name which serves as feature extraction base." )
29
29
parser .add_argument ('--dataset' , type = str , default = 'voc' ,
30
30
help = 'Training dataset. Now support voc.' )
31
+ parser .add_argument ('--short' , type = str , default = '' ,
32
+ help = 'Resize image to the given short side side, default to 600 for voc.' )
33
+ parser .add_argument ('--max-size' , type = str , default = '' ,
34
+ help = 'Max size of either side of image, default to 1000 for voc.' )
31
35
parser .add_argument ('--num-workers' , '-j' , dest = 'num_workers' , type = int ,
32
36
default = 4 , help = 'Number of data workers, you can use larger '
33
37
'number to accelerate data loading, if you CPU and GPUs are powerful.' )
34
38
parser .add_argument ('--gpus' , type = str , default = '0' ,
35
39
help = 'Training with GPUs, you can specify 1,3 for example.' )
36
- parser .add_argument ('--epochs' , type = int , default = 30 ,
40
+ parser .add_argument ('--epochs' , type = str , default = '' ,
37
41
help = 'Training epochs.' )
38
42
parser .add_argument ('--resume' , type = str , default = '' ,
39
43
help = 'Resume from previously saved parameters if not None. '
40
44
'For example, you can resume from ./faster_rcnn_xxx_0123.params' )
41
45
parser .add_argument ('--start-epoch' , type = int , default = 0 ,
42
46
help = 'Starting epoch for resuming, default is 0 for new training.'
43
47
'You can specify it to 100 for example to start from 100 epoch.' )
44
- parser .add_argument ('--lr' , type = float , default = 0.001 ,
45
- help = 'Learning rate, default is 0.001' )
48
+ parser .add_argument ('--lr' , type = str , default = '' ,
49
+ help = 'Learning rate, default is 0.001 for voc single gpu training. ' )
46
50
parser .add_argument ('--lr-decay' , type = float , default = 0.1 ,
47
51
help = 'decay rate of learning rate. default is 0.1.' )
48
- parser .add_argument ('--lr-decay-epoch' , type = str , default = '14,20' ,
49
- help = 'epoches at which learning rate decays. default is 14,20.' )
52
+ parser .add_argument ('--lr-decay-epoch' , type = str , default = '' ,
53
+ help = 'epoches at which learning rate decays. default is 14,20 for voc.' )
54
+ parser .add_argument ('--lr-warmup' , type = str , default = '' ,
55
+ help = 'warmup iterations to adjust learning rate, default is 0 for voc.' )
50
56
parser .add_argument ('--momentum' , type = float , default = 0.9 ,
51
57
help = 'SGD momentum, default is 0.9' )
52
- parser .add_argument ('--wd' , type = float , default = 0.0005 ,
53
- help = 'Weight decay, default is 5e-4' )
58
+ parser .add_argument ('--wd' , type = str , default = '' ,
59
+ help = 'Weight decay, default is 5e-4 for voc ' )
54
60
parser .add_argument ('--log-interval' , type = int , default = 100 ,
55
61
help = 'Logging mini-batch interval. Default is 100.' )
56
62
parser .add_argument ('--save-prefix' , type = str , default = '' ,
@@ -65,6 +71,28 @@ def parse_args():
65
71
parser .add_argument ('--verbose' , dest = 'verbose' , action = 'store_true' ,
66
72
help = 'Print helpful debugging info once set.' )
67
73
args = parser .parse_args ()
74
+ if args .dataset == 'voc' :
75
+ args .short = int (args .short ) if args .short else 600
76
+ args .max_size = int (args .max_size ) if args .max_size else 1000
77
+ args .epochs = int (args .epochs ) if args .epochs else 20
78
+ args .lr_decay_epoch = args .lr_decay_epoch if args .lr_decay_epoch else '14,20'
79
+ args .lr = float (args .lr ) if args .lr else 0.001
80
+ args .lr_warmup = args .lr_warmup if args .lr_warmup else - 1
81
+ args .wd = float (args .wd ) if args .wd else 5e-4
82
+ elif args .dataset == 'coco' :
83
+ args .short = int (args .short ) if args .short else 800
84
+ args .max_size = int (args .max_size ) if args .max_size else 1333
85
+ args .epochs = int (args .epochs ) if args .epochs else 24
86
+ args .lr_decay_epoch = args .lr_decay_epoch if args .lr_decay_epoch else '16,21'
87
+ args .lr = float (args .lr ) if args .lr else 0.00125
88
+ args .lr_warmup = args .lr_warmup if args .lr_warmup else 8000
89
+ args .wd = float (args .wd ) if args .wd else 1e-4
90
+ num_gpus = len (args .gpus .split (',' ))
91
+ if num_gpus == 1 :
92
+ args .lr_warmup = - 1
93
+ else :
94
+ args .lr *= num_gpus
95
+ args .lr_warmup /= num_gpus
68
96
return args
69
97
70
98
@@ -163,10 +191,8 @@ def get_dataset(dataset, args):
163
191
raise NotImplementedError ('Dataset: {} not implemented.' .format (dataset ))
164
192
return train_dataset , val_dataset , val_metric
165
193
166
- def get_dataloader (net , train_dataset , val_dataset , batch_size , num_workers ):
194
+ def get_dataloader (net , train_dataset , val_dataset , short , max_size , batch_size , num_workers ):
167
195
"""Get dataloader."""
168
- short , max_size = 600 , 1000
169
-
170
196
train_bfn = batchify .Tuple (* [batchify .Append () for _ in range (5 )])
171
197
train_loader = mx .gluon .data .DataLoader (
172
198
train_dataset .transform (FasterRCNNDefaultTrainTransform (short , max_size , net )),
@@ -177,15 +203,19 @@ def get_dataloader(net, train_dataset, val_dataset, batch_size, num_workers):
177
203
batch_size , False , batchify_fn = val_bfn , last_batch = 'keep' , num_workers = num_workers )
178
204
return train_loader , val_loader
179
205
180
- def save_params (net , best_map , current_map , epoch , save_interval , prefix ):
206
+ def save_params (net , logger , best_map , current_map , epoch , save_interval , prefix ):
181
207
current_map = float (current_map )
182
208
if current_map > best_map [0 ]:
209
+ logger .info ('[Epoch {}] mAP {} higher than current best {} saving to {}' .format (
210
+ epoch , current_map , best_map , '{:s}_best.params' .format (prefix )))
183
211
best_map [0 ] = current_map
184
- net .save_params ('{:s}_best.params' .format (prefix , epoch , current_map ))
212
+ net .save_parameters ('{:s}_best.params' .format (prefix ))
185
213
with open (prefix + '_best_map.log' , 'a' ) as f :
186
214
f .write ('\n {:04d}:\t {:.4f}' .format (epoch , current_map ))
187
- if save_interval and epoch % save_interval == 0 :
188
- net .save_params ('{:s}_{:04d}_{:.4f}.params' .format (prefix , epoch , current_map ))
215
+ if save_interval and (epoch + 1 ) % save_interval == 0 :
216
+ logger .info ('[Epoch {}] Saving parameters to {}' .format (
217
+ epoch , '{:s}_{:04d}_{:.4f}.params' .format (prefix , epoch , current_map )))
218
+ net .save_parameters ('{:s}_{:04d}_{:.4f}.params' .format (prefix , epoch , current_map ))
189
219
190
220
def split_and_load (batch , ctx_list ):
191
221
"""Split data to 1 batch each device."""
@@ -201,7 +231,7 @@ def validate(net, val_data, ctx, eval_metric):
201
231
eval_metric .reset ()
202
232
# set nms threshold and topk constraint
203
233
net .set_nms (nms_thresh = 0.3 , nms_topk = 400 )
204
- net .hybridize ()
234
+ net .hybridize (static_alloc = True )
205
235
for batch in val_data :
206
236
batch = split_and_load (batch , ctx_list = ctx )
207
237
det_bboxes = []
@@ -231,6 +261,9 @@ def validate(net, val_data, ctx, eval_metric):
231
261
eval_metric .update (det_bbox , det_id , det_score , gt_bbox , gt_id , gt_diff )
232
262
return eval_metric .get ()
233
263
264
+ def get_lr_at_iter (alpha ):
265
+ return 1. / 3. * (1 - alpha ) + alpha
266
+
234
267
def train (net , train_data , val_data , eval_metric , args ):
235
268
"""Training pipeline"""
236
269
net .collect_params ().reset_ctx (ctx )
@@ -245,6 +278,7 @@ def train(net, train_data, val_data, eval_metric, args):
245
278
# lr decay policy
246
279
lr_decay = float (args .lr_decay )
247
280
lr_steps = sorted ([float (ls ) for ls in args .lr_decay_epoch .split (',' ) if ls .strip ()])
281
+ lr_warmup = int (args .lr_warmup )
248
282
249
283
# TODO(zhreshold) losses?
250
284
rpn_cls_loss = mx .gluon .loss .SigmoidBinaryCrossEntropyLoss (from_sigmoid = False )
@@ -288,8 +322,14 @@ def train(net, train_data, val_data, eval_metric, args):
288
322
metric .reset ()
289
323
tic = time .time ()
290
324
btic = time .time ()
291
- net .hybridize ()
325
+ net .hybridize (static_alloc = True )
326
+ base_lr = trainer .learning_rate
292
327
for i , batch in enumerate (train_data ):
328
+ if epoch == 0 and i <= lr_warmup :
329
+ new_lr = base_lr * get_lr_at_iter ((i // 500 ) / (lr_warmup / 500. ))
330
+ if new_lr != trainer .learning_rate :
331
+ logger .info ('[Epoch 0 Iteration {}] Set learning rate to {}' .format (i , new_lr ))
332
+ trainer .set_learning_rate (new_lr )
293
333
batch = split_and_load (batch , ctx_list = ctx )
294
334
batch_size = len (batch [0 ])
295
335
losses = []
@@ -350,7 +390,7 @@ def train(net, train_data, val_data, eval_metric, args):
350
390
current_map = float (mean_ap [- 1 ])
351
391
else :
352
392
current_map = 0.
353
- save_params (net , best_map , current_map , epoch , args .save_interval , args .save_prefix )
393
+ save_params (net , logger , best_map , current_map , epoch , args .save_interval , args .save_prefix )
354
394
355
395
if __name__ == '__main__' :
356
396
args = parse_args ()
@@ -367,7 +407,7 @@ def train(net, train_data, val_data, eval_metric, args):
367
407
args .save_prefix += net_name
368
408
net = get_model (net_name , pretrained_base = True )
369
409
if args .resume .strip ():
370
- net .load_params (args .resume .strip ())
410
+ net .load_parameters (args .resume .strip ())
371
411
else :
372
412
for param in net .collect_params ().values ():
373
413
if param ._data is not None :
@@ -377,7 +417,7 @@ def train(net, train_data, val_data, eval_metric, args):
377
417
# training data
378
418
train_dataset , val_dataset , eval_metric = get_dataset (args .dataset , args )
379
419
train_data , val_data = get_dataloader (
380
- net , train_dataset , val_dataset , args .batch_size , args .num_workers )
420
+ net , train_dataset , val_dataset , args .short , args . max_size , args . batch_size , args .num_workers )
381
421
382
422
# training
383
423
train (net , train_data , val_data , eval_metric , args )
0 commit comments