package config import ( "net/http" "path/filepath" "strconv" "strings" "time" ) func (p Punctuation) XLMJoinSentences() bool { return p.ApplySBD } type Punctuation struct { Command []string `yaml:"command"` NumThreads int `yaml:"num_threads"` TimeoutSec int `yaml:"timeout_sec"` Engine string `yaml:"engine"` ModelDir string `yaml:"model_dir"` ModelFile string `yaml:"model_file"` SPModel string `yaml:"sp_model"` ConfigFile string `yaml:"config_file"` BpeVocab string `yaml:"bpe_vocab"` HTTPURL string `yaml:"http_url"` Enabled bool `yaml:"enabled"` DefaultOn bool `yaml:"default_on"` ApplySBD bool `yaml:"apply_sbd"` } func (p Punctuation) WithDefaults() Punctuation { if p.Engine == "" { p.Engine = "heuristic" } engine := strings.ToLower(strings.TrimSpace(p.Engine)) if p.ModelDir == "" { if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { p.ModelDir = "./models/punctuation/xlm-roberta" } else { p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8" } } if p.ModelFile == "" { if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { p.ModelFile = "model.onnx" } else { p.ModelFile = "model.int8.onnx" } } if p.NumThreads <= 0 { p.NumThreads = 2 } if p.TimeoutSec <= 0 { p.TimeoutSec = 120 } return p } func (p Punctuation) Timeout() time.Duration { p = p.WithDefaults() return time.Duration(p.TimeoutSec) * time.Second } func (p Punctuation) Active() bool { p = p.WithDefaults() return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off") } func (p Punctuation) ModelPath() string { p = p.WithDefaults() return filepath.Join(p.ModelDir, p.ModelFile) } func (p Punctuation) SPModelPath() string { p = p.WithDefaults() if p.SPModel == "" { return filepath.Join(p.ModelDir, "sp.model") } return filepath.Join(p.ModelDir, p.SPModel) } func (p Punctuation) XLMConfigPath() string { p = p.WithDefaults() if p.ConfigFile == "" { return filepath.Join(p.ModelDir, "config.yaml") } return filepath.Join(p.ModelDir, p.ConfigFile) } func (p Punctuation) BpeVocabPath() string { p = p.WithDefaults() if p.BpeVocab == "" { return filepath.Join(p.ModelDir, "bpe.vocab") } return filepath.Join(p.ModelDir, p.BpeVocab) } func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool { if !p.Active() { return false } q := strings.TrimSpace(r.URL.Query().Get("punctuation")) if q != "" { return parsePunctuationQuery(q, true) } return apiDefault || p.Enabled } func parsePunctuationQuery(raw string, def bool) bool { raw = strings.TrimSpace(raw) if raw == "" { return def } switch strings.ToLower(raw) { case "1", "true", "yes", "on": return true case "0", "false", "no", "off": return false } b, err := strconv.ParseBool(raw) if err != nil { return def } return b }