go-whisper-api/punctuation/punctuation.go
admin 318b736244
Some checks failed
Docker Image / build-docker (push) Failing after 1m26s
Lint and Testing / lint (push) Successful in 43s
Lint and Testing / test (push) Successful in 5m38s
Lint and Testing / golangci (push) Successful in 1m14s
CodeQL / Analyze (go) (push) Successful in 6m23s
first commit
2026-06-04 19:25:56 +07:00

159 lines
4.0 KiB
Go

package punctuation
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"go-whisper-api/config"
)
type Restorer interface {
Active() bool
Restore(ctx context.Context, text, language string) (string, error)
}
type Closer interface {
Close()
}
func New(cfg config.Punctuation) (Restorer, error) {
cfg = cfg.WithDefaults()
if !cfg.Active() {
return Nop{}, nil
}
engine := strings.ToLower(strings.TrimSpace(cfg.Engine))
switch engine {
case "heuristic":
return Heuristic{}, nil
case "sherpa", "sherpa-offline", "offline":
cfg.Engine = "sherpa"
return newSherpaRestorer(cfg)
case "sherpa-online", "online":
cfg.Engine = "sherpa-online"
return newSherpaRestorer(cfg)
case "xlm", "xlm-roberta", "roberta":
return newXLM(cfg)
case "http":
if cfg.HTTPURL == "" {
return nil, fmt.Errorf("punctuation.http_url is required when engine=http")
}
return HTTP{cfg: cfg}, nil
default:
return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine)
}
}
type Nop struct{}
func (Nop) Active() bool {
return false
}
func (Nop) Restore(ctx context.Context, text, language string) (string, error) {
return text, nil
}
type HTTP struct {
cfg config.Punctuation
client *http.Client
}
func (h HTTP) Active() bool { return true }
func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) {
body, err := json.Marshal(map[string]string{
"text": text,
"language": language,
})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := h.client
if client == nil {
client = &http.Client{Timeout: h.cfg.Timeout()}
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode >= 300 {
return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw)))
}
var out struct {
Text string `json:"text"`
}
if err := json.Unmarshal(raw, &out); err != nil {
return strings.TrimSpace(string(raw)), nil
}
if out.Text == "" {
return text, nil
}
return out.Text, nil
}
func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) {
if !enabled || r == nil || !r.Active() {
return text, nil
}
text = strings.TrimSpace(text)
if text == "" {
return text, nil
}
out, err := r.Restore(ctx, text, language)
if err != nil {
return "", err
}
return CleanExcessive(out), nil
}
func Close(r Restorer) {
if c, ok := r.(Closer); ok {
c.Close()
}
}
func AutoSelect(cfg config.Punctuation) (Restorer, error) {
cfg = cfg.WithDefaults()
if !cfg.Active() {
return Nop{}, nil
}
engine := strings.ToLower(cfg.Engine)
if engine == "heuristic" {
return Heuristic{}, nil
}
if engine == "http" {
return New(cfg)
}
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
return newXLM(cfg)
}
if _, err := os.Stat(cfg.ModelPath()); err == nil {
cfg.Engine = engine
if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" {
cfg.Engine = "sherpa"
}
return newSherpaRestorer(cfg)
}
if engine == "sherpa" || engine == "sherpa-online" || engine == "online" {
return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath())
}
return Heuristic{}, nil
}