159 lines
4.0 KiB
Go
159 lines
4.0 KiB
Go
package punctuation
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
"go-whisper-api/config"
|
|
)
|
|
|
|
type Restorer interface {
|
|
Active() bool
|
|
Restore(ctx context.Context, text, language string) (string, error)
|
|
}
|
|
|
|
type Closer interface {
|
|
Close()
|
|
}
|
|
|
|
func New(cfg config.Punctuation) (Restorer, error) {
|
|
cfg = cfg.WithDefaults()
|
|
if !cfg.Active() {
|
|
return Nop{}, nil
|
|
}
|
|
engine := strings.ToLower(strings.TrimSpace(cfg.Engine))
|
|
switch engine {
|
|
case "heuristic":
|
|
return Heuristic{}, nil
|
|
case "sherpa", "sherpa-offline", "offline":
|
|
cfg.Engine = "sherpa"
|
|
return newSherpaRestorer(cfg)
|
|
case "sherpa-online", "online":
|
|
cfg.Engine = "sherpa-online"
|
|
return newSherpaRestorer(cfg)
|
|
case "xlm", "xlm-roberta", "roberta":
|
|
return newXLM(cfg)
|
|
case "http":
|
|
if cfg.HTTPURL == "" {
|
|
return nil, fmt.Errorf("punctuation.http_url is required when engine=http")
|
|
}
|
|
return HTTP{cfg: cfg}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine)
|
|
}
|
|
}
|
|
|
|
type Nop struct{}
|
|
|
|
func (Nop) Active() bool {
|
|
return false
|
|
}
|
|
|
|
func (Nop) Restore(ctx context.Context, text, language string) (string, error) {
|
|
return text, nil
|
|
}
|
|
|
|
type HTTP struct {
|
|
cfg config.Punctuation
|
|
client *http.Client
|
|
}
|
|
|
|
func (h HTTP) Active() bool { return true }
|
|
|
|
func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) {
|
|
body, err := json.Marshal(map[string]string{
|
|
"text": text,
|
|
"language": language,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := h.client
|
|
if client == nil {
|
|
client = &http.Client{Timeout: h.cfg.Timeout()}
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
raw, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if resp.StatusCode >= 300 {
|
|
return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw)))
|
|
}
|
|
var out struct {
|
|
Text string `json:"text"`
|
|
}
|
|
if err := json.Unmarshal(raw, &out); err != nil {
|
|
return strings.TrimSpace(string(raw)), nil
|
|
}
|
|
if out.Text == "" {
|
|
return text, nil
|
|
}
|
|
return out.Text, nil
|
|
}
|
|
|
|
func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) {
|
|
if !enabled || r == nil || !r.Active() {
|
|
return text, nil
|
|
}
|
|
text = strings.TrimSpace(text)
|
|
if text == "" {
|
|
return text, nil
|
|
}
|
|
out, err := r.Restore(ctx, text, language)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return CleanExcessive(out), nil
|
|
}
|
|
|
|
func Close(r Restorer) {
|
|
if c, ok := r.(Closer); ok {
|
|
c.Close()
|
|
}
|
|
}
|
|
|
|
func AutoSelect(cfg config.Punctuation) (Restorer, error) {
|
|
cfg = cfg.WithDefaults()
|
|
if !cfg.Active() {
|
|
return Nop{}, nil
|
|
}
|
|
engine := strings.ToLower(cfg.Engine)
|
|
if engine == "heuristic" {
|
|
return Heuristic{}, nil
|
|
}
|
|
if engine == "http" {
|
|
return New(cfg)
|
|
}
|
|
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
|
|
return newXLM(cfg)
|
|
}
|
|
if _, err := os.Stat(cfg.ModelPath()); err == nil {
|
|
cfg.Engine = engine
|
|
if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" {
|
|
cfg.Engine = "sherpa"
|
|
}
|
|
return newSherpaRestorer(cfg)
|
|
}
|
|
if engine == "sherpa" || engine == "sherpa-online" || engine == "online" {
|
|
return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath())
|
|
}
|
|
return Heuristic{}, nil
|
|
}
|