go-whisper-api/config/punctuation.go
admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

126 lines
3.2 KiB
Go

package config
import (
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
)
func (p Punctuation) XLMJoinSentences() bool {
return p.ApplySBD
}
type Punctuation struct {
Command []string `yaml:"command"`
NumThreads int `yaml:"num_threads"`
TimeoutSec int `yaml:"timeout_sec"`
Engine string `yaml:"engine"`
ModelDir string `yaml:"model_dir"`
ModelFile string `yaml:"model_file"`
SPModel string `yaml:"sp_model"`
ConfigFile string `yaml:"config_file"`
BpeVocab string `yaml:"bpe_vocab"`
HTTPURL string `yaml:"http_url"`
Enabled bool `yaml:"enabled"`
DefaultOn bool `yaml:"default_on"`
ApplySBD bool `yaml:"apply_sbd"`
}
func (p Punctuation) WithDefaults() Punctuation {
if p.Engine == "" {
p.Engine = "heuristic"
}
engine := strings.ToLower(strings.TrimSpace(p.Engine))
if p.ModelDir == "" {
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
p.ModelDir = "./models/punctuation/xlm-roberta"
} else {
p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8"
}
}
if p.ModelFile == "" {
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
p.ModelFile = "model.onnx"
} else {
p.ModelFile = "model.int8.onnx"
}
}
if p.NumThreads <= 0 {
p.NumThreads = 2
}
if p.TimeoutSec <= 0 {
p.TimeoutSec = 120
}
return p
}
func (p Punctuation) Timeout() time.Duration {
p = p.WithDefaults()
return time.Duration(p.TimeoutSec) * time.Second
}
func (p Punctuation) Active() bool {
p = p.WithDefaults()
return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off")
}
func (p Punctuation) ModelPath() string {
p = p.WithDefaults()
return filepath.Join(p.ModelDir, p.ModelFile)
}
func (p Punctuation) SPModelPath() string {
p = p.WithDefaults()
if p.SPModel == "" {
return filepath.Join(p.ModelDir, "sp.model")
}
return filepath.Join(p.ModelDir, p.SPModel)
}
func (p Punctuation) XLMConfigPath() string {
p = p.WithDefaults()
if p.ConfigFile == "" {
return filepath.Join(p.ModelDir, "config.yaml")
}
return filepath.Join(p.ModelDir, p.ConfigFile)
}
func (p Punctuation) BpeVocabPath() string {
p = p.WithDefaults()
if p.BpeVocab == "" {
return filepath.Join(p.ModelDir, "bpe.vocab")
}
return filepath.Join(p.ModelDir, p.BpeVocab)
}
func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool {
if !p.Active() {
return false
}
q := strings.TrimSpace(r.URL.Query().Get("punctuation"))
if q != "" {
return parsePunctuationQuery(q, true)
}
return apiDefault || p.Enabled
}
func parsePunctuationQuery(raw string, def bool) bool {
raw = strings.TrimSpace(raw)
if raw == "" {
return def
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
b, err := strconv.ParseBool(raw)
if err != nil {
return def
}
return b
}