Skip to content

Commit f888c8d

Browse files
committed
1bit: rm wd mom
1 parent 7f66e90 commit f888c8d

File tree

2 files changed

+1
-45
lines changed

2 files changed

+1
-45
lines changed

byteps/mxnet/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,6 @@ def _register_compressor(self, params, optimizer_params, compression_params):
276276

277277
# change
278278
if compression_params.get("momentum"):
279-
# 1bit compressor use an additional momentum for weight decay
280-
if compressor == "onebit" and "wd" in optimizer_params:
281-
intra_compressor = Compression.wdmom(
282-
intra_compressor, optimizer_params["momentum"], optimizer_params["wd"])
283-
del optimizer_params["wd"]
284-
285279
del optimizer_params['momentum']
286280

287281
return intra_compressor
@@ -308,7 +302,7 @@ def _allreduce_grads(self):
308302
byteps_push_pull(compressed, is_average=False,
309303
name="gradient_" + str(i), priority=-i)
310304
param._grad[0] = self._intra_compressors[i].decompress(
311-
compressed, ctx, x=param._data[0])
305+
compressed, ctx)
312306

313307
def _init_params(self):
314308
tensors = []

byteps/mxnet/compression.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,41 +65,6 @@ def decompress(self, tensor, ctx, *args, **kwargs):
6565
return tensor_decompressed
6666

6767

68-
class WeightDecayMomentum(Compressor):
69-
"""For 1bit compression."""
70-
71-
def __init__(self, compressor, mu, wd, *args, **kwargs):
72-
self.compressor = compressor
73-
self.mom = None
74-
self.cache = None
75-
self.mu = mu
76-
self.wd = wd
77-
78-
def compress(self, tensor, *args, **kwargs):
79-
"""Returns the tensor unmodified."""
80-
return self.compressor.compress(tensor)
81-
82-
def decompress(self, tensor, ctx, *args, **kwargs):
83-
"""Returns the tensor added with additional momentum for wd
84-
m_t = \mu * m_{t-1} + wd * x_t
85-
x_{t+1} = x_t - \eta_t (tensor + \mu m_t + wd * x_t)
86-
"""
87-
if "x" not in kwargs:
88-
return self.compressor.decompress(tensor, ctx)
89-
90-
x = kwargs["x"]
91-
92-
if self.mom is None:
93-
self.mom = nd.zeros_like(tensor)
94-
self.cache = nd.zeros_like(tensor)
95-
96-
nd._internal._mul_scalar(x, self.wd, out=self.cache)
97-
self.mom += self.cache
98-
nd._internal._mul_scalar(self.mom, self.mu, out=self.mom)
99-
tensor += self.mom + self.cache
100-
return self.compressor.decompress(tensor, ctx)
101-
102-
10368
class Compression(object):
10469
"""Optional gradient compression algorithm used during push_pull."""
10570

@@ -109,9 +74,6 @@ class Compression(object):
10974
"""Compress all floating point gradients to 16-bit."""
11075
fp16 = FP16Compressor()
11176

112-
"""Additional Momentum for weight decay. This is only for 1bit. This is a wrapper."""
113-
wdmom = WeightDecayMomentum
114-
11577

11678
# if __name__ == "__main__":
11779
# x = WeightDecayMomentum(Compression.none, 0.9, 1e-4)

0 commit comments

Comments
 (0)