Skip to content

Commit bebd996

Browse files
authored
Support silero-vad v4 exported by k2-fsa (#2372)
1 parent 0d44df9 commit bebd996

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

sherpa-onnx/csrc/silero-vad-model.cc

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ class SileroVadModel::Impl {
180180
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
181181
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
182182

183-
if (input_names_.size() == 4 && output_names_.size() == 3) {
183+
if ((input_names_.size() == 4 && output_names_.size() == 3) ||
184+
IsExportedByK2Fsa()) {
184185
is_v5_ = false;
185186
} else if (input_names_.size() == 3 && output_names_.size() == 2) {
186187
is_v5_ = true;
@@ -248,7 +249,23 @@ class SileroVadModel::Impl {
248249
}
249250
}
250251

252+
bool IsExportedByK2Fsa() const {
253+
if (input_names_.size() == 3 && input_names_[0] == "x" &&
254+
input_names_[1] == "h" && input_names_[2] == "c" &&
255+
output_names_.size() == 3 && output_names_[0] == "prob" &&
256+
output_names_[1] == "new_h" && output_names_[2] == "new_c") {
257+
// this version is exported and maintained by us (k2-fsa)
258+
return true;
259+
}
260+
261+
return false;
262+
}
263+
251264
void CheckV4() const {
265+
if (IsExportedByK2Fsa()) {
266+
return;
267+
}
268+
252269
if (input_names_.size() != 4) {
253270
SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
254271
static_cast<int32_t>(input_names_.size()));
@@ -393,9 +410,15 @@ class SileroVadModel::Impl {
393410
Ort::Value sr =
394411
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
395412

396-
std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
397-
std::move(states_[0]),
398-
std::move(states_[1])};
413+
std::vector<Ort::Value> inputs;
414+
inputs.reserve(input_names_.size());
415+
416+
inputs.push_back(std::move(x));
417+
if (input_names_.size() == 4) {
418+
inputs.push_back(std::move(sr));
419+
}
420+
inputs.push_back(std::move(states_[0]));
421+
inputs.push_back(std::move(states_[1]));
399422

400423
auto out =
401424
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),

0 commit comments

Comments
 (0)