Skip to content

Commit cdeb2a4

Browse files
authored
compression: update cifar100 training script (#15)
* cifar: update cifar script * cifar: update lr * cifar: add warmup * cifar: update parse * cifar: update * cifar: add log * cifar: fix typo * cifar: fix bug * cifar: fix lr * cifar: fix typo * cifar: update num samples
1 parent 9cef619 commit cdeb2a4

File tree

1 file changed

+52
-44
lines changed

1 file changed

+52
-44
lines changed

example/mxnet/train_cifar100_byteps_gc.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from gluoncv.data import transforms as gcv_transforms
66
from gluoncv.utils import makedirs, TrainingHistory
77
from gluoncv.model_zoo import get_model
8+
from gluoncv.utils import makedirs, LRSequential, LRScheduler
89
import gluoncv as gcv
910
from mxnet.gluon.data.vision import transforms
1011
from mxnet.gluon import nn
@@ -32,20 +33,24 @@ def parse_args():
3233
help='model to use. options are resnet and wrn. default is resnet.')
3334
parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int,
3435
help='number of preprocessing workers')
35-
parser.add_argument('--num-epochs', type=int, default=3,
36+
parser.add_argument('--num-epochs', type=int, default=200,
3637
help='number of training epochs.')
3738
parser.add_argument('--lr', type=float, default=0.1,
3839
help='learning rate. default is 0.1.')
3940
parser.add_argument('--momentum', type=float, default=0.9,
4041
help='momentum value for optimizer, default is 0.9.')
41-
parser.add_argument('--wd', type=float, default=0.0001,
42-
help='weight decay rate. default is 0.0001.')
42+
parser.add_argument('--wd', type=float, default=0.0005,
43+
help='weight decay rate. default is 0.0005.')
4344
parser.add_argument('--lr-decay', type=float, default=0.1,
4445
help='decay rate of learning rate. default is 0.1.')
4546
parser.add_argument('--lr-decay-period', type=int, default=0,
4647
help='period in epoch for learning rate decays. default is 0 (has no effect).')
47-
parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
48-
help='epochs at which learning rate decays. default is 40,60.')
48+
parser.add_argument('--lr-decay-epoch', type=str, default='100,150',
49+
help='epochs at which learning rate decays. default is 100,150.')
50+
parser.add_argument('--warmup-lr', type=float, default=0.0,
51+
help='starting warmup learning rate. default is 0.0.')
52+
parser.add_argument('--warmup-epochs', type=int, default=0,
53+
help='number of warmup epochs.')
4954
parser.add_argument('--drop-rate', type=float, default=0.0,
5055
help='dropout rate for wide resnet. default is 0.')
5156
parser.add_argument('--mode', type=str,
@@ -63,14 +68,16 @@ def parse_args():
6368
# additional arguments for gradient compression
6469
parser.add_argument('--compressor', type=str, default='',
6570
help='which compressor')
66-
parser.add_argument('--ef', type=str, default=None,
67-
help='enable error-feedback')
71+
parser.add_argument('--ef', type=str, default='',
72+
help='which error-feedback')
73+
parser.add_argument('--compress-momentum', type=str, default='',
74+
help='which compress momentum')
6875
parser.add_argument('--onebit-scaling', action='store_true', default=False,
6976
help='enable scaling for onebit compressor')
77+
parser.add_argument('--k', default=1, type=int,
78+
help='topk or randomk')
7079
parser.add_argument('--fp16-pushpull', action='store_true', default=False,
7180
help='use fp16 compression during pushpull')
72-
parser.add_argument('--compress-momentum', action='store_true', default=False,
73-
help='enable compress momentum.')
7481
opt = parser.parse_args()
7582
return opt
7683

@@ -104,6 +111,17 @@ def main():
104111
lr_decay = opt.lr_decay
105112
lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]
106113

114+
num_batches = 50000 // (opt.batch_size * nworker)
115+
lr_scheduler = LRSequential([
116+
LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr * nworker / bps.local_size(),
117+
nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
118+
LRScheduler('step', base_lr=opt.lr * nworker / bps.local_size(), target_lr=0,
119+
nepochs=opt.num_epochs - opt.warmup_epochs,
120+
iters_per_epoch=num_batches,
121+
step_epoch=lr_decay_epoch,
122+
step_factor=lr_decay, power=2)
123+
])
124+
107125
model_name = opt.model
108126
if model_name.startswith('cifar_wideresnet'):
109127
kwargs = {'classes': classes,
@@ -113,7 +131,11 @@ def main():
113131
net = get_model(model_name, **kwargs)
114132
if opt.resume_from:
115133
net.load_parameters(opt.resume_from, ctx=context)
116-
optimizer = 'sgd'
134+
135+
if opt.compressor:
136+
optimizer = 'sgd'
137+
else:
138+
optimizer = 'nag'
117139

118140
save_period = opt.save_period
119141
if opt.save_dir and save_period:
@@ -166,34 +188,26 @@ def train(epochs, ctx):
166188
batch_size=batch_size, shuffle=False, num_workers=num_workers)
167189

168190
params = net.collect_params()
169-
if opt.compressor:
170-
for _, param in params.items():
171-
setattr(param, "byteps_compressor_type", opt.compressor)
172-
if opt.ef:
173-
setattr(param, "byteps_error_feedback_type", opt.ef)
174-
if opt.onebit_scaling:
175-
setattr(
176-
param, "byteps_compressor_onebit_enable_scale", opt.onebit_scaling)
177-
if opt.compress_momentum:
178-
setattr(param, "byteps_momentum_type", "nesterov")
179-
setattr(param, "byteps_momentum_mu", opt.momentum)
180-
181-
optimizer_params = {'learning_rate': opt.lr *
182-
nworker, 'wd': opt.wd, 'momentum': opt.momentum}
183-
if opt.compress_momentum:
184-
del optimizer_params["momentum"]
185-
186-
compression = bps.Compression.fp16 if opt.fp16_pushpull else bps.Compression.none
191+
192+
compression_params = {
193+
"compressor": opt.compressor,
194+
"ef": opt.ef,
195+
"momentum": opt.compress_momentum,
196+
"scaling": opt.onebit_scaling,
197+
"k": opt.k
198+
}
199+
200+
optimizer_params = {'lr_scheduler': lr_scheduler,
201+
'wd': opt.wd, 'momentum': opt.momentum}
202+
187203
trainer = bps.DistributedTrainer(params,
188-
optimizer, optimizer_params, compression=compression)
204+
optimizer, optimizer_params, compression_params=compression_params)
189205
metric = mx.metric.Accuracy()
190206
train_metric = mx.metric.Accuracy()
191207
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
192208
train_history = TrainingHistory(['training-error', 'validation-error'])
193209

194210
iteration = 0
195-
lr_decay_count = 0
196-
197211
best_val_score = 0
198212

199213
for epoch in range(epochs):
@@ -202,11 +216,6 @@ def train(epochs, ctx):
202216
metric.reset()
203217
train_loss = 0
204218
num_batch = len(train_data)
205-
alpha = 1
206-
207-
if epoch == lr_decay_epoch[lr_decay_count]:
208-
trainer.set_learning_rate(trainer.learning_rate*lr_decay)
209-
lr_decay_count += 1
210219

211220
for i, batch in enumerate(train_data):
212221
data = gluon.utils.split_and_load(
@@ -230,16 +239,15 @@ def train(epochs, ctx):
230239
name, acc = train_metric.get()
231240
throughput = int(batch_size * nworker * i / (time.time() - tic))
232241

233-
if rank == 0:
234-
logger.info('[Epoch %d] training: %s=%f' %
235-
(epoch, name, acc))
236-
logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' %
237-
(epoch, throughput, time.time()-tic))
242+
logger.info('[Epoch %d] training: %s=%f' %
243+
(epoch, name, acc))
244+
logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' %
245+
(epoch, throughput, time.time()-tic, trainer.learning_rate))
238246

239247
name, val_acc = test(ctx, val_data)
240-
if rank == 0:
241-
logger.info('[Epoch %d] validation: %s=%f' %
242-
(epoch, name, val_acc))
248+
249+
logger.info('[Epoch %d] validation: %s=%f' %
250+
(epoch, name, val_acc))
243251

244252
train_history.update([1-acc, 1-val_acc])
245253
train_history.plot(save_path='%s/%s_history.png' %

0 commit comments

Comments
 (0)