1
1
#!/usr/bin/env python3
2
2
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
3
3
4
+ """
5
+ <|en|>
6
+ <|pnc|>
7
+ <|noitn|>
8
+ <|nodiarize|>
9
+ <|notimestamp|>
10
+ """
11
+
4
12
import os
5
- from typing import Tuple
13
+ from typing import Dict , Tuple
6
14
7
15
import nemo
8
- import onnxmltools
16
+ import onnx
9
17
import torch
10
18
from nemo .collections .common .parts import NEG_INF
11
- from onnxmltools .utils .float16_converter import convert_float_to_float16
12
19
from onnxruntime .quantization import QuantType , quantize_dynamic
13
20
14
21
"""
@@ -64,10 +71,25 @@ def fixed_form_attention_mask(input_mask, diagonal=None):
64
71
from nemo .collections .asr .models import EncDecMultiTaskModel
65
72
66
73
67
- def export_onnx_fp16 (onnx_fp32_path , onnx_fp16_path ):
68
- onnx_fp32_model = onnxmltools .utils .load_model (onnx_fp32_path )
69
- onnx_fp16_model = convert_float_to_float16 (onnx_fp32_model , keep_io_types = True )
70
- onnxmltools .utils .save_model (onnx_fp16_model , onnx_fp16_path )
74
+ def add_meta_data (filename : str , meta_data : Dict [str , str ]):
75
+ """Add meta data to an ONNX model. It is changed in-place.
76
+
77
+ Args:
78
+ filename:
79
+ Filename of the ONNX model to be changed.
80
+ meta_data:
81
+ Key-value pairs.
82
+ """
83
+ model = onnx .load (filename )
84
+ while len (model .metadata_props ):
85
+ model .metadata_props .pop ()
86
+
87
+ for key , value in meta_data .items ():
88
+ meta = model .metadata_props .add ()
89
+ meta .key = key
90
+ meta .value = str (value )
91
+
92
+ onnx .save (model , filename )
71
93
72
94
73
95
def lens_to_mask (lens , max_length ):
@@ -222,7 +244,7 @@ def export_decoder(canary_model):
222
244
),
223
245
"decoder.onnx" ,
224
246
dynamo = True ,
225
- opset_version = 18 ,
247
+ opset_version = 14 ,
226
248
external_data = False ,
227
249
input_names = [
228
250
"decoder_input_ids" ,
@@ -269,6 +291,29 @@ def export_tokens(canary_model):
269
291
@torch .no_grad ()
270
292
def main ():
271
293
canary_model = EncDecMultiTaskModel .from_pretrained ("nvidia/canary-180m-flash" )
294
+ canary_model .eval ()
295
+
296
+ preprocessor = canary_model .cfg ["preprocessor" ]
297
+ sample_rate = preprocessor ["sample_rate" ]
298
+ normalize_type = preprocessor ["normalize" ]
299
+ window_size = preprocessor ["window_size" ] # ms
300
+ window_stride = preprocessor ["window_stride" ] # ms
301
+ window = preprocessor ["window" ]
302
+ features = preprocessor ["features" ]
303
+ n_fft = preprocessor ["n_fft" ]
304
+ vocab_size = canary_model .tokenizer .vocab_size # 5248
305
+
306
+ subsampling_factor = canary_model .cfg ["encoder" ]["subsampling_factor" ]
307
+
308
+ assert sample_rate == 16000 , sample_rate
309
+ assert normalize_type == "per_feature" , normalize_type
310
+ assert window_size == 0.025 , window_size
311
+ assert window_stride == 0.01 , window_stride
312
+ assert window == "hann" , window
313
+ assert features == 128 , features
314
+ assert n_fft == 512 , n_fft
315
+ assert subsampling_factor == 8 , subsampling_factor
316
+
272
317
export_tokens (canary_model )
273
318
export_encoder (canary_model )
274
319
export_decoder (canary_model )
@@ -280,7 +325,32 @@ def main():
280
325
weight_type = QuantType .QUInt8 ,
281
326
)
282
327
283
- export_onnx_fp16 (f"{ m } .onnx" , f"{ m } .fp16.onnx" )
328
+ meta_data = {
329
+ "vocab_size" : vocab_size ,
330
+ "normalize_type" : normalize_type ,
331
+ "subsampling_factor" : subsampling_factor ,
332
+ "model_type" : "EncDecMultiTaskModel" ,
333
+ "version" : "1" ,
334
+ "model_author" : "NeMo" ,
335
+ "url" : "https://huggingface.co/nvidia/canary-180m-flash" ,
336
+ "feat_dim" : features ,
337
+ }
338
+
339
+ add_meta_data ("encoder.onnx" , meta_data )
340
+ add_meta_data ("encoder.int8.onnx" , meta_data )
341
+
342
+ """
343
+ To fix the following error with onnxruntime 1.17.1 and 1.16.3:
344
+
345
+ onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 :FAIL : Load model from ./decoder.int8.onnx failed:/Users/runner/work/1/s/onnxruntime/core/graph/model.cc:150 onnxruntime::Model::Model(onnx::ModelProto &&, const onnxruntime::PathString &, const onnxruntime::IOnnxRuntimeOpSchemaRegistryList *, const logging::Logger &, const onnxruntime::ModelOptions &)
346
+ Unsupported model IR version: 10, max supported IR version: 9
347
+ """
348
+ for filename in ["./decoder.onnx" , "./decoder.int8.onnx" ]:
349
+ model = onnx .load (filename )
350
+ print ("old" , model .ir_version )
351
+ model .ir_version = 9
352
+ print ("new" , model .ir_version )
353
+ onnx .save (model , filename )
284
354
285
355
os .system ("ls -lh *.onnx" )
286
356
0 commit comments