Skip to content

Commit ee9719a

Browse files
authored
Estimator rcnn (#1366)
* move rcnn forward backward task to model zoo * revert #1249 * fix * fix * docstring * fix style * add docs * faster rcnn estimator * refactor * move dataset to init * lint * merge * disable sacred config for now * logger fix * fix fit * autogluon integration * fix small bug. training working * lint * sacred config for faster rcnn * add config docs * move all logging into base estimator logdir
1 parent ad4ff10 commit ee9719a

File tree

6 files changed

+197
-185
lines changed

6 files changed

+197
-185
lines changed

gluoncv/auto/estimators/base_estimator.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _apply(cls):
4747
return _apply
4848

4949

50-
class DotDict(dict):
50+
class ConfigDict(dict):
5151
MARKER = object()
5252
"""The view of a config dict where keys can be accessed like attribute, it also prevents
5353
naive modifications to the key-values.
@@ -72,6 +72,7 @@ def __init__(self, value=None):
7272
self.__setitem__(key, value[key])
7373
else:
7474
raise TypeError('expected dict')
75+
self.freeze()
7576

7677
def freeze(self):
7778
self.__dict__['_freeze'] = True
@@ -87,16 +88,16 @@ def __setitem__(self, key, value):
8788
msg = ('You are trying to modify the config to "{}={}" after initialization, '
8889
' this may result in unpredictable behaviour'.format(key, value))
8990
warnings.warn(msg)
90-
if isinstance(value, dict) and not isinstance(value, DotDict):
91-
value = DotDict(value)
92-
super(DotDict, self).__setitem__(key, value)
91+
if isinstance(value, dict) and not isinstance(value, ConfigDict):
92+
value = ConfigDict(value)
93+
super(ConfigDict, self).__setitem__(key, value)
9394

9495
def __getitem__(self, key):
95-
found = self.get(key, DotDict.MARKER)
96-
if found is DotDict.MARKER:
97-
found = DotDict()
98-
super(DotDict, self).__setitem__(key, found)
99-
if isinstance(found, DotDict):
96+
found = self.get(key, ConfigDict.MARKER)
97+
if found is ConfigDict.MARKER:
98+
found = ConfigDict()
99+
super(ConfigDict, self).__setitem__(key, found)
100+
if isinstance(found, ConfigDict):
100101
found.__dict__['_freeze'] = self.__dict__['_freeze']
101102
return found
102103

@@ -120,7 +121,7 @@ def __init__(self, config, logger=None, reporter=None, name=None):
120121

121122
# try to auto resume
122123
prefix = None
123-
if r.config.get('train_hp', {}).get('auto_resume', True):
124+
if r.config.get('train', {}).get('auto_resume', True):
124125
exists = [d for d in os.listdir(self._logdir) if d.startswith(name)]
125126
# latest timestamp
126127
exists = sorted(exists)
@@ -129,7 +130,7 @@ def __init__(self, config, logger=None, reporter=None, name=None):
129130
if prefix:
130131
self._ex.add_config(os.path.join(self._logdir, prefix, 'config.yaml'))
131132
r2 = self._ex.run('_get_config', options={'--loglevel': 50})
132-
if _compare_config(r2.config, r.config):
133+
if _compare_config(r2.config, r.config):
133134
self._logger.info('Auto resume detected previous run: {}'.format(prefix))
134135
r.config['seed'] = r2.config['seed']
135136
else:
@@ -147,7 +148,7 @@ def __init__(self, config, logger=None, reporter=None, name=None):
147148
save_config(r.config, self._logger, config_file)
148149

149150
# dot access for config
150-
self._cfg = DotDict(r.config)
151+
self._cfg = ConfigDict(r.config)
151152
self._cfg.freeze()
152153
_random.seed(self._cfg.seed)
153154

gluoncv/auto/estimators/center_net.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010
from ...data.transforms.presets.center_net import CenterNetDefaultTrainTransform
1111
from ...data.transforms.presets.center_net import CenterNetDefaultValTransform, get_post_transform
1212
from ...data.batchify import Tuple, Stack, Pad
13-
from ...utils.metrics.accuracy import Accuracy
1413
from ...utils import LRScheduler, LRSequential
1514
from ...model_zoo.center_net import get_center_net, get_base_network
1615
from ...loss import MaskedL1Loss, HeatmapFocalLoss
1716

1817
from ..data.coco_detection import coco_detection, load_coco_detection
19-
from .base_estimator import BaseEstimator, set_default, DotDict
20-
from .common import train_hp, valid_hp
18+
from .base_estimator import BaseEstimator, set_default, ConfigDict
19+
from .common import train, validation
2120

2221
from sacred import Experiment, Ingredient
2322

@@ -41,7 +40,7 @@ def center_net_default():
4140
center_reg_weight = 1.0 # Center regression loss weight
4241
data_shape = (512, 512)
4342

44-
@train_hp.config
43+
@train.config
4544
def update_train_config():
4645
gpus = (0, 1, 2, 3, 4, 5, 6, 7)
4746
pretrained_base = True # whether load the imagenet pre-trained base
@@ -55,15 +54,15 @@ def update_train_config():
5554
warmup_epochs = 0 # number of warmup epochs
5655

5756

58-
@valid_hp.config
57+
@validation.config
5958
def update_valid_config():
6059
flip_test = True # use flip in validation test
6160
nms_thresh = 0 # 0 means disable
6261
nms_topk = 400 # pre nms topk
6362
post_nms = 100 # post nms topk
6463

6564
ex = Experiment('center_net_default',
66-
ingredients=[coco_detection, train_hp, valid_hp, center_net])
65+
ingredients=[coco_detection, train, validation, center_net])
6766

6867
@ex.config
6968
def default_configs():
@@ -90,15 +89,15 @@ def __init__(self, config, logger=None, reporter=None):
9089
raise NotImplementedError
9190

9291
# network
93-
ctx = [mx.gpu(int(i)) for i in self._cfg.train_hp.gpus]
92+
ctx = [mx.gpu(int(i)) for i in self._cfg.train.gpus]
9493
ctx = ctx if ctx else [mx.cpu()]
9594
self._ctx = ctx
9695
net_name = '_'.join(('center_net', self._cfg.center_net.base_network, self._cfg.dataset))
9796
heads = OrderedDict([
9897
('heatmap', {'num_output': train_dataset.num_class, 'bias': self._cfg.center_net.heads.bias}),
9998
('wh', {'num_output': self._cfg.center_net.heads.wh_outputs}),
10099
('reg', {'num_output': self._cfg.center_net.heads.reg_outputs})])
101-
base_network = get_base_network(self._cfg.center_net.base_network, pretrained=self._cfg.train_hp.pretrained_base)
100+
base_network = get_base_network(self._cfg.center_net.base_network, pretrained=self._cfg.train.pretrained_base)
102101
net = get_center_net(self._cfg.center_net.base_network,
103102
self._cfg.dataset,
104103
base_network=base_network,
@@ -108,8 +107,8 @@ def __init__(self, config, logger=None, reporter=None):
108107
scale=self._cfg.center_net.scale,
109108
topk=self._cfg.center_net.topk,
110109
norm_layer=gluon.nn.BatchNorm)
111-
if self._cfg.train_hp.resume.strip():
112-
net.load_parameters(self._cfg.train_hp.resume.strip())
110+
if self._cfg.train.resume.strip():
111+
net.load_parameters(self._cfg.train.resume.strip())
113112
elif os.path.isfile(os.path.join(self._logdir, 'latest.params')):
114113
net.load_parameters(os.path.join(self._logdir, 'latest.params'))
115114
else:
@@ -118,20 +117,20 @@ def __init__(self, config, logger=None, reporter=None):
118117
net.initialize()
119118

120119
# dataloader
121-
batch_size = self._cfg.train_hp.batch_size
120+
batch_size = self._cfg.train.batch_size
122121
width, height = self._cfg.center_net.data_shape
123122
num_class = len(train_dataset.classes)
124123
batchify_fn = Tuple([Stack() for _ in range(6)]) # stack image, cls_targets, box_targets
125124
train_loader = gluon.data.DataLoader(
126125
train_dataset.transform(CenterNetDefaultTrainTransform(
127126
width, height, num_class=num_class, scale_factor=net.scale)),
128127
batch_size, True, batchify_fn=batchify_fn, last_batch='rollover',
129-
num_workers=self._cfg.train_hp.num_workers)
128+
num_workers=self._cfg.train.num_workers)
130129
val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
131130
val_loader = gluon.data.DataLoader(
132131
val_dataset.transform(CenterNetDefaultValTransform(width, height)),
133-
self._cfg.valid_hp.batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep',
134-
num_workers=self._cfg.valid_hp.num_workers)
132+
self._cfg.validation.batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep',
133+
num_workers=self._cfg.validation.num_workers)
135134

136135
self._train_data = train_loader
137136
self._val_data = val_loader
@@ -140,25 +139,25 @@ def __init__(self, config, logger=None, reporter=None):
140139
# trainer
141140
self._net = net
142141
self._net.collect_params().reset_ctx(ctx)
143-
lr_decay = float(self._cfg.train_hp.lr_decay)
144-
lr_steps = sorted(self._cfg.train_hp.lr_decay_epoch)
145-
lr_decay_epoch = [e - self._cfg.train_hp.warmup_epochs for e in lr_steps]
146-
num_batches = len(train_dataset) // self._cfg.train_hp.batch_size
142+
lr_decay = float(self._cfg.train.lr_decay)
143+
lr_steps = sorted(self._cfg.train.lr_decay_epoch)
144+
lr_decay_epoch = [e - self._cfg.train.warmup_epochs for e in lr_steps]
145+
num_batches = len(train_dataset) // self._cfg.train.batch_size
147146
lr_scheduler = LRSequential([
148-
LRScheduler('linear', base_lr=0, target_lr=self._cfg.train_hp.lr,
149-
nepochs=self._cfg.train_hp.warmup_epochs, iters_per_epoch=num_batches),
150-
LRScheduler(self._cfg.train_hp.lr_mode, base_lr=self._cfg.train_hp.lr,
151-
nepochs=self._cfg.train_hp.epochs - self._cfg.train_hp.warmup_epochs,
147+
LRScheduler('linear', base_lr=0, target_lr=self._cfg.train.lr,
148+
nepochs=self._cfg.train.warmup_epochs, iters_per_epoch=num_batches),
149+
LRScheduler(self._cfg.train.lr_mode, base_lr=self._cfg.train.lr,
150+
nepochs=self._cfg.train.epochs - self._cfg.train.warmup_epochs,
152151
iters_per_epoch=num_batches,
153152
step_epoch=lr_decay_epoch,
154-
step_factor=self._cfg.train_hp.lr_decay, power=2),
153+
step_factor=self._cfg.train.lr_decay, power=2),
155154
])
156155

157156
for k, v in self._net.collect_params('.*bias').items():
158157
v.wd_mult = 0.0
159158
self._trainer = gluon.Trainer(
160159
self._net.collect_params(), 'adam',
161-
{'learning_rate': self._cfg.train_hp.lr, 'wd': self._cfg.train_hp.wd,
160+
{'learning_rate': self._cfg.train.lr, 'wd': self._cfg.train.wd,
162161
'lr_scheduler': lr_scheduler})
163162

164163
self._save_prefix = os.path.join(self._logdir, net_name)
@@ -172,7 +171,7 @@ def _fit(self):
172171
wh_metric = mx.metric.Loss('WHL1')
173172
center_reg_metric = mx.metric.Loss('CenterRegL1')
174173

175-
for epoch in range(self._cfg.train_hp.start_epoch, self._cfg.train_hp.epochs):
174+
for epoch in range(self._cfg.train.start_epoch, self._cfg.train.epochs):
176175
wh_metric.reset()
177176
center_reg_metric.reset()
178177
heatmap_loss_metric.reset()
@@ -183,7 +182,7 @@ def _fit(self):
183182
for i, batch in enumerate(self._train_data):
184183
split_data = [gluon.utils.split_and_load(batch[ind], ctx_list=self._ctx, batch_axis=0) for ind in range(6)]
185184
data, heatmap_targets, wh_targets, wh_masks, center_reg_targets, center_reg_masks = split_data
186-
batch_size = self._cfg.train_hp.batch_size
185+
batch_size = self._cfg.train.batch_size
187186
with autograd.record():
188187
sum_losses = []
189188
heatmap_losses = []
@@ -206,7 +205,7 @@ def _fit(self):
206205
heatmap_loss_metric.update(0, heatmap_losses)
207206
wh_metric.update(0, wh_losses)
208207
center_reg_metric.update(0, center_reg_losses)
209-
if self._cfg.train_hp.log_interval and not (i + 1) % self._cfg.train_hp.log_interval:
208+
if self._cfg.train.log_interval and not (i + 1) % self._cfg.train.log_interval:
210209
name2, loss2 = wh_metric.get()
211210
name3, loss3 = center_reg_metric.get()
212211
name4, loss4 = heatmap_loss_metric.get()
@@ -219,22 +218,22 @@ def _fit(self):
219218
name4, loss4 = heatmap_loss_metric.get()
220219
self._log.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
221220
epoch, (time.time()-tic), name2, loss2, name3, loss3, name4, loss4))
222-
if (epoch % self._cfg.valid_hp.interval == 0) or \
223-
(self._cfg.train_hp.save_interval and epoch % self._cfg.train_hp.save_interval == 0) or \
224-
(epoch == self._cfg.train_hp.epochs - 1):
221+
if (epoch % self._cfg.validation.interval == 0) or \
222+
(self._cfg.train.save_interval and epoch % self._cfg.train.save_interval == 0) or \
223+
(epoch == self._cfg.train.epochs - 1):
225224
# consider reduce the frequency of validation to save time
226225
map_name, mean_ap = self._evaluate()
227226
val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
228227
self._log.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
229228
current_map = float(mean_ap[-1])
230229
else:
231230
current_map = 0.
232-
save_params(current_map, epoch, self._cfg.train_hp.save_interval, self._save_prefix)
231+
save_params(current_map, epoch, self._cfg.train.save_interval, self._save_prefix)
233232

234233
def _evaluate(self):
235234
"""Test on validation dataset."""
236235
self._eval_metric.reset()
237-
self._net.flip_test = self._cfg.valid_hp.flip_test
236+
self._net.flip_test = self._cfg.validation.flip_test
238237
mx.nd.waitall()
239238
self._net.hybridize()
240239
for batch in self._val_data:

gluoncv/auto/estimators/common.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
"""Common training hyperparameters"""
22
from sacred import Ingredient
33

4-
train_hp = Ingredient('train_hp')
4+
train = Ingredient('train')
55

6-
@train_hp.config
6+
7+
@train.config
78
def cfg():
8-
gpus = (0, 1, 2, 3) # gpu individual ids, not necessarily consecutive
9-
num_workers = 16 # cpu workers, the larger the more processes used
9+
gpus = (0, 1, 2, 3) # gpu individual ids, not necessarily consecutive
10+
num_workers = 16 # cpu workers, the larger the more processes used
1011
batch_size = 32
1112
epochs = 3
1213
resume = ''
13-
auto_resume = True # try to automatically resume last trial if config is default
14+
auto_resume = True # try to automatically resume last trial if config is default
1415
start_epoch = 0
15-
momentum = 0.9 # SGD momentum
16-
wd = 1e-4 # weight decay
17-
save_interval = 10 # Saving parameters epoch interval, best model will always be saved
18-
log_interval = 100 # logging interval
16+
momentum = 0.9 # SGD momentum
17+
wd = 1e-4 # weight decay
18+
save_interval = 10 # Saving parameters epoch interval, best model will always be saved
19+
log_interval = 100 # logging interval
20+
21+
22+
validation = Ingredient('validation')
1923

20-
valid_hp = Ingredient('valid_hp')
2124

22-
@valid_hp.config
25+
@validation.config
2326
def cfg():
24-
num_workers = 32 # cpu workers, the larger the more processes used
25-
batch_size = 32 # validation batch size
26-
interval = 10 # validation epoch interval, for slow validations
27+
num_workers = 32 # cpu workers, the larger the more processes used
28+
batch_size = 32 # validation batch size
29+
interval = 10 # validation epoch interval, for slow validations

gluoncv/auto/estimators/rcnn/default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def train_cfg():
8989
# Whether load the imagenet pre-trained base
9090
pretrained_base = True
9191
# Batch size during training
92-
batch_size = 16
92+
batch_size = 8
9393
# starting epoch
9494
start_epoch = 0
9595
# total epoch for training
@@ -195,3 +195,5 @@ def default_configs():
195195
kv_store = 'nccl'
196196
# Whether to disable hybridize the model. Memory usage and speed will decrese.
197197
disable_hybridization = False
198+
# Output directory for all training/validation artifacts.
199+
logdir = None

0 commit comments

Comments
 (0)