Skip to content

Commit db8559f

Browse files
authored
Fix a logical error of mask rcnn training script (#1249)
* Fix a logical error of mask rcnn training script After this fix, the training memory will be steady. If given a very big epoch(often with small dataset), this error will eat all GPU memory as epoch grows. The script of faster rcnn training has this problem too. I‘ll fix it later when I get some tests passed. * fix hybridize
1 parent d8601c9 commit db8559f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

scripts/detection/faster_rcnn/train_faster_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,13 +539,13 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
539539
logger.info(net.collect_train_params().keys())
540540
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
541541
best_map = [0]
542+
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
543+
rcnn_box_loss, mix_ratio=1.0)
544+
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
542545
for epoch in range(args.start_epoch, args.epochs):
543546
mix_ratio = 1.0
544547
if not args.disable_hybridization:
545548
net.hybridize(static_alloc=args.static_alloc)
546-
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
547-
rcnn_box_loss, mix_ratio=1.0)
548-
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
549549
if args.mixup:
550550
# TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
551551
train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)

scripts/instance/mask_rcnn/train_mask_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,12 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args)
615615
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
616616
best_map = [0]
617617
base_lr = trainer.learning_rate
618+
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
619+
rcnn_box_loss, rcnn_mask_loss)
620+
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
618621
for epoch in range(args.start_epoch, args.epochs):
619622
if not args.disable_hybridization:
620623
net.hybridize(static_alloc=args.static_alloc)
621-
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
622-
rcnn_box_loss, rcnn_mask_loss)
623-
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
624624
while lr_steps and epoch >= lr_steps[0]:
625625
new_lr = trainer.learning_rate * lr_decay
626626
lr_steps.pop(0)

0 commit comments

Comments
 (0)