go-whisper-api/api/server.go
admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

641 lines
18 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"go-whisper-api/config"
"go-whisper-api/diarization"
"go-whisper-api/punctuation"
"go-whisper-api/transcode"
"go-whisper-api/whisper"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type Server struct {
cfg config.API
punctCfg config.Punctuation
diarCfg config.Diarization
transcode *transcode.Engine
modelPool *whisper.ModelPool
punct punctuation.Restorer
diarizer diarization.Engine
models *Registry
cache *DiskCache
mux *http.ServeMux
queueWake chan struct{}
}
func NewServer(cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) (*Server, error) {
cfg = cfg.WithDefaults()
tc = tc.WithDefaults()
pc = pc.WithDefaults()
dc = dc.WithDefaults()
restorer, err := punctuation.New(pc)
if err != nil {
return nil, err
}
if pc.Active() && !restorer.Active() {
return nil, fmt.Errorf("punctuation is enabled but engine %q is not available", pc.Engine)
}
if cfg.ModelsDir == "" {
cfg.ModelsDir = "./models"
}
if cfg.Addr == "" {
cfg.Addr = ":8080"
}
if cfg.Threads == 0 {
cfg.Threads = uint(runtime.NumCPU())
}
cfg = cfg.WithDefaults()
if cfg.MaxContext == 0 {
cfg.MaxContext = 32
}
if cfg.BeamSize == 0 {
cfg.BeamSize = 5
}
if cfg.EntropyThold == 0 {
cfg.EntropyThold = 2.4
}
if err := os.MkdirAll(cfg.ModelsDir, 0o755); err != nil {
return nil, err
}
cacheDir := cfg.CacheDir
if cacheDir == "" {
cacheDir = "./cache"
}
cache, err := NewDiskCache(cacheDir)
if err != nil {
return nil, err
}
cfg.CacheDir = cache.Root()
if err := cache.RecoverInterrupted(); err != nil {
return nil, err
}
diar, err := diarization.New(dc)
if err != nil {
return nil, err
}
s := &Server{
cfg: cfg,
punctCfg: pc,
diarCfg: dc,
transcode: transcode.NewEngine(tc.FFmpegPath),
modelPool: whisper.NewModelPool(),
punct: restorer,
diarizer: diar,
models: NewRegistry(cfg.ModelsDir),
cache: cache,
mux: http.NewServeMux(),
queueWake: make(chan struct{}, 1),
}
s.routes()
go s.warmModels()
return s, nil
}
func (s *Server) warmModels() {
ids, err := s.models.List()
if err != nil {
return
}
for _, id := range ids {
path, err := s.models.Path(id)
if err != nil {
continue
}
if err := s.modelPool.WithModel(path, func(whisper.Model) error { return nil }); err != nil {
log.Warn().Err(err).Str("model", id).Msg("preload whisper model")
} else {
log.Info().Str("model", id).Msg("whisper model loaded")
}
}
}
func (s *Server) routes() {
s.mux.HandleFunc("/", s.handleSwaggerUI)
s.mux.HandleFunc("/swagger.json", s.handleSwaggerJSON)
s.mux.HandleFunc("/spr/models", s.handleModels)
s.mux.HandleFunc("/spr/hostname", s.handleHostname)
s.mux.HandleFunc("/spr/queue", s.handleQueue)
s.mux.HandleFunc("/spr/stt/", s.handleSTT)
s.mux.HandleFunc("/spr/result/", s.handleResult)
s.mux.HandleFunc("/spr/queue/", s.handleQueueItem)
s.mux.HandleFunc("/spr/audio/", s.handleAudio)
s.mux.HandleFunc("/spr/waveform/", s.handleWaveform)
s.mux.HandleFunc("/spr/delete/", s.handleDeleteModel)
s.mux.HandleFunc("/spr/export/", s.handleExportModel)
s.mux.HandleFunc("/spr/import/", s.handleImportModel)
s.mux.HandleFunc("/v1/audio/transcriptions", s.handleOpenAITranscriptions)
s.mux.HandleFunc("/v1/audio/transcriptions/", s.handleOpenAITranscriptions)
s.mux.HandleFunc("/v1/models", s.handleOpenAIModels)
}
func (s *Server) ListenAndServe() error {
log.Info().
Str("addr", s.cfg.Addr).
Str("models", s.cfg.ModelsDir).
Str("cache", s.cache.Root()).
Msg("starting API server")
return http.ListenAndServe(s.cfg.Addr, s.mux)
}
func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
models, err := s.models.List()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"models": models})
}
func (s *Server) handleHostname(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
host, _ := os.Hostname()
cwd, _ := os.Getwd()
writeJSON(w, http.StatusOK, map[string]any{
"error": 0,
"message": "Success",
"hostname": host,
"version": "go-whisper-api",
"cwd": cwd,
"models": s.cfg.ModelsDir,
"cache": s.cache.Root(),
})
}
func (s *Server) handleQueue(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
list, err := s.cache.List()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, list)
}
func (s *Server) handleQueueItem(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/spr/queue/")
if id == "" || !isValidTaskID(id) {
writeError(w, http.StatusBadRequest, "task id required")
return
}
switch r.Method {
case http.MethodGet:
s.handleQueueGet(w, id)
case http.MethodDelete:
if !s.cache.Delete(id) {
writeAPIError(w, http.StatusNotFound, "TaskNotFound")
return
}
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"})
default:
methodNotAllowed(w)
}
}
func (s *Server) handleQueueGet(w http.ResponseWriter, id string) {
params, phase, err := s.cache.LoadParams(id)
if err != nil {
writeAPIError(w, http.StatusNotFound, "task not found")
return
}
switch params.Status {
case string(statusReady):
if phase == cacheWaiting {
if err := s.cache.PromoteToReady(id); err != nil {
writeAPIError(w, http.StatusInternalServerError, err.Error())
return
}
}
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"})
case string(statusError):
msg := params.Error
if msg == "" {
msg = "transcription failed"
}
writeAPIError(w, http.StatusNotFound, msg)
default:
writeJSON(w, http.StatusOK, map[string]any{
"error": 0,
"message": params.Status,
"status": params.Status,
})
}
}
func (s *Server) handleSTT(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
methodNotAllowed(w)
return
}
modelID := strings.TrimPrefix(r.URL.Path, "/spr/stt/")
if modelID == "" {
writeError(w, http.StatusBadRequest, "model id required")
return
}
modelPath, err := s.models.Path(modelID)
if err != nil {
writeAPIError(w, http.StatusNotFound, err.Error())
return
}
audioPath, cleanup, err := s.saveUploadedWav(r)
if err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
stt, err := s.parseSTTOptions(r)
if err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
if queryAsync(r, s.cfg.DefaultAsync) {
taskID, err := s.enqueueAsync(r, modelID, audioPath, stt)
cleanup()
if err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]string{"taskID": taskID})
return
}
defer cleanup()
result, err := s.transcribe(r.Context(), modelPath, audioPath, stt)
if err != nil {
writeAPIError(w, http.StatusMethodNotAllowed, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"model": modelID,
"text": result.Text,
"words": result.Words,
})
}
func (s *Server) handleResult(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/result/")
if !isValidTaskID(id) {
writeAPIError(w, http.StatusBadRequest, "task id required")
return
}
params, _, err := s.cache.LoadParams(id)
if err != nil {
writeAPIError(w, http.StatusNotFound, "TaskNotFound")
return
}
switch params.Status {
case string(statusWaiting), string(statusProcessing):
writeJSON(w, http.StatusOK, map[string]string{"status": params.Status})
case string(statusError):
msg := params.Error
if msg == "" {
msg = "TaskNotFound"
}
writeAPIError(w, http.StatusNotFound, msg)
case string(statusReady):
writeJSON(w, http.StatusOK, sprResultReady(params))
default:
writeJSON(w, http.StatusOK, map[string]string{"status": params.Status})
}
}
func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/audio/")
path, ok := s.cache.AudioPath(id)
if !ok {
writeError(w, http.StatusNotFound, "task not found")
return
}
http.ServeFile(w, r, path)
}
func (s *Server) handleWaveform(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/waveform/")
if !isValidTaskID(id) {
writeAPIError(w, http.StatusBadRequest, "task id required")
return
}
wf, err := s.cache.Waveform(id)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "waveform": wf})
}
func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/delete/")
if err := s.models.Delete(id); err != nil {
writeAPIError(w, http.StatusNotFound, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (s *Server) handleExportModel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/export/")
f, err := s.models.Open(id)
if err != nil {
writeAPIError(w, http.StatusNotFound, err.Error())
return
}
defer f.Close()
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q.bin", id))
w.Header().Set("Content-Type", "application/octet-stream")
io.Copy(w, f)
}
func (s *Server) handleImportModel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/import/")
if err := r.ParseMultipartForm(32 << 20); err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
file, header, err := r.FormFile("zip-model")
if err != nil {
file, header, err = r.FormFile("model")
}
if err != nil {
writeAPIError(w, http.StatusBadRequest, "model file required")
return
}
defer file.Close()
src := io.Reader(file)
if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") {
writeAPIError(w, http.StatusBadRequest, "zip import is not supported; upload .bin model file as zip-model field")
return
}
if err := s.models.Import(id, src); err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (s *Server) enqueueAsync(r *http.Request, modelID, audioWavPath string, stt sttOptions) (string, error) {
id := uuid.New().String()
params := TaskParams{
ID: id,
Created: time.Now().Format("2006-01-02 15:04:05"),
Status: string(statusWaiting),
Model: modelID,
Language: stt.language,
Punctuation: stt.punctuate,
Speakers: stt.speakers,
NumClusters: stt.numClusters,
}
if err := s.cache.Enqueue(id, params, audioWavPath); err != nil {
return "", err
}
s.notifyQueue()
log.Info().
Str("task", id).
Str("model", modelID).
Str("cache", s.cache.Root()).
Msg("enqueued async task")
return id, nil
}
func (s *Server) transcribe(ctx context.Context, modelPath, audioPath string, stt sttOptions) (whisper.TranscriptResult, error) {
turns, err := s.runDiarization(ctx, audioPath, stt)
if err != nil {
return whisper.TranscriptResult{}, err
}
vad := s.cfg.VAD
if vad.Enabled {
vad.Model = vad.ResolveModelPath(s.cfg.ModelsDir)
}
cfg := &config.Whisper{
Model: modelPath,
AudioPath: audioPath,
Threads: s.cfg.Threads,
Language: stt.language,
Debug: s.cfg.Debug,
SpeedUp: s.cfg.SpeedUp,
Translate: s.cfg.Translate,
Prompt: s.cfg.Prompt,
MaxContext: s.cfg.MaxContext,
BeamSize: s.cfg.BeamSize,
EntropyThold: s.cfg.EntropyThold,
VAD: vad,
PrintProgress: false,
PrintSegment: false,
}
runOpts := s.whisperRunOpts(stt, turns)
if stt.punctuate && s.punct.Active() {
runOpts.PunctuateRestore = func(text string) (string, error) {
return punctuation.Apply(ctx, s.punct, true, text, stt.language)
}
}
result, err := whisper.TranscribeWithPool(s.modelPool, cfg, runOpts)
if err != nil {
return whisper.TranscriptResult{}, err
}
return applyGarbage(result, s.cfg.GarbagePatterns()), nil
}
func (s *Server) runDiarization(ctx context.Context, audioPath string, stt sttOptions) ([]whisper.Turn, error) {
if !stt.speakers {
return nil, nil
}
samples, err := whisper.LoadPCM16Mono(audioPath)
if err != nil {
return nil, fmt.Errorf("diarization audio: %w", err)
}
return s.diarizer.Process(ctx, samples, stt.numClusters)
}
func saveUploadedRaw(r *http.Request) (path string, cleanup func(), err error) {
return saveUploadedRawFields(r, []string{"audio", "wav", "file"})
}
func saveUploadedRawFields(r *http.Request, fieldNames []string) (path string, cleanup func(), err error) {
if r.MultipartForm == nil {
if err := r.ParseMultipartForm(128 << 20); err != nil {
return "", nil, err
}
}
var (
file multipart.File
header *multipart.FileHeader
found bool
)
for _, name := range fieldNames {
file, header, err = r.FormFile(name)
if err == nil {
found = true
break
}
}
if !found {
return "", nil, fmt.Errorf("audio file required (form field: %s)", strings.Join(fieldNames, ", "))
}
defer file.Close()
dir, err := config.MkdirTemp("go-whisper-api-upload-*")
if err != nil {
return "", nil, err
}
cleanup = func() { os.RemoveAll(dir) }
base := "input"
if header != nil {
if ext := filepath.Ext(header.Filename); ext != "" {
base += ext
}
}
raw := filepath.Join(dir, base)
out, err := os.Create(raw)
if err != nil {
cleanup()
return "", nil, err
}
if _, err := io.Copy(out, file); err != nil {
out.Close()
cleanup()
return "", nil, err
}
out.Close()
return raw, cleanup, nil
}
func (s *Server) saveUploadedWav(r *http.Request) (path string, cleanup func(), err error) {
raw, cleanup, err := saveUploadedRaw(r)
if err != nil {
return "", nil, err
}
dst := filepath.Join(filepath.Dir(raw), "audio.wav")
if err := s.transcode.Transcode(r.Context(), raw, dst, transcode.WhisperOptions()); err != nil {
cleanup()
return "", nil, err
}
return dst, cleanup, nil
}
func queryBoolDefault(r *http.Request, key string, def bool) bool {
v := r.URL.Query().Get(key)
if v == "" {
return def
}
b, err := strconv.ParseBool(v)
if err != nil {
return def
}
return b
}
func queryAsync(r *http.Request, defaultAsync bool) bool {
v := r.URL.Query().Get("async")
if v == "" {
return defaultAsync
}
n, err := strconv.Atoi(v)
if err != nil {
return defaultAsync
}
return n == 1
}
func queryInt(r *http.Request, key string, def int) int {
v := r.URL.Query().Get(key)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
if err != nil {
return def
}
return n
}
func writeJSON(w http.ResponseWriter, code int, v any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, code int, msg string) {
http.Error(w, msg, code)
}
func writeAPIError(w http.ResponseWriter, code int, msg string) {
writeJSON(w, code, map[string]any{"error": 1, "message": msg})
}
func methodNotAllowed(w http.ResponseWriter) {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
func (s *Server) notifyQueue() {
select {
case s.queueWake <- struct{}{}:
default:
}
}
func Run(ctx context.Context, cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) error {
srv, err := NewServer(cfg, tc, pc, dc)
if err != nil {
return err
}
defer srv.modelPool.Close()
srv.StartWorker(ctx)
hs := &http.Server{Addr: cfg.Addr, Handler: srv.mux}
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = hs.Shutdown(shutdownCtx)
}()
if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return err
}
punctuation.Close(srv.punct)
srv.diarizer.Close()
return nil
}