Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
108 lines
2.9 KiB
Go
108 lines
2.9 KiB
Go
//go:build sherpa
|
|
|
|
package punctuation
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"go-whisper-api/config"
|
|
|
|
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
|
|
)
|
|
|
|
type Sherpa struct {
|
|
offline *sherpa.OfflinePunctuation
|
|
online *sherpa.OnlinePunctuation
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) {
|
|
return newSherpa(cfg)
|
|
}
|
|
|
|
func newSherpa(cfg config.Punctuation) (*Sherpa, error) {
|
|
cfg = cfg.WithDefaults()
|
|
engine := strings.ToLower(cfg.Engine)
|
|
s := &Sherpa{}
|
|
switch engine {
|
|
case "sherpa-online", "online":
|
|
modelPath := cfg.ModelPath()
|
|
vocabPath := cfg.BpeVocabPath()
|
|
if _, err := os.Stat(modelPath); err != nil {
|
|
return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err)
|
|
}
|
|
if _, err := os.Stat(vocabPath); err != nil {
|
|
return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err)
|
|
}
|
|
conf := sherpa.OnlinePunctuationConfig{}
|
|
conf.Model.CnnBilstm = modelPath
|
|
conf.Model.BpeVocab = vocabPath
|
|
conf.Model.NumThreads = cfg.NumThreads
|
|
conf.Model.Provider = "cpu"
|
|
s.online = sherpa.NewOnlinePunctuation(&conf)
|
|
if s.online == nil {
|
|
return nil, fmt.Errorf("failed to create sherpa online punctuation")
|
|
}
|
|
default:
|
|
modelPath := cfg.ModelPath()
|
|
if _, err := os.Stat(modelPath); err != nil {
|
|
return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err)
|
|
}
|
|
conf := sherpa.OfflinePunctuationConfig{}
|
|
conf.Model.CtTransformer = modelPath
|
|
conf.Model.NumThreads = cfg.NumThreads
|
|
conf.Model.Provider = "cpu"
|
|
s.offline = sherpa.NewOfflinePunctuation(&conf)
|
|
if s.offline == nil {
|
|
return nil, fmt.Errorf("failed to create sherpa offline punctuation")
|
|
}
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Sherpa) Active() bool {
|
|
return true
|
|
}
|
|
|
|
func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) {
|
|
_ = ctx
|
|
_ = language
|
|
text = strings.TrimSpace(text)
|
|
if text == "" {
|
|
return text, nil
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
var out string
|
|
switch {
|
|
case s.offline != nil:
|
|
out = s.offline.AddPunct(text)
|
|
case s.online != nil:
|
|
out = s.online.AddPunct(text)
|
|
default:
|
|
return text, fmt.Errorf("sherpa punctuation not initialized")
|
|
}
|
|
out = strings.TrimSpace(out)
|
|
if out == "" {
|
|
return text, nil
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Sherpa) Close() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.offline != nil {
|
|
sherpa.DeleteOfflinePunc(s.offline)
|
|
s.offline = nil
|
|
}
|
|
if s.online != nil {
|
|
sherpa.DeleteOnlinePunctuation(s.online)
|
|
s.online = nil
|
|
}
|
|
}
|