Skip to content

Commit 0e738c3

Browse files
authored
Add C++ runtime and Python API for NeMo Canary models (#2352)
1 parent f8d957a commit 0e738c3

24 files changed

+1091
-8
lines changed

.github/scripts/test-python.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ log() {
88
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
99
}
1010

11+
log "test nemo canary"
12+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
13+
tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
14+
rm sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
15+
python3 ./python-api-examples/offline-nemo-canary-decode-files.py
16+
rm -rf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
17+
1118
log "test spleeter"
1219

1320
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
This file shows how to use a non-streaming Canary model from NeMo
5+
to decode files.
6+
7+
Please download model files from
8+
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
9+
10+
11+
The example model supports 4 languages and it is converted from
12+
https://huggingface.co/nvidia/canary-180m-flash
13+
14+
It supports automatic speech-to-text recognition (ASR) in 4 languages
15+
(English, German, French, Spanish) and translation from English to
16+
German/French/Spanish and from German/French/Spanish to English with or
17+
without punctuation and capitalization (PnC).
18+
"""
19+
20+
from pathlib import Path
21+
22+
import sherpa_onnx
23+
import soundfile as sf
24+
25+
26+
def create_recognizer():
27+
encoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/encoder.int8.onnx"
28+
decoder = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/decoder.int8.onnx"
29+
tokens = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/tokens.txt"
30+
31+
en_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/en.wav"
32+
de_wav = "./sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8/test_wavs/de.wav"
33+
34+
if not Path(encoder).is_file() or not Path(en_wav).is_file():
35+
raise ValueError(
36+
"""Please download model files from
37+
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
38+
"""
39+
)
40+
return (
41+
sherpa_onnx.OfflineRecognizer.from_nemo_canary(
42+
encoder=encoder,
43+
decoder=decoder,
44+
tokens=tokens,
45+
debug=True,
46+
),
47+
en_wav,
48+
de_wav,
49+
)
50+
51+
52+
def decode(recognizer, samples, sample_rate, src_lang, tgt_lang):
53+
stream = recognizer.create_stream()
54+
stream.accept_waveform(sample_rate, samples)
55+
56+
recognizer.recognizer.set_config(
57+
config=sherpa_onnx.OfflineRecognizerConfig(
58+
model_config=sherpa_onnx.OfflineModelConfig(
59+
canary=sherpa_onnx.OfflineCanaryModelConfig(
60+
src_lang=src_lang,
61+
tgt_lang=tgt_lang,
62+
)
63+
)
64+
)
65+
)
66+
67+
recognizer.decode_stream(stream)
68+
return stream.result.text
69+
70+
71+
def main():
72+
recognizer, en_wav, de_wav = create_recognizer()
73+
74+
en_audio, en_sample_rate = sf.read(en_wav, dtype="float32", always_2d=True)
75+
en_audio = en_audio[:, 0] # only use the first channel
76+
77+
de_audio, de_sample_rate = sf.read(de_wav, dtype="float32", always_2d=True)
78+
de_audio = de_audio[:, 0] # only use the first channel
79+
80+
en_wav_en_result = decode(
81+
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="en"
82+
)
83+
en_wav_es_result = decode(
84+
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="es"
85+
)
86+
en_wav_de_result = decode(
87+
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="de"
88+
)
89+
en_wav_fr_result = decode(
90+
recognizer, en_audio, en_sample_rate, src_lang="en", tgt_lang="fr"
91+
)
92+
93+
de_wav_en_result = decode(
94+
recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="en"
95+
)
96+
de_wav_de_result = decode(
97+
recognizer, de_audio, de_sample_rate, src_lang="de", tgt_lang="de"
98+
)
99+
100+
print("en_wav_en_result", en_wav_en_result)
101+
print("en_wav_es_result", en_wav_es_result)
102+
print("en_wav_de_result", en_wav_de_result)
103+
print("en_wav_fr_result", en_wav_fr_result)
104+
print("-" * 10)
105+
print("de_wav_en_result", de_wav_en_result)
106+
print("de_wav_de_result", de_wav_de_result)
107+
108+
109+
if __name__ == "__main__":
110+
main()

scripts/nemo/canary/export_onnx_180m_flash.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,14 @@ def export_decoder(canary_model):
281281

282282

283283
def export_tokens(canary_model):
284+
underline = "▁"
284285
with open("./tokens.txt", "w", encoding="utf-8") as f:
285286
for i in range(canary_model.tokenizer.vocab_size):
286287
s = canary_model.tokenizer.ids_to_text([i])
288+
289+
if s[0] == " ":
290+
s = underline + s[1:]
291+
287292
f.write(f"{s} {i}\n")
288293
print("Saved to tokens.txt")
289294

scripts/nemo/canary/test_180m_flash.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,13 @@ def main():
289289
tokens.append(t)
290290
print("len(tokens)", len(tokens))
291291
print("tokens", tokens)
292+
292293
text = "".join([id2token[i] for i in tokens])
294+
295+
underline = "▁"
296+
# underline = b"\xe2\x96\x81".decode()
297+
298+
text = text.replace(underline, " ").strip()
293299
print("text:", text)
294300

295301

sherpa-onnx/c-api/cxx-api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <algorithm>
77
#include <cstring>
8+
#include <utility>
89

910
namespace sherpa_onnx::cxx {
1011

sherpa-onnx/csrc/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ set(sources
2525
jieba.cc
2626
keyword-spotter-impl.cc
2727
keyword-spotter.cc
28+
offline-canary-model-config.cc
29+
offline-canary-model.cc
2830
offline-ctc-fst-decoder-config.cc
2931
offline-ctc-fst-decoder.cc
3032
offline-ctc-greedy-search-decoder.cc
@@ -50,15 +52,13 @@ set(sources
5052
offline-rnn-lm.cc
5153
offline-sense-voice-model-config.cc
5254
offline-sense-voice-model.cc
53-
5455
offline-source-separation-impl.cc
5556
offline-source-separation-model-config.cc
5657
offline-source-separation-spleeter-model-config.cc
5758
offline-source-separation-spleeter-model.cc
5859
offline-source-separation-uvr-model-config.cc
5960
offline-source-separation-uvr-model.cc
6061
offline-source-separation.cc
61-
6262
offline-stream.cc
6363
offline-tdnn-ctc-model.cc
6464
offline-tdnn-model-config.cc
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// sherpa-onnx/csrc/offline-canary-model-config.cc
2+
//
3+
// Copyright (c) 2025 Xiaomi Corporation
4+
5+
#include "sherpa-onnx/csrc/offline-canary-model-config.h"
6+
7+
#include <sstream>
8+
9+
#include "sherpa-onnx/csrc/file-utils.h"
10+
#include "sherpa-onnx/csrc/macros.h"
11+
12+
namespace sherpa_onnx {
13+
14+
void OfflineCanaryModelConfig::Register(ParseOptions *po) {
15+
po->Register("canary-encoder", &encoder,
16+
"Path to onnx encoder of Canary, e.g., encoder.int8.onnx");
17+
18+
po->Register("canary-decoder", &decoder,
19+
"Path to onnx decoder of Canary, e.g., decoder.int8.onnx");
20+
21+
po->Register("canary-src-lang", &src_lang,
22+
"Valid values: en, de, es, fr. If empty, default to use en");
23+
24+
po->Register("canary-tgt-lang", &tgt_lang,
25+
"Valid values: en, de, es, fr. If empty, default to use en");
26+
27+
po->Register("canary-use-pnc", &use_pnc,
28+
"true to enable punctuations and casing. false to disable them");
29+
}
30+
31+
bool OfflineCanaryModelConfig::Validate() const {
32+
if (encoder.empty()) {
33+
SHERPA_ONNX_LOGE("Please provide --canary-encoder");
34+
return false;
35+
}
36+
37+
if (!FileExists(encoder)) {
38+
SHERPA_ONNX_LOGE("Canary encoder file '%s' does not exist",
39+
encoder.c_str());
40+
return false;
41+
}
42+
43+
if (decoder.empty()) {
44+
SHERPA_ONNX_LOGE("Please provide --canary-decoder");
45+
return false;
46+
}
47+
48+
if (!FileExists(decoder)) {
49+
SHERPA_ONNX_LOGE("Canary decoder file '%s' does not exist",
50+
decoder.c_str());
51+
return false;
52+
}
53+
54+
if (!src_lang.empty()) {
55+
if (src_lang != "en" && src_lang != "de" && src_lang != "es" &&
56+
src_lang != "fr") {
57+
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-src-lang");
58+
return false;
59+
}
60+
}
61+
62+
if (!tgt_lang.empty()) {
63+
if (tgt_lang != "en" && tgt_lang != "de" && tgt_lang != "es" &&
64+
tgt_lang != "fr") {
65+
SHERPA_ONNX_LOGE("Please use en, de, es, or fr for --canary-tgt-lang");
66+
return false;
67+
}
68+
}
69+
70+
return true;
71+
}
72+
73+
std::string OfflineCanaryModelConfig::ToString() const {
74+
std::ostringstream os;
75+
76+
os << "OfflineCanaryModelConfig(";
77+
os << "encoder=\"" << encoder << "\", ";
78+
os << "decoder=\"" << decoder << "\", ";
79+
os << "src_lang=\"" << src_lang << "\", ";
80+
os << "tgt_lang=\"" << tgt_lang << "\", ";
81+
os << "use_pnc=" << (use_pnc ? "True" : "False") << ")";
82+
83+
return os.str();
84+
}
85+
86+
} // namespace sherpa_onnx
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// sherpa-onnx/csrc/offline-canary-model-config.h
2+
//
3+
// Copyright (c) 2025 Xiaomi Corporation
4+
5+
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
6+
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
7+
8+
#include <string>
9+
10+
#include "sherpa-onnx/csrc/parse-options.h"
11+
12+
namespace sherpa_onnx {
13+
14+
struct OfflineCanaryModelConfig {
15+
std::string encoder;
16+
std::string decoder;
17+
18+
// en, de, es, fr, or leave it empty to use en
19+
std::string src_lang;
20+
21+
// en, de, es, fr, or leave it empty to use en
22+
std::string tgt_lang;
23+
24+
// true to enable punctuations and casing
25+
// false to disable punctuations and casing
26+
bool use_pnc = true;
27+
28+
OfflineCanaryModelConfig() = default;
29+
OfflineCanaryModelConfig(const std::string &encoder,
30+
const std::string &decoder,
31+
const std::string &src_lang,
32+
const std::string &tgt_lang, bool use_pnc)
33+
: encoder(encoder),
34+
decoder(decoder),
35+
src_lang(src_lang),
36+
tgt_lang(tgt_lang),
37+
use_pnc(use_pnc) {}
38+
39+
void Register(ParseOptions *po);
40+
bool Validate() const;
41+
42+
std::string ToString() const;
43+
};
44+
45+
} // namespace sherpa_onnx
46+
47+
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_CONFIG_H_
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// sherpa-onnx/csrc/offline-canary-model-meta-data.h
2+
//
3+
// Copyright (c) 2024 Xiaomi Corporation
4+
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
5+
#define SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_
6+
7+
#include <string>
8+
#include <unordered_map>
9+
#include <vector>
10+
11+
namespace sherpa_onnx {
12+
13+
struct OfflineCanaryModelMetaData {
14+
int32_t vocab_size;
15+
int32_t subsampling_factor = 8;
16+
int32_t feat_dim = 120;
17+
std::string normalize_type;
18+
std::unordered_map<std::string, int32_t> lang2id;
19+
};
20+
21+
} // namespace sherpa_onnx
22+
23+
#endif // SHERPA_ONNX_CSRC_OFFLINE_CANARY_MODEL_META_DATA_H_

0 commit comments

Comments
 (0)