Skip to content

Commit 11a7ec0

Browse files
authored
1bit: don't do wd mom for uncompressed gradients (#60)
* 1bit: update * 1bit: test * 1bit: register wdmom
1 parent 7dc8d7f commit 11a7ec0

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

byteps/common/compressor/impl/onebit.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
#include <cstring>
1717

18-
#include "onebit.h"
1918
#include "../compressor_registry.h"
19+
#include "onebit.h"
2020

2121
namespace byteps {
2222
namespace common {
@@ -51,10 +51,11 @@ tensor_t OnebitCompressor::CompressImpl(index_t* dst, const scalar_t* src,
5151

5252
#pragma omp parallel for simd
5353
for (size_t i = 0; i < chunk_len; ++i) {
54-
index_t x = src[i * PACKING_SIZE] < 0;
54+
size_t idx = i * PACKING_SIZE;
55+
index_t x = src[idx] < 0;
5556
for (size_t j = 1; j < PACKING_SIZE; ++j) {
5657
x <<= 1;
57-
x |= src[i * PACKING_SIZE + j] < 0;
58+
x |= src[idx + j] < 0;
5859
}
5960
dst[i] = x;
6061
}
@@ -90,9 +91,10 @@ tensor_t OnebitCompressor::DecompressImpl(scalar_t* dst, const index_t* src,
9091
#pragma omp parallel for simd
9192
for (int i = chunk_len - 1; i >= 0; --i) {
9293
index_t x = ptr[i];
94+
size_t idx = i * PACKING_SIZE;
9395
for (int j = PACKING_SIZE - 1; j >= 0; --j) {
9496
int sign = 1 - ((x & 0x01) << 1);
95-
dst[i * PACKING_SIZE + j] = sign * scale;
97+
dst[idx + j] = sign * scale;
9698
x >>= 1;
9799
}
98100
}
@@ -123,10 +125,10 @@ void OnebitCompressor::FastUpdateErrorImpl(scalar_t* error, scalar_t* corrected,
123125
#pragma omp parallel for simd
124126
for (int i = chunk_len - 1; i >= 0; --i) {
125127
index_t x = compressed[i];
128+
size_t idx = i * PACKING_SIZE;
126129
for (int j = PACKING_SIZE - 1; j >= 0; --j) {
127130
int sign = ((x & 0x01) << 1) - 1;
128-
error[i * PACKING_SIZE + j] =
129-
corrected[i * PACKING_SIZE + j] + sign * scale;
131+
error[idx + j] = corrected[idx + j] + sign * scale;
130132
x >>= 1;
131133
}
132134
}

byteps/mxnet/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import struct
2222
import warnings
23+
from functools import reduce
2324

2425
import mxnet as mx
2526
import mxnet.ndarray as nd
@@ -218,8 +219,6 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
218219
self._intra_compressors = {}
219220
for i, param in enumerate(self._params):
220221
byteps_declare_tensor("parameter_" + str(i))
221-
self._intra_compressors[param.name] = type(self._intra_compressor)(
222-
**self._intra_compressor.__dict__)
223222
if param.grad_req != 'null':
224223
byteps_params = dict(
225224
filter(lambda attr: attr[0].startswith(
@@ -280,7 +279,7 @@ def _register_compressor(self, params, optimizer_params, compression_params):
280279
if compression_params.get("momentum"):
281280
# 1bit compressor use an additional momentum for weight decay
282281
if compressor == "onebit" and "wd" in optimizer_params:
283-
intra_compressor = Compression.wdmom(
282+
Compression.wdmom = Compression.wdmom(
284283
intra_compressor, optimizer_params["momentum"], optimizer_params["wd"])
285284
del optimizer_params["wd"]
286285

@@ -316,6 +315,7 @@ def _allreduce_grads(self):
316315

317316
def _init_params(self):
318317
tensors = []
318+
threshold = int(os.environ.get("BYTEPS_MIN_COMPRESS_BYTES", 65536))
319319
for param in self._params_to_init:
320320
if param._deferred_init:
321321
tensors.append(param)
@@ -326,6 +326,15 @@ def _init_params(self):
326326
if rank() != self.root_rank:
327327
param_arrays[0].__imul__(0)
328328

329+
# register intra-node compressor
330+
size = reduce(lambda x, y: x*y, param_arrays[0].shape)
331+
if size >= threshold:
332+
self._intra_compressors[param.name] = type(
333+
Compression.wdmom)(**Compression.wdmom.__dict__)
334+
else:
335+
self._intra_compressors[param.name] = type(
336+
self._intra_compressor)(**self._intra_compressor.__dict__)
337+
329338
compressed, ctx = self._intra_compressors[param.name].compress(
330339
param_arrays[0])
331340
byteps_push_pull(compressed, version=0, priority=0,

0 commit comments

Comments
 (0)