//go:build sherpa package punctuation import ( "context" "fmt" "os" "strings" "sync" "go-whisper-api/config" sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" ) type Sherpa struct { offline *sherpa.OfflinePunctuation online *sherpa.OnlinePunctuation mu sync.Mutex } func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) { return newSherpa(cfg) } func newSherpa(cfg config.Punctuation) (*Sherpa, error) { cfg = cfg.WithDefaults() engine := strings.ToLower(cfg.Engine) s := &Sherpa{} switch engine { case "sherpa-online", "online": modelPath := cfg.ModelPath() vocabPath := cfg.BpeVocabPath() if _, err := os.Stat(modelPath); err != nil { return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err) } if _, err := os.Stat(vocabPath); err != nil { return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err) } conf := sherpa.OnlinePunctuationConfig{} conf.Model.CnnBilstm = modelPath conf.Model.BpeVocab = vocabPath conf.Model.NumThreads = cfg.NumThreads conf.Model.Provider = "cpu" s.online = sherpa.NewOnlinePunctuation(&conf) if s.online == nil { return nil, fmt.Errorf("failed to create sherpa online punctuation") } default: modelPath := cfg.ModelPath() if _, err := os.Stat(modelPath); err != nil { return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err) } conf := sherpa.OfflinePunctuationConfig{} conf.Model.CtTransformer = modelPath conf.Model.NumThreads = cfg.NumThreads conf.Model.Provider = "cpu" s.offline = sherpa.NewOfflinePunctuation(&conf) if s.offline == nil { return nil, fmt.Errorf("failed to create sherpa offline punctuation") } } return s, nil } func (s *Sherpa) Active() bool { return true } func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) { _ = ctx _ = language text = strings.TrimSpace(text) if text == "" { return text, nil } s.mu.Lock() defer s.mu.Unlock() var out string switch { case s.offline != nil: out = s.offline.AddPunct(text) case s.online != nil: out = s.online.AddPunct(text) default: return text, fmt.Errorf("sherpa punctuation not initialized") } out = strings.TrimSpace(out) if out == "" { return text, nil } return out, nil } func (s *Sherpa) Close() { s.mu.Lock() defer s.mu.Unlock() if s.offline != nil { sherpa.DeleteOfflinePunc(s.offline) s.offline = nil } if s.online != nil { sherpa.DeleteOnlinePunctuation(s.online) s.online = nil } }