Skip to content

Commit e043d56

Browse files
authored
move rcnn forward backward task to model zoo (#1288)
* move rcnn forward backward task to model zoo * revert #1249 * fix * fix * docstring * fix style * add docs
1 parent 2716ec9 commit e043d56

File tree

6 files changed

+205
-147
lines changed

6 files changed

+205
-147
lines changed

gluoncv/model_zoo/rcnn/faster_rcnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .faster_rcnn import *
66
from .predefined_models import *
77
from .rcnn_target import RCNNTargetGenerator, RCNNTargetSampler
8+
from .data_parallel import ForwardBackwardTask
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Data parallel task for Faster RCNN Model."""
2+
3+
from mxnet import autograd
4+
from mxnet.contrib import amp
5+
6+
from gluoncv.utils.parallel import Parallelizable
7+
8+
9+
class ForwardBackwardTask(Parallelizable):
10+
""" Faster R-CNN training task that can be scheduled concurrently using Parallel.
11+
Parameters
12+
----------
13+
net : gluon.HybridBlock
14+
Faster R-CNN network.
15+
optimizer : gluon.Trainer
16+
Optimizer for the training.
17+
rpn_cls_loss : gluon.loss
18+
RPN box classification loss.
19+
rpn_box_loss : gluon.loss
20+
RPN box regression loss.
21+
rcnn_cls_loss : gluon.loss
22+
R-CNN box head classification loss.
23+
rcnn_box_loss : gluon.loss
24+
R-CNN box head regression loss.
25+
mix_ratio : int
26+
Object detection mixup ratio.
27+
amp_enabled : bool
28+
Whether to enable Automatic Mixed Precision.
29+
"""
30+
31+
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
32+
mix_ratio, amp_enabled):
33+
super(ForwardBackwardTask, self).__init__()
34+
self.net = net
35+
self._optimizer = optimizer
36+
self.rpn_cls_loss = rpn_cls_loss
37+
self.rpn_box_loss = rpn_box_loss
38+
self.rcnn_cls_loss = rcnn_cls_loss
39+
self.rcnn_box_loss = rcnn_box_loss
40+
self.mix_ratio = mix_ratio
41+
self.amp_enabled = amp_enabled
42+
43+
def forward_backward(self, x):
44+
data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x
45+
with autograd.record():
46+
gt_label = label[:, :, 4:5]
47+
gt_box = label[:, :, :4]
48+
cls_pred, box_pred, _, _, _Z, rpn_score, rpn_box, _, cls_targets, \
49+
box_targets, box_masks, _ = self.net(data, gt_box, gt_label)
50+
# losses of rpn
51+
rpn_score = rpn_score.squeeze(axis=-1)
52+
num_rpn_pos = (rpn_cls_targets >= 0).sum()
53+
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
54+
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
55+
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
56+
rpn_box_masks) * rpn_box.size / num_rpn_pos
57+
# rpn overall loss, use sum rather than average
58+
rpn_loss = rpn_loss1 + rpn_loss2
59+
# losses of rcnn
60+
num_rcnn_pos = (cls_targets >= 0).sum()
61+
rcnn_loss1 = self.rcnn_cls_loss(
62+
cls_pred, cls_targets, cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
63+
num_rcnn_pos
64+
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
65+
num_rcnn_pos
66+
rcnn_loss = rcnn_loss1 + rcnn_loss2
67+
# overall losses
68+
total_loss = rpn_loss.sum() * self.mix_ratio + rcnn_loss.sum() * self.mix_ratio
69+
70+
rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio
71+
rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio
72+
rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio
73+
rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio
74+
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
75+
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
76+
rcnn_acc_metric = [[cls_targets], [cls_pred]]
77+
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
78+
79+
if self.amp_enabled:
80+
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses:
81+
autograd.backward(scaled_losses)
82+
else:
83+
total_loss.backward()
84+
85+
return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
86+
rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, rcnn_l1_loss_metric

gluoncv/model_zoo/rcnn/mask_rcnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from .mask_rcnn import *
66
from .predefined_models import *
7+
from .data_parallel import ForwardBackwardTask
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Data parallel task for Mask R-CNN Model."""
2+
3+
import mxnet as mx
4+
from mxnet import autograd
5+
from mxnet.contrib import amp
6+
7+
from gluoncv.utils.parallel import Parallelizable
8+
9+
10+
class ForwardBackwardTask(Parallelizable):
11+
""" Mask R-CNN training task that can be scheduled concurrently using Parallel.
12+
Parameters
13+
----------
14+
net : gluon.HybridBlock
15+
Faster R-CNN network.
16+
optimizer : gluon.Trainer
17+
Optimizer for the training.
18+
rpn_cls_loss : gluon.loss
19+
RPN box classification loss.
20+
rpn_box_loss : gluon.loss
21+
RPN box regression loss.
22+
rcnn_cls_loss : gluon.loss
23+
R-CNN box head classification loss.
24+
rcnn_box_loss : gluon.loss
25+
R-CNN box head regression loss.
26+
rcnn_mask_loss : gluon.loss
27+
R-CNN mask head segmentation loss.
28+
amp_enabled : bool
29+
Whether to enable Automatic Mixed Precision.
30+
"""
31+
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
32+
rcnn_mask_loss, amp_enabled):
33+
super(ForwardBackwardTask, self).__init__()
34+
self.net = net
35+
self._optimizer = optimizer
36+
self.rpn_cls_loss = rpn_cls_loss
37+
self.rpn_box_loss = rpn_box_loss
38+
self.rcnn_cls_loss = rcnn_cls_loss
39+
self.rcnn_box_loss = rcnn_box_loss
40+
self.rcnn_mask_loss = rcnn_mask_loss
41+
self.amp_enabled = amp_enabled
42+
43+
def forward_backward(self, x):
44+
data, label, gt_mask, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x
45+
with autograd.record():
46+
gt_label = label[:, :, 4:5]
47+
gt_box = label[:, :, :4]
48+
cls_pred, box_pred, mask_pred, roi, _, matches, rpn_score, rpn_box, _, \
49+
cls_targets, box_targets, box_masks, indices = self.net(data, gt_box, gt_label)
50+
# losses of rpn
51+
rpn_score = rpn_score.squeeze(axis=-1)
52+
num_rpn_pos = (rpn_cls_targets >= 0).sum()
53+
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
54+
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
55+
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
56+
rpn_box_masks) * rpn_box.size / num_rpn_pos
57+
# rpn overall loss, use sum rather than average
58+
rpn_loss = rpn_loss1 + rpn_loss2
59+
60+
# losses of rcnn
61+
num_rcnn_pos = (cls_targets >= 0).sum()
62+
rcnn_loss1 = self.rcnn_cls_loss(
63+
cls_pred, cls_targets, cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
64+
num_rcnn_pos
65+
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
66+
num_rcnn_pos
67+
rcnn_loss = rcnn_loss1 + rcnn_loss2
68+
69+
# generate targets for mask
70+
roi = mx.nd.concat(
71+
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
72+
.reshape((indices.shape[0], -1, 4))
73+
m_cls_targets = mx.nd.concat(
74+
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
75+
.reshape((indices.shape[0], -1))
76+
matches = mx.nd.concat(
77+
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
78+
.reshape((indices.shape[0], -1))
79+
mask_targets, mask_masks = self.net.mask_target(roi, gt_mask, matches, m_cls_targets)
80+
# loss of mask
81+
mask_loss = self.rcnn_mask_loss(mask_pred, mask_targets, mask_masks) * \
82+
mask_targets.size / mask_masks.sum()
83+
84+
# overall losses
85+
total_loss = rpn_loss.sum() + rcnn_loss.sum() + mask_loss.sum()
86+
87+
rpn_loss1_metric = rpn_loss1.mean()
88+
rpn_loss2_metric = rpn_loss2.mean()
89+
rcnn_loss1_metric = rcnn_loss1.sum()
90+
rcnn_loss2_metric = rcnn_loss2.sum()
91+
mask_loss_metric = mask_loss.sum()
92+
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
93+
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
94+
rcnn_acc_metric = [[cls_targets], [cls_pred]]
95+
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
96+
rcnn_mask_metric = [[mask_targets, mask_masks], [mask_pred]]
97+
rcnn_fgmask_metric = [[mask_targets, mask_masks], [mask_pred]]
98+
99+
if self.amp_enabled:
100+
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses:
101+
autograd.backward(scaled_losses)
102+
else:
103+
total_loss.backward()
104+
105+
return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
106+
mask_loss_metric, rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, \
107+
rcnn_l1_loss_metric, rcnn_mask_metric, rcnn_fgmask_metric

scripts/detection/faster_rcnn/train_faster_rcnn.py

Lines changed: 5 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import numpy as np
1717
import mxnet as mx
1818
from mxnet import gluon
19-
from mxnet import autograd
2019
from mxnet.contrib import amp
2120
import gluoncv as gcv
2221

@@ -29,9 +28,10 @@
2928
FasterRCNNDefaultValTransform
3029
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
3130
from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
32-
from gluoncv.utils.parallel import Parallelizable, Parallel
31+
from gluoncv.utils.parallel import Parallel
3332
from gluoncv.utils.metrics.rcnn import RPNAccMetric, RPNL1LossMetric, RCNNAccMetric, \
3433
RCNNL1LossMetric
34+
from gluoncv.model_zoo.rcnn.faster_rcnn.data_parallel import ForwardBackwardTask
3535

3636
try:
3737
import horovod.mxnet as hvd
@@ -415,64 +415,6 @@ def get_lr_at_iter(alpha, lr_warmup_factor=1. / 3.):
415415
return lr_warmup_factor * (1 - alpha) + alpha
416416

417417

418-
class ForwardBackwardTask(Parallelizable):
419-
def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
420-
mix_ratio):
421-
super(ForwardBackwardTask, self).__init__()
422-
self.net = net
423-
self._optimizer = optimizer
424-
self.rpn_cls_loss = rpn_cls_loss
425-
self.rpn_box_loss = rpn_box_loss
426-
self.rcnn_cls_loss = rcnn_cls_loss
427-
self.rcnn_box_loss = rcnn_box_loss
428-
self.mix_ratio = mix_ratio
429-
430-
def forward_backward(self, x):
431-
data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x
432-
with autograd.record():
433-
gt_label = label[:, :, 4:5]
434-
gt_box = label[:, :, :4]
435-
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
436-
box_targets, box_masks, _ = self.net(data, gt_box, gt_label)
437-
# losses of rpn
438-
rpn_score = rpn_score.squeeze(axis=-1)
439-
num_rpn_pos = (rpn_cls_targets >= 0).sum()
440-
rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
441-
rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
442-
rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
443-
rpn_box_masks) * rpn_box.size / num_rpn_pos
444-
# rpn overall loss, use sum rather than average
445-
rpn_loss = rpn_loss1 + rpn_loss2
446-
# losses of rcnn
447-
num_rcnn_pos = (cls_targets >= 0).sum()
448-
rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets,
449-
cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
450-
num_rcnn_pos
451-
rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
452-
num_rcnn_pos
453-
rcnn_loss = rcnn_loss1 + rcnn_loss2
454-
# overall losses
455-
total_loss = rpn_loss.sum() * self.mix_ratio + rcnn_loss.sum() * self.mix_ratio
456-
457-
rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio
458-
rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio
459-
rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio
460-
rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio
461-
rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
462-
rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
463-
rcnn_acc_metric = [[cls_targets], [cls_pred]]
464-
rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]
465-
466-
if args.amp:
467-
with amp.scale_loss(total_loss, self._optimizer) as scaled_losses:
468-
autograd.backward(scaled_losses)
469-
else:
470-
total_loss.backward()
471-
472-
return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
473-
rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, rcnn_l1_loss_metric
474-
475-
476418
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
477419
"""Training pipeline"""
478420
args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store
@@ -539,10 +481,10 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
539481
logger.info(net.collect_train_params().keys())
540482
logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
541483
best_map = [0]
542-
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
543-
rcnn_box_loss, mix_ratio=1.0)
544-
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
545484
for epoch in range(args.start_epoch, args.epochs):
485+
rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
486+
rcnn_box_loss, mix_ratio=1.0, amp_enabled=args.amp)
487+
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
546488
mix_ratio = 1.0
547489
if not args.disable_hybridization:
548490
net.hybridize(static_alloc=args.static_alloc)

0 commit comments

Comments
 (0)