Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
108 lines
3.0 KiB
Go
108 lines
3.0 KiB
Go
//go:build sherpa
|
|
|
|
package diarization
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"sync"
|
|
|
|
"go-whisper-api/config"
|
|
"go-whisper-api/whisper"
|
|
|
|
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
|
|
)
|
|
|
|
type sherpaEngine struct {
|
|
cfg config.Diarization
|
|
sd *sherpa.OfflineSpeakerDiarization
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newEngine(cfg config.Diarization) (Engine, error) {
|
|
cfg = cfg.WithDefaults()
|
|
if !cfg.Active() {
|
|
return &noopEngine{}, nil
|
|
}
|
|
if !cfg.ModelsPresent() {
|
|
return nil, fmt.Errorf("diarization models missing (run: make download-diarization-models)")
|
|
}
|
|
conf := &sherpa.OfflineSpeakerDiarizationConfig{
|
|
Segmentation: sherpa.OfflineSpeakerSegmentationModelConfig{
|
|
Pyannote: sherpa.OfflineSpeakerSegmentationPyannoteModelConfig{
|
|
Model: cfg.SegmentationPath(),
|
|
},
|
|
NumThreads: cfg.NumThreads,
|
|
Debug: 0,
|
|
Provider: "cpu",
|
|
},
|
|
Embedding: sherpa.SpeakerEmbeddingExtractorConfig{
|
|
Model: cfg.EmbeddingPath(),
|
|
NumThreads: cfg.NumThreads,
|
|
Debug: 0,
|
|
Provider: "cpu",
|
|
},
|
|
Clustering: sherpa.FastClusteringConfig{
|
|
NumClusters: cfg.NumClusters,
|
|
Threshold: cfg.ClusteringThreshold,
|
|
},
|
|
MinDurationOn: cfg.MinDurationOn,
|
|
MinDurationOff: cfg.MinDurationOff,
|
|
}
|
|
sd := sherpa.NewOfflineSpeakerDiarization(conf)
|
|
if sd == nil {
|
|
return nil, fmt.Errorf("failed to create sherpa speaker diarization")
|
|
}
|
|
return &sherpaEngine{cfg: cfg, sd: sd}, nil
|
|
}
|
|
|
|
type noopEngine struct{}
|
|
|
|
func (noopEngine) Active() bool { return false }
|
|
func (noopEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { return nil, nil }
|
|
func (noopEngine) Close() {}
|
|
|
|
func (e *sherpaEngine) Active() bool {
|
|
return e.sd != nil
|
|
}
|
|
|
|
func (e *sherpaEngine) Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) {
|
|
_ = ctx
|
|
if len(samples) == 0 {
|
|
return nil, fmt.Errorf("empty audio for diarization")
|
|
}
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
if _, err := os.Stat(e.cfg.SegmentationPath()); err != nil {
|
|
return nil, err
|
|
}
|
|
clusters := numClusters
|
|
if clusters <= 0 {
|
|
clusters = e.cfg.NumClusters
|
|
}
|
|
e.sd.SetConfig(&sherpa.OfflineSpeakerDiarizationConfig{
|
|
Clustering: sherpa.FastClusteringConfig{
|
|
NumClusters: clusters,
|
|
Threshold: e.cfg.ClusteringThreshold,
|
|
},
|
|
})
|
|
segments := e.sd.Process(samples)
|
|
out := make([]whisper.Turn, len(segments))
|
|
for i, s := range segments {
|
|
out[i] = whisper.Turn{
|
|
Start: s.Start,
|
|
End: s.End,
|
|
Speaker: s.Speaker,
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (e *sherpaEngine) Close() {
|
|
if e.sd != nil {
|
|
sherpa.DeleteOfflineSpeakerDiarization(e.sd)
|
|
e.sd = nil
|
|
}
|
|
}
|