Skip to content

Commit 0b8ce73

Browse files
authored
Add Go API for ten-vad (#2384)
1 parent 2778498 commit 0b8ce73

File tree

3 files changed

+141
-64
lines changed

3 files changed

+141
-64
lines changed

go-api-examples/vad/main.go

Lines changed: 117 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package main
22

33
import (
44
"fmt"
5-
portaudio "github.com/csukuangfj/portaudio-go"
5+
"github.com/gen2brain/malgo"
66
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
77
"log"
8+
"os"
89
)
910

1011
func main() {
@@ -13,100 +14,153 @@ func main() {
1314
config := sherpa.VadModelConfig{}
1415

1516
// Please download silero_vad.onnx from
16-
// https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx
17+
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
18+
// or ten-vad.onnx from
19+
// https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
20+
21+
if FileExists("./silero_vad.onnx") {
22+
fmt.Println("Use silero-vad")
23+
config.SileroVad.Model = "./silero_vad.onnx"
24+
config.SileroVad.Threshold = 0.5
25+
config.SileroVad.MinSilenceDuration = 0.5
26+
config.SileroVad.MinSpeechDuration = 0.25
27+
config.SileroVad.MaxSpeechDuration = 10
28+
config.SileroVad.WindowSize = 512
29+
} else if FileExists("./ten-vad.onnx") {
30+
fmt.Println("Use ten-vad")
31+
config.TenVad.Model = "./ten-vad.onnx"
32+
config.TenVad.Threshold = 0.5
33+
config.TenVad.MinSilenceDuration = 0.5
34+
config.TenVad.MinSpeechDuration = 0.25
35+
config.TenVad.MaxSpeechDuration = 10
36+
config.TenVad.WindowSize = 256
37+
} else {
38+
fmt.Println("Please download either ./silero_vad.onnx or ./ten-vad.onnx")
39+
return
40+
}
1741

18-
config.SileroVad.Model = "./silero_vad.onnx"
19-
config.SileroVad.Threshold = 0.5
20-
config.SileroVad.MinSilenceDuration = 0.5
21-
config.SileroVad.MinSpeechDuration = 0.25
22-
config.SileroVad.WindowSize = 512
2342
config.SampleRate = 16000
2443
config.NumThreads = 1
2544
config.Provider = "cpu"
2645
config.Debug = 1
2746

47+
windowSize := config.SileroVad.WindowSize
48+
if config.TenVad.Model != "" {
49+
windowSize = config.TenVad.WindowSize
50+
}
51+
2852
var bufferSizeInSeconds float32 = 5
2953

3054
vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds)
3155
defer sherpa.DeleteVoiceActivityDetector(vad)
3256

33-
err := portaudio.Initialize()
34-
if err != nil {
35-
log.Fatalf("Unable to initialize portaudio: %v\n", err)
36-
}
37-
defer portaudio.Terminate()
57+
buffer := sherpa.NewCircularBuffer(10 * config.SampleRate)
58+
defer sherpa.DeleteCircularBuffer(buffer)
3859

39-
default_device, err := portaudio.DefaultInputDevice()
40-
if err != nil {
41-
log.Fatal("Failed to get default input device: %v\n", err)
42-
}
43-
log.Printf("Selected default input device: %s\n", default_device.Name)
44-
param := portaudio.StreamParameters{}
45-
param.Input.Device = default_device
46-
param.Input.Channels = 1
47-
param.Input.Latency = default_device.DefaultLowInputLatency
60+
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, func(message string) {
61+
fmt.Printf("LOG <%v>", message)
62+
})
63+
chk(err)
4864

49-
param.SampleRate = float64(config.SampleRate)
50-
param.FramesPerBuffer = 0
51-
param.Flags = portaudio.ClipOff
65+
defer func() {
66+
_ = ctx.Uninit()
67+
ctx.Free()
68+
}()
5269

53-
// you can choose another value for 0.1 if you want
54-
samplesPerCall := int32(param.SampleRate * 0.1) // 0.1 second
55-
samples := make([]float32, samplesPerCall)
70+
deviceConfig := malgo.DefaultDeviceConfig(malgo.Duplex)
71+
deviceConfig.Capture.Format = malgo.FormatS16
72+
deviceConfig.Capture.Channels = 1
73+
deviceConfig.Playback.Format = malgo.FormatS16
74+
deviceConfig.Playback.Channels = 1
75+
deviceConfig.SampleRate = 16000
76+
deviceConfig.Alsa.NoMMap = 1
5677

57-
s, err := portaudio.OpenStream(param, samples)
58-
if err != nil {
59-
log.Fatalf("Failed to open the stream")
60-
}
61-
62-
defer s.Close()
63-
chk(s.Start())
64-
65-
log.Print("Started! Please speak")
6678
printed := false
67-
6879
k := 0
69-
for {
70-
chk(s.Read())
71-
vad.AcceptWaveform(samples)
7280

73-
if vad.IsSpeech() && !printed {
74-
printed = true
75-
log.Print("Detected speech\n")
76-
}
81+
onRecvFrames := func(_, pSample []byte, framecount uint32) {
82+
samples := samplesInt16ToFloat(pSample)
83+
buffer.Push(samples)
84+
for buffer.Size() >= windowSize {
85+
head := buffer.Head()
86+
s := buffer.Get(head, windowSize)
87+
buffer.Pop(windowSize)
7788

78-
if !vad.IsSpeech() {
79-
printed = false
80-
}
89+
vad.AcceptWaveform(s)
8190

82-
for !vad.IsEmpty() {
83-
speechSegment := vad.Front()
84-
vad.Pop()
91+
if vad.IsSpeech() && !printed {
92+
printed = true
93+
log.Print("Detected speech\n")
94+
}
8595

86-
duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate)
96+
if !vad.IsSpeech() {
97+
printed = false
98+
}
8799

88-
audio := sherpa.GeneratedAudio{}
89-
audio.Samples = speechSegment.Samples
90-
audio.SampleRate = config.SampleRate
100+
for !vad.IsEmpty() {
101+
speechSegment := vad.Front()
102+
vad.Pop()
91103

92-
filename := fmt.Sprintf("seg-%d-%.2f-seconds.wav", k, duration)
93-
ok := audio.Save(filename)
94-
if ok {
95-
log.Printf("Saved to %s", filename)
96-
}
104+
duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate)
97105

98-
k += 1
106+
audio := sherpa.GeneratedAudio{}
107+
audio.Samples = speechSegment.Samples
108+
audio.SampleRate = config.SampleRate
99109

100-
log.Printf("Duration: %.2f seconds\n", duration)
101-
log.Print("----------\n")
110+
filename := fmt.Sprintf("seg-%d-%.2f-seconds.wav", k, duration)
111+
ok := audio.Save(filename)
112+
if ok {
113+
log.Printf("Saved to %s", filename)
114+
}
115+
116+
k += 1
117+
118+
log.Printf("Duration: %.2f seconds\n", duration)
119+
log.Print("----------\n")
120+
}
102121
}
103122
}
104123

105-
chk(s.Stop())
124+
captureCallbacks := malgo.DeviceCallbacks{
125+
Data: onRecvFrames,
126+
}
127+
128+
device, err := malgo.InitDevice(ctx.Context, deviceConfig, captureCallbacks)
129+
chk(err)
130+
131+
err = device.Start()
132+
chk(err)
133+
134+
fmt.Println("Started. Please speak. Press ctrl + C to exit")
135+
fmt.Scanln()
136+
device.Uninit()
137+
106138
}
107139

108140
func chk(err error) {
109141
if err != nil {
110142
panic(err)
111143
}
112144
}
145+
146+
func samplesInt16ToFloat(inSamples []byte) []float32 {
147+
numSamples := len(inSamples) / 2
148+
outSamples := make([]float32, numSamples)
149+
150+
for i := 0; i != numSamples; i++ {
151+
// Decode two bytes into an int16 using bit manipulation
152+
s16 := int16(inSamples[2*i]) | int16(inSamples[2*i+1])<<8
153+
outSamples[i] = float32(s16) / 32768
154+
}
155+
156+
return outSamples
157+
}
158+
159+
func FileExists(path string) bool {
160+
_, err := os.Stat(path)
161+
if err == nil {
162+
return true
163+
}
164+
165+
return false
166+
}

go-api-examples/vad/run.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
set -ex
44

55
if [ ! -f ./silero_vad.onnx ]; then
6-
curl -SL -O https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx
6+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
7+
fi
8+
9+
if [ ! -f ./ten-vad.onnx ]; then
10+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
711
fi
812

913
go mod tidy

scripts/go/sherpa_onnx.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,8 +1142,18 @@ type SileroVadModelConfig struct {
11421142
MaxSpeechDuration float32
11431143
}
11441144

1145+
type TenVadModelConfig struct {
1146+
Model string
1147+
Threshold float32
1148+
MinSilenceDuration float32
1149+
MinSpeechDuration float32
1150+
WindowSize int
1151+
MaxSpeechDuration float32
1152+
}
1153+
11451154
type VadModelConfig struct {
11461155
SileroVad SileroVadModelConfig
1156+
TenVad TenVadModelConfig
11471157
SampleRate int
11481158
NumThreads int
11491159
Provider string
@@ -1220,6 +1230,15 @@ func NewVoiceActivityDetector(config *VadModelConfig, bufferSizeInSeconds float3
12201230
c.silero_vad.window_size = C.int(config.SileroVad.WindowSize)
12211231
c.silero_vad.max_speech_duration = C.float(config.SileroVad.MaxSpeechDuration)
12221232

1233+
c.ten_vad.model = C.CString(config.TenVad.Model)
1234+
defer C.free(unsafe.Pointer(c.ten_vad.model))
1235+
1236+
c.ten_vad.threshold = C.float(config.TenVad.Threshold)
1237+
c.ten_vad.min_silence_duration = C.float(config.TenVad.MinSilenceDuration)
1238+
c.ten_vad.min_speech_duration = C.float(config.TenVad.MinSpeechDuration)
1239+
c.ten_vad.window_size = C.int(config.TenVad.WindowSize)
1240+
c.ten_vad.max_speech_duration = C.float(config.TenVad.MaxSpeechDuration)
1241+
12231242
c.sample_rate = C.int(config.SampleRate)
12241243
c.num_threads = C.int(config.NumThreads)
12251244
c.provider = C.CString(config.Provider)

0 commit comments

Comments
 (0)