5
5
from gluoncv .data import transforms as gcv_transforms
6
6
from gluoncv .utils import makedirs , TrainingHistory
7
7
from gluoncv .model_zoo import get_model
8
+ from gluoncv .utils import makedirs , LRSequential , LRScheduler
8
9
import gluoncv as gcv
9
10
from mxnet .gluon .data .vision import transforms
10
11
from mxnet .gluon import nn
@@ -32,20 +33,24 @@ def parse_args():
32
33
help = 'model to use. options are resnet and wrn. default is resnet.' )
33
34
parser .add_argument ('-j' , '--num-data-workers' , dest = 'num_workers' , default = 4 , type = int ,
34
35
help = 'number of preprocessing workers' )
35
- parser .add_argument ('--num-epochs' , type = int , default = 3 ,
36
+ parser .add_argument ('--num-epochs' , type = int , default = 200 ,
36
37
help = 'number of training epochs.' )
37
38
parser .add_argument ('--lr' , type = float , default = 0.1 ,
38
39
help = 'learning rate. default is 0.1.' )
39
40
parser .add_argument ('--momentum' , type = float , default = 0.9 ,
40
41
help = 'momentum value for optimizer, default is 0.9.' )
41
- parser .add_argument ('--wd' , type = float , default = 0.0001 ,
42
- help = 'weight decay rate. default is 0.0001 .' )
42
+ parser .add_argument ('--wd' , type = float , default = 0.0005 ,
43
+ help = 'weight decay rate. default is 0.0005 .' )
43
44
parser .add_argument ('--lr-decay' , type = float , default = 0.1 ,
44
45
help = 'decay rate of learning rate. default is 0.1.' )
45
46
parser .add_argument ('--lr-decay-period' , type = int , default = 0 ,
46
47
help = 'period in epoch for learning rate decays. default is 0 (has no effect).' )
47
- parser .add_argument ('--lr-decay-epoch' , type = str , default = '40,60' ,
48
- help = 'epochs at which learning rate decays. default is 40,60.' )
48
+ parser .add_argument ('--lr-decay-epoch' , type = str , default = '100,150' ,
49
+ help = 'epochs at which learning rate decays. default is 100,150.' )
50
+ parser .add_argument ('--warmup-lr' , type = float , default = 0.0 ,
51
+ help = 'starting warmup learning rate. default is 0.0.' )
52
+ parser .add_argument ('--warmup-epochs' , type = int , default = 0 ,
53
+ help = 'number of warmup epochs.' )
49
54
parser .add_argument ('--drop-rate' , type = float , default = 0.0 ,
50
55
help = 'dropout rate for wide resnet. default is 0.' )
51
56
parser .add_argument ('--mode' , type = str ,
@@ -63,14 +68,16 @@ def parse_args():
63
68
# additional arguments for gradient compression
64
69
parser .add_argument ('--compressor' , type = str , default = '' ,
65
70
help = 'which compressor' )
66
- parser .add_argument ('--ef' , type = str , default = None ,
67
- help = 'enable error-feedback' )
71
+ parser .add_argument ('--ef' , type = str , default = '' ,
72
+ help = 'which error-feedback' )
73
+ parser .add_argument ('--compress-momentum' , type = str , default = '' ,
74
+ help = 'which compress momentum' )
68
75
parser .add_argument ('--onebit-scaling' , action = 'store_true' , default = False ,
69
76
help = 'enable scaling for onebit compressor' )
77
+ parser .add_argument ('--k' , default = 1 , type = int ,
78
+ help = 'topk or randomk' )
70
79
parser .add_argument ('--fp16-pushpull' , action = 'store_true' , default = False ,
71
80
help = 'use fp16 compression during pushpull' )
72
- parser .add_argument ('--compress-momentum' , action = 'store_true' , default = False ,
73
- help = 'enable compress momentum.' )
74
81
opt = parser .parse_args ()
75
82
return opt
76
83
@@ -104,6 +111,17 @@ def main():
104
111
lr_decay = opt .lr_decay
105
112
lr_decay_epoch = [int (i ) for i in opt .lr_decay_epoch .split (',' )] + [np .inf ]
106
113
114
+ num_batches = 50000 // (opt .batch_size * nworker )
115
+ lr_scheduler = LRSequential ([
116
+ LRScheduler ('linear' , base_lr = opt .warmup_lr , target_lr = opt .lr * nworker / bps .local_size (),
117
+ nepochs = opt .warmup_epochs , iters_per_epoch = num_batches ),
118
+ LRScheduler ('step' , base_lr = opt .lr * nworker / bps .local_size (), target_lr = 0 ,
119
+ nepochs = opt .num_epochs - opt .warmup_epochs ,
120
+ iters_per_epoch = num_batches ,
121
+ step_epoch = lr_decay_epoch ,
122
+ step_factor = lr_decay , power = 2 )
123
+ ])
124
+
107
125
model_name = opt .model
108
126
if model_name .startswith ('cifar_wideresnet' ):
109
127
kwargs = {'classes' : classes ,
@@ -113,7 +131,11 @@ def main():
113
131
net = get_model (model_name , ** kwargs )
114
132
if opt .resume_from :
115
133
net .load_parameters (opt .resume_from , ctx = context )
116
- optimizer = 'sgd'
134
+
135
+ if opt .compressor :
136
+ optimizer = 'sgd'
137
+ else :
138
+ optimizer = 'nag'
117
139
118
140
save_period = opt .save_period
119
141
if opt .save_dir and save_period :
@@ -166,34 +188,26 @@ def train(epochs, ctx):
166
188
batch_size = batch_size , shuffle = False , num_workers = num_workers )
167
189
168
190
params = net .collect_params ()
169
- if opt .compressor :
170
- for _ , param in params .items ():
171
- setattr (param , "byteps_compressor_type" , opt .compressor )
172
- if opt .ef :
173
- setattr (param , "byteps_error_feedback_type" , opt .ef )
174
- if opt .onebit_scaling :
175
- setattr (
176
- param , "byteps_compressor_onebit_enable_scale" , opt .onebit_scaling )
177
- if opt .compress_momentum :
178
- setattr (param , "byteps_momentum_type" , "nesterov" )
179
- setattr (param , "byteps_momentum_mu" , opt .momentum )
180
-
181
- optimizer_params = {'learning_rate' : opt .lr *
182
- nworker , 'wd' : opt .wd , 'momentum' : opt .momentum }
183
- if opt .compress_momentum :
184
- del optimizer_params ["momentum" ]
185
-
186
- compression = bps .Compression .fp16 if opt .fp16_pushpull else bps .Compression .none
191
+
192
+ compression_params = {
193
+ "compressor" : opt .compressor ,
194
+ "ef" : opt .ef ,
195
+ "momentum" : opt .compress_momentum ,
196
+ "scaling" : opt .onebit_scaling ,
197
+ "k" : opt .k
198
+ }
199
+
200
+ optimizer_params = {'lr_scheduler' : lr_scheduler ,
201
+ 'wd' : opt .wd , 'momentum' : opt .momentum }
202
+
187
203
trainer = bps .DistributedTrainer (params ,
188
- optimizer , optimizer_params , compression = compression )
204
+ optimizer , optimizer_params , compression_params = compression_params )
189
205
metric = mx .metric .Accuracy ()
190
206
train_metric = mx .metric .Accuracy ()
191
207
loss_fn = gluon .loss .SoftmaxCrossEntropyLoss ()
192
208
train_history = TrainingHistory (['training-error' , 'validation-error' ])
193
209
194
210
iteration = 0
195
- lr_decay_count = 0
196
-
197
211
best_val_score = 0
198
212
199
213
for epoch in range (epochs ):
@@ -202,11 +216,6 @@ def train(epochs, ctx):
202
216
metric .reset ()
203
217
train_loss = 0
204
218
num_batch = len (train_data )
205
- alpha = 1
206
-
207
- if epoch == lr_decay_epoch [lr_decay_count ]:
208
- trainer .set_learning_rate (trainer .learning_rate * lr_decay )
209
- lr_decay_count += 1
210
219
211
220
for i , batch in enumerate (train_data ):
212
221
data = gluon .utils .split_and_load (
@@ -230,16 +239,15 @@ def train(epochs, ctx):
230
239
name , acc = train_metric .get ()
231
240
throughput = int (batch_size * nworker * i / (time .time () - tic ))
232
241
233
- if rank == 0 :
234
- logger .info ('[Epoch %d] training: %s=%f' %
235
- (epoch , name , acc ))
236
- logger .info ('[Epoch %d] speed: %d samples/sec\t time cost: %f' %
237
- (epoch , throughput , time .time ()- tic ))
242
+ logger .info ('[Epoch %d] training: %s=%f' %
243
+ (epoch , name , acc ))
244
+ logger .info ('[Epoch %d] speed: %d samples/sec\t time cost: %f lr=%f' %
245
+ (epoch , throughput , time .time ()- tic , trainer .learning_rate ))
238
246
239
247
name , val_acc = test (ctx , val_data )
240
- if rank == 0 :
241
- logger .info ('[Epoch %d] validation: %s=%f' %
242
- (epoch , name , val_acc ))
248
+
249
+ logger .info ('[Epoch %d] validation: %s=%f' %
250
+ (epoch , name , val_acc ))
243
251
244
252
train_history .update ([1 - acc , 1 - val_acc ])
245
253
train_history .plot (save_path = '%s/%s_history.png' %
0 commit comments