Skip to content

Commit c6464a9

Browse files
committed
Revert "hotfix: fix distributed initialization bytedance#285"
This reverts commit a692fea.
1 parent ba60a76 commit c6464a9

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

byteps/mxnet/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,15 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
217217
self.root_rank = root_rank
218218
self._intra_compressors = {}
219219
for i, param in enumerate(self._params):
220+
byteps_declare_tensor("parameter_" + str(i))
220221
self._intra_compressors[param.name] = type(self._intra_compressor)(
221222
**self._intra_compressor.__dict__)
222223
if param.grad_req != 'null':
223224
byteps_params = dict(
224225
filter(lambda attr: attr[0].startswith(
225226
"byteps_",), param.__dict__.items())
226227
)
227-
byteps_declare_tensor("tensor_" + str(i), **byteps_params)
228-
else:
229-
byteps_declare_tensor("tensor_" + str(i))
228+
byteps_declare_tensor("gradient_" + str(i), **byteps_params)
230229

231230
def _register_compressor(self, params, optimizer_params, compression_params):
232231
"""Register compressor for BytePS
@@ -314,7 +313,7 @@ def _allreduce_grads(self):
314313
compressed, ctx = self._intra_compressors[param.name].compress(
315314
param._grad[0])
316315
byteps_push_pull(compressed, is_average=False,
317-
name="tensor_" + str(i), priority=-i)
316+
name="gradient_" + str(i), priority=-i)
318317
param._grad[0][:] = self._intra_compressors[param.name].decompress(
319318
compressed, ctx, x=param._data[0])
320319

@@ -327,10 +326,13 @@ def _init_params(self):
327326
param_arrays = param._check_and_get(param._data, list)
328327
idx = self._param2idx[param.name]
329328

329+
if rank() != self.root_rank:
330+
param_arrays[0].__imul__(0)
331+
330332
compressed, ctx = self._intra_compressors[param.name].compress(
331333
param_arrays[0])
332334
byteps_push_pull(compressed, version=0, priority=0,
333-
name="tensor_" + str(idx), is_average=False)
335+
name="parameter_" + str(idx), is_average=False)
334336
param_arrays[0][:] = self._intra_compressors[param.name].decompress(
335337
compressed, ctx, x=param._data[0])
336338

0 commit comments

Comments
 (0)