Skip to content

sacred config for faster rcnn #1358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
revert #1249
  • Loading branch information
Jerryzcn committed May 6, 2020
commit f71a9842f0333dfa958826566242fcde65db33de
85 changes: 85 additions & 0 deletions gluoncv/model_zoo/rcnn/mask_rcnn/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import mxnet as mx
from mxnet import autograd
from mxnet.contrib import amp

from gluoncv.utils.parallel import Parallelizable


class ForwardBackwardTask(Parallelizable):
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
rcnn_mask_loss, amp_enabled):
super(ForwardBackwardTask, self).__init__()
self.net = net
self._optimizer = optimizer
self.rpn_cls_loss = rpn_cls_loss
self.rpn_box_loss = rpn_box_loss
self.rcnn_cls_loss = rcnn_cls_loss
self.rcnn_box_loss = rcnn_box_loss
self.rcnn_mask_loss = rcnn_mask_loss
self.amp_enabled = amp_enabled

def forward_backward(self, x):
data, label, gt_mask, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x
with autograd.record():
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = self.net(data, gt_box, gt_label)
# losses of rpn
rpn_score = rpn_score.squeeze(axis=-1)
num_rpn_pos = (rpn_cls_targets >= 0).sum()
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
rpn_box_masks) * rpn_box.size / num_rpn_pos
# rpn overall loss, use sum rather than average
rpn_loss = rpn_loss1 + rpn_loss2

# losses of rcnn
num_rcnn_pos = (cls_targets >= 0).sum()
rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets,
cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
num_rcnn_pos
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
num_rcnn_pos
rcnn_loss = rcnn_loss1 + rcnn_loss2

# generate targets for mask
roi = mx.nd.concat(
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1, 4))
m_cls_targets = mx.nd.concat(
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
matches = mx.nd.concat(
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
mask_targets, mask_masks = self.net.mask_target(roi, gt_mask, matches, m_cls_targets)
# loss of mask
mask_loss = self.rcnn_mask_loss(mask_pred, mask_targets, mask_masks) * \
mask_targets.size / mask_masks.sum()

# overall losses
total_loss = rpn_loss.sum() + rcnn_loss.sum() + mask_loss.sum()

rpn_loss1_metric = rpn_loss1.mean()
rpn_loss2_metric = rpn_loss2.mean()
rcnn_loss1_metric = rcnn_loss1.sum()
rcnn_loss2_metric = rcnn_loss2.sum()
mask_loss_metric = mask_loss.sum()
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
rcnn_acc_metric = [[cls_targets], [cls_pred]]
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
rcnn_mask_metric = [[mask_targets, mask_masks], [mask_pred]]
rcnn_fgmask_metric = [[mask_targets, mask_masks], [mask_pred]]

if self.amp_enabled:
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses:
autograd.backward(scaled_losses)
else:
total_loss.backward()

return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
mask_loss_metric, rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, \
rcnn_l1_loss_metric, rcnn_mask_metric, rcnn_fgmask_metric
7 changes: 4 additions & 3 deletions scripts/detection/faster_rcnn/train_faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,10 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
logger.info(net.collect_train_params().keys())
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
best_map = [0]
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
rcnn_box_loss, mix_ratio=1.0, amp_enabled=args.amp)
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
for epoch in range(args.start_epoch, args.epochs):
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
rcnn_box_loss, mix_ratio=1.0, amp_enabled=args.amp)
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
mix_ratio = 1.0
if not args.disable_hybridization:
net.hybridize(static_alloc=args.static_alloc)
Expand Down Expand Up @@ -537,6 +537,7 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
for pred in records:
metric.update(pred[0], pred[1])
trainer.step(batch_size)
break

# update metrics
if (not args.horovod or hvd.rank() == 0) and args.log_interval \
Expand Down
85 changes: 3 additions & 82 deletions scripts/instance/mask_rcnn/train_mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,85 +471,6 @@ def get_lr_at_iter(alpha, lr_warmup_factor=1. / 3.):
return lr_warmup_factor * (1 - alpha) + alpha


class ForwardBackwardTask(Parallelizable):
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
rcnn_mask_loss):
super(ForwardBackwardTask, self).__init__()
self.net = net
self._optimizer = optimizer
self.rpn_cls_loss = rpn_cls_loss
self.rpn_box_loss = rpn_box_loss
self.rcnn_cls_loss = rcnn_cls_loss
self.rcnn_box_loss = rcnn_box_loss
self.rcnn_mask_loss = rcnn_mask_loss

def forward_backward(self, x):
data, label, gt_mask, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x
with autograd.record():
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = self.net(data, gt_box, gt_label)
# losses of rpn
rpn_score = rpn_score.squeeze(axis=-1)
num_rpn_pos = (rpn_cls_targets >= 0).sum()
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
rpn_box_masks) * rpn_box.size / num_rpn_pos
# rpn overall loss, use sum rather than average
rpn_loss = rpn_loss1 + rpn_loss2

# losses of rcnn
num_rcnn_pos = (cls_targets >= 0).sum()
rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets,
cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
num_rcnn_pos
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
num_rcnn_pos
rcnn_loss = rcnn_loss1 + rcnn_loss2

# generate targets for mask
roi = mx.nd.concat(
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1, 4))
m_cls_targets = mx.nd.concat(
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
matches = mx.nd.concat(
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
mask_targets, mask_masks = self.net.mask_target(roi, gt_mask, matches, m_cls_targets)
# loss of mask
mask_loss = self.rcnn_mask_loss(mask_pred, mask_targets, mask_masks) * \
mask_targets.size / mask_masks.sum()

# overall losses
total_loss = rpn_loss.sum() + rcnn_loss.sum() + mask_loss.sum()

rpn_loss1_metric = rpn_loss1.mean()
rpn_loss2_metric = rpn_loss2.mean()
rcnn_loss1_metric = rcnn_loss1.sum()
rcnn_loss2_metric = rcnn_loss2.sum()
mask_loss_metric = mask_loss.sum()
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
rcnn_acc_metric = [[cls_targets], [cls_pred]]
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
rcnn_mask_metric = [[mask_targets, mask_masks], [mask_pred]]
rcnn_fgmask_metric = [[mask_targets, mask_masks], [mask_pred]]

if args.amp:
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses:
autograd.backward(scaled_losses)
else:
total_loss.backward()

return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
mask_loss_metric, rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, \
rcnn_l1_loss_metric, rcnn_mask_metric, rcnn_fgmask_metric


def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
"""Training pipeline"""
args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store
Expand Down Expand Up @@ -615,10 +536,10 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args)
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
best_map = [0]
base_lr = trainer.learning_rate
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
rcnn_box_loss, rcnn_mask_loss)
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
for epoch in range(args.start_epoch, args.epochs):
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
rcnn_box_loss, rcnn_mask_loss)
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
if not args.disable_hybridization:
net.hybridize(static_alloc=args.static_alloc)
while lr_steps and epoch >= lr_steps[0]:
Expand Down