Skip to content

Commit 58a52cc

Browse files
authored
Merge pull request #223 from didi/transfer
Transfer learning
2 parents 0a06c4a + 3b5e8d9 commit 58a52cc

File tree

19 files changed

+152
-73
lines changed

19 files changed

+152
-73
lines changed

delta/utils/misc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ def len_to_padding(length, maxlen=None, dtype=tf.bool):
6363

6464
def log_vars(prefix, variables):
6565
''' logging TF varables metadata '''
66+
logging.info(f"{prefix}:")
6667
for var in variables:
67-
logging.info("{}: name: {} shape: {} device: {}".format(
68-
prefix, var.name, var.shape, var.device))
68+
logging.info(
69+
f"\tname = {var.name}, shape = {var.shape}, device = {var.device}")
6970

7071

7172
#pylint: disable=bad-continuation

delta/utils/model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ def print_ops(graph, prefix=''):
3030
logging.info('{} : op name: {}'.format(prefix, operator.name))
3131

3232

33-
def log_vars(prefix, variables):
34-
"""Print tensorflow variables."""
35-
for var in variables:
36-
logging.info("{}: name: {} shape: {} device: {}".format(
37-
prefix, var.name, var.shape, var.device))
38-
39-
4033
def model_size(variables):
4134
"""Get model size."""
4235
total_params = sum(

delta/utils/solver/estimator_solver.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
# ==============================================================================
1616
''' Estimator base class for classfication '''
1717
import os
18+
import re
1819
import functools
20+
import collections
1921
from absl import logging
2022
import delta.compat as tf
2123
from tensorflow.python import debug as tf_debug #pylint: disable=no-name-in-module
@@ -72,6 +74,82 @@ def l2_loss(self, tvars=None):
7274
summary_lib.scalar('l2_loss', _l2_loss)
7375
return _l2_loss
7476

77+
def get_assignment_map_from_checkpoint(self, tvars, init_checkpoint):
78+
"""Compute the union of the current variables and checkpoint variables."""
79+
assignment_map = {}
80+
initialized_variable_names = {}
81+
82+
name_to_variable = collections.OrderedDict()
83+
for var in tvars:
84+
name = var.name
85+
m = re.match("^(.*):\\d+$", name)
86+
if m is not None:
87+
name = m.group(1)
88+
name_to_variable[name] = var
89+
90+
init_vars = tf.train.list_variables(init_checkpoint)
91+
92+
assignment_map = collections.OrderedDict()
93+
for x in init_vars:
94+
(name, var) = (x[0], x[1])
95+
if name not in name_to_variable:
96+
continue
97+
assignment_map[name] = name
98+
initialized_variable_names[name] = 1
99+
initialized_variable_names[name + ":0"] = 1
100+
101+
return (assignment_map, initialized_variable_names)
102+
103+
def init_from_checkpoint(self):
104+
''' do transfer learning by init sub vars from other checkpoint. '''
105+
if 'transfer' not in self.config['solver']:
106+
return
107+
transfer_cfg = self.config['solver']['transfer']
108+
enable = transfer_cfg['enable']
109+
if not enable:
110+
return
111+
init_checkpoint = transfer_cfg['ckpt_path']
112+
exclude = transfer_cfg['exclude_reg']
113+
include = transfer_cfg['include_reg']
114+
logging.info(f"Transfer from checkpoint: {init_checkpoint}")
115+
logging.info(f"Transfer exclude: {exclude}")
116+
logging.info(f"Transfer include: {include}")
117+
118+
tvars = tf.trainable_variables()
119+
initialized_variable_names = {}
120+
if init_checkpoint:
121+
122+
def _filter_by_reg(tvars, include, exclude):
123+
include = include if include else []
124+
exclude = exclude if exclude else []
125+
outs = []
126+
for var in tvars:
127+
name = var.name
128+
for reg_str in include:
129+
logging.debug(f"var:{name}, reg: {reg_str}")
130+
m = re.match(reg_str, name)
131+
if m is not None:
132+
outs.append(var)
133+
for reg_str in exclude:
134+
logging.debug(f"var:{name}, reg: {reg_str}")
135+
m = re.match(reg_str, name)
136+
if m is None:
137+
outs.append(var)
138+
return outs
139+
140+
tvars = _filter_by_reg(tvars, include, exclude)
141+
assignment_map, initialized_variable_names = \
142+
self.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
143+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
144+
145+
logging.info("**** Trainable Variables ****")
146+
for var in tvars:
147+
init_string = ""
148+
if var.name in initialized_variable_names:
149+
init_string = ", *INIT_FROM_CKPT*"
150+
logging.info(" name = %s, shape = %s%s", var.name, var.shape,
151+
init_string)
152+
75153
def model_fn(self):
76154
''' return model_fn '''
77155
model_class = super().model_fn()
@@ -144,10 +222,11 @@ def _model_fn(features, labels, mode, params):
144222
# L2 loss
145223
loss_all += self.l2_loss()
146224

225+
utils.log_vars('****** Global Vars *****', tf.global_variables())
226+
self.init_from_checkpoint()
147227
train_op = self.get_train_op(loss_all)
148228
train_hooks = self.get_train_hooks(labels, logits, alpha=alignment)
149229

150-
utils.log_vars('Global Vars', tf.global_variables())
151230
return tf.estimator.EstimatorSpec( #pylint: disable=no-member
152231
mode=mode,
153232
loss=loss_all,
@@ -179,7 +258,7 @@ def create_estimator(self):
179258
# multi-gpus
180259
devices, num_gpu = utils.gpu_device_names()
181260
distribution = utils.get_distribution_strategy(num_gpu)
182-
logging.info('Device: {}/{}'.format(num_gpu, devices))
261+
logging.info('Device: num = {}, list = {}'.format(num_gpu, devices))
183262

184263
# run config
185264
tfconf = self.config['solver']['run_config']

deltann/api/c_api.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
==============================================================================*/
1616

17+
#include <iostream>
1718
#include <string>
1819
#include <vector>
19-
#include <iostream>
2020

2121
#include "api/c_api.h"
2222
#include "core/config.h"
@@ -47,10 +47,10 @@ DeltaStatus DeltaSetInputs(InferHandel inf, Input* inputs, int num) {
4747
Runtime* rt = static_cast<Runtime*>(inf);
4848
std::vector<In> ins;
4949
for (int i = 0; i < num; ++i) {
50-
//std::cout << "set inputs name : " << inputs[i].input_name << "\n";
51-
//std::cout << "set inputs nelms: " << inputs[i].nelms << "\n";
50+
// std::cout << "set inputs name : " << inputs[i].input_name << "\n";
51+
// std::cout << "set inputs nelms: " << inputs[i].nelms << "\n";
5252

53-
const int *data = static_cast<const int*>(inputs[i].ptr);
53+
const int* data = static_cast<const int*>(inputs[i].ptr);
5454
if (inputs[i].shape == NULL) {
5555
ins.push_back(In(inputs[i].graph_name, inputs[i].input_name,
5656
inputs[i].ptr, inputs[i].nelms));

deltann/core/buffer.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ class Buffer {
8585
void copy_from(const void* src, const std::size_t size) {
8686
DELTA_CHECK(_ptr);
8787
DELTA_CHECK(src);
88-
DELTA_CHECK(size <= _size)
89-
<< "expect size: " << size << " real size:" << _size;
88+
DELTA_CHECK(size <= _size) << "expect size: " << size
89+
<< " real size:" << _size;
9090
std::memcpy(_ptr, src, size);
9191
}
9292

9393
void copy_to(void* dst, const std::size_t size) {
9494
DELTA_CHECK(_ptr);
9595
DELTA_CHECK(dst);
96-
DELTA_CHECK(size <= _size)
97-
<< "expect size: " << size << " real size:" << _size;
96+
DELTA_CHECK(size <= _size) << "expect size: " << size
97+
<< " real size:" << _size;
9898
std::memcpy(dst, _ptr, size);
9999
}
100100

deltann/core/io.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class BaseInOutData {
185185
std::size_t bytes = nelms * delta_dtype_size(dtype);
186186
this->resize(bytes);
187187
_data->copy_from(src, bytes);
188-
}
188+
}
189189

190190
void copy_from(const float* src) { copy_from(src, this->nelms()); }
191191

@@ -203,7 +203,7 @@ class BaseInOutData {
203203

204204
#ifdef USE_TF
205205
tensorflow::TensorShape tensor_shape() const {
206-
//tensorflow::Status::Status status;
206+
// tensorflow::Status::Status status;
207207
const Shape& shape = this->shape();
208208
tensorflow::TensorShape ts;
209209
auto s = shape.vec();

deltann/core/runtime.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,8 @@ DeltaStatus Runtime::set_inputs(const std::vector<In>& ins) {
151151
<< in._shape;
152152
input.set_shape(in._shape);
153153
}
154-
DELTA_CHECK_EQ(in._nelms, input.nelms())
155-
<< in._nelms << ":"
156-
<< input.nelms();
154+
DELTA_CHECK_EQ(in._nelms, input.nelms()) << in._nelms << ":"
155+
<< input.nelms();
157156

158157
InputData input_data(input);
159158
input_data.copy_from(in._ptr, in._nelms);

deltann/core/runtime.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct In {
6161
std::string _graph_name;
6262
std::string _input_name;
6363
const void* _ptr;
64-
std::size_t _nelms; // elements
64+
std::size_t _nelms; // elements
6565
Shape _shape;
6666
};
6767

deltann/core/tfmodel.cc

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,28 @@ TFModel::TFModel(ModelMeta model_meta, int num_threads)
7474
void TFModel::feed_tensor(Tensor* tensor, const InputData& input) {
7575
std::int64_t num_elements = tensor->NumElements();
7676
switch (input.dtype()) {
77-
case DataType::DELTA_FLOAT32:{
78-
std::cout << "input: " << num_elements << " " << tensor->TotalBytes() << std::endl;
77+
case DataType::DELTA_FLOAT32: {
78+
std::cout << "input: " << num_elements << " " << tensor->TotalBytes()
79+
<< std::endl;
7980
auto ptr = tensor->flat<float>().data();
8081
std::fill_n(ptr, num_elements, 0.0);
81-
std::copy_n(static_cast<float*>(input.ptr()), num_elements,
82-
ptr);
82+
std::copy_n(static_cast<float*>(input.ptr()), num_elements, ptr);
8383
break;
8484
}
85-
case DataType::DELTA_INT32:{
85+
case DataType::DELTA_INT32: {
8686
std::copy_n(static_cast<int*>(input.ptr()), num_elements,
8787
tensor->flat<int>().data());
8888
break;
89-
}
89+
}
9090
case DataType::DELTA_CHAR: {
9191
char* cstr = static_cast<char*>(input.ptr());
9292
std::string str = std::string(cstr);
9393
tensor->scalar<tensorflow::tstring>()() = str;
9494
break;
9595
}
96-
default:{
96+
default: {
9797
LOG_FATAL << "Not support dtype:" << delta_dtype_str(input.dtype());
98-
}
98+
}
9999
}
100100
}
101101

@@ -107,7 +107,7 @@ void TFModel::fetch_tensor(const Tensor& tensor, OutputData* output) {
107107
// copy data
108108
std::size_t num_elements = tensor.NumElements();
109109
std::size_t total_bytes = tensor.TotalBytes();
110-
std::cout << "output: " << num_elements << " " << total_bytes << "\n";
110+
std::cout << "output: " << num_elements << " " << total_bytes << "\n";
111111
DELTA_CHECK(num_elements == output->nelms())
112112
<< "expect " << num_elements << "elems, but given " << output->nelms();
113113

@@ -186,33 +186,32 @@ int TFModel::run(const std::vector<InputData>& inputs,
186186
set_feeds(&feeds, inputs);
187187
set_fetches(&fetches, *output);
188188

189-
//std::cout << "input xxxxxxxxxxxxxxxxx"<< "\n";
190-
//auto ti = feeds[0].second;
191-
//for (auto i = 0; i < ti.NumElements(); i++){
189+
// std::cout << "input xxxxxxxxxxxxxxxxx"<< "\n";
190+
// auto ti = feeds[0].second;
191+
// for (auto i = 0; i < ti.NumElements(); i++){
192192
// std::cout << std::showpoint << ti.flat<float>()(i) << " ";
193193
// if (i % 40 == 1){std::cout << "\n";}
194194
//}
195-
//std::cout << "\n";
196-
//std::cout << "input -------------------"<< "\n";
197-
195+
// std::cout << "\n";
196+
// std::cout << "input -------------------"<< "\n";
198197

199198
// Session run
200199
RunOptions run_options;
201200
RunMetadata run_meta;
202-
tensorflow::Status s = _bundle.GetSession()->Run(run_options, feeds, fetches, {},
203-
&output_tensors, &run_meta);
201+
tensorflow::Status s = _bundle.GetSession()->Run(
202+
run_options, feeds, fetches, {}, &output_tensors, &run_meta);
204203
if (!s.ok()) {
205204
LOG_FATAL << "Error, TF Model run failed: " << s;
206205
exit(-1);
207206
}
208207

209-
//std::cout << "output xxxxxxxxxxxxxxxxx"<< "\n";
210-
//auto t = output_tensors[0];
211-
//for (auto i = 0; i < t.NumElements(); i++){
208+
// std::cout << "output xxxxxxxxxxxxxxxxx"<< "\n";
209+
// auto t = output_tensors[0];
210+
// for (auto i = 0; i < t.NumElements(); i++){
212211
// std::cout << std::showpoint << t.flat<float>()(i) << " ";
213212
//}
214-
//std::cout << "\n";
215-
//std::cout << "output -------------------"<< "\n";
213+
// std::cout << "\n";
214+
// std::cout << "output -------------------"<< "\n";
216215

217216
get_featches(output_tensors, output);
218217

@@ -287,13 +286,14 @@ DeltaStatus TFModel::load_from_saved_model() {
287286
LOG_INFO << "load saved model from path: " << path;
288287
if (!MaybeSavedModelDirectory(path)) {
289288
LOG_FATAL << "SaveModel not in :" << path;
290-
return DeltaStatus::STATUS_ERROR;
289+
return DeltaStatus::STATUS_ERROR;
291290
}
292291

293-
tensorflow::Status s = LoadSavedModel(options, run_options, path,
294-
{tensorflow::kSavedModelTagServe}, &_bundle);
292+
tensorflow::Status s = LoadSavedModel(
293+
options, run_options, path, {tensorflow::kSavedModelTagServe}, &_bundle);
295294
if (!s.ok()) {
296-
LOG_FATAL << "Failed Load model from saved_model.pb : " << s.error_message();
295+
LOG_FATAL << "Failed Load model from saved_model.pb : "
296+
<< s.error_message();
297297
}
298298

299299
return DeltaStatus::STATUS_OK;

deltann/examples/speaker/test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ struct DeltaModel {
7474
}
7575

7676
DeltaStatus SetInputs(T* buf, const std::vector<int> shape) {
77-
return this->SetInputs(buf, this->NumElems(shape), shape.data(), shape.size());
77+
return this->SetInputs(buf, this->NumElems(shape), shape.data(),
78+
shape.size());
7879
}
7980

8081
DeltaStatus SetInputs(T* buf, int nelms, const int* shape, const int ndims) {

0 commit comments

Comments
 (0)