Skip to content

Commit b382f99

Browse files
authored
cifar: update (#19)
1 parent 2f4c641 commit b382f99

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

example/mxnet/train_cifar100_byteps_gc.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,18 +179,24 @@ def main():
179179
save_dir = ''
180180
save_period = 0
181181

182+
# from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py
183+
CIFAR100_TRAIN_MEAN = [0.5070751592371323,
184+
0.48654887331495095, 0.4409178433670343]
185+
CIFAR100_TRAIN_STD = [0.2673342858792401,
186+
0.2564384629170883, 0.27615047132568404]
187+
182188
transform_train = transforms.Compose([
183189
gcv_transforms.RandomCrop(32, pad=4),
184190
transforms.RandomFlipLeftRight(),
185191
transforms.ToTensor(),
186-
transforms.Normalize([0.4914, 0.4822, 0.4465],
187-
[0.2023, 0.1994, 0.2010])
192+
transforms.Normalize(CIFAR100_TRAIN_MEAN,
193+
CIFAR100_TRAIN_STD)
188194
])
189195

190196
transform_test = transforms.Compose([
191197
transforms.ToTensor(),
192-
transforms.Normalize([0.4914, 0.4822, 0.4465],
193-
[0.2023, 0.1994, 0.2010])
198+
transforms.Normalize(CIFAR100_TRAIN_MEAN,
199+
CIFAR100_TRAIN_STD)
194200
])
195201

196202
def test(ctx, val_data):

0 commit comments

Comments
 (0)