Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
126 lines
3.2 KiB
Go
126 lines
3.2 KiB
Go
package config
|
|
|
|
import (
|
|
"net/http"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func (p Punctuation) XLMJoinSentences() bool {
|
|
return p.ApplySBD
|
|
}
|
|
|
|
type Punctuation struct {
|
|
Command []string `yaml:"command"`
|
|
NumThreads int `yaml:"num_threads"`
|
|
TimeoutSec int `yaml:"timeout_sec"`
|
|
Engine string `yaml:"engine"`
|
|
ModelDir string `yaml:"model_dir"`
|
|
ModelFile string `yaml:"model_file"`
|
|
SPModel string `yaml:"sp_model"`
|
|
ConfigFile string `yaml:"config_file"`
|
|
BpeVocab string `yaml:"bpe_vocab"`
|
|
HTTPURL string `yaml:"http_url"`
|
|
Enabled bool `yaml:"enabled"`
|
|
DefaultOn bool `yaml:"default_on"`
|
|
ApplySBD bool `yaml:"apply_sbd"`
|
|
}
|
|
|
|
func (p Punctuation) WithDefaults() Punctuation {
|
|
if p.Engine == "" {
|
|
p.Engine = "heuristic"
|
|
}
|
|
engine := strings.ToLower(strings.TrimSpace(p.Engine))
|
|
if p.ModelDir == "" {
|
|
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
|
|
p.ModelDir = "./models/punctuation/xlm-roberta"
|
|
} else {
|
|
p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8"
|
|
}
|
|
}
|
|
if p.ModelFile == "" {
|
|
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
|
|
p.ModelFile = "model.onnx"
|
|
} else {
|
|
p.ModelFile = "model.int8.onnx"
|
|
}
|
|
}
|
|
if p.NumThreads <= 0 {
|
|
p.NumThreads = 2
|
|
}
|
|
if p.TimeoutSec <= 0 {
|
|
p.TimeoutSec = 120
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p Punctuation) Timeout() time.Duration {
|
|
p = p.WithDefaults()
|
|
return time.Duration(p.TimeoutSec) * time.Second
|
|
}
|
|
|
|
func (p Punctuation) Active() bool {
|
|
p = p.WithDefaults()
|
|
return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off")
|
|
}
|
|
|
|
func (p Punctuation) ModelPath() string {
|
|
p = p.WithDefaults()
|
|
return filepath.Join(p.ModelDir, p.ModelFile)
|
|
}
|
|
|
|
func (p Punctuation) SPModelPath() string {
|
|
p = p.WithDefaults()
|
|
if p.SPModel == "" {
|
|
return filepath.Join(p.ModelDir, "sp.model")
|
|
}
|
|
return filepath.Join(p.ModelDir, p.SPModel)
|
|
}
|
|
|
|
func (p Punctuation) XLMConfigPath() string {
|
|
p = p.WithDefaults()
|
|
if p.ConfigFile == "" {
|
|
return filepath.Join(p.ModelDir, "config.yaml")
|
|
}
|
|
return filepath.Join(p.ModelDir, p.ConfigFile)
|
|
}
|
|
|
|
func (p Punctuation) BpeVocabPath() string {
|
|
p = p.WithDefaults()
|
|
if p.BpeVocab == "" {
|
|
return filepath.Join(p.ModelDir, "bpe.vocab")
|
|
}
|
|
return filepath.Join(p.ModelDir, p.BpeVocab)
|
|
}
|
|
|
|
func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool {
|
|
if !p.Active() {
|
|
return false
|
|
}
|
|
q := strings.TrimSpace(r.URL.Query().Get("punctuation"))
|
|
if q != "" {
|
|
return parsePunctuationQuery(q, true)
|
|
}
|
|
return apiDefault || p.Enabled
|
|
}
|
|
|
|
func parsePunctuationQuery(raw string, def bool) bool {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return def
|
|
}
|
|
switch strings.ToLower(raw) {
|
|
case "1", "true", "yes", "on":
|
|
return true
|
|
case "0", "false", "no", "off":
|
|
return false
|
|
}
|
|
b, err := strconv.ParseBool(raw)
|
|
if err != nil {
|
|
return def
|
|
}
|
|
return b
|
|
}
|