Skip to content

Commit ba60a76

Browse files
committed
hotfix: merge two buffer (mentioned in bytedance#285)
1 parent a692fea commit ba60a76

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

byteps/mxnet/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,16 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
217217
self.root_rank = root_rank
218218
self._intra_compressors = {}
219219
for i, param in enumerate(self._params):
220-
byteps_declare_tensor("parameter_" + str(i))
221220
self._intra_compressors[param.name] = type(self._intra_compressor)(
222221
**self._intra_compressor.__dict__)
223222
if param.grad_req != 'null':
224223
byteps_params = dict(
225224
filter(lambda attr: attr[0].startswith(
226225
"byteps_",), param.__dict__.items())
227226
)
228-
byteps_declare_tensor("gradient_" + str(i), **byteps_params)
227+
byteps_declare_tensor("tensor_" + str(i), **byteps_params)
228+
else:
229+
byteps_declare_tensor("tensor_" + str(i))
229230

230231
def _register_compressor(self, params, optimizer_params, compression_params):
231232
"""Register compressor for BytePS
@@ -313,7 +314,7 @@ def _allreduce_grads(self):
313314
compressed, ctx = self._intra_compressors[param.name].compress(
314315
param._grad[0])
315316
byteps_push_pull(compressed, is_average=False,
316-
name="gradient_" + str(i), priority=-i)
317+
name="tensor_" + str(i), priority=-i)
317318
param._grad[0][:] = self._intra_compressors[param.name].decompress(
318319
compressed, ctx, x=param._data[0])
319320

@@ -328,8 +329,8 @@ def _init_params(self):
328329

329330
compressed, ctx = self._intra_compressors[param.name].compress(
330331
param_arrays[0])
331-
byteps_push_pull(param_arrays[0], version=0, priority=0,
332-
name="parameter_" + str(idx), is_average=False)
332+
byteps_push_pull(compressed, version=0, priority=0,
333+
name="tensor_" + str(idx), is_average=False)
333334
param_arrays[0][:] = self._intra_compressors[param.name].decompress(
334335
compressed, ctx, x=param._data[0])
335336

0 commit comments

Comments
 (0)