Skip to content

Commit 9975c04

Browse files
ijkguozhreshold
authored andcommitted
Add static alloc and fix load/save_params (#183)
* fix save_params * add warmup lr * add static alloc * tune coco settings * fix load_params * add logging to saving parameters * tune coco param num_sample, test_post_nms * fix params doc * add coco settings to eval * change coco to 2x lr schedule * fix load_params in eval, pretrained backbone is still unchanged
1 parent 5921740 commit 9975c04

File tree

4 files changed

+105
-31
lines changed

4 files changed

+105
-31
lines changed

gluoncv/model_zoo/faster_rcnn/faster_rcnn.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ class FasterRCNN(RCNN):
4747
This is usually the ratio between original image size and feature map size.
4848
rpn_channel : int, default is 1024
4949
Channel number used in RPN convolutional layers.
50+
rpn_train_pre_nms : int, default is 12000
51+
Filter top proposals before NMS in training of RPN.
52+
rpn_train_post_nms : int, default is 2000
53+
Return top proposal results after NMS in training of RPN.
54+
rpn_test_pre_nms : int, default is 6000
55+
Filter top proposals before NMS in testing of RPN.
56+
rpn_test_post_nms : int, default is 300
57+
Return top proposal results after NMS in testing of RPN.
5058
nms_thresh : float, default is 0.3.
5159
Non-maximum suppression threshold. You can speficy < 0 or > 1 to disable NMS.
5260
nms_topk : int, default is 400
@@ -73,16 +81,20 @@ class FasterRCNN(RCNN):
7381
7482
"""
7583
def __init__(self, features, top_features, scales, ratios, classes, roi_mode, roi_size,
76-
stride=16, rpn_channel=1024, num_sample=128, pos_iou_thresh=0.5,
77-
neg_iou_thresh_high=0.5, neg_iou_thresh_low=0.0, pos_ratio=0.25, **kwargs):
84+
stride=16, rpn_channel=1024, rpn_train_pre_nms=12000, rpn_train_post_nms=2000,
85+
rpn_test_pre_nms=6000, rpn_test_post_nms=300,
86+
num_sample=128, pos_iou_thresh=0.5, neg_iou_thresh_high=0.5,
87+
neg_iou_thresh_low=0.0, pos_ratio=0.25, **kwargs):
7888
super(FasterRCNN, self).__init__(
7989
features, top_features, classes, roi_mode, roi_size, **kwargs)
8090
self.stride = stride
8191
self._max_batch = 1 # currently only support batch size = 1
8292
self._max_roi = 100000 # maximum allowed ROIs
8393
self._target_generator = set([RCNNTargetGenerator(self.num_class)])
8494
with self.name_scope():
85-
self.rpn = RPN(rpn_channel, stride, scales=scales, ratios=ratios)
95+
self.rpn = RPN(rpn_channel, stride, scales=scales, ratios=ratios,
96+
train_pre_nms=rpn_train_pre_nms, train_post_nms=rpn_train_post_nms,
97+
test_pre_nms=rpn_test_pre_nms, test_post_nms=rpn_test_post_nms)
8698
self.sampler = RCNNTargetSampler(num_sample, pos_iou_thresh, neg_iou_thresh_high,
8799
neg_iou_thresh_low, pos_ratio)
88100

@@ -238,7 +250,7 @@ def get_faster_rcnn(name, features, top_features, scales, ratios, classes,
238250
if pretrained:
239251
from ..model_store import get_model_file
240252
full_name = '_'.join(('faster_rcnn', name, dataset))
241-
net.load_params(get_model_file(full_name, root=root), ctx=ctx)
253+
net.load_parameters(get_model_file(full_name, root=root), ctx=ctx)
242254
return net
243255

244256
def faster_rcnn_resnet50_v1b_voc(pretrained=False, pretrained_base=True, **kwargs):
@@ -319,7 +331,8 @@ def faster_rcnn_resnet50_v1b_coco(pretrained=False, pretrained_base=True, **kwar
319331
ratios=(0.5, 1, 2), classes=classes, dataset='coco',
320332
roi_mode='align', roi_size=(14, 14), stride=16,
321333
rpn_channel=1024, train_patterns=train_patterns,
322-
pretrained=pretrained, **kwargs)
334+
pretrained=pretrained, num_sample=512, rpn_test_post_nms=1000,
335+
**kwargs)
323336

324337
def faster_rcnn_resnet50_v2a_voc(pretrained=False, pretrained_base=True, **kwargs):
325338
r"""Faster RCNN model from the paper
@@ -399,7 +412,8 @@ def faster_rcnn_resnet50_v2a_coco(pretrained=False, pretrained_base=True, **kwar
399412
ratios=(0.5, 1, 2), classes=classes, dataset='coco',
400413
roi_mode='align', roi_size=(14, 14), stride=16,
401414
rpn_channel=1024, train_patterns=train_patterns,
402-
pretrained=pretrained, **kwargs)
415+
pretrained=pretrained, num_sample=512, rpn_test_post_nms=1000,
416+
**kwargs)
403417

404418
def faster_rcnn_resnet50_v2_voc(pretrained=False, pretrained_base=True, **kwargs):
405419
r"""Faster RCNN model from the paper

scripts/detection/faster_rcnn/demo_faster_rcnn.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,24 @@ def parse_args():
1010
parser = argparse.ArgumentParser(description='Test with Faster RCNN networks.')
1111
parser.add_argument('--network', type=str, default='faster_rcnn_resnet50_v2a_voc',
1212
help="Faster RCNN full network name")
13+
parser.add_argument('--short', type=str, default='',
14+
help='Resize image to the given short side side, default to 600 for voc.')
15+
parser.add_argument('--max-size', type=str, default='',
16+
help='Max size of either side of image, default to 1000 for voc.')
1317
parser.add_argument('--images', type=str, default='',
1418
help='Test images, use comma to split multiple.')
1519
parser.add_argument('--gpus', type=str, default='0',
1620
help='Training with GPUs, you can specify 1,3 for example.')
1721
parser.add_argument('--pretrained', type=str, default='True',
1822
help='Load weights from previously saved parameters. You can specify parameter file name.')
1923
args = parser.parse_args()
24+
dataset = args.network.split('_')[-1]
25+
if dataset == 'voc':
26+
args.short = int(args.short) if args.short else 600
27+
args.max_size = int(args.max_size) if args.max_size else 1000
28+
elif dataset == 'coco':
29+
args.short = int(args.short) if args.short else 800
30+
args.max_size = int(args.max_size) if args.max_size else 1333
2031
return args
2132

2233
if __name__ == '__main__':
@@ -37,12 +48,12 @@ def parse_args():
3748
net = gcv.model_zoo.get_model(args.network, pretrained=True)
3849
else:
3950
net = gcv.model_zoo.get_model(args.network, pretrained=False)
40-
net.load_params(args.pretrained)
51+
net.load_parameters(args.pretrained)
4152
net.set_nms(0.3, 200)
4253

4354
ax = None
4455
for image in image_list:
45-
x, img = presets.rcnn.load_test(image, short=600, max_size=1000)
56+
x, img = presets.rcnn.load_test(image, short=args.short, max_size=args.max_size)
4657
ids, scores, bboxes = [xx.asnumpy() for xx in net(x)]
4758
ax = gcv.utils.viz.plot_bbox(img, bboxes, scores, ids,
4859
class_names=net.classes, ax=ax)

scripts/detection/faster_rcnn/eval_faster_rcnn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def parse_args():
2424
help="Base feature extraction network name")
2525
parser.add_argument('--dataset', type=str, default='voc',
2626
help='Training dataset.')
27+
parser.add_argument('--short', type=str, default='',
28+
help='Resize image to the given short side side, default to 600 for voc.')
29+
parser.add_argument('--max-size', type=str, default='',
30+
help='Max size of either side of image, default to 1000 for voc.')
2731
parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
2832
default=4, help='Number of data workers')
2933
parser.add_argument('--gpus', type=str, default='0',
@@ -33,6 +37,12 @@ def parse_args():
3337
parser.add_argument('--save-prefix', type=str, default='',
3438
help='Saving parameter prefix')
3539
args = parser.parse_args()
40+
if args.dataset == 'voc':
41+
args.short = int(args.short) if args.short else 600
42+
args.max_size = int(args.max_size) if args.max_size else 1000
43+
elif args.dataset == 'coco':
44+
args.short = int(args.short) if args.short else 800
45+
args.max_size = int(args.max_size) if args.max_size else 1333
3646
return args
3747

3848
def get_dataset(dataset, args):
@@ -47,9 +57,8 @@ def get_dataset(dataset, args):
4757
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
4858
return val_dataset, val_metric
4959

50-
def get_dataloader(net, val_dataset, batch_size, num_workers):
60+
def get_dataloader(net, val_dataset, short, max_size, batch_size, num_workers):
5161
"""Get dataloader."""
52-
short, max_size = 600, 1000
5362
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)])
5463
val_loader = mx.gluon.data.DataLoader(
5564
val_dataset.transform(FasterRCNNDefaultValTransform(short, max_size)),
@@ -116,12 +125,12 @@ def validate(net, val_data, ctx, eval_metric, size):
116125
net = gcv.model_zoo.get_model(net_name, pretrained=True)
117126
else:
118127
net = gcv.model_zoo.get_model(net_name, pretrained=False)
119-
net.load_params(args.pretrained.strip())
128+
net.load_parameters(args.pretrained.strip())
120129

121130
# training data
122131
val_dataset, eval_metric = get_dataset(args.dataset, args)
123132
val_data = get_dataloader(
124-
net, val_dataset, args.batch_size, args.num_workers)
133+
net, val_dataset, args.short, args.max_size, args.batch_size, args.num_workers)
125134

126135
# validation
127136
names, values = validate(net, val_data, ctx, eval_metric, len(val_dataset))

scripts/detection/faster_rcnn/train_faster_rcnn.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,35 @@ def parse_args():
2828
help="Base network name which serves as feature extraction base.")
2929
parser.add_argument('--dataset', type=str, default='voc',
3030
help='Training dataset. Now support voc.')
31+
parser.add_argument('--short', type=str, default='',
32+
help='Resize image to the given short side side, default to 600 for voc.')
33+
parser.add_argument('--max-size', type=str, default='',
34+
help='Max size of either side of image, default to 1000 for voc.')
3135
parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
3236
default=4, help='Number of data workers, you can use larger '
3337
'number to accelerate data loading, if you CPU and GPUs are powerful.')
3438
parser.add_argument('--gpus', type=str, default='0',
3539
help='Training with GPUs, you can specify 1,3 for example.')
36-
parser.add_argument('--epochs', type=int, default=30,
40+
parser.add_argument('--epochs', type=str, default='',
3741
help='Training epochs.')
3842
parser.add_argument('--resume', type=str, default='',
3943
help='Resume from previously saved parameters if not None. '
4044
'For example, you can resume from ./faster_rcnn_xxx_0123.params')
4145
parser.add_argument('--start-epoch', type=int, default=0,
4246
help='Starting epoch for resuming, default is 0 for new training.'
4347
'You can specify it to 100 for example to start from 100 epoch.')
44-
parser.add_argument('--lr', type=float, default=0.001,
45-
help='Learning rate, default is 0.001')
48+
parser.add_argument('--lr', type=str, default='',
49+
help='Learning rate, default is 0.001 for voc single gpu training.')
4650
parser.add_argument('--lr-decay', type=float, default=0.1,
4751
help='decay rate of learning rate. default is 0.1.')
48-
parser.add_argument('--lr-decay-epoch', type=str, default='14,20',
49-
help='epoches at which learning rate decays. default is 14,20.')
52+
parser.add_argument('--lr-decay-epoch', type=str, default='',
53+
help='epoches at which learning rate decays. default is 14,20 for voc.')
54+
parser.add_argument('--lr-warmup', type=str, default='',
55+
help='warmup iterations to adjust learning rate, default is 0 for voc.')
5056
parser.add_argument('--momentum', type=float, default=0.9,
5157
help='SGD momentum, default is 0.9')
52-
parser.add_argument('--wd', type=float, default=0.0005,
53-
help='Weight decay, default is 5e-4')
58+
parser.add_argument('--wd', type=str, default='',
59+
help='Weight decay, default is 5e-4 for voc')
5460
parser.add_argument('--log-interval', type=int, default=100,
5561
help='Logging mini-batch interval. Default is 100.')
5662
parser.add_argument('--save-prefix', type=str, default='',
@@ -65,6 +71,28 @@ def parse_args():
6571
parser.add_argument('--verbose', dest='verbose', action='store_true',
6672
help='Print helpful debugging info once set.')
6773
args = parser.parse_args()
74+
if args.dataset == 'voc':
75+
args.short = int(args.short) if args.short else 600
76+
args.max_size = int(args.max_size) if args.max_size else 1000
77+
args.epochs = int(args.epochs) if args.epochs else 20
78+
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20'
79+
args.lr = float(args.lr) if args.lr else 0.001
80+
args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
81+
args.wd = float(args.wd) if args.wd else 5e-4
82+
elif args.dataset == 'coco':
83+
args.short = int(args.short) if args.short else 800
84+
args.max_size = int(args.max_size) if args.max_size else 1333
85+
args.epochs = int(args.epochs) if args.epochs else 24
86+
args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '16,21'
87+
args.lr = float(args.lr) if args.lr else 0.00125
88+
args.lr_warmup = args.lr_warmup if args.lr_warmup else 8000
89+
args.wd = float(args.wd) if args.wd else 1e-4
90+
num_gpus = len(args.gpus.split(','))
91+
if num_gpus == 1:
92+
args.lr_warmup = -1
93+
else:
94+
args.lr *= num_gpus
95+
args.lr_warmup /= num_gpus
6896
return args
6997

7098

@@ -163,10 +191,8 @@ def get_dataset(dataset, args):
163191
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
164192
return train_dataset, val_dataset, val_metric
165193

166-
def get_dataloader(net, train_dataset, val_dataset, batch_size, num_workers):
194+
def get_dataloader(net, train_dataset, val_dataset, short, max_size, batch_size, num_workers):
167195
"""Get dataloader."""
168-
short, max_size = 600, 1000
169-
170196
train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)])
171197
train_loader = mx.gluon.data.DataLoader(
172198
train_dataset.transform(FasterRCNNDefaultTrainTransform(short, max_size, net)),
@@ -177,15 +203,19 @@ def get_dataloader(net, train_dataset, val_dataset, batch_size, num_workers):
177203
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
178204
return train_loader, val_loader
179205

180-
def save_params(net, best_map, current_map, epoch, save_interval, prefix):
206+
def save_params(net, logger, best_map, current_map, epoch, save_interval, prefix):
181207
current_map = float(current_map)
182208
if current_map > best_map[0]:
209+
logger.info('[Epoch {}] mAP {} higher than current best {} saving to {}'.format(
210+
epoch, current_map, best_map, '{:s}_best.params'.format(prefix)))
183211
best_map[0] = current_map
184-
net.save_params('{:s}_best.params'.format(prefix, epoch, current_map))
212+
net.save_parameters('{:s}_best.params'.format(prefix))
185213
with open(prefix+'_best_map.log', 'a') as f:
186214
f.write('\n{:04d}:\t{:.4f}'.format(epoch, current_map))
187-
if save_interval and epoch % save_interval == 0:
188-
net.save_params('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))
215+
if save_interval and (epoch + 1) % save_interval == 0:
216+
logger.info('[Epoch {}] Saving parameters to {}'.format(
217+
epoch, '{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)))
218+
net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))
189219

190220
def split_and_load(batch, ctx_list):
191221
"""Split data to 1 batch each device."""
@@ -201,7 +231,7 @@ def validate(net, val_data, ctx, eval_metric):
201231
eval_metric.reset()
202232
# set nms threshold and topk constraint
203233
net.set_nms(nms_thresh=0.3, nms_topk=400)
204-
net.hybridize()
234+
net.hybridize(static_alloc=True)
205235
for batch in val_data:
206236
batch = split_and_load(batch, ctx_list=ctx)
207237
det_bboxes = []
@@ -231,6 +261,9 @@ def validate(net, val_data, ctx, eval_metric):
231261
eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff)
232262
return eval_metric.get()
233263

264+
def get_lr_at_iter(alpha):
265+
return 1. / 3. * (1 - alpha) + alpha
266+
234267
def train(net, train_data, val_data, eval_metric, args):
235268
"""Training pipeline"""
236269
net.collect_params().reset_ctx(ctx)
@@ -245,6 +278,7 @@ def train(net, train_data, val_data, eval_metric, args):
245278
# lr decay policy
246279
lr_decay = float(args.lr_decay)
247280
lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
281+
lr_warmup = int(args.lr_warmup)
248282

249283
# TODO(zhreshold) losses?
250284
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
@@ -288,8 +322,14 @@ def train(net, train_data, val_data, eval_metric, args):
288322
metric.reset()
289323
tic = time.time()
290324
btic = time.time()
291-
net.hybridize()
325+
net.hybridize(static_alloc=True)
326+
base_lr = trainer.learning_rate
292327
for i, batch in enumerate(train_data):
328+
if epoch == 0 and i <= lr_warmup:
329+
new_lr = base_lr * get_lr_at_iter((i // 500) / (lr_warmup / 500.))
330+
if new_lr != trainer.learning_rate:
331+
logger.info('[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
332+
trainer.set_learning_rate(new_lr)
293333
batch = split_and_load(batch, ctx_list=ctx)
294334
batch_size = len(batch[0])
295335
losses = []
@@ -350,7 +390,7 @@ def train(net, train_data, val_data, eval_metric, args):
350390
current_map = float(mean_ap[-1])
351391
else:
352392
current_map = 0.
353-
save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)
393+
save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)
354394

355395
if __name__ == '__main__':
356396
args = parse_args()
@@ -367,7 +407,7 @@ def train(net, train_data, val_data, eval_metric, args):
367407
args.save_prefix += net_name
368408
net = get_model(net_name, pretrained_base=True)
369409
if args.resume.strip():
370-
net.load_params(args.resume.strip())
410+
net.load_parameters(args.resume.strip())
371411
else:
372412
for param in net.collect_params().values():
373413
if param._data is not None:
@@ -377,7 +417,7 @@ def train(net, train_data, val_data, eval_metric, args):
377417
# training data
378418
train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
379419
train_data, val_data = get_dataloader(
380-
net, train_dataset, val_dataset, args.batch_size, args.num_workers)
420+
net, train_dataset, val_dataset, args.short, args.max_size, args.batch_size, args.num_workers)
381421

382422
# training
383423
train(net, train_data, val_data, eval_metric, args)

0 commit comments

Comments
 (0)