|
| 1 | +import logging |
| 2 | + |
| 3 | +import autogluon as ag |
| 4 | +from autogluon.core.decorator import sample_config |
| 5 | +from autogluon.scheduler.resource import get_cpu_count, get_gpu_count |
| 6 | +from autogluon.task import BaseTask |
| 7 | +from autogluon.utils import collect_params |
| 8 | + |
| 9 | +from ..estimators.rcnn import FasterRCNNEstimator |
| 10 | +from ... import utils as gutils |
| 11 | + |
| 12 | +__all__ = ['ObjectDetection'] |
| 13 | + |
| 14 | + |
| 15 | +@ag.args() |
| 16 | +def _train_object_detection(args, reporter): |
| 17 | + # fix seed for mxnet, numpy and python builtin random generator. |
| 18 | + gutils.random.seed(args.seed) |
| 19 | + |
| 20 | + # training contexts |
| 21 | + if args.meta_arch == 'yolo3': |
| 22 | + net_name = '_'.join((args.meta_arch, args.net, 'custom')) |
| 23 | + elif args.meta_arch == 'faster_rcnn': |
| 24 | + net_name = '_'.join(('custom', args.meta_arch, 'fpn')) |
| 25 | + kwargs = {'network': args.net, 'base_network_name': args.net, |
| 26 | + 'image_short': args.data_shape, 'max_size': 1000, 'nms_thresh': 0.5, |
| 27 | + 'nms_topk': -1, 'min_stage': 2, 'max_stage': 6, 'post_nms': -1, |
| 28 | + 'roi_mode': 'align', 'roi_size': (7, 7), 'strides': (4, 8, 16, 32, 64), |
| 29 | + 'clip': 4.14, 'rpn_channel': 256, 'anchor_scales': (2, 4, 8, 16, 32), |
| 30 | + 'anchor_aspect_ratio': (0.5, 1, 2), 'anchor_alloc_size': (384, 384), |
| 31 | + 'rpn_nms_thresh': 0.7, 'rpn_train_pre_nms': 12000, 'rpn_train_post_nms': 2000, |
| 32 | + 'rpn_test_pre_nms': 6000, 'rpn_test_post_nms': 1000, 'rpn_min_size': 1, |
| 33 | + 'per_device_batch_size': args.batch_size // args.num_gpus, 'num_sample': 512, |
| 34 | + 'rcnn_pos_iou_thresh': 0.5, 'rcnn_pos_ratio': 0.25, 'max_num_gt': 100, |
| 35 | + 'custom_model': True, 'no_pretrained_base': True, 'num_fpn_filters': 256, |
| 36 | + 'num_box_head_conv': 4, 'num_box_head_conv_filters': 256, 'amp': False, |
| 37 | + 'num_box_head_dense_filters': 1024, 'image_max_size': 1333, 'kv_store': 'nccl', |
| 38 | + 'anchor_base_size': 16, 'rcnn_num_samples': 512, 'rpn_smoothl1_rho': 0.001, |
| 39 | + 'rcnn_smoothl1_rho': 0.001, 'lr_warmup_factor': 1. / 3., 'lr_warmup': 500, |
| 40 | + 'executor_threads': 4, 'disable_hybridization': False, 'static_alloc': False} |
| 41 | + vars(args).update(kwargs) |
| 42 | + else: |
| 43 | + raise NotImplementedError(args.meta_arch, 'is not implemented.') |
| 44 | + |
| 45 | + if args.meta_arch == 'faster_rcnn': |
| 46 | + estimator = FasterRCNNEstimator(args, reporter=reporter) |
| 47 | + else: |
| 48 | + raise NotImplementedError('%s' % args.meta_arch) |
| 49 | + |
| 50 | + # training |
| 51 | + estimator.fit() |
| 52 | + |
| 53 | + if args.final_fit: |
| 54 | + return {'model_params': collect_params(estimator.net)} |
| 55 | + |
| 56 | + |
| 57 | +class ObjectDetection(BaseTask): |
| 58 | + def __init__(self, config, logger=None): |
| 59 | + super(ObjectDetection, self).__init__() |
| 60 | + self._logger = logger if logger is not None else logging.getLogger(__name__) |
| 61 | + self._config = config |
| 62 | + nthreads_per_trial = get_cpu_count() if self._config.nthreads_per_trial > get_cpu_count() \ |
| 63 | + else self._config.nthreads_per_trial |
| 64 | + if self._config.ngpus_per_trial > get_gpu_count(): |
| 65 | + self._logger.warning( |
| 66 | + "The number of requested GPUs is greater than the number of available GPUs.") |
| 67 | + ngpus_per_trial = get_gpu_count() if self._config.ngpus_per_trial > get_gpu_count() \ |
| 68 | + else self._config.ngpus_per_trial |
| 69 | + |
| 70 | + _train_object_detection.register_args( |
| 71 | + meta_arch=self._config.meta_arch, dataset=self._config.dataset, net=self._config.net, |
| 72 | + lr=self._config.lr, loss=self._config.loss, num_gpus=self._config.ngpus_per_trial, |
| 73 | + batch_size=self._config.batch_size, split_ratio=self._config.split_ratio, |
| 74 | + epochs=self._config.epochs, num_workers=self._config.nthreads_per_trial, |
| 75 | + hybridize=self._config.hybridize, verbose=self._config.verbose, final_fit=False, |
| 76 | + seed=self._config.seed, data_shape=self._config.data_shape, start_epoch=0, |
| 77 | + transfer=self._config.transfer, lr_mode=self._config.lr_mode, |
| 78 | + lr_decay=self._config.lr_decay, lr_decay_period=self._config.lr_decay_period, |
| 79 | + lr_decay_epoch=self._config.lr_decay_epoch, warmup_lr=self._config.warmup_lr, |
| 80 | + warmup_epochs=self._config.warmup_epochs, warmup_iters=self._config.warmup_iters, |
| 81 | + warmup_factor=self._config.warmup_factor, momentum=self._config.momentum, |
| 82 | + wd=self._config.wd, log_interval=self._config.log_interval, |
| 83 | + save_prefix=self._config.save_prefix, save_interval=self._config.save_interval, |
| 84 | + val_interval=self._config.val_interval, num_samples=self._config.num_samples, |
| 85 | + no_random_shape=self._config.no_random_shape, no_wd=self._config.no_wd, |
| 86 | + mixup=self._config.mixup, no_mixup_epochs=self._config.no_mixup_epochs, |
| 87 | + label_smooth=self._config.label_smooth, resume=self._config.resume, |
| 88 | + syncbn=self._config.syncbn, reuse_pred_weights=self._config.reuse_pred_weights, |
| 89 | + horovod=self._config.horovod, gpus='0,1,2,3,4,5,6,7', use_fpn=True, |
| 90 | + norm_layer='syncbn' if self._config.syncbn else None, |
| 91 | + ) |
| 92 | + |
| 93 | + self._config.scheduler_options = { |
| 94 | + 'resource': {'num_cpus': nthreads_per_trial, 'num_gpus': ngpus_per_trial}, |
| 95 | + 'checkpoint': self._config.checkpoint, |
| 96 | + 'num_trials': self._config.num_trials, |
| 97 | + 'time_out': self._config.time_limits, |
| 98 | + 'resume': self._config.resume, |
| 99 | + 'visualizer': self._config.visualizer, |
| 100 | + 'time_attr': 'epoch', |
| 101 | + 'reward_attr': 'map_reward', |
| 102 | + 'dist_ip_addrs': self._config.dist_ip_addrs, |
| 103 | + 'searcher': self._config.search_strategy, |
| 104 | + 'search_options': self._config.search_options, |
| 105 | + } |
| 106 | + if self._config.search_strategy == 'hyperband': |
| 107 | + self._config.scheduler_options.update({ |
| 108 | + 'searcher': 'random', |
| 109 | + 'max_t': self._config.epochs, |
| 110 | + 'grace_period': self._config.grace_period if self._config.grace_period |
| 111 | + else self._config.epochs // 4}) |
| 112 | + |
| 113 | + def fit(self): |
| 114 | + results = self.run_fit(_train_object_detection, self._config.search_strategy, |
| 115 | + self._config.scheduler_options) |
| 116 | + self._logger.info(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> finish model fitting") |
| 117 | + best_config = sample_config(_train_object_detection.args, results['best_config']) |
| 118 | + self._logger.info('The best config: {}'.format(results['best_config'])) |
| 119 | + |
| 120 | + estimator = FasterRCNNEstimator(best_config) |
| 121 | + estimator.load_parameters(results.pop('model_params')) |
| 122 | + return estimator |
0 commit comments