package punctuation import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "os" "strings" "go-whisper-api/config" ) type Restorer interface { Active() bool Restore(ctx context.Context, text, language string) (string, error) } type Closer interface { Close() } func New(cfg config.Punctuation) (Restorer, error) { cfg = cfg.WithDefaults() if !cfg.Active() { return Nop{}, nil } engine := strings.ToLower(strings.TrimSpace(cfg.Engine)) switch engine { case "heuristic": return Heuristic{}, nil case "sherpa", "sherpa-offline", "offline": cfg.Engine = "sherpa" return newSherpaRestorer(cfg) case "sherpa-online", "online": cfg.Engine = "sherpa-online" return newSherpaRestorer(cfg) case "xlm", "xlm-roberta", "roberta": return newXLM(cfg) case "http": if cfg.HTTPURL == "" { return nil, fmt.Errorf("punctuation.http_url is required when engine=http") } return HTTP{cfg: cfg}, nil default: return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine) } } type Nop struct{} func (Nop) Active() bool { return false } func (Nop) Restore(ctx context.Context, text, language string) (string, error) { return text, nil } type HTTP struct { cfg config.Punctuation client *http.Client } func (h HTTP) Active() bool { return true } func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) { body, err := json.Marshal(map[string]string{ "text": text, "language": language, }) if err != nil { return "", err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body)) if err != nil { return "", err } req.Header.Set("Content-Type", "application/json") client := h.client if client == nil { client = &http.Client{Timeout: h.cfg.Timeout()} } resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() raw, err := io.ReadAll(resp.Body) if err != nil { return "", err } if resp.StatusCode >= 300 { return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw))) } var out struct { Text string `json:"text"` } if err := json.Unmarshal(raw, &out); err != nil { return strings.TrimSpace(string(raw)), nil } if out.Text == "" { return text, nil } return out.Text, nil } func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) { if !enabled || r == nil || !r.Active() { return text, nil } text = strings.TrimSpace(text) if text == "" { return text, nil } return r.Restore(ctx, text, language) } func Close(r Restorer) { if c, ok := r.(Closer); ok { c.Close() } } func AutoSelect(cfg config.Punctuation) (Restorer, error) { cfg = cfg.WithDefaults() if !cfg.Active() { return Nop{}, nil } engine := strings.ToLower(cfg.Engine) if engine == "heuristic" { return Heuristic{}, nil } if engine == "http" { return New(cfg) } if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { return newXLM(cfg) } if _, err := os.Stat(cfg.ModelPath()); err == nil { cfg.Engine = engine if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" { cfg.Engine = "sherpa" } return newSherpaRestorer(cfg) } if engine == "sherpa" || engine == "sherpa-online" || engine == "online" { return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath()) } return Heuristic{}, nil }