//go:build sherpa package diarization import ( "context" "fmt" "os" "sync" "go-whisper-api/config" "go-whisper-api/whisper" sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" ) type sherpaEngine struct { cfg config.Diarization sd *sherpa.OfflineSpeakerDiarization mu sync.Mutex } func newEngine(cfg config.Diarization) (Engine, error) { cfg = cfg.WithDefaults() if !cfg.Active() { return &noopEngine{}, nil } if !cfg.ModelsPresent() { return nil, fmt.Errorf("diarization models missing (run: make download-diarization-models)") } conf := &sherpa.OfflineSpeakerDiarizationConfig{ Segmentation: sherpa.OfflineSpeakerSegmentationModelConfig{ Pyannote: sherpa.OfflineSpeakerSegmentationPyannoteModelConfig{ Model: cfg.SegmentationPath(), }, NumThreads: cfg.NumThreads, Debug: 0, Provider: "cpu", }, Embedding: sherpa.SpeakerEmbeddingExtractorConfig{ Model: cfg.EmbeddingPath(), NumThreads: cfg.NumThreads, Debug: 0, Provider: "cpu", }, Clustering: sherpa.FastClusteringConfig{ NumClusters: cfg.NumClusters, Threshold: cfg.ClusteringThreshold, }, MinDurationOn: cfg.MinDurationOn, MinDurationOff: cfg.MinDurationOff, } sd := sherpa.NewOfflineSpeakerDiarization(conf) if sd == nil { return nil, fmt.Errorf("failed to create sherpa speaker diarization") } return &sherpaEngine{cfg: cfg, sd: sd}, nil } type noopEngine struct{} func (noopEngine) Active() bool { return false } func (noopEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { return nil, nil } func (noopEngine) Close() {} func (e *sherpaEngine) Active() bool { return e.sd != nil } func (e *sherpaEngine) Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) { _ = ctx if len(samples) == 0 { return nil, fmt.Errorf("empty audio for diarization") } e.mu.Lock() defer e.mu.Unlock() if _, err := os.Stat(e.cfg.SegmentationPath()); err != nil { return nil, err } clusters := numClusters if clusters <= 0 { clusters = e.cfg.NumClusters } e.sd.SetConfig(&sherpa.OfflineSpeakerDiarizationConfig{ Clustering: sherpa.FastClusteringConfig{ NumClusters: clusters, Threshold: e.cfg.ClusteringThreshold, }, }) segments := e.sd.Process(samples) out := make([]whisper.Turn, len(segments)) for i, s := range segments { out[i] = whisper.Turn{ Start: s.Start, End: s.End, Speaker: s.Speaker, } } return out, nil } func (e *sherpaEngine) Close() { if e.sd != nil { sherpa.DeleteOfflineSpeakerDiarization(e.sd) e.sd = nil } }