Skip to content

Commit dae88d6

Browse files
Merge pull request #225 from vycezhong/gradient_compression
gradient compression support
2 parents b8948f0 + ef6f916 commit dae88d6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+6262
-410
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,4 @@ venv.bak/
116116

117117
# for development
118118
scripts/
119+
exps/

byteps/common/common.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ int GetCommandType(RequestType requestType, int d) {
100100
return (((m + d) * (m + d + 1)) / 2) + d;
101101
}
102102

103+
#ifndef BYTEPS_BUILDING_SERVER
103104
ncclDataType_t getNcclDataType(DataType dtype) {
104105
switch (dtype) {
105106
case BYTEPS_FLOAT32:
@@ -121,6 +122,7 @@ ncclDataType_t getNcclDataType(DataType dtype) {
121122
}
122123
return ncclFloat32;
123124
}
125+
#endif
124126

125127
int getDataTypeLength(int dtype) {
126128
switch (dtype) {

byteps/common/common.h

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,23 @@
3131
#include <vector>
3232

3333
// Add for profiling communication events
34-
#include <fstream>
3534
#include <stdio.h>
3635
#include <stdlib.h>
37-
#include <iostream>
38-
#include <thread>
36+
3937
#include <chrono>
38+
#include <fstream>
39+
#include <iostream>
4040
#include <queue>
41+
#include <thread>
4142

4243
namespace byteps {
4344
namespace common {
45+
namespace compressor {
46+
struct BPSTensor;
47+
typedef BPSTensor tensor_t;
48+
class Compressor;
49+
class ErrorFeedback;
50+
} // namespace compressor
4451

4552
// Device ID used for CPU.
4653
#define CPU_DEVICE_ID (-1)
@@ -83,8 +90,10 @@ enum QueueType {
8390
COPYD2H,
8491
PCIE_REDUCE,
8592
COORDINATE_PUSH,
93+
COMPRESS,
8694
PUSH,
8795
PULL,
96+
DECOMPRESS,
8897
COPYH2D,
8998
COORDINATE_BROADCAST,
9099
BROADCAST,
@@ -94,10 +103,18 @@ enum QueueType {
94103
const int QueueNum =
95104
(int)QUEUE_NUM_AND_NOT_A_REAL_QUEUE_TYPE_AND_MUST_BE_THE_LAST;
96105

97-
const std::vector<std::string> LogStrings = {
98-
"COORDINATE_REDUCE", "REDUCE", "COPYD2H", "PCIE_REDUCE",
99-
"COORDINATE_PUSH", "PUSH", "PULL", "COPYH2D",
100-
"COORDINATE_BROADCAST", "BROADCAST"};
106+
const std::vector<std::string> LogStrings = {"COORDINATE_REDUCE",
107+
"REDUCE",
108+
"COPYD2H",
109+
"PCIE_REDUCE",
110+
"COORDINATE_PUSH",
111+
"COMPRESS",
112+
"PUSH",
113+
"PULL",
114+
"DECOMPRESS",
115+
"COPYH2D",
116+
"COORDINATE_BROADCAST",
117+
"BROADCAST"};
101118

102119
class Status {
103120
public:
@@ -173,11 +190,17 @@ typedef struct BytePSContext {
173190
std::vector<void*> pcie_cpubuff;
174191
size_t buff_len;
175192
// Used for profiling communication events
176-
std::queue<BPSCommTime *> comm_time;
193+
std::queue<BPSCommTime*> comm_time;
177194
bool profile_flag = false;
178195
int step_cnt = 0;
179196
int local_rank = 0;
180-
std::unordered_map<uint64_t, std::unordered_map<int, std::queue<BPSCommTime *>>> part_comm_time;
197+
std::unordered_map<uint64_t,
198+
std::unordered_map<int, std::queue<BPSCommTime*>>>
199+
part_comm_time;
200+
// Compressor list
201+
std::vector<std::shared_ptr<compressor::Compressor>> compressor_list;
202+
// kwargs
203+
std::unordered_map<std::string, std::string> kwargs;
181204
} BPSContext;
182205

183206
class Tensor {
@@ -233,6 +256,10 @@ struct TensorTableEntry {
233256
std::shared_ptr<std::atomic_int> counter_ptr;
234257
// How many partitions
235258
unsigned int total_partnum = 0;
259+
// Compressor
260+
std::shared_ptr<compressor::Compressor> compressor;
261+
// Compressed
262+
std::shared_ptr<compressor::tensor_t> compressed;
236263
};
237264
using TensorTable = std::unordered_map<std::string, TensorTableEntry>;
238265

@@ -250,6 +277,11 @@ ncclDataType_t getNcclDataType(DataType dtype);
250277

251278
int getDataTypeLength(int dtype);
252279

280+
inline size_t Align(size_t size, int dtype) {
281+
const size_t min_size =
282+
(getDataTypeLength(dtype) * getDataTypeLength(dtype)) * 8;
283+
return size + (min_size - size % min_size) % min_size;
284+
}
253285
} // namespace common
254286
} // namespace byteps
255287

byteps/common/compressor/common.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#ifndef BYTEPS_COMPRESSOR_COMMON_H
17+
#define BYTEPS_COMPRESSOR_COMMON_H
18+
19+
#include <unordered_map>
20+
#if __F16C__
21+
#include "../half.h"
22+
using half_t = mshadow::half::half_t;
23+
#endif
24+
25+
namespace byteps {
26+
namespace common {
27+
namespace compressor {
28+
typedef char byte_t;
29+
/*!
30+
* \brief Tensor type
31+
*/
32+
typedef struct BPSTensor {
33+
byte_t* data;
34+
size_t size;
35+
int dtype;
36+
37+
BPSTensor() : data(nullptr), size(0), dtype(0) {}
38+
BPSTensor(void* data, size_t size = 0, int dtype = 0)
39+
: data(reinterpret_cast<byte_t*>(data)), size(size), dtype(dtype) {}
40+
} tensor_t;
41+
42+
using kwargs_t = std::unordered_map<std::string, std::string>;
43+
44+
#define COMPRESS_IMPL_SWITCH(dtype, func, dst, src, size) \
45+
switch (dtype) { \
46+
case BYTEPS_FLOAT16: \
47+
return func(reinterpret_cast<uint16_t*>(dst), \
48+
reinterpret_cast<const half_t*>(src), \
49+
size / sizeof(half_t)); \
50+
case BYTEPS_FLOAT32: \
51+
return func(reinterpret_cast<uint32_t*>(dst), \
52+
reinterpret_cast<const float*>(src), size / sizeof(float)); \
53+
case BYTEPS_FLOAT64: \
54+
return func(reinterpret_cast<uint64_t*>(dst), \
55+
reinterpret_cast<const double*>(src), \
56+
size / sizeof(double)); \
57+
default: \
58+
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
59+
}
60+
61+
#define DECOMPRESS_IMPL_SWITCH(dtype, func, dst, src, compressed_size) \
62+
switch (dtype) { \
63+
case BYTEPS_FLOAT16: \
64+
return func(reinterpret_cast<half_t*>(dst), \
65+
reinterpret_cast<const uint16_t*>(src), compressed_size); \
66+
case BYTEPS_FLOAT32: \
67+
return func(reinterpret_cast<float*>(dst), \
68+
reinterpret_cast<const uint32_t*>(src), compressed_size); \
69+
case BYTEPS_FLOAT64: \
70+
return func(reinterpret_cast<double*>(dst), \
71+
reinterpret_cast<const uint64_t*>(src), compressed_size); \
72+
default: \
73+
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
74+
}
75+
76+
#define FAST_UPDATE_ERROR_IMPL_SWITCH(dtype, func, dst, src1, src2, \
77+
compressed_size) \
78+
switch (dtype) { \
79+
case BYTEPS_FLOAT16: \
80+
return func(reinterpret_cast<half_t*>(dst), \
81+
reinterpret_cast<half_t*>(src1), \
82+
reinterpret_cast<const uint16_t*>(src2), compressed_size); \
83+
case BYTEPS_FLOAT32: \
84+
return func(reinterpret_cast<float*>(dst), \
85+
reinterpret_cast<float*>(src1), \
86+
reinterpret_cast<const uint32_t*>(src2), compressed_size); \
87+
case BYTEPS_FLOAT64: \
88+
return func(reinterpret_cast<double*>(dst), \
89+
reinterpret_cast<double*>(src1), \
90+
reinterpret_cast<const uint64_t*>(src2), compressed_size); \
91+
default: \
92+
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
93+
}
94+
95+
} // namespace compressor
96+
} // namespace common
97+
} // namespace byteps
98+
99+
#endif // BYTEPS_COMPRESSOR_COMMON_H

byteps/common/compressor/compressor.h

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#ifndef BYTEPS_COMPRESSOR_COMPRESSOR_H
17+
#define BYTEPS_COMPRESSOR_COMPRESSOR_H
18+
19+
#include <memory>
20+
21+
#include "../common.h"
22+
#include "../logging.h"
23+
#include "common.h"
24+
25+
namespace byteps {
26+
namespace common {
27+
namespace compressor {
28+
/*!
29+
* \brief Compressor interface
30+
* Compressor defines two universal API - Compress & Decompress
31+
*
32+
* \par
33+
* The caller do not need to allocate additional memory to store compressed data
34+
* because there is an internal buffer to store the compressed data and the
35+
* pointer will be returned to the caller. Then the caller can send the returned
36+
* compressed data as normal.
37+
*
38+
* \par
39+
* There are two optional features of the compressor - error-feedback &
40+
* momentum. These two features can be added to any common compressors like 1bit
41+
* and topk. To be generic, these two features are also compressors, exposing
42+
* the same API as Compressor. More details can be found in their own files.
43+
*
44+
* \par
45+
* To add a new compressor, developers need to inherit this class in 'impl'
46+
* directory. If a new optional feature like error-feedback is needed,
47+
* developers need to use decorator pattern and add new files in the current
48+
* directory. The existing implementation can be used as a reference.
49+
*
50+
*
51+
* \sa ErrorFeedback, Momentum
52+
*/
53+
class Compressor {
54+
public:
55+
Compressor(size_t size, DataType dtype)
56+
: _size(size), _dtype(dtype), _buf(new byte_t[size]){};
57+
virtual ~Compressor() = default;
58+
59+
/*!
60+
* \brief Compress function
61+
*
62+
* \note Except for error-feedback and momentum, the underlying data of input
63+
* should never be changed. this is because input is still used in error
64+
* feedback if enabled.
65+
*
66+
* \note Compressed data should be stored in the buffer of the compressor. So
67+
* it is not an inplace operation.
68+
*
69+
* \param grad gradient tensor, passed by value.
70+
* \return compressed tensor. it is the buffer of the compressor,
71+
* which contains the compressed data. the returned size is the size of
72+
* compressed data.
73+
*/
74+
virtual tensor_t Compress(tensor_t grad) = 0;
75+
76+
/*!
77+
* \brief Decompress function
78+
*
79+
* \note For servers, decompression is not an inplace operation. The
80+
* decompressed results locates in the buffer of the compressor. For workers,
81+
* it is an inplace operation.
82+
*
83+
* \param compressed compressed tensor.
84+
* \return decompressed tensor. For servers, it is the buffer of the
85+
* compressor, which contains the decompressed data. For workers, its pointer
86+
* is the same as the input's, while the size is decompressed size, which is
87+
* also the original size.
88+
*/
89+
virtual tensor_t Decompress(tensor_t compressed) = 0;
90+
91+
/*!
92+
* \brief faster version of `UpdateError` via operation fusion
93+
*
94+
* \par
95+
* This is a helper function implemented by each compressor. If defined,
96+
* `ErrorFeedback` will use this function instead of defualt `UpdateError`
97+
* function implemented in error_feedback.cc. If undefined, default
98+
* `UpdateError` will be used.
99+
*
100+
* \par
101+
* Typically `UpdateError` needs to decompress and do a substraction. But for
102+
* most compressors, the step of decompression can be avoided. For example,
103+
* for topk compressor, `UpdateError` can be simplied in this way:
104+
* 1. e <- p (e is the error and p is the corrected gradient)
105+
* 2. zero-fill e with selected k indices
106+
*
107+
* Actually it is a fusion of original decompression and substraction. It is
108+
* optional to override.
109+
*
110+
* \param corrected gradient corrected with error
111+
* \param error error
112+
* \param compressed compressed gradient
113+
*/
114+
virtual void FastUpdateError(tensor_t error, tensor_t corrected,
115+
tensor_t compressed) {
116+
BPS_LOG(FATAL) << "FastUpdateError is not implemented";
117+
};
118+
119+
protected:
120+
/*! \brief original size */
121+
size_t _size;
122+
123+
DataType _dtype;
124+
125+
/*! \brief buffer to store compressed grad */
126+
std::unique_ptr<byte_t[]> _buf;
127+
};
128+
129+
} // namespace compressor
130+
} // namespace common
131+
} // namespace byteps
132+
133+
#endif // BYTEPS_COMPRESSOR_COMPRESSOR_H

0 commit comments

Comments
 (0)