Skip to content

Commit f096034

Browse files
authored
Add LODR support to online and offline recognizers (#2026)
This PR integrates LODR (Level-Ordered Deterministic Rescoring) support from Icefall into both online and offline recognizers, enabling LODR for LM shallow fusion and LM rescore. - Extended OnlineLMConfig and OfflineLMConfig to include lodr_fst, lodr_scale, and lodr_backoff_id. - Implemented LodrFst and LodrStateCost classes and wired them into RNN LM scoring in both online and offline code paths. - Updated Python bindings, CLI entry points, examples, and CI test scripts to accept and exercise the new LODR options.
1 parent 6122a67 commit f096034

21 files changed

+613
-14
lines changed

.github/scripts/test-offline-transducer.sh

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,39 @@ time $EXE \
281281
$repo/test_wavs/1.wav \
282282
$repo/test_wavs/8k.wav
283283

284-
rm -rf $repo
284+
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
285+
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
286+
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
287+
lm_repo=$(basename $lm_repo_url)
288+
pushd $lm_repo
289+
git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
290+
popd
291+
292+
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
293+
log "Download bi-gram LM from ${bigram_repo_url}"
294+
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
295+
bigramlm_repo=$(basename $bigram_repo_url)
296+
pushd $bigramlm_repo
297+
git lfs pull --include "2gram.fst"
298+
popd
299+
300+
log "Start testing with LM and bi-gram LODR"
301+
# TODO: find test examples that change with the LODR
302+
time $EXE \
303+
--tokens=$repo/tokens.txt \
304+
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
305+
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
306+
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
307+
--num-threads=2 \
308+
--decoding_method="modified_beam_search" \
309+
--lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
310+
--lodr-fst=$bigramlm_repo/2gram.fst \
311+
--lodr-scale=-0.5 \
312+
$repo/test_wavs/0.wav \
313+
$repo/test_wavs/1.wav \
314+
$repo/test_wavs/8k.wav
315+
316+
rm -rf $repo $lm_repo $bigramlm_repo
285317

286318
log "------------------------------------------------------------"
287319
log "Run Paraformer (Chinese)"

.github/scripts/test-online-transducer.sh

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,60 @@ for wave in ${waves[@]}; do
174174
$wave
175175
done
176176

177-
rm -rf $repo
177+
lm_repo_url=https://huggingface.co/vsd-vector/icefall-librispeech-rnn-lm
178+
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
179+
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
180+
lm_repo=$(basename $lm_repo_url)
181+
pushd $lm_repo
182+
git lfs pull --include "with-state-epoch-99-avg-1.onnx"
183+
popd
184+
185+
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
186+
log "Download bi-gram LM from ${bigram_repo_url}"
187+
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
188+
bigramlm_repo=$(basename $bigram_repo_url)
189+
pushd $bigramlm_repo
190+
git lfs pull --include "2gram.fst"
191+
popd
192+
193+
log "Start testing LODR"
194+
195+
waves=(
196+
$repo/test_wavs/0.wav
197+
$repo/test_wavs/1.wav
198+
$repo/test_wavs/8k.wav
199+
)
200+
201+
for wave in ${waves[@]}; do
202+
time $EXE \
203+
--tokens=$repo/tokens.txt \
204+
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
205+
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
206+
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
207+
--num-threads=2 \
208+
--decoding_method="modified_beam_search" \
209+
--lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
210+
--lodr-fst=$bigramlm_repo/2gram.fst \
211+
--lodr-scale=-0.5 \
212+
$wave
213+
done
214+
215+
for wave in ${waves[@]}; do
216+
time $EXE \
217+
--tokens=$repo/tokens.txt \
218+
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
219+
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
220+
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
221+
--num-threads=2 \
222+
--decoding_method="modified_beam_search" \
223+
--lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
224+
--lodr-fst=$bigramlm_repo/2gram.fst \
225+
--lodr-scale=-0.5 \
226+
--lm-shallow-fusion=true \
227+
$wave
228+
done
229+
230+
rm -rf $repo $bigramlm_repo $lm_repo
178231

179232
log "------------------------------------------------------------"
180233
log "Run streaming Zipformer transducer (Bilingual, Chinese + English)"

.github/scripts/test-python.sh

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \
562562
$repo/test_wavs/1.wav \
563563
$repo/test_wavs/8k.wav
564564

565+
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
566+
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
567+
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
568+
lm_repo=$(basename $lm_repo_url)
569+
pushd $lm_repo
570+
git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
571+
popd
572+
573+
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
574+
log "Download bi-gram LM from ${bigram_repo_url}"
575+
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
576+
bigramlm_repo=$(basename $bigram_repo_url)
577+
pushd $bigramlm_repo
578+
git lfs pull --include "2gram.fst"
579+
popd
580+
581+
log "Perform offline decoding with RNN-LM and LODR"
582+
python3 ./python-api-examples/offline-decode-files.py \
583+
--tokens=$repo/tokens.txt \
584+
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
585+
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
586+
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
587+
--decoding-method=modified_beam_search \
588+
--lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
589+
--lodr-fst=$bigramlm_repo/2gram.fst \
590+
--lodr-scale=-0.5 \
591+
$repo/test_wavs/0.wav \
592+
$repo/test_wavs/1.wav \
593+
$repo/test_wavs/8k.wav
594+
565595
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
566596

567-
rm -rf $repo
597+
rm -rf $repo $lm_repo $bigramlm_repo
568598

569599
log "Test non-streaming paraformer models"
570600

python-api-examples/offline-decode-files.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,25 @@
3535
/path/to/0.wav \
3636
/path/to/1.wav
3737
38+
also with RNN LM rescoring and LODR (optional):
39+
40+
./python-api-examples/offline-decode-files.py \
41+
--tokens=/path/to/tokens.txt \
42+
--encoder=/path/to/encoder.onnx \
43+
--decoder=/path/to/decoder.onnx \
44+
--joiner=/path/to/joiner.onnx \
45+
--num-threads=2 \
46+
--decoding-method=modified_beam_search \
47+
--debug=false \
48+
--sample-rate=16000 \
49+
--feature-dim=80 \
50+
--lm=/path/to/lm.onnx \
51+
--lm-scale=0.1 \
52+
--lodr-fst=/path/to/lodr.fst \
53+
--lodr-scale=-0.1 \
54+
/path/to/0.wav \
55+
/path/to/1.wav
56+
3857
(3) For CTC models from NeMo
3958
4059
python3 ./python-api-examples/offline-decode-files.py \
@@ -269,6 +288,39 @@ def get_args():
269288
default="greedy_search",
270289
help="Valid values are greedy_search and modified_beam_search",
271290
)
291+
292+
parser.add_argument(
293+
"--lm",
294+
metavar="file",
295+
type=str,
296+
default="",
297+
help="Path to RNN LM model",
298+
)
299+
300+
parser.add_argument(
301+
"--lm-scale",
302+
metavar="lm_scale",
303+
type=float,
304+
default=0.1,
305+
help="LM model scale for rescoring",
306+
)
307+
308+
parser.add_argument(
309+
"--lodr-fst",
310+
metavar="file",
311+
type=str,
312+
default="",
313+
help="Path to LODR FST model. Used only when --lm is given.",
314+
)
315+
316+
parser.add_argument(
317+
"--lodr-scale",
318+
metavar="lodr_scale",
319+
type=float,
320+
default=-0.1,
321+
help="LODR scale for rescoring.Used only when --lodr_fst is given.",
322+
)
323+
272324
parser.add_argument(
273325
"--debug",
274326
type=bool,
@@ -364,6 +416,10 @@ def main():
364416
num_threads=args.num_threads,
365417
sample_rate=args.sample_rate,
366418
feature_dim=args.feature_dim,
419+
lm=args.lm,
420+
lm_scale=args.lm_scale,
421+
lodr_fst=args.lodr_fst,
422+
lodr_scale=args.lodr_scale,
367423
decoding_method=args.decoding_method,
368424
hotwords_file=args.hotwords_file,
369425
hotwords_score=args.hotwords_score,

python-api-examples/online-decode-files.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@
2121
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
2222
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
2323
24+
or with RNN LM rescoring and LODR:
25+
26+
./python-api-examples/online-decode-files.py \
27+
--tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
28+
--encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
29+
--decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
30+
--joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
31+
--decoding-method=modified_beam_search \
32+
--lm=/path/to/lm.onnx \
33+
--lm-scale=0.1 \
34+
--lodr-fst=/path/to/lodr.fst \
35+
--lodr-scale=-0.1 \
36+
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
37+
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
38+
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
39+
2440
(2) Streaming paraformer
2541
2642
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
@@ -186,6 +202,22 @@ def get_args():
186202
""",
187203
)
188204

205+
parser.add_argument(
206+
"--lodr-fst",
207+
metavar="file",
208+
type=str,
209+
default="",
210+
help="Path to LODR FST model. Used only when --lm is given.",
211+
)
212+
213+
parser.add_argument(
214+
"--lodr-scale",
215+
metavar="lodr_scale",
216+
type=float,
217+
default=-0.1,
218+
help="LODR scale for rescoring.Used only when --lodr_fst is given.",
219+
)
220+
189221
parser.add_argument(
190222
"--provider",
191223
type=str,
@@ -320,6 +352,8 @@ def main():
320352
max_active_paths=args.max_active_paths,
321353
lm=args.lm,
322354
lm_scale=args.lm_scale,
355+
lodr_fst=args.lodr_fst,
356+
lodr_scale=args.lodr_scale,
323357
hotwords_file=args.hotwords_file,
324358
hotwords_score=args.hotwords_score,
325359
modeling_unit=args.modeling_unit,

sherpa-onnx/csrc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set(sources
2525
jieba.cc
2626
keyword-spotter-impl.cc
2727
keyword-spotter.cc
28+
lodr-fst.cc
2829
offline-canary-model-config.cc
2930
offline-canary-model.cc
3031
offline-ctc-fst-decoder-config.cc

sherpa-onnx/csrc/hypothesis.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
#include <unordered_map>
1313
#include <utility>
1414
#include <vector>
15+
#include <memory>
1516

1617
#include "onnxruntime_cxx_api.h" // NOLINT
1718
#include "sherpa-onnx/csrc/context-graph.h"
19+
#include "sherpa-onnx/csrc/lodr-fst.h"
1820
#include "sherpa-onnx/csrc/math.h"
1921
#include "sherpa-onnx/csrc/onnx-utils.h"
2022

@@ -61,6 +63,9 @@ struct Hypothesis {
6163
// the nn lm states
6264
std::vector<CopyableOrtValue> nn_lm_states;
6365

66+
// the LODR states
67+
std::shared_ptr<LodrStateCost> lodr_state;
68+
6469
const ContextState *context_state;
6570

6671
// TODO(fangjun): Make it configurable

0 commit comments

Comments
 (0)