Skip to content

sacred config for faster rcnn #1358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move dataset to init
  • Loading branch information
Jerryzcn committed Jun 16, 2020
commit c5ea69342bebbb149fd861d3cf70b8f2b8f36d46
47 changes: 36 additions & 11 deletions gluoncv/estimators/rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from mxnet import gluon

from ... import data as gdata
from ...data.batchify import FasterRCNNTrainBatchify, Tuple, Append
from ...data.sampler import SplitSortedBucketSampler
from ...data.transforms import presets
Expand All @@ -14,8 +15,10 @@
from ...model_zoo import get_model
from ...model_zoo.rcnn.faster_rcnn.data_parallel import ForwardBackwardTask
from ...nn.bbox import BBoxClipToImage
from ...utils.metrics.coco_detection import COCODetectionMetric
from ...utils.metrics.rcnn import RPNAccMetric, RPNL1LossMetric, RCNNAccMetric, \
RCNNL1LossMetric
from ...utils.metrics.voc_detection import VOC07MApMetric
from ...utils.parallel import Parallel

logging.basicConfig()
Expand Down Expand Up @@ -85,18 +88,42 @@ def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transfo
return train_loader, val_loader


def get_dataset(dataset, args):
if dataset.lower() == 'voc':
train_dataset = gdata.VOCDetection(
splits=[(2007, 'trainval'), (2012, 'trainval')])
val_dataset = gdata.VOCDetection(
splits=[(2007, 'test')])
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
elif dataset.lower() in ['clipart', 'comic', 'watercolor']:
root = os.path.join('~', '.mxnet', 'datasets', dataset.lower())
train_dataset = gdata.CustomVOCDetection(root=root, splits=[('', 'train')],
generate_classes=True)
val_dataset = gdata.CustomVOCDetection(root=root, splits=[('', 'test')],
generate_classes=True)
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
elif dataset.lower() == 'coco':
train_dataset = gdata.COCODetection(splits='instances_train2017', use_crowd=False)
val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False)
val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True)
else:
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
if args.mixup:
from gluoncv.data.mixup import detection
train_dataset = detection.MixupDetection(train_dataset)
return train_dataset, val_dataset, val_metric


class FasterRCNNEstimator:
""" Estimator for Faster R-CNN.
"""

def __init__(self, cfg, eval_metric=None):
def __init__(self, cfg):
"""
Parameters
----------
cfg : configuration object
configuration object containing information for constructing Faster R-CNN estimator.
eval_metric : evaluation metric
evaluation metric that would be used for validation.
"""
self.cfg = cfg
# training contexts
Expand All @@ -105,7 +132,9 @@ def __init__(self, cfg, eval_metric=None):
else:
ctx = [mx.gpu(int(i)) for i in self.cfg.gpus.split(',') if i.strip()]
self.ctx = ctx if ctx else [mx.cpu()]
self.eval_metric = eval_metric
# training data
self.train_dataset, self.val_dataset, self.eval_metric = \
get_dataset(self.cfg.dataset, self.cfg)
# network
kwargs = {}
module_list = []
Expand Down Expand Up @@ -263,18 +292,14 @@ def _validate(self, val_data, ctx, eval_metric):
eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff)
return eval_metric.get()

def fit(self, train_data, val_data=None):
def fit(self):

"""
Fit faster R-CNN models.


Examples
--------
Fit faster R-CNN models. All parameters are set
"""
batch_size = self.cfg.batch_size // self.num_gpus if self.cfg.horovod else self.cfg.batch_size
train_data, val_data = get_dataloader(
self.net, train_data, val_data, FasterRCNNDefaultTrainTransform,
self.net, self.train_data, self.val_data, FasterRCNNDefaultTrainTransform,
FasterRCNNDefaultValTransform, batch_size, len(self.ctx), self.cfg)

self.cfg.kv_store = 'device' if (self.cfg.amp and 'nccl' in self.cfg.kv_store) \
Expand Down
4 changes: 0 additions & 4 deletions scripts/detection/faster_rcnn/eval_faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
import glob
import logging
logging.basicConfig(level=logging.INFO)
import time
import numpy as np
import mxnet as mx
from tqdm import tqdm
from mxnet import nd
from mxnet import gluon
import gluoncv as gcv
gcv.utils.check_version('0.6.0')
from gluoncv import data as gdata
Expand Down
36 changes: 2 additions & 34 deletions scripts/detection/faster_rcnn/train_faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
import gluoncv as gcv

gcv.utils.check_version('0.7.0')
from gluoncv import data as gdata
from gluoncv import utils as gutils
from gluoncv.estimators.rcnn.faster_rcnn import FasterRCNNEstimator
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.metrics.coco_detection import COCODetectionMetric

try:
import horovod.mxnet as hvd
Expand Down Expand Up @@ -282,32 +279,6 @@ def str_args2num_args(arguments, args_name, num_type):
return args


def get_dataset(dataset, args):
if dataset.lower() == 'voc':
train_dataset = gdata.VOCDetection(
splits=[(2007, 'trainval'), (2012, 'trainval')])
val_dataset = gdata.VOCDetection(
splits=[(2007, 'test')])
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
elif dataset.lower() in ['clipart', 'comic', 'watercolor']:
root = os.path.join('~', '.mxnet', 'datasets', dataset.lower())
train_dataset = gdata.CustomVOCDetection(root=root, splits=[('', 'train')],
generate_classes=True)
val_dataset = gdata.CustomVOCDetection(root=root, splits=[('', 'test')],
generate_classes=True)
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
elif dataset.lower() == 'coco':
train_dataset = gdata.COCODetection(splits='instances_train2017', use_crowd=False)
val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False)
val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', cleanup=True)
else:
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
if args.mixup:
from gluoncv.data.mixup import detection
train_dataset = detection.MixupDetection(train_dataset)
return train_dataset, val_dataset, val_metric


if __name__ == '__main__':
import sys

Expand All @@ -319,10 +290,7 @@ def get_dataset(dataset, args):
if args.amp:
amp.init()

# training data
train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)

frcnn_estimator = FasterRCNNEstimator(args, eval_metric)
frcnn_estimator = FasterRCNNEstimator(args)

# training
frcnn_estimator.fit(train_dataset, val_dataset)
frcnn_estimator.fit()