20
20
import os
21
21
import struct
22
22
import warnings
23
+ from functools import reduce
23
24
24
25
import mxnet as mx
25
26
import mxnet .ndarray as nd
@@ -218,8 +219,6 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
218
219
self ._intra_compressors = {}
219
220
for i , param in enumerate (self ._params ):
220
221
byteps_declare_tensor ("parameter_" + str (i ))
221
- self ._intra_compressors [param .name ] = type (self ._intra_compressor )(
222
- ** self ._intra_compressor .__dict__ )
223
222
if param .grad_req != 'null' :
224
223
byteps_params = dict (
225
224
filter (lambda attr : attr [0 ].startswith (
@@ -280,7 +279,7 @@ def _register_compressor(self, params, optimizer_params, compression_params):
280
279
if compression_params .get ("momentum" ):
281
280
# 1bit compressor use an additional momentum for weight decay
282
281
if compressor == "onebit" and "wd" in optimizer_params :
283
- intra_compressor = Compression .wdmom (
282
+ Compression . wdmom = Compression .wdmom (
284
283
intra_compressor , optimizer_params ["momentum" ], optimizer_params ["wd" ])
285
284
del optimizer_params ["wd" ]
286
285
@@ -316,6 +315,7 @@ def _allreduce_grads(self):
316
315
317
316
def _init_params (self ):
318
317
tensors = []
318
+ threshold = int (os .environ .get ("BYTEPS_MIN_COMPRESS_BYTES" , 65536 ))
319
319
for param in self ._params_to_init :
320
320
if param ._deferred_init :
321
321
tensors .append (param )
@@ -326,6 +326,15 @@ def _init_params(self):
326
326
if rank () != self .root_rank :
327
327
param_arrays [0 ].__imul__ (0 )
328
328
329
+ # register intra-node compressor
330
+ size = reduce (lambda x , y : x * y , param_arrays [0 ].shape )
331
+ if size >= threshold :
332
+ self ._intra_compressors [param .name ] = type (
333
+ Compression .wdmom )(** Compression .wdmom .__dict__ )
334
+ else :
335
+ self ._intra_compressors [param .name ] = type (
336
+ self ._intra_compressor )(** self ._intra_compressor .__dict__ )
337
+
329
338
compressed , ctx = self ._intra_compressors [param .name ].compress (
330
339
param_arrays [0 ])
331
340
byteps_push_pull (compressed , version = 0 , priority = 0 ,
0 commit comments