admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

108 lines
2.9 KiB
Go

//go:build sherpa
package punctuation
import (
"context"
"fmt"
"os"
"strings"
"sync"
"go-whisper-api/config"
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
)
type Sherpa struct {
offline *sherpa.OfflinePunctuation
online *sherpa.OnlinePunctuation
mu sync.Mutex
}
func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) {
return newSherpa(cfg)
}
func newSherpa(cfg config.Punctuation) (*Sherpa, error) {
cfg = cfg.WithDefaults()
engine := strings.ToLower(cfg.Engine)
s := &Sherpa{}
switch engine {
case "sherpa-online", "online":
modelPath := cfg.ModelPath()
vocabPath := cfg.BpeVocabPath()
if _, err := os.Stat(modelPath); err != nil {
return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err)
}
if _, err := os.Stat(vocabPath); err != nil {
return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err)
}
conf := sherpa.OnlinePunctuationConfig{}
conf.Model.CnnBilstm = modelPath
conf.Model.BpeVocab = vocabPath
conf.Model.NumThreads = cfg.NumThreads
conf.Model.Provider = "cpu"
s.online = sherpa.NewOnlinePunctuation(&conf)
if s.online == nil {
return nil, fmt.Errorf("failed to create sherpa online punctuation")
}
default:
modelPath := cfg.ModelPath()
if _, err := os.Stat(modelPath); err != nil {
return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err)
}
conf := sherpa.OfflinePunctuationConfig{}
conf.Model.CtTransformer = modelPath
conf.Model.NumThreads = cfg.NumThreads
conf.Model.Provider = "cpu"
s.offline = sherpa.NewOfflinePunctuation(&conf)
if s.offline == nil {
return nil, fmt.Errorf("failed to create sherpa offline punctuation")
}
}
return s, nil
}
func (s *Sherpa) Active() bool {
return true
}
func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) {
_ = ctx
_ = language
text = strings.TrimSpace(text)
if text == "" {
return text, nil
}
s.mu.Lock()
defer s.mu.Unlock()
var out string
switch {
case s.offline != nil:
out = s.offline.AddPunct(text)
case s.online != nil:
out = s.online.AddPunct(text)
default:
return text, fmt.Errorf("sherpa punctuation not initialized")
}
out = strings.TrimSpace(out)
if out == "" {
return text, nil
}
return out, nil
}
func (s *Sherpa) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.offline != nil {
sherpa.DeleteOfflinePunc(s.offline)
s.offline = nil
}
if s.online != nil {
sherpa.DeleteOnlinePunctuation(s.online)
s.online = nil
}
}