@@ -217,16 +217,15 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
217
217
self .root_rank = root_rank
218
218
self ._intra_compressors = {}
219
219
for i , param in enumerate (self ._params ):
220
+ byteps_declare_tensor ("parameter_" + str (i ))
220
221
self ._intra_compressors [param .name ] = type (self ._intra_compressor )(
221
222
** self ._intra_compressor .__dict__ )
222
223
if param .grad_req != 'null' :
223
224
byteps_params = dict (
224
225
filter (lambda attr : attr [0 ].startswith (
225
226
"byteps_" ,), param .__dict__ .items ())
226
227
)
227
- byteps_declare_tensor ("tensor_" + str (i ), ** byteps_params )
228
- else :
229
- byteps_declare_tensor ("tensor_" + str (i ))
228
+ byteps_declare_tensor ("gradient_" + str (i ), ** byteps_params )
230
229
231
230
def _register_compressor (self , params , optimizer_params , compression_params ):
232
231
"""Register compressor for BytePS
@@ -314,7 +313,7 @@ def _allreduce_grads(self):
314
313
compressed , ctx = self ._intra_compressors [param .name ].compress (
315
314
param ._grad [0 ])
316
315
byteps_push_pull (compressed , is_average = False ,
317
- name = "tensor_ " + str (i ), priority = - i )
316
+ name = "gradient_ " + str (i ), priority = - i )
318
317
param ._grad [0 ][:] = self ._intra_compressors [param .name ].decompress (
319
318
compressed , ctx , x = param ._data [0 ])
320
319
@@ -327,10 +326,13 @@ def _init_params(self):
327
326
param_arrays = param ._check_and_get (param ._data , list )
328
327
idx = self ._param2idx [param .name ]
329
328
329
+ if rank () != self .root_rank :
330
+ param_arrays [0 ].__imul__ (0 )
331
+
330
332
compressed , ctx = self ._intra_compressors [param .name ].compress (
331
333
param_arrays [0 ])
332
334
byteps_push_pull (compressed , version = 0 , priority = 0 ,
333
- name = "tensor_ " + str (idx ), is_average = False )
335
+ name = "parameter_ " + str (idx ), is_average = False )
334
336
param_arrays [0 ][:] = self ._intra_compressors [param .name ].decompress (
335
337
compressed , ctx , x = param ._data [0 ])
336
338
0 commit comments