package api import ( "context" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "os" "path/filepath" "runtime" "strconv" "strings" "time" "go-whisper-api/config" "go-whisper-api/diarization" "go-whisper-api/punctuation" "go-whisper-api/transcode" "go-whisper-api/whisper" "github.com/google/uuid" "github.com/rs/zerolog/log" ) type Server struct { cfg config.API punctCfg config.Punctuation diarCfg config.Diarization transcode *transcode.Engine modelPool *whisper.ModelPool punct punctuation.Restorer diarizer diarization.Engine models *Registry cache *DiskCache mux *http.ServeMux queueWake chan struct{} } func NewServer(cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) (*Server, error) { cfg = cfg.WithDefaults() tc = tc.WithDefaults() pc = pc.WithDefaults() dc = dc.WithDefaults() restorer, err := punctuation.New(pc) if err != nil { return nil, err } if pc.Active() && !restorer.Active() { return nil, fmt.Errorf("punctuation is enabled but engine %q is not available", pc.Engine) } if cfg.ModelsDir == "" { cfg.ModelsDir = "./models" } if cfg.Addr == "" { cfg.Addr = ":8080" } if cfg.Threads == 0 { cfg.Threads = uint(runtime.NumCPU()) } cfg = cfg.WithDefaults() if cfg.MaxContext == 0 { cfg.MaxContext = 32 } if cfg.BeamSize == 0 { cfg.BeamSize = 5 } if cfg.EntropyThold == 0 { cfg.EntropyThold = 2.4 } if err := os.MkdirAll(cfg.ModelsDir, 0o755); err != nil { return nil, err } cacheDir := cfg.CacheDir if cacheDir == "" { cacheDir = "./cache" } cache, err := NewDiskCache(cacheDir) if err != nil { return nil, err } cfg.CacheDir = cache.Root() if err := cache.RecoverInterrupted(); err != nil { return nil, err } diar, err := diarization.New(dc) if err != nil { return nil, err } s := &Server{ cfg: cfg, punctCfg: pc, diarCfg: dc, transcode: transcode.NewEngine(tc.FFmpegPath), modelPool: whisper.NewModelPool(), punct: restorer, diarizer: diar, models: NewRegistry(cfg.ModelsDir), cache: cache, mux: http.NewServeMux(), queueWake: make(chan struct{}, 1), } s.routes() go s.warmModels() return s, nil } func (s *Server) warmModels() { ids, err := s.models.List() if err != nil { return } for _, id := range ids { path, err := s.models.Path(id) if err != nil { continue } if err := s.modelPool.WithModel(path, func(whisper.Model) error { return nil }); err != nil { log.Warn().Err(err).Str("model", id).Msg("preload whisper model") } else { log.Info().Str("model", id).Msg("whisper model loaded") } } } func (s *Server) routes() { s.mux.HandleFunc("/", s.handleSwaggerUI) s.mux.HandleFunc("/swagger.json", s.handleSwaggerJSON) s.mux.HandleFunc("/spr/models", s.handleModels) s.mux.HandleFunc("/spr/hostname", s.handleHostname) s.mux.HandleFunc("/spr/queue", s.handleQueue) s.mux.HandleFunc("/spr/stt/", s.handleSTT) s.mux.HandleFunc("/spr/result/", s.handleResult) s.mux.HandleFunc("/spr/queue/", s.handleQueueItem) s.mux.HandleFunc("/spr/audio/", s.handleAudio) s.mux.HandleFunc("/spr/waveform/", s.handleWaveform) s.mux.HandleFunc("/spr/delete/", s.handleDeleteModel) s.mux.HandleFunc("/spr/export/", s.handleExportModel) s.mux.HandleFunc("/spr/import/", s.handleImportModel) s.mux.HandleFunc("/v1/audio/transcriptions", s.handleOpenAITranscriptions) s.mux.HandleFunc("/v1/audio/transcriptions/", s.handleOpenAITranscriptions) s.mux.HandleFunc("/v1/models", s.handleOpenAIModels) } func (s *Server) ListenAndServe() error { log.Info(). Str("addr", s.cfg.Addr). Str("models", s.cfg.ModelsDir). Str("cache", s.cache.Root()). Msg("starting API server") return http.ListenAndServe(s.cfg.Addr, s.mux) } func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } models, err := s.models.List() if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]any{"models": models}) } func (s *Server) handleHostname(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } host, _ := os.Hostname() cwd, _ := os.Getwd() writeJSON(w, http.StatusOK, map[string]any{ "error": 0, "message": "Success", "hostname": host, "version": "go-whisper-api", "cwd": cwd, "models": s.cfg.ModelsDir, "cache": s.cache.Root(), }) } func (s *Server) handleQueue(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } list, err := s.cache.List() if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, list) } func (s *Server) handleQueueItem(w http.ResponseWriter, r *http.Request) { id := strings.TrimPrefix(r.URL.Path, "/spr/queue/") if id == "" || !isValidTaskID(id) { writeError(w, http.StatusBadRequest, "task id required") return } switch r.Method { case http.MethodGet: s.handleQueueGet(w, id) case http.MethodDelete: if !s.cache.Delete(id) { writeAPIError(w, http.StatusNotFound, "TaskNotFound") return } writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"}) default: methodNotAllowed(w) } } func (s *Server) handleQueueGet(w http.ResponseWriter, id string) { params, phase, err := s.cache.LoadParams(id) if err != nil { writeAPIError(w, http.StatusNotFound, "task not found") return } switch params.Status { case string(statusReady): if phase == cacheWaiting { if err := s.cache.PromoteToReady(id); err != nil { writeAPIError(w, http.StatusInternalServerError, err.Error()) return } } writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"}) case string(statusError): msg := params.Error if msg == "" { msg = "transcription failed" } writeAPIError(w, http.StatusNotFound, msg) default: writeJSON(w, http.StatusOK, map[string]any{ "error": 0, "message": params.Status, "status": params.Status, }) } } func (s *Server) handleSTT(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { methodNotAllowed(w) return } modelID := strings.TrimPrefix(r.URL.Path, "/spr/stt/") if modelID == "" { writeError(w, http.StatusBadRequest, "model id required") return } modelPath, err := s.models.Path(modelID) if err != nil { writeAPIError(w, http.StatusNotFound, err.Error()) return } audioPath, cleanup, err := s.saveUploadedWav(r) if err != nil { writeAPIError(w, http.StatusBadRequest, err.Error()) return } stt, err := s.parseSTTOptions(r) if err != nil { writeAPIError(w, http.StatusBadRequest, err.Error()) return } if queryAsync(r, s.cfg.DefaultAsync) { taskID, err := s.enqueueAsync(r, modelID, audioPath, stt) cleanup() if err != nil { writeAPIError(w, http.StatusBadRequest, err.Error()) return } writeJSON(w, http.StatusOK, map[string]string{"taskID": taskID}) return } defer cleanup() result, err := s.transcribe(r.Context(), modelPath, audioPath, stt) if err != nil { writeAPIError(w, http.StatusMethodNotAllowed, err.Error()) return } writeJSON(w, http.StatusOK, map[string]any{ "model": modelID, "text": result.Text, "words": result.Words, }) } func (s *Server) handleResult(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } id := strings.TrimPrefix(r.URL.Path, "/spr/result/") if !isValidTaskID(id) { writeAPIError(w, http.StatusBadRequest, "task id required") return } params, _, err := s.cache.LoadParams(id) if err != nil { writeAPIError(w, http.StatusNotFound, "TaskNotFound") return } switch params.Status { case string(statusWaiting), string(statusProcessing): writeJSON(w, http.StatusOK, map[string]string{"status": params.Status}) case string(statusError): msg := params.Error if msg == "" { msg = "TaskNotFound" } writeAPIError(w, http.StatusNotFound, msg) case string(statusReady): writeJSON(w, http.StatusOK, sprResultReady(params)) default: writeJSON(w, http.StatusOK, map[string]string{"status": params.Status}) } } func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } id := strings.TrimPrefix(r.URL.Path, "/spr/audio/") path, ok := s.cache.AudioPath(id) if !ok { writeError(w, http.StatusNotFound, "task not found") return } http.ServeFile(w, r, path) } func (s *Server) handleWaveform(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } id := strings.TrimPrefix(r.URL.Path, "/spr/waveform/") if !isValidTaskID(id) { writeAPIError(w, http.StatusBadRequest, "task id required") return } wf, err := s.cache.Waveform(id) if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]any{"error": 0, "waveform": wf}) } func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { methodNotAllowed(w) return } id := strings.TrimPrefix(r.URL.Path, "/spr/delete/") if err := s.models.Delete(id); err != nil { writeAPIError(w, http.StatusNotFound, err.Error()) return } w.WriteHeader(http.StatusOK) } func (s *Server) handleExportModel(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { methodNotAllowed(w) return } id := strings.TrimPrefix(r.URL.Path, "/spr/export/") f, err := s.models.Open(id) if err != nil { writeAPIError(w, http.StatusNotFound, err.Error()) return } defer f.Close() w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q.bin", id)) w.Header().Set("Content-Type", "application/octet-stream") io.Copy(w, f) } func (s *Server) handleImportModel(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { methodNotAllowed(w) return } id := strings.TrimPrefix(r.URL.Path, "/spr/import/") if err := r.ParseMultipartForm(32 << 20); err != nil { writeAPIError(w, http.StatusBadRequest, err.Error()) return } file, header, err := r.FormFile("zip-model") if err != nil { file, header, err = r.FormFile("model") } if err != nil { writeAPIError(w, http.StatusBadRequest, "model file required") return } defer file.Close() src := io.Reader(file) if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { writeAPIError(w, http.StatusBadRequest, "zip import is not supported; upload .bin model file as zip-model field") return } if err := s.models.Import(id, src); err != nil { writeAPIError(w, http.StatusBadRequest, err.Error()) return } w.WriteHeader(http.StatusOK) } func (s *Server) enqueueAsync(r *http.Request, modelID, audioWavPath string, stt sttOptions) (string, error) { id := uuid.New().String() params := TaskParams{ ID: id, Created: time.Now().Format("2006-01-02 15:04:05"), Status: string(statusWaiting), Model: modelID, Language: stt.language, Punctuation: stt.punctuate, Speakers: stt.speakers, NumClusters: stt.numClusters, } if err := s.cache.Enqueue(id, params, audioWavPath); err != nil { return "", err } s.notifyQueue() log.Info(). Str("task", id). Str("model", modelID). Str("cache", s.cache.Root()). Msg("enqueued async task") return id, nil } func (s *Server) transcribe(ctx context.Context, modelPath, audioPath string, stt sttOptions) (whisper.TranscriptResult, error) { turns, err := s.runDiarization(ctx, audioPath, stt) if err != nil { return whisper.TranscriptResult{}, err } vad := s.cfg.VAD if vad.Enabled { vad.Model = vad.ResolveModelPath(s.cfg.ModelsDir) } cfg := &config.Whisper{ Model: modelPath, AudioPath: audioPath, Threads: s.cfg.Threads, Language: stt.language, Debug: s.cfg.Debug, SpeedUp: s.cfg.SpeedUp, Translate: s.cfg.Translate, Prompt: s.cfg.Prompt, MaxContext: s.cfg.MaxContext, BeamSize: s.cfg.BeamSize, EntropyThold: s.cfg.EntropyThold, VAD: vad, PrintProgress: false, PrintSegment: false, } runOpts := s.whisperRunOpts(stt, turns) if stt.punctuate && s.punct.Active() { runOpts.PunctuateRestore = func(text string) (string, error) { return punctuation.Apply(ctx, s.punct, true, text, stt.language) } } result, err := whisper.TranscribeWithPool(s.modelPool, cfg, runOpts) if err != nil { return whisper.TranscriptResult{}, err } return applyGarbage(result, s.cfg.GarbagePatterns()), nil } func (s *Server) runDiarization(ctx context.Context, audioPath string, stt sttOptions) ([]whisper.Turn, error) { if !stt.speakers { return nil, nil } samples, err := whisper.LoadPCM16Mono(audioPath) if err != nil { return nil, fmt.Errorf("diarization audio: %w", err) } return s.diarizer.Process(ctx, samples, stt.numClusters) } func saveUploadedRaw(r *http.Request) (path string, cleanup func(), err error) { return saveUploadedRawFields(r, []string{"audio", "wav", "file"}) } func saveUploadedRawFields(r *http.Request, fieldNames []string) (path string, cleanup func(), err error) { if r.MultipartForm == nil { if err := r.ParseMultipartForm(128 << 20); err != nil { return "", nil, err } } var ( file multipart.File header *multipart.FileHeader found bool ) for _, name := range fieldNames { file, header, err = r.FormFile(name) if err == nil { found = true break } } if !found { return "", nil, fmt.Errorf("audio file required (form field: %s)", strings.Join(fieldNames, ", ")) } defer file.Close() dir, err := config.MkdirTemp("go-whisper-api-upload-*") if err != nil { return "", nil, err } cleanup = func() { os.RemoveAll(dir) } base := "input" if header != nil { if ext := filepath.Ext(header.Filename); ext != "" { base += ext } } raw := filepath.Join(dir, base) out, err := os.Create(raw) if err != nil { cleanup() return "", nil, err } if _, err := io.Copy(out, file); err != nil { out.Close() cleanup() return "", nil, err } out.Close() return raw, cleanup, nil } func (s *Server) saveUploadedWav(r *http.Request) (path string, cleanup func(), err error) { raw, cleanup, err := saveUploadedRaw(r) if err != nil { return "", nil, err } dst := filepath.Join(filepath.Dir(raw), "audio.wav") if err := s.transcode.Transcode(r.Context(), raw, dst, transcode.WhisperOptions()); err != nil { cleanup() return "", nil, err } return dst, cleanup, nil } func queryBoolDefault(r *http.Request, key string, def bool) bool { v := r.URL.Query().Get(key) if v == "" { return def } b, err := strconv.ParseBool(v) if err != nil { return def } return b } func queryAsync(r *http.Request, defaultAsync bool) bool { v := r.URL.Query().Get("async") if v == "" { return defaultAsync } n, err := strconv.Atoi(v) if err != nil { return defaultAsync } return n == 1 } func queryInt(r *http.Request, key string, def int) int { v := r.URL.Query().Get(key) if v == "" { return def } n, err := strconv.Atoi(v) if err != nil { return def } return n } func writeJSON(w http.ResponseWriter, code int, v any) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(code) _ = json.NewEncoder(w).Encode(v) } func writeError(w http.ResponseWriter, code int, msg string) { http.Error(w, msg, code) } func writeAPIError(w http.ResponseWriter, code int, msg string) { writeJSON(w, code, map[string]any{"error": 1, "message": msg}) } func methodNotAllowed(w http.ResponseWriter) { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } func (s *Server) notifyQueue() { select { case s.queueWake <- struct{}{}: default: } } func Run(ctx context.Context, cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) error { srv, err := NewServer(cfg, tc, pc, dc) if err != nil { return err } defer srv.modelPool.Close() srv.StartWorker(ctx) hs := &http.Server{Addr: cfg.Addr, Handler: srv.mux} go func() { <-ctx.Done() shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() _ = hs.Shutdown(shutdownCtx) }() if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed { return err } punctuation.Close(srv.punct) srv.diarizer.Close() return nil }