Skip to content

Commit fce481c

Browse files
authored
Add meta data to NeMo canary ONNX models (#2351)
1 parent 25f9cec commit fce481c

File tree

4 files changed

+87
-68
lines changed

4 files changed

+87
-68
lines changed

.github/workflows/export-nemo-canary-180m-flash.yaml

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,7 @@ jobs:
6262
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
6363
mkdir -p $d
6464
cp encoder.int8.onnx $d
65-
cp decoder.fp16.onnx $d
66-
cp tokens.txt $d
67-
68-
mkdir $d/test_wavs
69-
cp de.wav $d/test_wavs
70-
cp en.wav $d/test_wavs
71-
72-
tar cjfv $d.tar.bz2 $d
73-
74-
- name: Collect files (fp16)
75-
shell: bash
76-
run: |
77-
d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
78-
mkdir -p $d
79-
cp encoder.fp16.onnx $d
80-
cp decoder.fp16.onnx $d
65+
cp decoder.int8.onnx $d
8166
cp tokens.txt $d
8267
8368
mkdir $d/test_wavs
@@ -101,7 +86,6 @@ jobs:
10186
models=(
10287
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
10388
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
104-
sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
10589
)
10690
10791
for m in ${models[@]}; do

scripts/nemo/canary/export_onnx_180m_flash.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
#!/usr/bin/env python3
22
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
33

4+
"""
5+
<|en|>
6+
<|pnc|>
7+
<|noitn|>
8+
<|nodiarize|>
9+
<|notimestamp|>
10+
"""
11+
412
import os
5-
from typing import Tuple
13+
from typing import Dict, Tuple
614

715
import nemo
8-
import onnxmltools
16+
import onnx
917
import torch
1018
from nemo.collections.common.parts import NEG_INF
11-
from onnxmltools.utils.float16_converter import convert_float_to_float16
1219
from onnxruntime.quantization import QuantType, quantize_dynamic
1320

1421
"""
@@ -64,10 +71,25 @@ def fixed_form_attention_mask(input_mask, diagonal=None):
6471
from nemo.collections.asr.models import EncDecMultiTaskModel
6572

6673

67-
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
68-
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
69-
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
70-
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
74+
def add_meta_data(filename: str, meta_data: Dict[str, str]):
75+
"""Add meta data to an ONNX model. It is changed in-place.
76+
77+
Args:
78+
filename:
79+
Filename of the ONNX model to be changed.
80+
meta_data:
81+
Key-value pairs.
82+
"""
83+
model = onnx.load(filename)
84+
while len(model.metadata_props):
85+
model.metadata_props.pop()
86+
87+
for key, value in meta_data.items():
88+
meta = model.metadata_props.add()
89+
meta.key = key
90+
meta.value = str(value)
91+
92+
onnx.save(model, filename)
7193

7294

7395
def lens_to_mask(lens, max_length):
@@ -222,7 +244,7 @@ def export_decoder(canary_model):
222244
),
223245
"decoder.onnx",
224246
dynamo=True,
225-
opset_version=18,
247+
opset_version=14,
226248
external_data=False,
227249
input_names=[
228250
"decoder_input_ids",
@@ -269,6 +291,29 @@ def export_tokens(canary_model):
269291
@torch.no_grad()
270292
def main():
271293
canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
294+
canary_model.eval()
295+
296+
preprocessor = canary_model.cfg["preprocessor"]
297+
sample_rate = preprocessor["sample_rate"]
298+
normalize_type = preprocessor["normalize"]
299+
window_size = preprocessor["window_size"] # ms
300+
window_stride = preprocessor["window_stride"] # ms
301+
window = preprocessor["window"]
302+
features = preprocessor["features"]
303+
n_fft = preprocessor["n_fft"]
304+
vocab_size = canary_model.tokenizer.vocab_size # 5248
305+
306+
subsampling_factor = canary_model.cfg["encoder"]["subsampling_factor"]
307+
308+
assert sample_rate == 16000, sample_rate
309+
assert normalize_type == "per_feature", normalize_type
310+
assert window_size == 0.025, window_size
311+
assert window_stride == 0.01, window_stride
312+
assert window == "hann", window
313+
assert features == 128, features
314+
assert n_fft == 512, n_fft
315+
assert subsampling_factor == 8, subsampling_factor
316+
272317
export_tokens(canary_model)
273318
export_encoder(canary_model)
274319
export_decoder(canary_model)
@@ -280,7 +325,32 @@ def main():
280325
weight_type=QuantType.QUInt8,
281326
)
282327

283-
export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
328+
meta_data = {
329+
"vocab_size": vocab_size,
330+
"normalize_type": normalize_type,
331+
"subsampling_factor": subsampling_factor,
332+
"model_type": "EncDecMultiTaskModel",
333+
"version": "1",
334+
"model_author": "NeMo",
335+
"url": "https://huggingface.co/nvidia/canary-180m-flash",
336+
"feat_dim": features,
337+
}
338+
339+
add_meta_data("encoder.onnx", meta_data)
340+
add_meta_data("encoder.int8.onnx", meta_data)
341+
342+
"""
343+
To fix the following error with onnxruntime 1.17.1 and 1.16.3:
344+
345+
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
346+
Unsupported model IR version: 10, max supported IR version: 9
347+
"""
348+
for filename in ["./decoder.onnx", "./decoder.int8.onnx"]:
349+
model = onnx.load(filename)
350+
print("old", model.ir_version)
351+
model.ir_version = 9
352+
print("new", model.ir_version)
353+
onnx.save(model, filename)
284354

285355
os.system("ls -lh *.onnx")
286356

scripts/nemo/canary/run_180m_flash.sh

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ pip install \
1919
kaldi-native-fbank \
2020
librosa \
2121
onnx==1.17.0 \
22-
onnxmltools \
2322
onnxruntime==1.17.1 \
23+
onnxscript \
2424
soundfile
2525

2626
python3 ./export_onnx_180m_flash.py
@@ -66,65 +66,31 @@ log "-----int8------"
6666

6767
python3 ./test_180m_flash.py \
6868
--encoder ./encoder.int8.onnx \
69-
--decoder ./decoder.fp16.onnx \
69+
--decoder ./decoder.int8.onnx \
7070
--source-lang en \
7171
--target-lang en \
7272
--tokens ./tokens.txt \
7373
--wav ./en.wav
7474

7575
python3 ./test_180m_flash.py \
7676
--encoder ./encoder.int8.onnx \
77-
--decoder ./decoder.fp16.onnx \
77+
--decoder ./decoder.int8.onnx \
7878
--source-lang en \
7979
--target-lang de \
8080
--tokens ./tokens.txt \
8181
--wav ./en.wav
8282

8383
python3 ./test_180m_flash.py \
8484
--encoder ./encoder.int8.onnx \
85-
--decoder ./decoder.fp16.onnx \
85+
--decoder ./decoder.int8.onnx \
8686
--source-lang de \
8787
--target-lang de \
8888
--tokens ./tokens.txt \
8989
--wav ./de.wav
9090

9191
python3 ./test_180m_flash.py \
9292
--encoder ./encoder.int8.onnx \
93-
--decoder ./decoder.fp16.onnx \
94-
--source-lang de \
95-
--target-lang en \
96-
--tokens ./tokens.txt \
97-
--wav ./de.wav
98-
99-
log "-----fp16------"
100-
101-
python3 ./test_180m_flash.py \
102-
--encoder ./encoder.fp16.onnx \
103-
--decoder ./decoder.fp16.onnx \
104-
--source-lang en \
105-
--target-lang en \
106-
--tokens ./tokens.txt \
107-
--wav ./en.wav
108-
109-
python3 ./test_180m_flash.py \
110-
--encoder ./encoder.fp16.onnx \
111-
--decoder ./decoder.fp16.onnx \
112-
--source-lang en \
113-
--target-lang de \
114-
--tokens ./tokens.txt \
115-
--wav ./en.wav
116-
117-
python3 ./test_180m_flash.py \
118-
--encoder ./encoder.fp16.onnx \
119-
--decoder ./decoder.fp16.onnx \
120-
--source-lang de \
121-
--target-lang de \
122-
--tokens ./tokens.txt \
123-
--wav ./de.wav
124-
125-
python3 ./test_180m_flash.py \
126-
--encoder ./encoder.fp16.onnx \
127-
--decoder ./decoder.fp16.onnx \
93+
--decoder ./decoder.int8.onnx \
12894
--source-lang de \
12995
--target-lang en \
13096
--tokens ./tokens.txt \

scripts/nemo/canary/test_180m_flash.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def init_encoder(self, encoder):
7979
)
8080

8181
meta = self.encoder.get_modelmeta().custom_metadata_map
82-
# self.normalize_type = meta["normalize_type"]
83-
self.normalize_type = "per_feature"
82+
self.normalize_type = meta["normalize_type"]
8483
print(meta)
8584

8685
def init_decoder(self, decoder):
@@ -267,7 +266,7 @@ def main():
267266

268267
for pos, decoder_input_id in enumerate(decoder_input_ids):
269268
logits, decoder_mems_list = model.run_decoder(
270-
np.array([[decoder_input_id,pos]], dtype=np.int32),
269+
np.array([[decoder_input_id, pos]], dtype=np.int32),
271270
decoder_mems_list,
272271
enc_states,
273272
enc_masks,

0 commit comments

Comments
 (0)