@@ -217,15 +217,16 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
217
217
self .root_rank = root_rank
218
218
self ._intra_compressors = {}
219
219
for i , param in enumerate (self ._params ):
220
- byteps_declare_tensor ("parameter_" + str (i ))
221
220
self ._intra_compressors [param .name ] = type (self ._intra_compressor )(
222
221
** self ._intra_compressor .__dict__ )
223
222
if param .grad_req != 'null' :
224
223
byteps_params = dict (
225
224
filter (lambda attr : attr [0 ].startswith (
226
225
"byteps_" ,), param .__dict__ .items ())
227
226
)
228
- byteps_declare_tensor ("gradient_" + str (i ), ** byteps_params )
227
+ byteps_declare_tensor ("tensor_" + str (i ), ** byteps_params )
228
+ else :
229
+ byteps_declare_tensor ("tensor_" + str (i ))
229
230
230
231
def _register_compressor (self , params , optimizer_params , compression_params ):
231
232
"""Register compressor for BytePS
@@ -313,7 +314,7 @@ def _allreduce_grads(self):
313
314
compressed , ctx = self ._intra_compressors [param .name ].compress (
314
315
param ._grad [0 ])
315
316
byteps_push_pull (compressed , is_average = False ,
316
- name = "gradient_ " + str (i ), priority = - i )
317
+ name = "tensor_ " + str (i ), priority = - i )
317
318
param ._grad [0 ][:] = self ._intra_compressors [param .name ].decompress (
318
319
compressed , ctx , x = param ._data [0 ])
319
320
@@ -328,8 +329,8 @@ def _init_params(self):
328
329
329
330
compressed , ctx = self ._intra_compressors [param .name ].compress (
330
331
param_arrays [0 ])
331
- byteps_push_pull (param_arrays [ 0 ] , version = 0 , priority = 0 ,
332
- name = "parameter_ " + str (idx ), is_average = False )
332
+ byteps_push_pull (compressed , version = 0 , priority = 0 ,
333
+ name = "tensor_ " + str (idx ), is_average = False )
333
334
param_arrays [0 ][:] = self ._intra_compressors [param .name ].decompress (
334
335
compressed , ctx , x = param ._data [0 ])
335
336
0 commit comments