10
10
from ...data .transforms .presets .center_net import CenterNetDefaultTrainTransform
11
11
from ...data .transforms .presets .center_net import CenterNetDefaultValTransform , get_post_transform
12
12
from ...data .batchify import Tuple , Stack , Pad
13
- from ...utils .metrics .accuracy import Accuracy
14
13
from ...utils import LRScheduler , LRSequential
15
14
from ...model_zoo .center_net import get_center_net , get_base_network
16
15
from ...loss import MaskedL1Loss , HeatmapFocalLoss
17
16
18
17
from ..data .coco_detection import coco_detection , load_coco_detection
19
- from .base_estimator import BaseEstimator , set_default , DotDict
20
- from .common import train_hp , valid_hp
18
+ from .base_estimator import BaseEstimator , set_default , ConfigDict
19
+ from .common import train , validation
21
20
22
21
from sacred import Experiment , Ingredient
23
22
@@ -41,7 +40,7 @@ def center_net_default():
41
40
center_reg_weight = 1.0 # Center regression loss weight
42
41
data_shape = (512 , 512 )
43
42
44
- @train_hp .config
43
+ @train .config
45
44
def update_train_config ():
46
45
gpus = (0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 )
47
46
pretrained_base = True # whether load the imagenet pre-trained base
@@ -55,15 +54,15 @@ def update_train_config():
55
54
warmup_epochs = 0 # number of warmup epochs
56
55
57
56
58
- @valid_hp .config
57
+ @validation .config
59
58
def update_valid_config ():
60
59
flip_test = True # use flip in validation test
61
60
nms_thresh = 0 # 0 means disable
62
61
nms_topk = 400 # pre nms topk
63
62
post_nms = 100 # post nms topk
64
63
65
64
ex = Experiment ('center_net_default' ,
66
- ingredients = [coco_detection , train_hp , valid_hp , center_net ])
65
+ ingredients = [coco_detection , train , validation , center_net ])
67
66
68
67
@ex .config
69
68
def default_configs ():
@@ -90,15 +89,15 @@ def __init__(self, config, logger=None, reporter=None):
90
89
raise NotImplementedError
91
90
92
91
# network
93
- ctx = [mx .gpu (int (i )) for i in self ._cfg .train_hp .gpus ]
92
+ ctx = [mx .gpu (int (i )) for i in self ._cfg .train .gpus ]
94
93
ctx = ctx if ctx else [mx .cpu ()]
95
94
self ._ctx = ctx
96
95
net_name = '_' .join (('center_net' , self ._cfg .center_net .base_network , self ._cfg .dataset ))
97
96
heads = OrderedDict ([
98
97
('heatmap' , {'num_output' : train_dataset .num_class , 'bias' : self ._cfg .center_net .heads .bias }),
99
98
('wh' , {'num_output' : self ._cfg .center_net .heads .wh_outputs }),
100
99
('reg' , {'num_output' : self ._cfg .center_net .heads .reg_outputs })])
101
- base_network = get_base_network (self ._cfg .center_net .base_network , pretrained = self ._cfg .train_hp .pretrained_base )
100
+ base_network = get_base_network (self ._cfg .center_net .base_network , pretrained = self ._cfg .train .pretrained_base )
102
101
net = get_center_net (self ._cfg .center_net .base_network ,
103
102
self ._cfg .dataset ,
104
103
base_network = base_network ,
@@ -108,8 +107,8 @@ def __init__(self, config, logger=None, reporter=None):
108
107
scale = self ._cfg .center_net .scale ,
109
108
topk = self ._cfg .center_net .topk ,
110
109
norm_layer = gluon .nn .BatchNorm )
111
- if self ._cfg .train_hp .resume .strip ():
112
- net .load_parameters (self ._cfg .train_hp .resume .strip ())
110
+ if self ._cfg .train .resume .strip ():
111
+ net .load_parameters (self ._cfg .train .resume .strip ())
113
112
elif os .path .isfile (os .path .join (self ._logdir , 'latest.params' )):
114
113
net .load_parameters (os .path .join (self ._logdir , 'latest.params' ))
115
114
else :
@@ -118,20 +117,20 @@ def __init__(self, config, logger=None, reporter=None):
118
117
net .initialize ()
119
118
120
119
# dataloader
121
- batch_size = self ._cfg .train_hp .batch_size
120
+ batch_size = self ._cfg .train .batch_size
122
121
width , height = self ._cfg .center_net .data_shape
123
122
num_class = len (train_dataset .classes )
124
123
batchify_fn = Tuple ([Stack () for _ in range (6 )]) # stack image, cls_targets, box_targets
125
124
train_loader = gluon .data .DataLoader (
126
125
train_dataset .transform (CenterNetDefaultTrainTransform (
127
126
width , height , num_class = num_class , scale_factor = net .scale )),
128
127
batch_size , True , batchify_fn = batchify_fn , last_batch = 'rollover' ,
129
- num_workers = self ._cfg .train_hp .num_workers )
128
+ num_workers = self ._cfg .train .num_workers )
130
129
val_batchify_fn = Tuple (Stack (), Pad (pad_val = - 1 ))
131
130
val_loader = gluon .data .DataLoader (
132
131
val_dataset .transform (CenterNetDefaultValTransform (width , height )),
133
- self ._cfg .valid_hp .batch_size , False , batchify_fn = val_batchify_fn , last_batch = 'keep' ,
134
- num_workers = self ._cfg .valid_hp .num_workers )
132
+ self ._cfg .validation .batch_size , False , batchify_fn = val_batchify_fn , last_batch = 'keep' ,
133
+ num_workers = self ._cfg .validation .num_workers )
135
134
136
135
self ._train_data = train_loader
137
136
self ._val_data = val_loader
@@ -140,25 +139,25 @@ def __init__(self, config, logger=None, reporter=None):
140
139
# trainer
141
140
self ._net = net
142
141
self ._net .collect_params ().reset_ctx (ctx )
143
- lr_decay = float (self ._cfg .train_hp .lr_decay )
144
- lr_steps = sorted (self ._cfg .train_hp .lr_decay_epoch )
145
- lr_decay_epoch = [e - self ._cfg .train_hp .warmup_epochs for e in lr_steps ]
146
- num_batches = len (train_dataset ) // self ._cfg .train_hp .batch_size
142
+ lr_decay = float (self ._cfg .train .lr_decay )
143
+ lr_steps = sorted (self ._cfg .train .lr_decay_epoch )
144
+ lr_decay_epoch = [e - self ._cfg .train .warmup_epochs for e in lr_steps ]
145
+ num_batches = len (train_dataset ) // self ._cfg .train .batch_size
147
146
lr_scheduler = LRSequential ([
148
- LRScheduler ('linear' , base_lr = 0 , target_lr = self ._cfg .train_hp .lr ,
149
- nepochs = self ._cfg .train_hp .warmup_epochs , iters_per_epoch = num_batches ),
150
- LRScheduler (self ._cfg .train_hp .lr_mode , base_lr = self ._cfg .train_hp .lr ,
151
- nepochs = self ._cfg .train_hp .epochs - self ._cfg .train_hp .warmup_epochs ,
147
+ LRScheduler ('linear' , base_lr = 0 , target_lr = self ._cfg .train .lr ,
148
+ nepochs = self ._cfg .train .warmup_epochs , iters_per_epoch = num_batches ),
149
+ LRScheduler (self ._cfg .train .lr_mode , base_lr = self ._cfg .train .lr ,
150
+ nepochs = self ._cfg .train .epochs - self ._cfg .train .warmup_epochs ,
152
151
iters_per_epoch = num_batches ,
153
152
step_epoch = lr_decay_epoch ,
154
- step_factor = self ._cfg .train_hp .lr_decay , power = 2 ),
153
+ step_factor = self ._cfg .train .lr_decay , power = 2 ),
155
154
])
156
155
157
156
for k , v in self ._net .collect_params ('.*bias' ).items ():
158
157
v .wd_mult = 0.0
159
158
self ._trainer = gluon .Trainer (
160
159
self ._net .collect_params (), 'adam' ,
161
- {'learning_rate' : self ._cfg .train_hp .lr , 'wd' : self ._cfg .train_hp .wd ,
160
+ {'learning_rate' : self ._cfg .train .lr , 'wd' : self ._cfg .train .wd ,
162
161
'lr_scheduler' : lr_scheduler })
163
162
164
163
self ._save_prefix = os .path .join (self ._logdir , net_name )
@@ -172,7 +171,7 @@ def _fit(self):
172
171
wh_metric = mx .metric .Loss ('WHL1' )
173
172
center_reg_metric = mx .metric .Loss ('CenterRegL1' )
174
173
175
- for epoch in range (self ._cfg .train_hp .start_epoch , self ._cfg .train_hp .epochs ):
174
+ for epoch in range (self ._cfg .train .start_epoch , self ._cfg .train .epochs ):
176
175
wh_metric .reset ()
177
176
center_reg_metric .reset ()
178
177
heatmap_loss_metric .reset ()
@@ -183,7 +182,7 @@ def _fit(self):
183
182
for i , batch in enumerate (self ._train_data ):
184
183
split_data = [gluon .utils .split_and_load (batch [ind ], ctx_list = self ._ctx , batch_axis = 0 ) for ind in range (6 )]
185
184
data , heatmap_targets , wh_targets , wh_masks , center_reg_targets , center_reg_masks = split_data
186
- batch_size = self ._cfg .train_hp .batch_size
185
+ batch_size = self ._cfg .train .batch_size
187
186
with autograd .record ():
188
187
sum_losses = []
189
188
heatmap_losses = []
@@ -206,7 +205,7 @@ def _fit(self):
206
205
heatmap_loss_metric .update (0 , heatmap_losses )
207
206
wh_metric .update (0 , wh_losses )
208
207
center_reg_metric .update (0 , center_reg_losses )
209
- if self ._cfg .train_hp .log_interval and not (i + 1 ) % self ._cfg .train_hp .log_interval :
208
+ if self ._cfg .train .log_interval and not (i + 1 ) % self ._cfg .train .log_interval :
210
209
name2 , loss2 = wh_metric .get ()
211
210
name3 , loss3 = center_reg_metric .get ()
212
211
name4 , loss4 = heatmap_loss_metric .get ()
@@ -219,22 +218,22 @@ def _fit(self):
219
218
name4 , loss4 = heatmap_loss_metric .get ()
220
219
self ._log .info ('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}' .format (
221
220
epoch , (time .time ()- tic ), name2 , loss2 , name3 , loss3 , name4 , loss4 ))
222
- if (epoch % self ._cfg .valid_hp .interval == 0 ) or \
223
- (self ._cfg .train_hp .save_interval and epoch % self ._cfg .train_hp .save_interval == 0 ) or \
224
- (epoch == self ._cfg .train_hp .epochs - 1 ):
221
+ if (epoch % self ._cfg .validation .interval == 0 ) or \
222
+ (self ._cfg .train .save_interval and epoch % self ._cfg .train .save_interval == 0 ) or \
223
+ (epoch == self ._cfg .train .epochs - 1 ):
225
224
# consider reduce the frequency of validation to save time
226
225
map_name , mean_ap = self ._evaluate ()
227
226
val_msg = '\n ' .join (['{}={}' .format (k , v ) for k , v in zip (map_name , mean_ap )])
228
227
self ._log .info ('[Epoch {}] Validation: \n {}' .format (epoch , val_msg ))
229
228
current_map = float (mean_ap [- 1 ])
230
229
else :
231
230
current_map = 0.
232
- save_params (current_map , epoch , self ._cfg .train_hp .save_interval , self ._save_prefix )
231
+ save_params (current_map , epoch , self ._cfg .train .save_interval , self ._save_prefix )
233
232
234
233
def _evaluate (self ):
235
234
"""Test on validation dataset."""
236
235
self ._eval_metric .reset ()
237
- self ._net .flip_test = self ._cfg .valid_hp .flip_test
236
+ self ._net .flip_test = self ._cfg .validation .flip_test
238
237
mx .nd .waitall ()
239
238
self ._net .hybridize ()
240
239
for batch in self ._val_data :
0 commit comments