Skip to content

Commit 7f1d71f

Browse files
authored
Add Dart API for ten-vad (#2386)
1 parent 71aea2f commit 7f1d71f

File tree

5 files changed

+195
-10
lines changed

5 files changed

+195
-10
lines changed

.github/scripts/test-dart.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ set -ex
44

55
cd dart-api-examples
66

7+
pushd vad
8+
./run-ten-vad.sh
9+
./run.sh
10+
rm *.onnx
11+
popd
12+
713
pushd non-streaming-asr
814

915
echo '----------Zipformer CTC----------'
@@ -186,9 +192,3 @@ echo '----------streaming paraformer----------'
186192
rm -rf sherpa-onnx-*
187193

188194
popd # streaming-asr
189-
190-
pushd vad
191-
./run.sh
192-
rm *.onnx
193-
popd
194-
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) 2024 Xiaomi Corporation
2+
import 'dart:io';
3+
import 'dart:typed_data';
4+
5+
import 'package:args/args.dart';
6+
import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx;
7+
import './init.dart';
8+
9+
void main(List<String> arguments) async {
10+
await initSherpaOnnx();
11+
12+
final parser = ArgParser()
13+
..addOption('ten-vad', help: 'Path to ten-vad.onnx')
14+
..addOption('input-wav', help: 'Path to input.wav')
15+
..addOption('output-wav', help: 'Path to output.wav');
16+
17+
final res = parser.parse(arguments);
18+
if (res['ten-vad'] == null ||
19+
res['input-wav'] == null ||
20+
res['output-wav'] == null) {
21+
print(parser.usage);
22+
exit(1);
23+
}
24+
25+
final tenVad = res['ten-vad'] as String;
26+
final inputWav = res['input-wav'] as String;
27+
final outputWav = res['output-wav'] as String;
28+
29+
final tenVadConfig = sherpa_onnx.TenVadModelConfig(
30+
model: tenVad,
31+
threshold: 0.25,
32+
minSilenceDuration: 0.25,
33+
minSpeechDuration: 0.5,
34+
windowSize: 256,
35+
);
36+
37+
final config = sherpa_onnx.VadModelConfig(
38+
tenVad: tenVadConfig,
39+
numThreads: 1,
40+
debug: true,
41+
);
42+
43+
final vad = sherpa_onnx.VoiceActivityDetector(
44+
config: config, bufferSizeInSeconds: 10);
45+
46+
final waveData = sherpa_onnx.readWave(inputWav);
47+
if (waveData.sampleRate != 16000) {
48+
print('Only 16000 Hz is supported. Given: ${waveData.sampleRate}');
49+
exit(1);
50+
}
51+
52+
int numSamples = waveData.samples.length;
53+
int numIter = numSamples ~/ config.tenVad.windowSize;
54+
55+
List<List<double>> allSamples = [];
56+
57+
for (int i = 0; i != numIter; ++i) {
58+
int start = i * config.tenVad.windowSize;
59+
vad.acceptWaveform(Float32List.sublistView(
60+
waveData.samples, start, start + config.tenVad.windowSize));
61+
62+
if (vad.isDetected()) {
63+
while (!vad.isEmpty()) {
64+
allSamples.add(vad.front().samples);
65+
vad.pop();
66+
}
67+
}
68+
}
69+
70+
vad.flush();
71+
while (!vad.isEmpty()) {
72+
allSamples.add(vad.front().samples);
73+
vad.pop();
74+
}
75+
76+
vad.free();
77+
78+
final s = Float32List.fromList(allSamples.expand((x) => x).toList());
79+
sherpa_onnx.writeWave(
80+
filename: outputWav, samples: s, sampleRate: waveData.sampleRate);
81+
82+
print('Saved to $outputWav');
83+
}

dart-api-examples/vad/run-ten-vad.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env bash
2+
3+
set -ex
4+
5+
dart pub get
6+
7+
8+
if [[ ! -f ./ten-vad.onnx ]]; then
9+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
10+
fi
11+
12+
if [[ ! -f ./lei-jun-test.wav ]]; then
13+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
14+
fi
15+
16+
dart run \
17+
./bin/ten-vad.dart \
18+
--ten-vad ./ten-vad.onnx \
19+
--input-wav ./lei-jun-test.wav \
20+
--output-wav ./lei-jun-test-no-silence.wav
21+
22+
ls -lh *.wav

flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,25 @@ final class SherpaOnnxSileroVadModelConfig extends Struct {
487487
external double maxSpeechDuration;
488488
}
489489

490+
final class SherpaOnnxTenVadModelConfig extends Struct {
491+
external Pointer<Utf8> model;
492+
493+
@Float()
494+
external double threshold;
495+
496+
@Float()
497+
external double minSilenceDuration;
498+
499+
@Float()
500+
external double minSpeechDuration;
501+
502+
@Int32()
503+
external int windowSize;
504+
505+
@Float()
506+
external double maxSpeechDuration;
507+
}
508+
490509
final class SherpaOnnxVadModelConfig extends Struct {
491510
external SherpaOnnxSileroVadModelConfig sileroVad;
492511

@@ -500,6 +519,8 @@ final class SherpaOnnxVadModelConfig extends Struct {
500519

501520
@Int32()
502521
external int debug;
522+
523+
external SherpaOnnxTenVadModelConfig tenVad;
503524
}
504525

505526
final class SherpaOnnxSpeechSegment extends Struct {

flutter/sherpa_onnx/lib/src/vad.dart

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,62 @@ class SileroVadModelConfig {
4949
final double maxSpeechDuration;
5050
}
5151

52+
class TenVadModelConfig {
53+
const TenVadModelConfig(
54+
{this.model = '',
55+
this.threshold = 0.5,
56+
this.minSilenceDuration = 0.5,
57+
this.minSpeechDuration = 0.25,
58+
this.windowSize = 256,
59+
this.maxSpeechDuration = 5.0});
60+
61+
factory TenVadModelConfig.fromJson(Map<String, dynamic> json) {
62+
return TenVadModelConfig(
63+
model: json['model'] as String? ?? '',
64+
threshold: (json['threshold'] as num?)?.toDouble() ?? 0.5,
65+
minSilenceDuration:
66+
(json['minSilenceDuration'] as num?)?.toDouble() ?? 0.5,
67+
minSpeechDuration:
68+
(json['minSpeechDuration'] as num?)?.toDouble() ?? 0.25,
69+
windowSize: json['windowSize'] as int? ?? 256,
70+
maxSpeechDuration: (json['maxSpeechDuration'] as num?)?.toDouble() ?? 5.0,
71+
);
72+
}
73+
74+
@override
75+
String toString() {
76+
return 'TenVadModelConfig(model: $model, threshold: $threshold, minSilenceDuration: $minSilenceDuration, minSpeechDuration: $minSpeechDuration, windowSize: $windowSize, maxSpeechDuration: $maxSpeechDuration)';
77+
}
78+
79+
Map<String, dynamic> toJson() => {
80+
'model': model,
81+
'threshold': threshold,
82+
'minSilenceDuration': minSilenceDuration,
83+
'minSpeechDuration': minSpeechDuration,
84+
'windowSize': windowSize,
85+
'maxSpeechDuration': maxSpeechDuration,
86+
};
87+
88+
final String model;
89+
final double threshold;
90+
final double minSilenceDuration;
91+
final double minSpeechDuration;
92+
final int windowSize;
93+
final double maxSpeechDuration;
94+
}
95+
5296
class VadModelConfig {
5397
VadModelConfig({
5498
this.sileroVad = const SileroVadModelConfig(),
5599
this.sampleRate = 16000,
56100
this.numThreads = 1,
57101
this.provider = 'cpu',
58102
this.debug = true,
103+
this.tenVad = const TenVadModelConfig(),
59104
});
60105

61106
final SileroVadModelConfig sileroVad;
107+
final TenVadModelConfig tenVad;
62108
final int sampleRate;
63109
final int numThreads;
64110
final String provider;
@@ -68,6 +114,8 @@ class VadModelConfig {
68114
return VadModelConfig(
69115
sileroVad: SileroVadModelConfig.fromJson(
70116
json['sileroVad'] as Map<String, dynamic>? ?? const {}),
117+
tenVad: TenVadModelConfig.fromJson(
118+
json['tenVad'] as Map<String, dynamic>? ?? const {}),
71119
sampleRate: json['sampleRate'] as int? ?? 16000,
72120
numThreads: json['numThreads'] as int? ?? 1,
73121
provider: json['provider'] as String? ?? 'cpu',
@@ -77,6 +125,7 @@ class VadModelConfig {
77125

78126
Map<String, dynamic> toJson() => {
79127
'sileroVad': sileroVad.toJson(),
128+
'tenVad': tenVad.toJson(),
80129
'sampleRate': sampleRate,
81130
'numThreads': numThreads,
82131
'provider': provider,
@@ -85,7 +134,7 @@ class VadModelConfig {
85134

86135
@override
87136
String toString() {
88-
return 'VadModelConfig(sileroVad: $sileroVad, sampleRate: $sampleRate, numThreads: $numThreads, provider: $provider, debug: $debug)';
137+
return 'VadModelConfig(sileroVad: $sileroVad, tenVad: $tenVad, sampleRate: $sampleRate, numThreads: $numThreads, provider: $provider, debug: $debug)';
89138
}
90139
}
91140

@@ -168,15 +217,24 @@ class VoiceActivityDetector {
168217
{required VadModelConfig config, required double bufferSizeInSeconds}) {
169218
final c = calloc<SherpaOnnxVadModelConfig>();
170219

171-
final modelPtr = config.sileroVad.model.toNativeUtf8();
172-
c.ref.sileroVad.model = modelPtr;
220+
final sileroVadModelPtr = config.sileroVad.model.toNativeUtf8();
221+
c.ref.sileroVad.model = sileroVadModelPtr;
173222

174223
c.ref.sileroVad.threshold = config.sileroVad.threshold;
175224
c.ref.sileroVad.minSilenceDuration = config.sileroVad.minSilenceDuration;
176225
c.ref.sileroVad.minSpeechDuration = config.sileroVad.minSpeechDuration;
177226
c.ref.sileroVad.windowSize = config.sileroVad.windowSize;
178227
c.ref.sileroVad.maxSpeechDuration = config.sileroVad.maxSpeechDuration;
179228

229+
final tenVadModelPtr = config.tenVad.model.toNativeUtf8();
230+
c.ref.tenVad.model = tenVadModelPtr;
231+
232+
c.ref.tenVad.threshold = config.tenVad.threshold;
233+
c.ref.tenVad.minSilenceDuration = config.tenVad.minSilenceDuration;
234+
c.ref.tenVad.minSpeechDuration = config.tenVad.minSpeechDuration;
235+
c.ref.tenVad.windowSize = config.tenVad.windowSize;
236+
c.ref.tenVad.maxSpeechDuration = config.tenVad.maxSpeechDuration;
237+
180238
c.ref.sampleRate = config.sampleRate;
181239
c.ref.numThreads = config.numThreads;
182240

@@ -190,7 +248,8 @@ class VoiceActivityDetector {
190248
nullptr;
191249

192250
calloc.free(providerPtr);
193-
calloc.free(modelPtr);
251+
calloc.free(tenVadModelPtr);
252+
calloc.free(sileroVadModelPtr);
194253
calloc.free(c);
195254

196255
return VoiceActivityDetector._(ptr: ptr, config: config);

0 commit comments

Comments
 (0)