From b5c083e06f523460c47a02776567e475bc90180e Mon Sep 17 00:00:00 2001 From: admin Date: Thu, 4 Jun 2026 18:10:52 +0700 Subject: [PATCH] first commit --- .gitea/FUNDING.yml | 13 + .gitea/dependabot.yml | 10 + .gitea/workflows/README.md | 27 + .gitea/workflows/codeql.yml | 32 + .gitea/workflows/docker.yml | 60 ++ .gitea/workflows/goreleaser.yml | 43 ++ .gitea/workflows/lint.yml | 62 ++ .gitignore | 44 ++ .golangci.yml | 3 + .goreleaser.yaml | 3 + .hadolint.yaml | 3 + LICENSE | 21 + Makefile | 280 ++++++++ README.md | 129 ++++ api/cache.go | 362 ++++++++++ api/cache_test.go | 127 ++++ api/garbage.go | 26 + api/models.go | 181 +++++ api/models_test.go | 55 ++ api/openai.go | 121 ++++ api/openai_test.go | 90 +++ api/queue_worker.go | 98 +++ api/result.go | 28 + api/result_test.go | 26 + api/server.go | 640 ++++++++++++++++++ api/swagger-ui.html | 23 + api/swagger.go | 34 + api/swagger.json | 457 +++++++++++++ api/tasks.go | 10 + api/transcribe_opts.go | 77 +++ api/waveform.go | 48 ++ config.yaml.example | 55 ++ config/api.go | 42 ++ config/diarization.go | 72 ++ config/file.go | 53 ++ config/file_test.go | 35 + config/garbage.go | 13 + config/garbage_test.go | 23 + config/merge.go | 136 ++++ config/merge_test.go | 28 + config/punctuation.go | 125 ++++ config/punctuation_test.go | 58 ++ config/tmp.go | 9 + config/transcode.go | 11 + config/transcript.go | 26 + config/vad.go | 88 +++ config/vad_test.go | 65 ++ config/whisper.go | 43 ++ config/xlm-roberta-model.yaml | 81 +++ diarization/diarization.go | 19 + diarization/sherpa.go | 107 +++ diarization/stub.go | 29 + docker/Dockerfile | 55 ++ docker/Dockerfile.ci | 27 + garbage/filter.go | 56 ++ garbage/filter_test.go | 30 + go.mod | 48 ++ go.sum | 127 ++++ main.go | 251 +++++++ models/.gitkeep | 0 models/diarization/.gitkeep | 0 .../pyannote-segmentation-3-0/.gitkeep | 0 models/punctuation/.gitkeep | 0 models/punctuation/xlm-roberta/.gitkeep | 0 models/vad/.gitkeep | 0 punctuation/heuristic.go | 86 +++ punctuation/heuristic_test.go | 32 + punctuation/internal/spwrap/sp.go | 92 +++ punctuation/internal/spwrap/sp_wrap.cc | 88 +++ punctuation/internal/spwrap/sp_wrap.h | 21 + punctuation/ort_env.go | 177 +++++ punctuation/punctuation.go | 154 +++++ punctuation/punctuation_test.go | 40 ++ punctuation/sherpa.go | 107 +++ punctuation/sherpa_stub.go | 13 + punctuation/xlm.go | 319 +++++++++ punctuation/xlm_config.go | 43 ++ punctuation/xlm_decode.go | 97 +++ punctuation/xlm_stub.go | 13 + transcode/aac_decode.go | 132 ++++ transcode/aac_decode_test.go | 31 + transcode/decode.go | 127 ++++ transcode/engine.go | 119 ++++ transcode/engine_test.go | 85 +++ transcode/format.go | 47 ++ transcode/mp4_aac_decode.go | 288 ++++++++ transcode/ogg_decode.go | 154 +++++ transcode/ogg_decode_test.go | 52 ++ transcode/options.go | 43 ++ transcode/options_test.go | 20 + transcode/samples_stream.go | 38 ++ transcode/sniff.go | 44 ++ transcode/stream.go | 42 ++ transcode/wav_out.go | 53 ++ whisper/audio.go | 11 + whisper/audio_load.go | 60 ++ whisper/format.go | 154 +++++ whisper/format_test.go | 40 ++ whisper/helper.go | 15 + whisper/helper_test.go | 39 ++ whisper/model_pool.go | 89 +++ whisper/options.go | 8 + whisper/vad.go | 43 ++ whisper/whisper.go | 245 +++++++ whisper/whisper_test.go | 66 ++ 105 files changed, 8172 insertions(+) create mode 100644 .gitea/FUNDING.yml create mode 100644 .gitea/dependabot.yml create mode 100644 .gitea/workflows/README.md create mode 100644 .gitea/workflows/codeql.yml create mode 100644 .gitea/workflows/docker.yml create mode 100644 .gitea/workflows/goreleaser.yml create mode 100644 .gitea/workflows/lint.yml create mode 100644 .gitignore create mode 100644 .golangci.yml create mode 100644 .goreleaser.yaml create mode 100644 .hadolint.yaml create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 api/cache.go create mode 100644 api/cache_test.go create mode 100644 api/garbage.go create mode 100644 api/models.go create mode 100644 api/models_test.go create mode 100644 api/openai.go create mode 100644 api/openai_test.go create mode 100644 api/queue_worker.go create mode 100644 api/result.go create mode 100644 api/result_test.go create mode 100644 api/server.go create mode 100644 api/swagger-ui.html create mode 100644 api/swagger.go create mode 100644 api/swagger.json create mode 100644 api/tasks.go create mode 100644 api/transcribe_opts.go create mode 100644 api/waveform.go create mode 100644 config.yaml.example create mode 100644 config/api.go create mode 100644 config/diarization.go create mode 100644 config/file.go create mode 100644 config/file_test.go create mode 100644 config/garbage.go create mode 100644 config/garbage_test.go create mode 100644 config/merge.go create mode 100644 config/merge_test.go create mode 100644 config/punctuation.go create mode 100644 config/punctuation_test.go create mode 100644 config/tmp.go create mode 100644 config/transcode.go create mode 100644 config/transcript.go create mode 100644 config/vad.go create mode 100644 config/vad_test.go create mode 100644 config/whisper.go create mode 100644 config/xlm-roberta-model.yaml create mode 100644 diarization/diarization.go create mode 100644 diarization/sherpa.go create mode 100644 diarization/stub.go create mode 100644 docker/Dockerfile create mode 100644 docker/Dockerfile.ci create mode 100644 garbage/filter.go create mode 100644 garbage/filter_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 models/.gitkeep create mode 100644 models/diarization/.gitkeep create mode 100644 models/diarization/pyannote-segmentation-3-0/.gitkeep create mode 100644 models/punctuation/.gitkeep create mode 100644 models/punctuation/xlm-roberta/.gitkeep create mode 100644 models/vad/.gitkeep create mode 100644 punctuation/heuristic.go create mode 100644 punctuation/heuristic_test.go create mode 100644 punctuation/internal/spwrap/sp.go create mode 100644 punctuation/internal/spwrap/sp_wrap.cc create mode 100644 punctuation/internal/spwrap/sp_wrap.h create mode 100644 punctuation/ort_env.go create mode 100644 punctuation/punctuation.go create mode 100644 punctuation/punctuation_test.go create mode 100644 punctuation/sherpa.go create mode 100644 punctuation/sherpa_stub.go create mode 100644 punctuation/xlm.go create mode 100644 punctuation/xlm_config.go create mode 100644 punctuation/xlm_decode.go create mode 100644 punctuation/xlm_stub.go create mode 100644 transcode/aac_decode.go create mode 100644 transcode/aac_decode_test.go create mode 100644 transcode/decode.go create mode 100644 transcode/engine.go create mode 100644 transcode/engine_test.go create mode 100644 transcode/format.go create mode 100644 transcode/mp4_aac_decode.go create mode 100644 transcode/ogg_decode.go create mode 100644 transcode/ogg_decode_test.go create mode 100644 transcode/options.go create mode 100644 transcode/options_test.go create mode 100644 transcode/samples_stream.go create mode 100644 transcode/sniff.go create mode 100644 transcode/stream.go create mode 100644 transcode/wav_out.go create mode 100644 whisper/audio.go create mode 100644 whisper/audio_load.go create mode 100644 whisper/format.go create mode 100644 whisper/format_test.go create mode 100644 whisper/helper.go create mode 100644 whisper/helper_test.go create mode 100644 whisper/model_pool.go create mode 100644 whisper/options.go create mode 100644 whisper/vad.go create mode 100644 whisper/whisper.go create mode 100644 whisper/whisper_test.go diff --git a/.gitea/FUNDING.yml b/.gitea/FUNDING.yml new file mode 100644 index 0000000..df9ae63 --- /dev/null +++ b/.gitea/FUNDING.yml @@ -0,0 +1,13 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +custom: ['https://www.paypal.me/appleboy46'] diff --git a/.gitea/dependabot.yml b/.gitea/dependabot.yml new file mode 100644 index 0000000..632e8eb --- /dev/null +++ b/.gitea/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly diff --git a/.gitea/workflows/README.md b/.gitea/workflows/README.md new file mode 100644 index 0000000..bb54ab0 --- /dev/null +++ b/.gitea/workflows/README.md @@ -0,0 +1,27 @@ +# Gitea Actions + +Workflows use [Gitea Actions](https://docs.gitea.com/usage/actions/overview) (compatible with GitHub Actions syntax). + +## Secrets + +| Secret | Used by | Purpose | +|--------|---------|---------| +| `GITEA_TOKEN` | `docker.yml` | Push images to the instance container registry | + +Create a token with **write:package** (and **read:repository** for checkout if needed). + +## Images + +- **CI / CPU**: `docker/Dockerfile.ci` — built on every push to `main` and tags `v*` +- **GPU**: `docker/Dockerfile` — build manually on NVIDIA hosts + +Registry URL: `${{ gitea.server_url }}/${{ gitea.repository }}` + +## Local parity + +```sh +sudo apt-get install -y build-essential cmake git +go test ./config/... ./punctuation/... ./transcode/... +make dependency && make test +docker build -f docker/Dockerfile.ci -t go-whisper-api:ci . +``` diff --git a/.gitea/workflows/codeql.yml b/.gitea/workflows/codeql.yml new file mode 100644 index 0000000..c2ffa10 --- /dev/null +++ b/.gitea/workflows/codeql.yml @@ -0,0 +1,32 @@ +# CodeQL requires GitHub.com or a Gitea instance with the codeql-action available. +# Disable this workflow on self-hosted Gitea if the action cannot be resolved. + +name: CodeQL + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + - cron: "41 23 * * 6" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + strategy: + fail-fast: false + matrix: + language: ["go"] + steps: + - uses: actions/checkout@v4 + - uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + continue-on-error: true + - uses: github/codeql-action/analyze@v3 + continue-on-error: true diff --git a/.gitea/workflows/docker.yml b/.gitea/workflows/docker.yml new file mode 100644 index 0000000..2a70dd3 --- /dev/null +++ b/.gitea/workflows/docker.yml @@ -0,0 +1,60 @@ +name: Docker Image + +on: + push: + branches: + - main + tags: + - "v*" + pull_request: + branches: + - main + +env: + # Gitea container registry: // + IMAGE: ${{ gitea.server_url }}/${{ gitea.repository }} + +jobs: + build-docker: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Gitea Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ gitea.server_url }} + username: ${{ gitea.actor }} + password: ${{ secrets.GITEA_TOKEN }} + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: | + ${{ env.IMAGE }} + tags: | + type=raw,value=latest,enable={{is_default_branch}} + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + + - name: Build and push (CPU) + uses: docker/build-push-action@v6 + with: + context: . + file: docker/Dockerfile.ci + platforms: linux/amd64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitea/workflows/goreleaser.yml b/.gitea/workflows/goreleaser.yml new file mode 100644 index 0000000..ff3c09e --- /dev/null +++ b/.gitea/workflows/goreleaser.yml @@ -0,0 +1,43 @@ +name: Release + +on: + push: + tags: + - "v*" + +permissions: + contents: write + packages: write + +jobs: + goreleaser: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Install build dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends build-essential cmake git + + - name: Build whisper dependency + run: make dependency + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: latest + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }} diff --git a/.gitea/workflows/lint.yml b/.gitea/workflows/lint.yml new file mode 100644 index 0000000..6eda352 --- /dev/null +++ b/.gitea/workflows/lint.yml @@ -0,0 +1,62 @@ +name: Lint and Testing + +on: + push: + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Hadolint Dockerfile (CI) + uses: hadolint/hadolint-action@v3.1.0 + with: + dockerfile: docker/Dockerfile.ci + + - name: Hadolint Dockerfile (GPU) + uses: hadolint/hadolint-action@v3.1.0 + with: + dockerfile: docker/Dockerfile + continue-on-error: true + + test: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Install build dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential cmake git libsentencepiece-dev pkg-config + + - name: Unit tests (no CGO whisper) + run: go test ./config/... ./punctuation/... ./transcode/... -count=1 + + - name: Build with xlm punctuation and run full tests + run: make dependency && make test TAGS=xlm + + golangci: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + args: --timeout=10m + continue-on-error: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..08d335d --- /dev/null +++ b/.gitignore @@ -0,0 +1,44 @@ +# Go build artifacts +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +go.work + +# Binaries and native libs (local build) +/bin/ +bin/ +/lib/ +lib/ + +# Runtime data (never commit) +config.yaml +.env +.env.* +cache/ + +# Packaged releases +dist/ +*.tar.gz + +# Downloaded models and ONNX weights +models/**/*.bin +models/**/*.onnx +models/**/*.model +models/*.bin +models/*.onnx +models/*.model + +# whisper.cpp submodule and its build tree +third_party/ + +# IDE and OS +.idea/ +.vscode/ +*.swp +*~ +.DS_Store +Thumbs.db diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..70e82c9 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,3 @@ +run: + skip-dirs: + - third_party/whisper.cpp diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..cbed619 --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,3 @@ +builds: +- skip: true + diff --git a/.hadolint.yaml b/.hadolint.yaml new file mode 100644 index 0000000..502b578 --- /dev/null +++ b/.hadolint.yaml @@ -0,0 +1,3 @@ +ignored: + - DL3018 + - DL3008 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2f70a03 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Bo-Yi Wu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8aa83ef --- /dev/null +++ b/Makefile @@ -0,0 +1,280 @@ +EXECUTABLE := go-whisper-api +GO ?= go +GOFILES := $(shell find . -name "*.go" -type f) +HAS_GO = $(shell hash $(GO) > /dev/null 2>&1 && echo "GO" || echo "NOGO" ) + +WHISPER_CPP := $(abspath third_party/whisper.cpp) +WHISPER_BUILD := $(WHISPER_CPP)/build +WHISPER_VENDOR := $(WHISPER_CPP)/bindings/go +WHISPER_LIBDIR := $(WHISPER_BUILD)/src:$(WHISPER_BUILD)/ggml/src + +RUNTIME_LIB_DIR := $(abspath lib) +# $ORIGIN/lib — binary next to lib/ (e.g. ./go-whisper-api + ./lib/) +# $ORIGIN/../lib — binary in bin/ (e.g. bin/go-whisper-api + lib/) +RUNTIME_RPATH := -Wl,-rpath,$$ORIGIN/lib:$$ORIGIN/../lib + +ifneq ($(shell uname), Darwin) + EXTLDFLAGS = -extldflags "$(RUNTIME_RPATH)" +else + EXTLDFLAGS = +endif + +ifeq ($(HAS_GO), GO) + GOPATH ?= $(shell $(GO) env GOPATH) + export PATH := $(GOPATH)/bin:$(PATH) + + CGO_EXTRA_CFLAGS := -DSQLITE_MAX_VARIABLE_NUMBER=32766 + CGO_CFLAGS ?= $(shell $(GO) env CGO_CFLAGS) $(CGO_EXTRA_CFLAGS) +endif + +ifeq ($(OS), Windows_NT) + GOFLAGS := -v -buildmode=exe + EXECUTABLE ?= $(EXECUTABLE).exe +else ifeq ($(OS), Windows) + GOFLAGS := -v -buildmode=exe + EXECUTABLE ?= $(EXECUTABLE).exe +else + GOFLAGS := -v + EXECUTABLE ?= $(EXECUTABLE) +endif + +ifneq ($(DRONE_TAG),) + VERSION ?= $(DRONE_TAG) +else + VERSION ?= $(shell git describe --tags --always || git rev-parse --short HEAD) +endif + +TAGS ?= +UNAME_M := $(shell uname -m) +ifeq ($(UNAME_M),x86_64) +SHERPA_LIBARCH := x86_64-unknown-linux-gnu +endif +ifeq ($(UNAME_M),aarch64) +SHERPA_LIBARCH := aarch64-unknown-linux-gnu +endif +SHERPA_LINUX_VER := $(shell awk '/sherpa-onnx-go-linux/ {print $$2; exit}' go.mod) +SHERPA_LIBDIR := $(GOPATH)/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH) +ifneq ($(shell uname), Darwin) +EXTLDFLAGS_SHERPA = -extldflags "$(RUNTIME_RPATH)" +EXTLDFLAGS_XLM = -extldflags "$(RUNTIME_RPATH)" +else +EXTLDFLAGS_SHERPA = +EXTLDFLAGS_XLM = +endif +GOLDFLAGS ?= -X 'main.Version=$(VERSION)' + +INCLUDE_PATH := $(WHISPER_CPP)/include:$(WHISPER_CPP)/ggml/include:$(WHISPER_VENDOR):$(INCLUDE_PATH) +LIBRARY_PATH := $(WHISPER_LIBDIR):$(LIBRARY_PATH) +export LD_LIBRARY_PATH := $(WHISPER_LIBDIR):$(LD_LIBRARY_PATH) + +ifdef WHISPER_CUBLAS + CGO_CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + CGO_CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + EXTLDFLAGS = -extldflags "-lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib" + +build: $(EXECUTABLE) + +$(EXECUTABLE): $(GOFILES) + CGO_CXXFLAGS=${CGO_CXXFLAGS} CGO_CFLAGS=${CGO_CFLAGS} C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@ +endif + +MODEL_URL ?= https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin +MODEL_PATH ?= models/ggml-tiny.en.bin +VAD_MODEL ?= silero-v6.2.0 +VAD_MODEL_PATH ?= models/ggml-silero-v6.2.0.bin + +all: build + +PUNCT_MODEL_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8.tar.bz2 +PUNCT_MODEL_DIR ?= models/punctuation/ct-transformer-zh-en-int8 + +XLM_PUNCT_DIR ?= models/punctuation/xlm-roberta +XLM_HF_REPO ?= Salama1429/xlm-roberta_punctuation_fullstop_truecase + +XLM_MODEL_CONFIG_SRC ?= config/xlm-roberta-model.yaml +ORT_LIB_SRC ?= $(shell $(GO) env GOMODCACHE 2>/dev/null)/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH)/libonnxruntime.so + +# Copy runtime .so into ./lib/. Binary rpath: $ORIGIN/lib or $ORIGIN/../lib (see RUNTIME_RPATH). +# Use cp -n where possible: existing root-owned libs in lib/ must not break the build. +install-runtime-libs: dependency + @mkdir -p "$(RUNTIME_LIB_DIR)" + @cp -an "$(WHISPER_BUILD)/src"/libwhisper.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true + @cp -an "$(WHISPER_BUILD)/ggml/src"/libggml*.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true + @echo "Whisper/ggml libs ready in $(RUNTIME_LIB_DIR)/" + +install-ort-lib: + @mkdir -p "$(RUNTIME_LIB_DIR)" + @if [ ! -f "$(ORT_LIB_SRC)" ]; then echo "missing $(ORT_LIB_SRC); run: go mod download"; exit 1; fi + @dest="$(RUNTIME_LIB_DIR)/libonnxruntime.so"; \ + if [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \ + echo "libonnxruntime.so already up to date in $(RUNTIME_LIB_DIR)/"; \ + elif cp -f "$(ORT_LIB_SRC)" "$$dest" 2>/dev/null; then \ + echo "Installed $$dest"; \ + elif [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \ + echo "libonnxruntime.so present in $(RUNTIME_LIB_DIR)/ (unchanged, not writable)"; \ + else \ + echo "cannot install libonnxruntime.so to $$dest"; \ + echo "fix: sudo chown -R $$USER:$$(id -gn) $(RUNTIME_LIB_DIR)"; \ + exit 1; \ + fi + +# XLM punctuation links -lsentencepiece; bundle .so for hosts without libsentencepiece0 package. +SP_LIB_DIRS := /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu /usr/lib64 /usr/lib +install-sp-lib: + @mkdir -p "$(RUNTIME_LIB_DIR)" + @found=0; \ + for d in $(SP_LIB_DIRS); do \ + if [ -e "$$d/libsentencepiece.so.0" ] || [ -L "$$d/libsentencepiece.so.0" ]; then \ + cp -an "$$d"/libsentencepiece.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true; \ + found=1; \ + break; \ + fi; \ + done; \ + if [ "$$found" = "0" ]; then \ + echo "libsentencepiece.so.0 not found; install: sudo apt-get install libsentencepiece0"; \ + exit 1; \ + fi + @test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0 after install-sp-lib"; exit 1) + @echo "Sentencepiece libs ready in $(RUNTIME_LIB_DIR)/" + +# If lib/*.so were created as root (e.g. manual cp with sudo), reclaim ownership for builds. +fix-lib-perms: + @if [ -d "$(RUNTIME_LIB_DIR)" ]; then \ + chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)" 2>/dev/null || \ + sudo chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)"; \ + echo "Ownership of $(RUNTIME_LIB_DIR)/ updated"; \ + fi + +install-runtime-libs-xlm: install-runtime-libs install-ort-lib install-sp-lib + +# Fail fast before deploy if ./lib is incomplete (lib/ is not in git: *.so is gitignored). +verify-runtime-libs-xlm: + @test -f bin/$(EXECUTABLE) || (echo "missing bin/$(EXECUTABLE); run: make build-xlm"; exit 1) + @test -f "$(RUNTIME_LIB_DIR)/libonnxruntime.so" || (echo "missing $(RUNTIME_LIB_DIR)/libonnxruntime.so; run: make install-runtime-libs-xlm"; exit 1) + @test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0; run: make install-sp-lib"; exit 1) + @test -e "$(RUNTIME_LIB_DIR)/libwhisper.so.1" || (echo "missing $(RUNTIME_LIB_DIR)/libwhisper.so.1; run: make install-runtime-libs"; exit 1) + @echo "Runtime libs OK in $(RUNTIME_LIB_DIR)/" + +RUNTIME_TARBALL := dist/go-whisper-api-runtime-$(shell uname -m).tar.gz +package-runtime-xlm: verify-runtime-libs-xlm + @mkdir -p dist + tar -czf "$(RUNTIME_TARBALL)" bin/$(EXECUTABLE) lib + @echo "Created $(RUNTIME_TARBALL) — on prod: tar -xzf ... -C /opt/go-whisper-api (keeps bin/ and lib/)" + +# Copy bundled label config (needs write access to $(XLM_PUNCT_DIR); fix with: sudo chown -R $$USER models/punctuation) +install-xlm-punctuation-config: + @mkdir -p $(XLM_PUNCT_DIR) + @cp "$(XLM_MODEL_CONFIG_SRC)" "$(XLM_PUNCT_DIR)/config.yaml" + @echo "Installed $(XLM_PUNCT_DIR)/config.yaml" + +download-xlm-punctuation-model: install-xlm-punctuation-config + @mkdir -p $(XLM_PUNCT_DIR) + @for f in model.onnx sp.model; do \ + if [ ! -f "$(XLM_PUNCT_DIR)/$$f" ]; then \ + echo "Downloading $$f from $(XLM_HF_REPO)..."; \ + curl -fL "https://huggingface.co/$(XLM_HF_REPO)/resolve/main/$$f" -o "$(XLM_PUNCT_DIR)/$$f"; \ + else \ + echo "Already have $(XLM_PUNCT_DIR)/$$f"; \ + fi; \ + done + +DIAR_SEG_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +DIAR_EMB_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx +DIAR_DIR ?= models/diarization + +download-diarization-models: + @mkdir -p $(DIAR_DIR) + @if [ ! -f "$(DIAR_DIR)/pyannote-segmentation-3-0/model.onnx" ]; then \ + echo "Downloading speaker segmentation model..."; \ + curl -fL "$(DIAR_SEG_URL)" -o /tmp/diar-seg.tar.bz2; \ + tar -xjf /tmp/diar-seg.tar.bz2 -C $(DIAR_DIR); \ + rm -f /tmp/diar-seg.tar.bz2; \ + else \ + echo "Segmentation model present"; \ + fi + @if [ ! -f "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" ]; then \ + echo "Downloading speaker embedding model..."; \ + curl -fL "$(DIAR_EMB_URL)" -o "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; \ + else \ + echo "Embedding model present"; \ + fi + +download-punctuation-model: + @mkdir -p models/punctuation + @if [ ! -f "$(PUNCT_MODEL_DIR)/model.int8.onnx" ]; then \ + echo "Downloading punctuation model..."; \ + curl -fL "$(PUNCT_MODEL_URL)" -o /tmp/punct-model.tar.bz2; \ + tar -xjf /tmp/punct-model.tar.bz2 -C models/punctuation; \ + rm -f /tmp/punct-model.tar.bz2; \ + if [ -d models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 ]; then \ + mv models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 "$(PUNCT_MODEL_DIR)"; \ + fi; \ + else \ + echo "Punctuation model already exists: $(PUNCT_MODEL_DIR)/model.int8.onnx"; \ + fi + +download-model: + @mkdir -p models + @if [ ! -f "$(MODEL_PATH)" ]; then \ + echo "Downloading $(MODEL_PATH)..."; \ + curl -fL "$(MODEL_URL)" -o "$(MODEL_PATH)"; \ + else \ + echo "Model already exists: $(MODEL_PATH)"; \ + fi + +download-vad-model: + @mkdir -p models + @if [ ! -f "$(VAD_MODEL_PATH)" ]; then \ + echo "Downloading VAD model $(VAD_MODEL) to models/..."; \ + ./third_party/whisper.cpp/models/download-vad-model.sh $(VAD_MODEL) models; \ + else \ + echo "VAD model already exists: $(VAD_MODEL_PATH)"; \ + fi + +clone: + @[ -d third_party/whisper.cpp ] || git clone https://github.com/appleboy/whisper.cpp.git third_party/whisper.cpp + +dependency: clone + @echo Build whisper + @if [ ! -f "$(WHISPER_BUILD)/src/libwhisper.so" ] && [ ! -f "$(WHISPER_BUILD)/src/libwhisper.a" ]; then \ + cmake -S "$(WHISPER_CPP)" -B "$(WHISPER_BUILD)" -DCMAKE_BUILD_TYPE=Release && \ + cmake --build "$(WHISPER_BUILD)" --config Release -j$$(nproc 2>/dev/null || echo 4); \ + else \ + echo "whisper library already built in $(WHISPER_BUILD)"; \ + fi + +test: + @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) test -v -cover -coverprofile coverage.txt ./... && echo "\n==>\033[32m Ok\033[m\n" || exit 1 + +install: $(GOFILES) + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) install -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' + +build: install-runtime-libs $(EXECUTABLE) + +# Build with sherpa-onnx (punctuation + speaker diarization) +build-sherpa: + @$(MAKE) build TAGS=sherpa + +# XLM-RoBERTa punctuation (47 languages); requires libsentencepiece-dev +build-xlm: + @$(MAKE) install-runtime-libs-xlm + @$(MAKE) build TAGS=xlm + +$(EXECUTABLE): $(GOFILES) +ifneq (,$(findstring xlm,$(TAGS))) + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_XLM) -s -w $(GOLDFLAGS)' -o bin/$@ +else ifneq (,$(findstring sherpa,$(TAGS))) + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_SHERPA) -s -w $(GOLDFLAGS)' -o bin/$@ +else + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@ +endif + +clean: + $(GO) clean -x -i ./... + rm -rf coverage.txt $(EXECUTABLE) $(DIST) bin/$(EXECUTABLE) + +clean-whisper: + rm -rf "$(WHISPER_BUILD)" + +version: + @echo $(VERSION) diff --git a/README.md b/README.md new file mode 100644 index 0000000..5df4305 --- /dev/null +++ b/README.md @@ -0,0 +1,129 @@ +# go-whisper-api + +HTTP-сервер распознавания речи на базе [whisper.cpp](https://github.com/ggerganov/whisper.cpp): совместимость с **SPR** (`/spr/*`) и **OpenAI Whisper API** (`/v1/*`) для Open WebUI и других клиентов. + +## Модели Whisper (ggml) + +См. [каталог моделей на Hugging Face](https://huggingface.co/ggerganov/whisper.cpp/tree/main). + +```sh +make download-model +# или вручную: +curl -LJ https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin \ + --output models/ggml-small.bin +``` + +Веса STT: `*.bin` в корне `api.models_dir` (список — `GET /spr/models`). Подкаталоги `models/vad/`, `models/punctuation/` и т.п. в список моделей не входят. + +## Конфигурация + +Шаблон в репозитории — `config.yaml.example`. Рабочий файл создайте локально (он в `.gitignore`): + +```sh +cp config.yaml.example config.yaml +``` + +Если `--config` не указан и в текущем каталоге есть `config.yaml`, он подгружается автоматически. + +Промежуточные файлы (загрузки, конвертация) — только в `/tmp`, удаляются после обработки. Форматы **wav, mp3, flac, ogg, m4a, mp4, aac** декодируются на чистом Go (без ffmpeg). + +## Запуск + +```sh +go-whisper-api serve --config config.yaml +# или с флагами (перекрывают YAML): +go-whisper-api --models-dir ./models --addr :8080 +``` + +Docker: + +```sh +docker run -p 8080:8080 \ + -v $PWD/models:/app/models \ + -v $PWD/config.yaml:/app/config.yaml \ + ghcr.io/appleboy/go-whisper-api:latest \ + serve --config /app/config.yaml +``` + +Swagger UI: [http://localhost:8080/](http://localhost:8080/) + +## VAD (детекция речи) + +```sh +make download-vad-model +``` + +```yaml +api: + models_dir: ./models + vad: + enabled: true + model: ggml-silero-v6.2.0.bin +``` + +Полезно на длинных записях с тишиной или музыкой в начале/конце. + +## Пунктуация + +| Уровень | Параметр | Эффект | +|---------|----------|--------| +| Мастер | `punctuation.enabled: false` | Пунктуация отключена | +| API по умолчанию | `api.default_punctuation: true` | Если нет `?punctuation=` | +| На запрос | `?punctuation=1` / `?punctuation=0` | Только при `enabled: true` | + +Движки: `heuristic`, `xlm`, `sherpa`, `sherpa-online`, `http`, `off`. + +```sh +curl -X POST "http://localhost:8080/spr/stt/ggml-small?punctuation=1" -F "wav=@audio.wav" +``` + +Сборка XLM: `make download-xlm-punctuation-model && make build-xlm`. Продакшен: `make package-runtime-xlm` — в git не коммитятся `lib/*.so` (см. `.gitignore`). + +Фильтр артефактов: `api.garbage` (по умолчанию `*выбая*`), отключить: `garbage: []`. + +## HTTP API + +- **SPR** (`/spr/*`) — очередь, waveform, импорт/экспорт моделей +- **OpenAI** (`/v1/*`) — синхронный STT + +Пример SPR: + +```sh +curl -X POST "http://localhost:8080/spr/stt/ggml-small?async=0" -F "wav=@audio.wav" +``` + +По умолчанию STT **асинхронный**: `POST /spr/stt/{id}` → `taskID`, опрос `GET /spr/queue/{taskID}`, результат `GET /spr/result/{taskID}`. Кэш: `./cache/waiting/` и `./cache/ready/`. + +Язык по умолчанию — **`ru`** (`?language=en`, `?language=auto`). + +| Endpoint | Метод | Описание | +|----------|--------|-------------| +| `/spr/models` | GET | Список моделей | +| `/spr/stt/{id}` | POST | Транскрипция | +| `/spr/result/{taskID}` | GET | Результат | +| `/spr/queue` | GET | Все задачи | +| `/spr/queue/{taskID}` | GET / DELETE | Статус / удаление | +| `/v1/audio/transcriptions` | POST | OpenAI-совместимая транскрипция | +| `/v1/models` | GET | Список моделей | + +## Open WebUI + +| Параметр | Значение | +|----------|----------| +| Engine | `OpenAI` | +| API Base URL | `http://:6183/v1` | +| API Key | любая непустая строка | +| STT Model | `whisper-1` (см. `api.default_model` в config) | + +```yaml +api: + default_model: ggml-large-v3-turbo +``` + +## CI/CD + +Workflows в `.gitea/workflows/` (`lint.yml`, `docker.yml`). Секрет `GITEA_TOKEN` с `write:package` для push образов. + +```sh +docker build -f docker/Dockerfile.ci -t go-whisper-api . +``` diff --git a/api/cache.go b/api/cache.go new file mode 100644 index 0000000..1c435e0 --- /dev/null +++ b/api/cache.go @@ -0,0 +1,362 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "go-whisper-api/whisper" +) + +const ( + cacheWaiting = "waiting" + cacheReady = "ready" + fileParams = "params.conf" + fileAudio = "audio.wav" + fileAudioJSON = "audio.json" +) + +type TaskParams struct { + ID string `json:"id"` + Created string `json:"created"` + Processed string `json:"processed,omitempty"` + Status string `json:"status"` + Model string `json:"model"` + Language string `json:"language,omitempty"` + Punctuation bool `json:"punctuation,omitempty"` + Speakers bool `json:"speakers,omitempty"` + NumClusters int `json:"num_clusters,omitempty"` + Text string `json:"text,omitempty"` + Words []whisper.Word `json:"words,omitempty"` + Error string `json:"error,omitempty"` +} + +type AudioJSON struct { + Waveform []float64 `json:"waveform"` + Buckets int `json:"buckets"` +} + +type DiskCache struct { + root string +} + +func (c *DiskCache) Root() string { + return c.root +} + +func resolveCacheRoot(root string) (string, error) { + if root == "" { + root = "./cache" + } + abs, err := filepath.Abs(root) + if err != nil { + return "", fmt.Errorf("cache dir: %w", err) + } + return filepath.Clean(abs), nil +} + +func NewDiskCache(root string) (*DiskCache, error) { + abs, err := resolveCacheRoot(root) + if err != nil { + return nil, err + } + c := &DiskCache{root: abs} + for _, sub := range []string{cacheWaiting, cacheReady} { + if err := os.MkdirAll(filepath.Join(abs, sub), 0o755); err != nil { + return nil, err + } + } + return c, nil +} + +func (c *DiskCache) waitingDir(id string) string { + return filepath.Join(c.root, cacheWaiting, id) +} + +func (c *DiskCache) readyDir(id string) string { + return filepath.Join(c.root, cacheReady, id) +} + +func (c *DiskCache) locate(id string) (dir, phase string, ok bool) { + if id == "" { + return "", "", false + } + ready := c.readyDir(id) + if st, err := os.Stat(ready); err == nil && st.IsDir() { + return ready, cacheReady, true + } + waiting := c.waitingDir(id) + if st, err := os.Stat(waiting); err == nil && st.IsDir() { + return waiting, cacheWaiting, true + } + return "", "", false +} + +func (c *DiskCache) Enqueue(id string, params TaskParams, audioWavPath string) error { + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + params.ID = id + if err := c.writeParams(dir, params); err != nil { + _ = os.RemoveAll(dir) + return err + } + dst := filepath.Join(dir, fileAudio) + if err := os.Rename(audioWavPath, dst); err != nil { + if err2 := copyFile(audioWavPath, dst); err2 != nil { + _ = os.RemoveAll(dir) + return fmt.Errorf("move audio to %s: %w", dst, err2) + } + _ = os.Remove(audioWavPath) + } + return nil +} + +func (c *DiskCache) writeParams(dir string, params TaskParams) error { + data, err := json.Marshal(params) + if err != nil { + return err + } + return os.WriteFile(filepath.Join(dir, fileParams), data, 0o644) +} + +func (c *DiskCache) LoadParams(id string) (TaskParams, string, error) { + dir, phase, ok := c.locate(id) + if !ok { + return TaskParams{}, "", fmt.Errorf("task not found") + } + params, err := c.readParams(dir) + if err != nil { + return TaskParams{}, "", err + } + return params, phase, nil +} + +func (c *DiskCache) readParams(dir string) (TaskParams, error) { + path := filepath.Join(dir, fileParams) + data, err := os.ReadFile(path) + if err != nil { + return TaskParams{}, fmt.Errorf("read %s: %w", path, err) + } + var params TaskParams + if err := json.Unmarshal(data, ¶ms); err != nil { + return TaskParams{}, err + } + return params, nil +} + +func (c *DiskCache) List() (map[string]map[string]string, error) { + out := make(map[string]map[string]string) + for _, phase := range []string{cacheWaiting, cacheReady} { + base := filepath.Join(c.root, phase) + entries, err := os.ReadDir(base) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + for _, e := range entries { + if !e.IsDir() { + continue + } + id := e.Name() + params, err := c.readParams(filepath.Join(base, id)) + if err != nil { + continue + } + out[id] = map[string]string{ + "created": params.Created, + "status": params.Status, + } + } + } + return out, nil +} + +func (c *DiskCache) NextWaiting() (string, bool, error) { + base := filepath.Join(c.root, cacheWaiting) + entries, err := os.ReadDir(base) + if err != nil { + if os.IsNotExist(err) { + return "", false, nil + } + return "", false, err + } + type item struct { + id string + created time.Time + } + var pending []item + for _, e := range entries { + if !e.IsDir() { + continue + } + params, err := c.readParams(filepath.Join(base, e.Name())) + if err != nil || params.Status != string(statusWaiting) { + continue + } + t, _ := time.ParseInLocation("2006-01-02 15:04:05", params.Created, time.Local) + pending = append(pending, item{id: e.Name(), created: t}) + } + if len(pending) == 0 { + return "", false, nil + } + sort.Slice(pending, func(i, j int) bool { + return pending[i].created.Before(pending[j].created) + }) + return pending[0].id, true, nil +} + +func (c *DiskCache) SetStatus(id string, status taskStatus, mutate func(*TaskParams)) error { + dir := c.waitingDir(id) + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("task %s not in waiting", id) + } + params, err := c.readParams(dir) + if err != nil { + return err + } + params.Status = string(status) + if mutate != nil { + mutate(¶ms) + } + return c.writeParams(dir, params) +} + +func (c *DiskCache) FinishWaiting(id string, result whisper.TranscriptResult, errMsg string, waveform []float64) error { + dir := c.waitingDir(id) + params, err := c.readParams(dir) + if err != nil { + return err + } + params.Processed = time.Now().Format("2006-01-02 15:04:05") + if errMsg != "" { + params.Status = string(statusError) + params.Error = errMsg + } else { + params.Status = string(statusReady) + params.Text = result.Text + params.Words = result.Words + } + if err := c.writeParams(dir, params); err != nil { + return fmt.Errorf("update %s: %w", filepath.Join(dir, fileParams), err) + } + if len(waveform) > 0 { + aj := AudioJSON{Waveform: waveform, Buckets: len(waveform)} + data, err := json.Marshal(aj) + if err != nil { + return err + } + if err := os.WriteFile(filepath.Join(dir, fileAudioJSON), data, 0o644); err != nil { + return err + } + } + return nil +} + +func (c *DiskCache) PromoteToReady(id string) error { + if _, phase, ok := c.locate(id); ok && phase == cacheReady { + return nil + } + src := c.waitingDir(id) + dst := c.readyDir(id) + if _, err := os.Stat(src); err != nil { + return fmt.Errorf("task %s not in waiting", id) + } + if _, err := os.Stat(dst); err == nil { + return os.RemoveAll(src) + } + return os.Rename(src, dst) +} + +func (c *DiskCache) Delete(id string) bool { + dir, _, ok := c.locate(id) + if !ok { + return false + } + _ = os.RemoveAll(dir) + return true +} + +func (c *DiskCache) AudioPath(id string) (string, bool) { + dir, _, ok := c.locate(id) + if !ok { + return "", false + } + p := filepath.Join(dir, fileAudio) + if _, err := os.Stat(p); err != nil { + return "", false + } + return p, true +} + +func (c *DiskCache) Waveform(id string) ([]float64, error) { + dir, _, ok := c.locate(id) + if !ok { + return nil, fmt.Errorf("task not found") + } + data, err := os.ReadFile(filepath.Join(dir, fileAudioJSON)) + if err == nil { + var aj AudioJSON + if json.Unmarshal(data, &aj) == nil && len(aj.Waveform) > 0 { + return aj.Waveform, nil + } + } + return waveformFromWav(filepath.Join(dir, fileAudio), 512) +} + +func (c *DiskCache) RecoverInterrupted() error { + base := filepath.Join(c.root, cacheWaiting) + entries, err := os.ReadDir(base) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + for _, e := range entries { + if !e.IsDir() { + continue + } + id := e.Name() + dir := filepath.Join(base, id) + params, err := c.readParams(dir) + if err != nil { + continue + } + switch params.Status { + case string(statusProcessing): + params.Status = string(statusWaiting) + _ = c.writeParams(dir, params) + case string(statusReady), string(statusError): + _ = c.PromoteToReady(id) + } + } + return nil +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + _, err = io.Copy(out, in) + return err +} + +func isValidTaskID(id string) bool { + return id != "" && !strings.Contains(id, "..") && !strings.ContainsAny(id, `/\`) +} diff --git a/api/cache_test.go b/api/cache_test.go new file mode 100644 index 0000000..2860c81 --- /dev/null +++ b/api/cache_test.go @@ -0,0 +1,127 @@ +package api + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiskCache_params_promote_list(t *testing.T) { + root := t.TempDir() + c, err := NewDiskCache(root) + if err != nil { + t.Fatal(err) + } + + id := "319d72c7-301d-44fd-935f-3526dfb70f9f" + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + params := TaskParams{ + ID: id, + Created: "2026-03-31 21:37:46", + Status: string(statusReady), + Model: "ggml-small", + Text: "test", + } + if err := c.writeParams(dir, params); err != nil { + t.Fatal(err) + } + + list, err := c.List() + if err != nil { + t.Fatal(err) + } + if list[id]["status"] != string(statusReady) { + t.Fatalf("list: %v", list[id]) + } + + if err := c.PromoteToReady(id); err != nil { + t.Fatal(err) + } + _, phase, err := c.LoadParams(id) + if err != nil || phase != cacheReady { + t.Fatalf("phase=%s err=%v", phase, err) + } + if _, err := os.Stat(filepath.Join(c.readyDir(id), fileParams)); err != nil { + t.Fatal(err) + } +} + +func TestIsValidTaskID(t *testing.T) { + if !isValidTaskID("319d72c7-301d-44fd-935f-3526dfb70f9f") { + t.Fatal("uuid should be valid") + } + if isValidTaskID("../etc") { + t.Fatal("path traversal must be rejected") + } +} + +func TestResolveCacheRoot_absolute(t *testing.T) { + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + abs, err := resolveCacheRoot("./cache") + if err != nil { + t.Fatal(err) + } + want := filepath.Join(cwd, "cache") + if abs != want { + t.Fatalf("got %q want %q", abs, want) + } +} + +func TestDiskCache_RecoverInterrupted(t *testing.T) { + root := t.TempDir() + c, err := NewDiskCache(root) + if err != nil { + t.Fatal(err) + } + id := "task-1" + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := c.writeParams(dir, TaskParams{ + ID: id, Created: "2026-01-01 00:00:00", Status: string(statusProcessing), Model: "m", + }); err != nil { + t.Fatal(err) + } + if err := c.RecoverInterrupted(); err != nil { + t.Fatal(err) + } + p, _, err := c.LoadParams(id) + if err != nil { + t.Fatal(err) + } + if p.Status != string(statusWaiting) { + t.Fatalf("got %q", p.Status) + } +} + +func TestDiskCache_RecoverInterrupted_promotesReady(t *testing.T) { + root := t.TempDir() + c, err := NewDiskCache(root) + if err != nil { + t.Fatal(err) + } + id := "done-task" + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := c.writeParams(dir, TaskParams{ + ID: id, Created: "2026-01-01 00:00:00", Status: string(statusReady), Model: "m", Text: "hi", + }); err != nil { + t.Fatal(err) + } + if err := c.RecoverInterrupted(); err != nil { + t.Fatal(err) + } + _, phase, err := c.LoadParams(id) + if err != nil || phase != cacheReady { + t.Fatalf("phase=%s err=%v", phase, err) + } +} diff --git a/api/garbage.go b/api/garbage.go new file mode 100644 index 0000000..7ec8330 --- /dev/null +++ b/api/garbage.go @@ -0,0 +1,26 @@ +package api + +import ( + "go-whisper-api/garbage" + "go-whisper-api/whisper" +) + +func applyGarbage(r whisper.TranscriptResult, patterns []string) whisper.TranscriptResult { + if len(patterns) == 0 { + return r + } + r.Text = garbage.FilterText(r.Text, patterns) + if len(r.Words) == 0 { + return r + } + gw := make([]garbage.Word, len(r.Words)) + for i, w := range r.Words { + gw[i] = garbage.Word{Word: w.Word, Start: w.Start, Stop: w.Stop} + } + gw = garbage.FilterWords(gw, patterns) + r.Words = make([]whisper.Word, len(gw)) + for i, w := range gw { + r.Words[i] = whisper.Word{Word: w.Word, Start: w.Start, Stop: w.Stop} + } + return r +} diff --git a/api/models.go b/api/models.go new file mode 100644 index 0000000..7345254 --- /dev/null +++ b/api/models.go @@ -0,0 +1,181 @@ +package api + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" +) + +// Subdirectories under models_dir that hold non-Whisper assets (VAD, punctuation, etc.). +var reservedModelSubdirs = map[string]struct{}{ + "vad": {}, + "punctuation": {}, +} + +// Top-level .bin files that are not Whisper STT models. +var excludedWhisperModelFiles = map[string]struct{}{ + "vad.bin": {}, +} + +type Registry struct { + dir string + mu sync.RWMutex +} + +func NewRegistry(dir string) *Registry { + return &Registry{dir: dir} +} + +func isReservedModelSubdir(name string) bool { + _, ok := reservedModelSubdirs[strings.ToLower(name)] + return ok +} + +func isWhisperModelFile(name string) bool { + if !strings.HasSuffix(strings.ToLower(name), ".bin") { + return false + } + _, excluded := excludedWhisperModelFiles[strings.ToLower(name)] + return !excluded +} + +func (r *Registry) List() ([]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + entries, err := os.ReadDir(r.dir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var models []string + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !isWhisperModelFile(name) { + continue + } + models = append(models, strings.TrimSuffix(name, filepath.Ext(name))) + } + return models, nil +} + +func (r *Registry) Path(id string) (string, error) { + if id == "" { + return "", fmt.Errorf("model id is required") + } + if strings.Contains(id, "/") || strings.Contains(id, "..") { + return "", fmt.Errorf("invalid model id") + } + if isReservedModelSubdir(id) { + return "", fmt.Errorf("model %q not found", id) + } + r.mu.RLock() + defer r.mu.RUnlock() + candidates := []string{ + filepath.Join(r.dir, id+".bin"), + filepath.Join(r.dir, id), + filepath.Join(r.dir, "ggml-"+id+".bin"), + } + for _, p := range candidates { + if st, err := os.Stat(p); err == nil && !st.IsDir() && isWhisperModelFile(filepath.Base(p)) { + return p, nil + } + } + return "", fmt.Errorf("model %q not found", id) +} + +func (r *Registry) Delete(id string) error { + p, err := r.Path(id) + if err != nil { + return err + } + r.mu.Lock() + defer r.mu.Unlock() + return os.Remove(p) +} + +func (r *Registry) Import(id string, src io.Reader) error { + if id == "" { + return fmt.Errorf("model id is required") + } + if strings.Contains(id, "/") || strings.Contains(id, "..") { + return fmt.Errorf("invalid model id") + } + if isReservedModelSubdir(id) { + return fmt.Errorf("invalid model id") + } + r.mu.Lock() + defer r.mu.Unlock() + if err := os.MkdirAll(r.dir, 0o755); err != nil { + return err + } + dst := filepath.Join(r.dir, id+".bin") + f, err := os.Create(dst) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(f, src); err != nil { + os.Remove(dst) + return err + } + return nil +} + +func (r *Registry) Open(id string) (*os.File, error) { + p, err := r.Path(id) + if err != nil { + return nil, err + } + return os.Open(p) +} + +// Resolve maps an OpenAI-style model name to a local whisper model id. +func (r *Registry) Resolve(id, defaultModel string) (string, error) { + id = strings.TrimSpace(id) + if id == "" { + id = strings.TrimSpace(defaultModel) + } + if id == "" { + models, err := r.List() + if err != nil { + return "", err + } + if len(models) == 0 { + return "", fmt.Errorf("model is required") + } + return models[0], nil + } + if id == "whisper-1" { + if dm := strings.TrimSpace(defaultModel); dm != "" { + if _, err := r.Path(dm); err == nil { + return dm, nil + } + } + models, err := r.List() + if err != nil { + return "", err + } + if len(models) == 0 { + return "", fmt.Errorf("no whisper models installed") + } + return models[0], nil + } + if _, err := r.Path(id); err == nil { + return id, nil + } + if alt := strings.TrimPrefix(id, "whisper-"); alt != id { + if _, err := r.Path(alt); err == nil { + return alt, nil + } + } + return "", fmt.Errorf("model %q not found", id) +} diff --git a/api/models_test.go b/api/models_test.go new file mode 100644 index 0000000..21a65ad --- /dev/null +++ b/api/models_test.go @@ -0,0 +1,55 @@ +package api + +import ( + "os" + "path/filepath" + "testing" +) + +func TestRegistry_List_excludesAuxiliaryDirs(t *testing.T) { + root := t.TempDir() + for _, name := range []string{"common.bin", "vad/vad.bin", "punctuation/model.onnx"} { + p := filepath.Join(root, name) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(p, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + } + r := NewRegistry(root) + list, err := r.List() + if err != nil { + t.Fatal(err) + } + if len(list) != 1 || list[0] != "common" { + t.Fatalf("list=%v want [common]", list) + } +} + +func TestRegistry_List_excludesVADAtRoot(t *testing.T) { + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "vad.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + list, err := NewRegistry(root).List() + if err != nil { + t.Fatal(err) + } + if len(list) != 1 || list[0] != "ggml-small" { + t.Fatalf("list=%v", list) + } +} + +func TestRegistry_Path_rejectsReservedIDs(t *testing.T) { + root := t.TempDir() + r := NewRegistry(root) + for _, id := range []string{"vad", "punctuation"} { + if _, err := r.Path(id); err == nil { + t.Fatalf("expected error for %q", id) + } + } +} diff --git a/api/openai.go b/api/openai.go new file mode 100644 index 0000000..f1658f0 --- /dev/null +++ b/api/openai.go @@ -0,0 +1,121 @@ +package api + +import ( + "encoding/json" + "net/http" + "strings" + "time" +) + +func (s *Server) handleOpenAITranscriptions(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if err := r.ParseMultipartForm(128 << 20); err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + modelID, err := s.models.Resolve(r.FormValue("model"), s.cfg.DefaultModel) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + modelPath, err := s.models.Path(modelID) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + audioPath, cleanup, err := s.saveUploadedOpenAI(r) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + defer cleanup() + + stt := s.parseOpenAISTTOptions(r) + result, err := s.transcribe(r.Context(), modelPath, audioPath, stt) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, err.Error()) + return + } + + switch strings.ToLower(strings.TrimSpace(r.FormValue("response_format"))) { + case "text": + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(result.Text)) + default: + writeJSON(w, http.StatusOK, map[string]string{"text": result.Text}) + } +} + +func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + ids, err := s.models.List() + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, err.Error()) + return + } + now := time.Now().Unix() + data := make([]map[string]any, 0, len(ids)) + for _, id := range ids { + data = append(data, map[string]any{ + "id": id, + "object": "model", + "created": now, + "owned_by": "go-whisper-api", + }) + } + writeJSON(w, http.StatusOK, map[string]any{ + "object": "list", + "data": data, + }) +} + +func (s *Server) parseOpenAISTTOptions(r *http.Request) sttOptions { + lang := strings.TrimSpace(r.FormValue("language")) + if lang == "" { + lang = s.cfg.Language + } + return sttOptions{ + language: lang, + punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation), + } +} + +func (s *Server) saveUploadedOpenAI(r *http.Request) (path string, cleanup func(), err error) { + if r.MultipartForm == nil { + if err := r.ParseMultipartForm(128 << 20); err != nil { + return "", nil, err + } + } + return saveUploadedRawFields(r, []string{"file", "audio", "wav"}) +} + +func writeOpenAIError(w http.ResponseWriter, code int, msg string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": msg, + "type": openAIErrorType(code), + }, + }) +} + +func openAIErrorType(code int) string { + switch code { + case http.StatusBadRequest: + return "invalid_request_error" + case http.StatusUnauthorized: + return "authentication_error" + case http.StatusNotFound: + return "not_found_error" + default: + return "server_error" + } +} diff --git a/api/openai_test.go b/api/openai_test.go new file mode 100644 index 0000000..cd81175 --- /dev/null +++ b/api/openai_test.go @@ -0,0 +1,90 @@ +package api + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "go-whisper-api/config" +) + +func TestRegistryResolve(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "ggml-large-v3-turbo.bin"), []byte("y"), 0o644); err != nil { + t.Fatal(err) + } + reg := NewRegistry(dir) + + id, err := reg.Resolve("whisper-1", "ggml-large-v3-turbo") + if err != nil || id != "ggml-large-v3-turbo" { + t.Fatalf("whisper-1: id=%q err=%v", id, err) + } + id, err = reg.Resolve("whisper-large-v3-turbo", "") + if err != nil || id != "large-v3-turbo" { + t.Fatalf("whisper-large-v3-turbo: id=%q err=%v", id, err) + } + id, err = reg.Resolve("ggml-small", "") + if err != nil || id != "ggml-small" { + t.Fatalf("ggml-small: id=%q err=%v", id, err) + } + id, err = reg.Resolve("", "ggml-small") + if err != nil || id != "ggml-small" { + t.Fatalf("empty with default: id=%q err=%v", id, err) + } +} + +func TestHandleOpenAITranscriptionsMissingFile(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + srv := &Server{ + cfg: config.API{ModelsDir: dir, Language: "ru"}, + models: NewRegistry(dir), + } + body := &bytes.Buffer{} + w := multipart.NewWriter(body) + _ = w.WriteField("model", "ggml-small") + w.Close() + + req := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", body) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + srv.handleOpenAITranscriptions(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "file") { + t.Fatalf("expected file error, got %s", rec.Body.String()) + } +} + +func TestHandleOpenAIModels(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + srv := &Server{ + cfg: config.API{ModelsDir: dir, Language: "ru"}, + models: NewRegistry(dir), + } + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + rec := httptest.NewRecorder() + srv.handleOpenAIModels(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "ggml-small") { + t.Fatalf("expected model list, got %s", rec.Body.String()) + } +} diff --git a/api/queue_worker.go b/api/queue_worker.go new file mode 100644 index 0000000..a93737f --- /dev/null +++ b/api/queue_worker.go @@ -0,0 +1,98 @@ +package api + +import ( + "context" + "time" + + "go-whisper-api/whisper" + + "github.com/rs/zerolog/log" +) + +func (s *Server) StartWorker(ctx context.Context) { + go s.queueWorker(ctx) +} + +func (s *Server) queueWorker(ctx context.Context) { + const idlePoll = 2 * time.Second + for { + if ctx.Err() != nil { + return + } + id, ok, err := s.cache.NextWaiting() + if err != nil { + log.Error().Err(err).Msg("cache queue scan") + if !sleepOrWake(ctx, s.queueWake, time.Second) { + return + } + continue + } + if ok { + s.processCacheTask(ctx, id) + continue + } + if !sleepOrWake(ctx, s.queueWake, idlePoll) { + return + } + } +} + +func sleepOrWake(ctx context.Context, wake <-chan struct{}, d time.Duration) bool { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-wake: + return true + case <-timer.C: + return true + } +} + +func (s *Server) processCacheTask(ctx context.Context, id string) { + params, _, err := s.cache.LoadParams(id) + if err != nil { + return + } + modelPath, err := s.models.Path(params.Model) + if err != nil { + s.completeCacheTask(id, whisper.TranscriptResult{}, err.Error()) + return + } + if err := s.cache.SetStatus(id, statusProcessing, nil); err != nil { + log.Error().Err(err).Str("task", id).Msg("set processing") + return + } + audioPath, ok := s.cache.AudioPath(id) + if !ok { + s.completeCacheTask(id, whisper.TranscriptResult{}, "audio file missing") + return + } + stt := sttOptions{ + language: params.Language, + punctuate: params.Punctuation, + speakers: params.Speakers, + numClusters: params.NumClusters, + } + if stt.language == "" { + stt.language = s.cfg.Language + } + result, err := s.transcribe(ctx, modelPath, audioPath, stt) + if err != nil { + s.completeCacheTask(id, result, err.Error()) + log.Error().Err(err).Str("task", id).Msg("async transcribe") + return + } + s.completeCacheTask(id, result, "") +} + +func (s *Server) completeCacheTask(id string, result whisper.TranscriptResult, errMsg string) { + if err := s.cache.FinishWaiting(id, result, errMsg, nil); err != nil { + log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("finish task") + return + } + if err := s.cache.PromoteToReady(id); err != nil { + log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("promote to ready") + } +} diff --git a/api/result.go b/api/result.go new file mode 100644 index 0000000..180f9ba --- /dev/null +++ b/api/result.go @@ -0,0 +1,28 @@ +package api + +import "go-whisper-api/whisper" + +// sprResultReady builds GET /spr/result/{taskID} body for completed tasks (SPR-compatible). +func sprResultReady(params TaskParams) map[string]any { + words := params.Words + if words == nil { + words = []whisper.Word{} + } + return map[string]any{ + "model": params.Model, + "text": params.Text, + "words": words, + "toxicity": map[string]float64{ + "insult": 0, + "obscenity": 0, + "threat": 0, + "politeness": 0, + }, + "emotion": map[string]any{}, + "voice_analysis": map[string]any{}, + "status": "ready", + "taskID": params.ID, + "created": params.Created, + "processed": params.Processed, + } +} diff --git a/api/result_test.go b/api/result_test.go new file mode 100644 index 0000000..7b6adc9 --- /dev/null +++ b/api/result_test.go @@ -0,0 +1,26 @@ +package api + +import "testing" + +func TestSprResultReady(t *testing.T) { + body := sprResultReady(TaskParams{ + ID: "701b3e84-5815-4baf-97a2-933ff820f16d", + Created: "2026-06-03 10:06:35", + Processed: "2026-06-03 10:07:48", + Status: "ready", + Model: "common", + Text: "hello", + }) + for _, key := range []string{"model", "text", "words", "toxicity", "emotion", "voice_analysis", "status", "taskID", "created", "processed"} { + if body[key] == nil { + t.Fatalf("missing %q", key) + } + } + if body["status"] != "ready" { + t.Fatalf("status=%v", body["status"]) + } + tox, ok := body["toxicity"].(map[string]float64) + if !ok || len(tox) != 4 { + t.Fatalf("toxicity=%v", body["toxicity"]) + } +} diff --git a/api/server.go b/api/server.go new file mode 100644 index 0000000..cbcf66d --- /dev/null +++ b/api/server.go @@ -0,0 +1,640 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "go-whisper-api/config" + "go-whisper-api/diarization" + "go-whisper-api/punctuation" + "go-whisper-api/transcode" + "go-whisper-api/whisper" + + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +type Server struct { + cfg config.API + punctCfg config.Punctuation + diarCfg config.Diarization + transcode *transcode.Engine + modelPool *whisper.ModelPool + punct punctuation.Restorer + diarizer diarization.Engine + models *Registry + cache *DiskCache + mux *http.ServeMux + queueWake chan struct{} +} + +func NewServer(cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) (*Server, error) { + cfg = cfg.WithDefaults() + tc = tc.WithDefaults() + pc = pc.WithDefaults() + dc = dc.WithDefaults() + restorer, err := punctuation.New(pc) + if err != nil { + return nil, err + } + if pc.Active() && !restorer.Active() { + return nil, fmt.Errorf("punctuation is enabled but engine %q is not available", pc.Engine) + } + if cfg.ModelsDir == "" { + cfg.ModelsDir = "./models" + } + if cfg.Addr == "" { + cfg.Addr = ":8080" + } + if cfg.Threads == 0 { + cfg.Threads = uint(runtime.NumCPU()) + } + cfg = cfg.WithDefaults() + if cfg.MaxContext == 0 { + cfg.MaxContext = 32 + } + if cfg.BeamSize == 0 { + cfg.BeamSize = 5 + } + if cfg.EntropyThold == 0 { + cfg.EntropyThold = 2.4 + } + if err := os.MkdirAll(cfg.ModelsDir, 0o755); err != nil { + return nil, err + } + cacheDir := cfg.CacheDir + if cacheDir == "" { + cacheDir = "./cache" + } + cache, err := NewDiskCache(cacheDir) + if err != nil { + return nil, err + } + cfg.CacheDir = cache.Root() + if err := cache.RecoverInterrupted(); err != nil { + return nil, err + } + diar, err := diarization.New(dc) + if err != nil { + return nil, err + } + s := &Server{ + cfg: cfg, + punctCfg: pc, + diarCfg: dc, + transcode: transcode.NewEngine(tc.FFmpegPath), + modelPool: whisper.NewModelPool(), + punct: restorer, + diarizer: diar, + models: NewRegistry(cfg.ModelsDir), + cache: cache, + mux: http.NewServeMux(), + queueWake: make(chan struct{}, 1), + } + s.routes() + go s.warmModels() + return s, nil +} + +func (s *Server) warmModels() { + ids, err := s.models.List() + if err != nil { + return + } + for _, id := range ids { + path, err := s.models.Path(id) + if err != nil { + continue + } + if err := s.modelPool.WithModel(path, func(whisper.Model) error { return nil }); err != nil { + log.Warn().Err(err).Str("model", id).Msg("preload whisper model") + } else { + log.Info().Str("model", id).Msg("whisper model loaded") + } + } +} + +func (s *Server) routes() { + s.mux.HandleFunc("/", s.handleSwaggerUI) + s.mux.HandleFunc("/swagger.json", s.handleSwaggerJSON) + s.mux.HandleFunc("/spr/models", s.handleModels) + s.mux.HandleFunc("/spr/hostname", s.handleHostname) + s.mux.HandleFunc("/spr/queue", s.handleQueue) + s.mux.HandleFunc("/spr/stt/", s.handleSTT) + s.mux.HandleFunc("/spr/result/", s.handleResult) + s.mux.HandleFunc("/spr/queue/", s.handleQueueItem) + s.mux.HandleFunc("/spr/audio/", s.handleAudio) + s.mux.HandleFunc("/spr/waveform/", s.handleWaveform) + s.mux.HandleFunc("/spr/delete/", s.handleDeleteModel) + s.mux.HandleFunc("/spr/export/", s.handleExportModel) + s.mux.HandleFunc("/spr/import/", s.handleImportModel) + s.mux.HandleFunc("/v1/audio/transcriptions", s.handleOpenAITranscriptions) + s.mux.HandleFunc("/v1/audio/transcriptions/", s.handleOpenAITranscriptions) + s.mux.HandleFunc("/v1/models", s.handleOpenAIModels) +} + +func (s *Server) ListenAndServe() error { + log.Info(). + Str("addr", s.cfg.Addr). + Str("models", s.cfg.ModelsDir). + Str("cache", s.cache.Root()). + Msg("starting API server") + return http.ListenAndServe(s.cfg.Addr, s.mux) +} + +func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + models, err := s.models.List() + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"models": models}) +} + +func (s *Server) handleHostname(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + host, _ := os.Hostname() + cwd, _ := os.Getwd() + writeJSON(w, http.StatusOK, map[string]any{ + "error": 0, + "message": "Success", + "hostname": host, + "version": "go-whisper-api", + "cwd": cwd, + "models": s.cfg.ModelsDir, + "cache": s.cache.Root(), + }) +} + +func (s *Server) handleQueue(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + list, err := s.cache.List() + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, list) +} + +func (s *Server) handleQueueItem(w http.ResponseWriter, r *http.Request) { + id := strings.TrimPrefix(r.URL.Path, "/spr/queue/") + if id == "" || !isValidTaskID(id) { + writeError(w, http.StatusBadRequest, "task id required") + return + } + switch r.Method { + case http.MethodGet: + s.handleQueueGet(w, id) + case http.MethodDelete: + if !s.cache.Delete(id) { + writeAPIError(w, http.StatusNotFound, "TaskNotFound") + return + } + writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"}) + default: + methodNotAllowed(w) + } +} + +func (s *Server) handleQueueGet(w http.ResponseWriter, id string) { + params, phase, err := s.cache.LoadParams(id) + if err != nil { + writeAPIError(w, http.StatusNotFound, "task not found") + return + } + switch params.Status { + case string(statusReady): + if phase == cacheWaiting { + if err := s.cache.PromoteToReady(id); err != nil { + writeAPIError(w, http.StatusInternalServerError, err.Error()) + return + } + } + writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"}) + case string(statusError): + msg := params.Error + if msg == "" { + msg = "transcription failed" + } + writeAPIError(w, http.StatusNotFound, msg) + default: + writeJSON(w, http.StatusOK, map[string]any{ + "error": 0, + "message": params.Status, + "status": params.Status, + }) + } +} + +func (s *Server) handleSTT(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + methodNotAllowed(w) + return + } + modelID := strings.TrimPrefix(r.URL.Path, "/spr/stt/") + if modelID == "" { + writeError(w, http.StatusBadRequest, "model id required") + return + } + modelPath, err := s.models.Path(modelID) + if err != nil { + writeAPIError(w, http.StatusNotFound, err.Error()) + return + } + audioPath, cleanup, err := s.saveUploadedWav(r) + if err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + stt, err := s.parseSTTOptions(r) + if err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + if queryAsync(r, s.cfg.DefaultAsync) { + taskID, err := s.enqueueAsync(r, modelID, audioPath, stt) + cleanup() + if err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]string{"taskID": taskID}) + return + } + defer cleanup() + result, err := s.transcribe(r.Context(), modelPath, audioPath, stt) + if err != nil { + writeAPIError(w, http.StatusMethodNotAllowed, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "model": modelID, + "text": result.Text, + "words": result.Words, + }) +} + +func (s *Server) handleResult(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/result/") + if !isValidTaskID(id) { + writeAPIError(w, http.StatusBadRequest, "task id required") + return + } + params, _, err := s.cache.LoadParams(id) + if err != nil { + writeAPIError(w, http.StatusNotFound, "TaskNotFound") + return + } + switch params.Status { + case string(statusWaiting), string(statusProcessing): + writeJSON(w, http.StatusOK, map[string]string{"status": params.Status}) + case string(statusError): + msg := params.Error + if msg == "" { + msg = "TaskNotFound" + } + writeAPIError(w, http.StatusNotFound, msg) + case string(statusReady): + writeJSON(w, http.StatusOK, sprResultReady(params)) + default: + writeJSON(w, http.StatusOK, map[string]string{"status": params.Status}) + } +} + +func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/audio/") + path, ok := s.cache.AudioPath(id) + if !ok { + writeError(w, http.StatusNotFound, "task not found") + return + } + http.ServeFile(w, r, path) +} + +func (s *Server) handleWaveform(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/waveform/") + if !isValidTaskID(id) { + writeAPIError(w, http.StatusBadRequest, "task id required") + return + } + wf, err := s.cache.Waveform(id) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"error": 0, "waveform": wf}) +} + +func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/delete/") + if err := s.models.Delete(id); err != nil { + writeAPIError(w, http.StatusNotFound, err.Error()) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *Server) handleExportModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/export/") + f, err := s.models.Open(id) + if err != nil { + writeAPIError(w, http.StatusNotFound, err.Error()) + return + } + defer f.Close() + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q.bin", id)) + w.Header().Set("Content-Type", "application/octet-stream") + io.Copy(w, f) +} + +func (s *Server) handleImportModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/import/") + if err := r.ParseMultipartForm(32 << 20); err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + file, header, err := r.FormFile("zip-model") + if err != nil { + file, header, err = r.FormFile("model") + } + if err != nil { + writeAPIError(w, http.StatusBadRequest, "model file required") + return + } + defer file.Close() + src := io.Reader(file) + if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { + writeAPIError(w, http.StatusBadRequest, "zip import is not supported; upload .bin model file as zip-model field") + return + } + if err := s.models.Import(id, src); err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *Server) enqueueAsync(r *http.Request, modelID, audioWavPath string, stt sttOptions) (string, error) { + id := uuid.New().String() + params := TaskParams{ + ID: id, + Created: time.Now().Format("2006-01-02 15:04:05"), + Status: string(statusWaiting), + Model: modelID, + Language: stt.language, + Punctuation: stt.punctuate, + Speakers: stt.speakers, + NumClusters: stt.numClusters, + } + if err := s.cache.Enqueue(id, params, audioWavPath); err != nil { + return "", err + } + s.notifyQueue() + log.Info(). + Str("task", id). + Str("model", modelID). + Str("cache", s.cache.Root()). + Msg("enqueued async task") + return id, nil +} + +func (s *Server) transcribe(ctx context.Context, modelPath, audioPath string, stt sttOptions) (whisper.TranscriptResult, error) { + turns, err := s.runDiarization(ctx, audioPath, stt) + if err != nil { + return whisper.TranscriptResult{}, err + } + vad := s.cfg.VAD + if vad.Enabled { + vad.Model = vad.ResolveModelPath(s.cfg.ModelsDir) + } + cfg := &config.Whisper{ + Model: modelPath, + AudioPath: audioPath, + Threads: s.cfg.Threads, + Language: stt.language, + Debug: s.cfg.Debug, + SpeedUp: s.cfg.SpeedUp, + Translate: s.cfg.Translate, + Prompt: s.cfg.Prompt, + MaxContext: s.cfg.MaxContext, + BeamSize: s.cfg.BeamSize, + EntropyThold: s.cfg.EntropyThold, + VAD: vad, + PrintProgress: false, + PrintSegment: false, + } + runOpts := s.whisperRunOpts(stt, turns) + if stt.punctuate && s.punct.Active() { + runOpts.PunctuateRestore = func(text string) (string, error) { + return punctuation.Apply(ctx, s.punct, true, text, stt.language) + } + } + result, err := whisper.TranscribeWithPool(s.modelPool, cfg, runOpts) + if err != nil { + return whisper.TranscriptResult{}, err + } + return applyGarbage(result, s.cfg.GarbagePatterns()), nil +} + +func (s *Server) runDiarization(ctx context.Context, audioPath string, stt sttOptions) ([]whisper.Turn, error) { + if !stt.speakers { + return nil, nil + } + samples, err := whisper.LoadPCM16Mono(audioPath) + if err != nil { + return nil, fmt.Errorf("diarization audio: %w", err) + } + return s.diarizer.Process(ctx, samples, stt.numClusters) +} + +func saveUploadedRaw(r *http.Request) (path string, cleanup func(), err error) { + return saveUploadedRawFields(r, []string{"audio", "wav", "file"}) +} + +func saveUploadedRawFields(r *http.Request, fieldNames []string) (path string, cleanup func(), err error) { + if r.MultipartForm == nil { + if err := r.ParseMultipartForm(128 << 20); err != nil { + return "", nil, err + } + } + var ( + file multipart.File + header *multipart.FileHeader + found bool + ) + for _, name := range fieldNames { + file, header, err = r.FormFile(name) + if err == nil { + found = true + break + } + } + if !found { + return "", nil, fmt.Errorf("audio file required (form field: %s)", strings.Join(fieldNames, ", ")) + } + defer file.Close() + dir, err := config.MkdirTemp("go-whisper-api-upload-*") + if err != nil { + return "", nil, err + } + cleanup = func() { os.RemoveAll(dir) } + base := "input" + if header != nil { + if ext := filepath.Ext(header.Filename); ext != "" { + base += ext + } + } + raw := filepath.Join(dir, base) + out, err := os.Create(raw) + if err != nil { + cleanup() + return "", nil, err + } + if _, err := io.Copy(out, file); err != nil { + out.Close() + cleanup() + return "", nil, err + } + out.Close() + return raw, cleanup, nil +} + +func (s *Server) saveUploadedWav(r *http.Request) (path string, cleanup func(), err error) { + raw, cleanup, err := saveUploadedRaw(r) + if err != nil { + return "", nil, err + } + dst := filepath.Join(filepath.Dir(raw), "audio.wav") + if err := s.transcode.Transcode(r.Context(), raw, dst, transcode.WhisperOptions()); err != nil { + cleanup() + return "", nil, err + } + return dst, cleanup, nil +} + +func queryBoolDefault(r *http.Request, key string, def bool) bool { + v := r.URL.Query().Get(key) + if v == "" { + return def + } + b, err := strconv.ParseBool(v) + if err != nil { + return def + } + return b +} + +func queryAsync(r *http.Request, defaultAsync bool) bool { + v := r.URL.Query().Get("async") + if v == "" { + return defaultAsync + } + n, err := strconv.Atoi(v) + if err != nil { + return defaultAsync + } + return n == 1 +} + +func queryInt(r *http.Request, key string, def int) int { + v := r.URL.Query().Get(key) + if v == "" { + return def + } + n, err := strconv.Atoi(v) + if err != nil { + return def + } + return n +} + +func writeJSON(w http.ResponseWriter, code int, v any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, code int, msg string) { + http.Error(w, msg, code) +} + +func writeAPIError(w http.ResponseWriter, code int, msg string) { + writeJSON(w, code, map[string]any{"error": 1, "message": msg}) +} + +func methodNotAllowed(w http.ResponseWriter) { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) +} + +func (s *Server) notifyQueue() { + select { + case s.queueWake <- struct{}{}: + default: + } +} + +func Run(ctx context.Context, cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) error { + srv, err := NewServer(cfg, tc, pc, dc) + if err != nil { + return err + } + defer srv.modelPool.Close() + srv.StartWorker(ctx) + hs := &http.Server{Addr: cfg.Addr, Handler: srv.mux} + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = hs.Shutdown(shutdownCtx) + }() + if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + punctuation.Close(srv.punct) + srv.diarizer.Close() + return nil +} diff --git a/api/swagger-ui.html b/api/swagger-ui.html new file mode 100644 index 0000000..1bb5e40 --- /dev/null +++ b/api/swagger-ui.html @@ -0,0 +1,23 @@ + + + + + go-whisper-api + + + +
+ + + + diff --git a/api/swagger.go b/api/swagger.go new file mode 100644 index 0000000..fafeae1 --- /dev/null +++ b/api/swagger.go @@ -0,0 +1,34 @@ +package api + +import ( + _ "embed" + "net/http" +) + +//go:embed swagger.json +var swaggerSpec []byte + +//go:embed swagger-ui.html +var swaggerUI []byte + +func (s *Server) handleSwaggerJSON(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(swaggerSpec) +} + +func (s *Server) handleSwaggerUI(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write(swaggerUI) +} diff --git a/api/swagger.json b/api/swagger.json new file mode 100644 index 0000000..188ddf7 --- /dev/null +++ b/api/swagger.json @@ -0,0 +1,457 @@ +{ + "swagger": "2.0", + "basePath": "/", + "paths": { + "/spr/audio/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_audio_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/delete/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "delete": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "delete_model_delete", + "tags": [ + "spr" + ] + } + }, + "/spr/export/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_model_export", + "tags": [ + "spr" + ] + } + }, + "/spr/hostname": { + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_hostname_class", + "tags": [ + "spr" + ] + } + }, + "/spr/import/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "post": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "post_model_import", + "parameters": [ + { + "name": "zip-model", + "in": "formData", + "type": "file", + "required": true, + "description": "prepared model zip file" + } + ], + "consumes": [ + "multipart/form-data" + ], + "tags": [ + "spr" + ] + } + }, + "/spr/models": { + "get": { + "responses": { + "200": { + "description": "Success", + "schema": { + "$ref": "#/definitions/modelList" + } + } + }, + "operationId": "get_model_list", + "tags": [ + "spr" + ] + } + }, + "/spr/queue": { + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_queue_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/queue/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "delete": { + "responses": { + "200": { + "description": "Success", + "schema": { + "type": "object", + "properties": { + "error": { "type": "integer" }, + "message": { "type": "string" } + } + } + }, + "404": { + "description": "TaskNotFound" + } + }, + "operationId": "delete_queue_del_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/result/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "404": { + "description": "Not found", + "schema": { + "$ref": "#/definitions/error" + } + }, + "200": { + "description": "Success", + "schema": { + "$ref": "#/definitions/resultTTS" + } + } + }, + "operationId": "get_result_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/stt/{id}": { + "post": { + "responses": { + "405": { + "description": "Error", + "schema": { + "$ref": "#/definitions/error" + } + }, + "404": { + "description": "Not found", + "schema": { + "$ref": "#/definitions/error" + } + }, + "200": { + "description": "Success", + "schema": { + "$ref": "#/definitions/modelTTS" + } + } + }, + "operationId": "post_model_test", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string", + "description": "NN Model ID" + }, + { + "name": "wav", + "in": "formData", + "type": "file", + "description": "file to recognize" + }, + { + "name": "async", + "in": "query", + "type": "integer", + "description": "async mode (default 1: enqueue to cache/waiting; 0: sync text response)", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "speakers", + "in": "query", + "type": "integer", + "description": "find speakers", + "default": 0, + "enum": [ + 0, + 1 + ] + }, + { + "name": "speaker_counter", + "in": "query", + "type": "integer", + "description": "number of speakers. 0 for autodetect. -1 disable speaker detection.", + "default": 0 + }, + { + "name": "normalization", + "in": "query", + "type": "integer", + "description": "normalize text", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "punctuation", + "in": "query", + "type": "integer", + "description": "punctuate text", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "toxicity", + "in": "query", + "type": "integer", + "description": "toxicity analyzer", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "emotion", + "in": "query", + "type": "integer", + "description": "emotion analyzer", + "default": 0, + "enum": [ + 0, + 1 + ] + }, + { + "name": "voice_analyzer", + "in": "query", + "type": "integer", + "description": "voice analyzer", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "vad", + "in": "query", + "type": "string", + "description": "VAD type", + "default": "webrtc", + "enum": [ + "webrtc" + ] + }, + { + "name": "classifiers", + "in": "query", + "type": "string", + "description": "JSON with classification models to process each sentence" + }, + { + "name": "webhook", + "in": "query", + "type": "string", + "description": "webhook url to send stt async result" + } + ], + "consumes": [ + "multipart/form-data" + ], + "tags": [ + "spr" + ] + } + }, + "/spr/waveform/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_audioarray_stt", + "tags": [ + "spr" + ] + } + } + }, + "info": { + "title": "go-whisper-api", + "version": "5.008 release", + "description": "Whisper speech-to-text API (SPR-compatible)" + }, + "produces": [ + "application/json" + ], + "consumes": [ + "application/json" + ], + "tags": [ + { + "name": "spr", + "description": "NN Model operations" + } + ], + "definitions": { + "modelList": { + "properties": { + "models": { + "type": "array", + "items": { + "type": "string", + "description": "NN Model ID" + } + } + }, + "type": "object" + }, + "error": { + "properties": { + "error": { + "type": "integer", + "description": "Error flag" + }, + "message": { + "type": "string", + "description": "Error description" + } + }, + "type": "object" + }, + "modelTTS": { + "required": [ + "text" + ], + "properties": { + "text": { + "type": "string", + "description": "Recognized text" + } + }, + "type": "object" + }, + "resultTTS": { + "properties": { + "model": { "type": "string" }, + "text": { "type": "string" }, + "words": { "type": "array" }, + "toxicity": { "type": "object" }, + "emotion": { "type": "object" }, + "voice_analysis": { "type": "object" }, + "status": { "type": "string" }, + "taskID": { "type": "string" }, + "created": { "type": "string" }, + "processed": { "type": "string" } + }, + "type": "object" + } + }, + "responses": { + "ParseError": { + "description": "When a mask can't be parsed" + }, + "MaskError": { + "description": "When any error occurs on mask" + } + } +} \ No newline at end of file diff --git a/api/tasks.go b/api/tasks.go new file mode 100644 index 0000000..7f2a8f8 --- /dev/null +++ b/api/tasks.go @@ -0,0 +1,10 @@ +package api + +type taskStatus string + +const ( + statusWaiting taskStatus = "waiting" + statusProcessing taskStatus = "processing" + statusReady taskStatus = "ready" + statusError taskStatus = "error" +) diff --git a/api/transcribe_opts.go b/api/transcribe_opts.go new file mode 100644 index 0000000..cc923c9 --- /dev/null +++ b/api/transcribe_opts.go @@ -0,0 +1,77 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + + "go-whisper-api/config" + "go-whisper-api/whisper" +) + +type sttOptions struct { + language string + punctuate bool + speakers bool + numClusters int +} + +func (s *Server) parseSTTOptions(r *http.Request) (sttOptions, error) { + opts := sttOptions{ + language: resolveLanguage(r, s.cfg.Language), + punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation), + } + sp, clusters, err := querySpeakers(r, s.cfg.DefaultSpeakers, s.diarCfg, s.diarizer.Active()) + if err != nil { + return opts, err + } + opts.speakers = sp + opts.numClusters = clusters + return opts, nil +} + +func resolveLanguage(r *http.Request, defaultLang string) string { + if v := strings.TrimSpace(r.URL.Query().Get("language")); v != "" { + return v + } + return strings.TrimSpace(defaultLang) +} + +func querySpeakers(r *http.Request, defaultOn bool, dc config.Diarization, diarizerActive bool) (enabled bool, numClusters int, err error) { + counter := queryInt(r, "speaker_counter", -999) + if counter == -1 { + return false, 0, nil + } + speakersQ := r.URL.Query().Get("speakers") + enabled = defaultOn + if speakersQ != "" { + enabled = queryInt(r, "speakers", 0) == 1 + } + if !enabled { + return false, 0, nil + } + if !dc.Active() { + return false, 0, fmt.Errorf("speaker diarization is disabled in config (diarization.enabled: true)") + } + if !diarizerActive { + return false, 0, fmt.Errorf("speaker diarization requires server built with -tags sherpa (make build-sherpa) and models (make download-diarization-models)") + } + if counter > 0 { + numClusters = counter + } else if dc.NumClusters > 0 { + numClusters = dc.NumClusters + } + return true, numClusters, nil +} + +func (s *Server) whisperRunOpts(stt sttOptions, turns []whisper.Turn) whisper.RunOptions { + t := s.cfg.Transcript.WithDefaults() + return whisper.RunOptions{ + Turns: turns, + Format: whisper.FormatOptions{ + PauseGap: t.PauseGapDuration(), + SpeakerLabel: t.SpeakerLabel, + UseSpeakers: stt.speakers && len(turns) > 0, + }, + } +} diff --git a/api/waveform.go b/api/waveform.go new file mode 100644 index 0000000..1b6f07c --- /dev/null +++ b/api/waveform.go @@ -0,0 +1,48 @@ +package api + +import ( + "math" + "os" + + "github.com/go-audio/wav" +) + +func waveformFromWav(path string, buckets int) ([]float64, error) { + if buckets <= 0 { + buckets = 256 + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + dec := wav.NewDecoder(f) + buf, err := dec.FullPCMBuffer() + if err != nil { + return nil, err + } + samples := buf.AsFloat32Buffer().Data + if len(samples) == 0 { + return make([]float64, buckets), nil + } + chunk := len(samples) / buckets + if chunk < 1 { + chunk = 1 + } + out := make([]float64, 0, buckets) + for i := 0; i < len(samples) && len(out) < buckets; i += chunk { + end := i + chunk + if end > len(samples) { + end = len(samples) + } + peak := 0.0 + for _, s := range samples[i:end] { + v := math.Abs(float64(s)) + if v > peak { + peak = v + } + } + out = append(out, math.Round(peak*1000)/1000) + } + return out, nil +} diff --git a/config.yaml.example b/config.yaml.example new file mode 100644 index 0000000..2a3fd24 --- /dev/null +++ b/config.yaml.example @@ -0,0 +1,55 @@ +api: + addr: "0.0.0.0:6183" + models_dir: "./models" + cache_dir: "./cache" + # Open WebUI / OpenAI STT: model id when client sends whisper-1 (e.g. ggml-large-v3-turbo) + default_model: ggml-large-v3-turbo + threads: 16 + language: ru + transcript: + pause_gap_sec: 1.5 + speaker_label: "Спикер" + default_speakers: false + debug: false + speedup: false + translate: false + prompt: "" + max_context: 32 + beam_size: 5 + entropy_thold: 2.4 + vad: + enabled: false + model: vad/vad.bin + threshold: 0.5 + min_speech_duration_ms: 250 + min_silence_duration_ms: 100 + speech_pad_ms: 30 + samples_overlap: 0.1 + default_punctuation: true + default_async: true + garbage: + - "*выбая*" + +# transcode: pure Go (wav, mp3, flac, ogg, m4a, mp4, aac) — no ffmpeg required +transcode: {} + +diarization: + enabled: false + model_dir: ./models/diarization + segmentation_model: pyannote-segmentation-3-0/model.onnx + embedding_model: 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + num_threads: 2 + num_clusters: 0 + clustering_threshold: 0.5 + +punctuation: + enabled: true + default_on: true + # off | heuristic | xlm | sherpa | sherpa-online | http + engine: xlm + model_dir: ./models/punctuation/xlm-roberta + model_file: model.onnx + sp_model: sp.model + config_file: config.yaml + apply_sbd: true + num_threads: 2 diff --git a/config/api.go b/config/api.go new file mode 100644 index 0000000..8bb25b0 --- /dev/null +++ b/config/api.go @@ -0,0 +1,42 @@ +package config + +import "strings" + +type API struct { + Addr string `yaml:"addr"` + ModelsDir string `yaml:"models_dir"` + CacheDir string `yaml:"cache_dir"` + // DefaultModel: whisper model id for OpenAI /v1/audio/transcriptions (maps whisper-1). + DefaultModel string `yaml:"default_model"` + Language string `yaml:"language"` + Transcript Transcript `yaml:"transcript"` + DefaultSpeakers bool `yaml:"default_speakers"` + Prompt string `yaml:"prompt"` + Threads uint `yaml:"threads"` + MaxContext uint `yaml:"max_context"` + BeamSize uint `yaml:"beam_size"` + EntropyThold float64 `yaml:"entropy_thold"` + VAD VAD `yaml:"vad"` + Debug bool `yaml:"debug"` + SpeedUp bool `yaml:"speedup"` + Translate bool `yaml:"translate"` + DefaultPunctuation bool `yaml:"default_punctuation"` + // DefaultAsync: STT via API enqueues to cache/waiting and returns taskID (use ?async=0 for sync). + DefaultAsync bool `yaml:"default_async"` + // Garbage: artifact substrings removed from transcript text and words (default includes *выбая*). + Garbage []string `yaml:"garbage"` +} + +func (a API) WithDefaults() API { + a.VAD = a.VAD.WithDefaults() + a.Transcript = a.Transcript.WithDefaults() + if strings.TrimSpace(a.Language) == "" { + a.Language = "ru" + } + return a +} + +// GarbagePatterns returns garbage filter list (never empty unless explicitly set to [] in YAML). +func (a API) GarbagePatterns() []string { + return a.garbagePatterns() +} diff --git a/config/diarization.go b/config/diarization.go new file mode 100644 index 0000000..ae4858c --- /dev/null +++ b/config/diarization.go @@ -0,0 +1,72 @@ +package config + +import ( + "os" + "path/filepath" +) + +type Diarization struct { + Enabled bool `yaml:"enabled"` + ModelDir string `yaml:"model_dir"` + SegmentationModel string `yaml:"segmentation_model"` + EmbeddingModel string `yaml:"embedding_model"` + NumThreads int `yaml:"num_threads"` + NumClusters int `yaml:"num_clusters"` + ClusteringThreshold float32 `yaml:"clustering_threshold"` + MinDurationOn float32 `yaml:"min_duration_on"` + MinDurationOff float32 `yaml:"min_duration_off"` +} + +func (d Diarization) WithDefaults() Diarization { + if d.ModelDir == "" { + d.ModelDir = "./models/diarization" + } + if d.SegmentationModel == "" { + d.SegmentationModel = "pyannote-segmentation-3-0/model.onnx" + } + if d.EmbeddingModel == "" { + d.EmbeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + } + if d.NumThreads <= 0 { + d.NumThreads = 2 + } + if d.ClusteringThreshold <= 0 { + d.ClusteringThreshold = 0.5 + } + if d.MinDurationOn <= 0 { + d.MinDurationOn = 0.3 + } + if d.MinDurationOff <= 0 { + d.MinDurationOff = 0.5 + } + return d +} + +func (d Diarization) SegmentationPath() string { + return resolveModelPath(d.ModelDir, d.SegmentationModel) +} + +func (d Diarization) EmbeddingPath() string { + return resolveModelPath(d.ModelDir, d.EmbeddingModel) +} + +func (d Diarization) ModelsPresent() bool { + if _, err := os.Stat(d.SegmentationPath()); err != nil { + return false + } + if _, err := os.Stat(d.EmbeddingPath()); err != nil { + return false + } + return true +} + +func resolveModelPath(dir, name string) string { + if filepath.IsAbs(name) { + return name + } + return filepath.Join(dir, name) +} + +func (d Diarization) Active() bool { + return d.Enabled +} diff --git a/config/file.go b/config/file.go new file mode 100644 index 0000000..2cc513f --- /dev/null +++ b/config/file.go @@ -0,0 +1,53 @@ +package config + +import ( + "fmt" + "os" + "runtime" + + "gopkg.in/yaml.v3" +) + +type File struct { + API API `yaml:"api"` + Transcode Transcode `yaml:"transcode"` + Punctuation Punctuation `yaml:"punctuation"` + Diarization Diarization `yaml:"diarization"` +} + +func DefaultFile() File { + return File{ + API: API{ + Addr: ":8080", + ModelsDir: "./models", + CacheDir: "./cache", + Threads: uint(runtime.NumCPU()), + Language: "ru", + Transcript: Transcript{}.WithDefaults(), + MaxContext: 32, + BeamSize: 5, + EntropyThold: 2.4, + DefaultAsync: true, + Garbage: DefaultGarbage(), + }, + Diarization: Diarization{}.WithDefaults(), + Transcode: Transcode{}.WithDefaults(), + Punctuation: Punctuation{Enabled: false, Engine: "off"}.WithDefaults(), + } +} + +func LoadFile(path string) (File, error) { + data, err := os.ReadFile(path) + if err != nil { + return File{}, err + } + cfg := DefaultFile() + if err := yaml.Unmarshal(data, &cfg); err != nil { + return File{}, fmt.Errorf("parse config %s: %w", path, err) + } + return cfg, nil +} + +func (f File) APIConfig() API { + return f.API +} diff --git a/config/file_test.go b/config/file_test.go new file mode 100644 index 0000000..63b9f6f --- /dev/null +++ b/config/file_test.go @@ -0,0 +1,35 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + err := os.WriteFile(path, []byte(` +api: + addr: ":9090" + models_dir: "/data/models" + language: ru +`), 0o644) + if err != nil { + t.Fatal(err) + } + + cfg, err := LoadFile(path) + if err != nil { + t.Fatal(err) + } + if cfg.API.Addr != ":9090" { + t.Fatalf("addr: got %q", cfg.API.Addr) + } + if cfg.API.ModelsDir != "/data/models" { + t.Fatalf("models_dir: got %q", cfg.API.ModelsDir) + } + if cfg.API.Language != "ru" { + t.Fatalf("language: got %q", cfg.API.Language) + } +} diff --git a/config/garbage.go b/config/garbage.go new file mode 100644 index 0000000..4cfa7fe --- /dev/null +++ b/config/garbage.go @@ -0,0 +1,13 @@ +package config + +// DefaultGarbage returns STT artifact tokens filtered from API/CLI transcript output. +func DefaultGarbage() []string { + return []string{"*выбая*"} +} + +func (a API) garbagePatterns() []string { + if a.Garbage == nil { + return DefaultGarbage() + } + return a.Garbage +} diff --git a/config/garbage_test.go b/config/garbage_test.go new file mode 100644 index 0000000..2efa8d3 --- /dev/null +++ b/config/garbage_test.go @@ -0,0 +1,23 @@ +package config + +import "testing" + +func TestDefaultGarbage(t *testing.T) { + if len(DefaultGarbage()) == 0 || DefaultGarbage()[0] != "*выбая*" { + t.Fatalf("got %v", DefaultGarbage()) + } +} + +func TestAPI_GarbagePatterns_default(t *testing.T) { + p := (API{}).GarbagePatterns() + if len(p) != 1 || p[0] != "*выбая*" { + t.Fatalf("got %v", p) + } +} + +func TestAPI_GarbagePatterns_explicitEmpty(t *testing.T) { + p := (API{Garbage: []string{}}).GarbagePatterns() + if len(p) != 0 { + t.Fatalf("got %v", p) + } +} diff --git a/config/merge.go b/config/merge.go new file mode 100644 index 0000000..23407be --- /dev/null +++ b/config/merge.go @@ -0,0 +1,136 @@ +package config + +import ( + "os" + + "github.com/urfave/cli/v2" +) + +func LoadResolved(path string) (File, error) { + if path == "" { + if _, err := os.Stat("config.yaml"); err == nil { + path = "config.yaml" + } else { + return DefaultFile(), nil + } + } + return LoadFile(path) +} + +func mergeVAD(c *cli.Context, v VAD) VAD { + if c.IsSet("vad") { + v.Enabled = c.Bool("vad") + } + if c.IsSet("vad-model") { + v.Model = c.String("vad-model") + } + if c.IsSet("vad-threshold") { + v.Threshold = c.Float64("vad-threshold") + } + if c.IsSet("vad-min-speech-ms") { + v.MinSpeechMs = c.Int("vad-min-speech-ms") + } + if c.IsSet("vad-min-silence-ms") { + v.MinSilenceMs = c.Int("vad-min-silence-ms") + } + if c.IsSet("vad-max-speech-sec") { + v.MaxSpeechSec = c.Float64("vad-max-speech-sec") + } + if c.IsSet("vad-speech-pad-ms") { + v.SpeechPadMs = c.Int("vad-speech-pad-ms") + } + if c.IsSet("vad-samples-overlap") { + v.SamplesOverlap = c.Float64("vad-samples-overlap") + } + return v.WithDefaults() +} + +func mergeAPI(c *cli.Context, a API) API { + if c.IsSet("addr") { + a.Addr = c.String("addr") + } + if c.IsSet("models-dir") { + a.ModelsDir = c.String("models-dir") + } + if c.IsSet("cache-dir") { + a.CacheDir = c.String("cache-dir") + } + if c.IsSet("threads") { + a.Threads = c.Uint("threads") + } + if c.IsSet("language") { + a.Language = c.String("language") + } + if c.IsSet("debug") { + a.Debug = c.Bool("debug") + } + if c.IsSet("speedup") { + a.SpeedUp = c.Bool("speedup") + } + if c.IsSet("translate") { + a.Translate = c.Bool("translate") + } + if c.IsSet("prompt") { + a.Prompt = c.String("prompt") + } + if c.IsSet("max-context") { + a.MaxContext = c.Uint("max-context") + } + if c.IsSet("beam-size") { + a.BeamSize = c.Uint("beam-size") + } + if c.IsSet("entropy-thold") { + a.EntropyThold = c.Float64("entropy-thold") + } + a.VAD = mergeVAD(c, a.VAD) + if c.IsSet("default-punctuation") { + a.DefaultPunctuation = c.Bool("default-punctuation") + } + return a +} + +func APIFromCLI(c *cli.Context) (API, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return API{}, err + } + return mergeAPI(c, file.API), nil +} + +func TranscodeFromCLI(c *cli.Context) (Transcode, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return Transcode{}, err + } + return file.Transcode.WithDefaults(), nil +} + +func mergePunctuation(c *cli.Context, p Punctuation) Punctuation { + if c.IsSet("punctuation-enabled") { + p.Enabled = c.Bool("punctuation-enabled") + } + if c.IsSet("punctuation-engine") { + p.Engine = c.String("punctuation-engine") + } + if c.IsSet("punctuation-default-on") { + p.DefaultOn = c.Bool("punctuation-default-on") + } + return p.WithDefaults() +} + +func PunctuationFromCLI(c *cli.Context) (Punctuation, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return Punctuation{}, err + } + return mergePunctuation(c, file.Punctuation), nil +} + +func DiarizationFromCLI(c *cli.Context) (Diarization, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return Diarization{}, err + } + return file.Diarization.WithDefaults(), nil +} + diff --git a/config/merge_test.go b/config/merge_test.go new file mode 100644 index 0000000..c78d6ea --- /dev/null +++ b/config/merge_test.go @@ -0,0 +1,28 @@ +package config + +import ( + "os" + "strings" + "testing" +) + +func TestDefaultFile_apiDefaults(t *testing.T) { + f := DefaultFile() + if f.API.Addr == "" { + t.Fatal("expected API listen addr") + } + if f.API.ModelsDir == "" { + t.Fatal("expected models_dir") + } +} + +func TestMkdirTemp_usesTmpRoot(t *testing.T) { + dir, err := MkdirTemp("go-whisper-api-test-*") + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.RemoveAll(dir) }() + if !strings.HasPrefix(dir, TempRoot+"/") { + t.Fatalf("temp dir should be under %s, got %s", TempRoot, dir) + } +} diff --git a/config/punctuation.go b/config/punctuation.go new file mode 100644 index 0000000..d0683fb --- /dev/null +++ b/config/punctuation.go @@ -0,0 +1,125 @@ +package config + +import ( + "net/http" + "path/filepath" + "strconv" + "strings" + "time" +) + +func (p Punctuation) XLMJoinSentences() bool { + return p.ApplySBD +} + +type Punctuation struct { + Command []string `yaml:"command"` + NumThreads int `yaml:"num_threads"` + TimeoutSec int `yaml:"timeout_sec"` + Engine string `yaml:"engine"` + ModelDir string `yaml:"model_dir"` + ModelFile string `yaml:"model_file"` + SPModel string `yaml:"sp_model"` + ConfigFile string `yaml:"config_file"` + BpeVocab string `yaml:"bpe_vocab"` + HTTPURL string `yaml:"http_url"` + Enabled bool `yaml:"enabled"` + DefaultOn bool `yaml:"default_on"` + ApplySBD bool `yaml:"apply_sbd"` +} + +func (p Punctuation) WithDefaults() Punctuation { + if p.Engine == "" { + p.Engine = "heuristic" + } + engine := strings.ToLower(strings.TrimSpace(p.Engine)) + if p.ModelDir == "" { + if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { + p.ModelDir = "./models/punctuation/xlm-roberta" + } else { + p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8" + } + } + if p.ModelFile == "" { + if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { + p.ModelFile = "model.onnx" + } else { + p.ModelFile = "model.int8.onnx" + } + } + if p.NumThreads <= 0 { + p.NumThreads = 2 + } + if p.TimeoutSec <= 0 { + p.TimeoutSec = 120 + } + return p +} + +func (p Punctuation) Timeout() time.Duration { + p = p.WithDefaults() + return time.Duration(p.TimeoutSec) * time.Second +} + +func (p Punctuation) Active() bool { + p = p.WithDefaults() + return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off") +} + +func (p Punctuation) ModelPath() string { + p = p.WithDefaults() + return filepath.Join(p.ModelDir, p.ModelFile) +} + +func (p Punctuation) SPModelPath() string { + p = p.WithDefaults() + if p.SPModel == "" { + return filepath.Join(p.ModelDir, "sp.model") + } + return filepath.Join(p.ModelDir, p.SPModel) +} + +func (p Punctuation) XLMConfigPath() string { + p = p.WithDefaults() + if p.ConfigFile == "" { + return filepath.Join(p.ModelDir, "config.yaml") + } + return filepath.Join(p.ModelDir, p.ConfigFile) +} + +func (p Punctuation) BpeVocabPath() string { + p = p.WithDefaults() + if p.BpeVocab == "" { + return filepath.Join(p.ModelDir, "bpe.vocab") + } + return filepath.Join(p.ModelDir, p.BpeVocab) +} + +func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool { + if !p.Active() { + return false + } + q := strings.TrimSpace(r.URL.Query().Get("punctuation")) + if q != "" { + return parsePunctuationQuery(q, true) + } + return apiDefault || p.Enabled +} + +func parsePunctuationQuery(raw string, def bool) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return def + } + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + } + b, err := strconv.ParseBool(raw) + if err != nil { + return def + } + return b +} diff --git a/config/punctuation_test.go b/config/punctuation_test.go new file mode 100644 index 0000000..1072217 --- /dev/null +++ b/config/punctuation_test.go @@ -0,0 +1,58 @@ +package config + +import ( + "net/http/httptest" + "testing" + +) + +func TestPunctuation_Active(t *testing.T) { + if (Punctuation{Enabled: false, Engine: "heuristic"}).Active() { + t.Fatal("disabled should be inactive") + } + if (Punctuation{Enabled: true, Engine: "off"}).Active() { + t.Fatal("engine off should be inactive") + } + if !(Punctuation{Enabled: true, Engine: "heuristic"}).Active() { + t.Fatal("enabled heuristic should be active") + } +} + +func TestPunctuation_ShouldApplyAPI(t *testing.T) { + p := Punctuation{Enabled: true, Engine: "heuristic"} + + req := httptest.NewRequest("GET", "/?punctuation=1", nil) + if !p.ShouldApplyAPI(req, false) { + t.Fatal("query 1 should enable") + } + + req = httptest.NewRequest("GET", "/?punctuation=0", nil) + if p.ShouldApplyAPI(req, true) { + t.Fatal("query 0 should disable") + } + + req = httptest.NewRequest("GET", "/", nil) + if !p.ShouldApplyAPI(req, false) { + t.Fatal("enabled in config => apply when query omitted") + } + if !p.ShouldApplyAPI(req, true) { + t.Fatal("api default true") + } + + off := Punctuation{Enabled: false, Engine: "heuristic"} + if off.ShouldApplyAPI(req, true) { + t.Fatal("master disabled must never apply") + } +} + +func TestParsePunctuationQuery(t *testing.T) { + if !parsePunctuationQuery("1", false) { + t.Fatal("1 => true") + } + if parsePunctuationQuery("0", true) { + t.Fatal("0 => false") + } + if !parsePunctuationQuery("", true) { + t.Fatal("empty => default") + } +} diff --git a/config/tmp.go b/config/tmp.go new file mode 100644 index 0000000..305c5e1 --- /dev/null +++ b/config/tmp.go @@ -0,0 +1,9 @@ +package config + +import "os" + +const TempRoot = "/tmp" + +func MkdirTemp(prefix string) (string, error) { + return os.MkdirTemp(TempRoot, prefix) +} diff --git a/config/transcode.go b/config/transcode.go new file mode 100644 index 0000000..cd0a175 --- /dev/null +++ b/config/transcode.go @@ -0,0 +1,11 @@ +package config + +// Transcode holds audio normalization settings (pure Go decoders; no external ffmpeg). +type Transcode struct { + // FFmpegPath is deprecated and ignored; kept for backward-compatible YAML. + FFmpegPath string `yaml:"ffmpeg_path,omitempty"` +} + +func (t Transcode) WithDefaults() Transcode { + return t +} diff --git a/config/transcript.go b/config/transcript.go new file mode 100644 index 0000000..1f5a942 --- /dev/null +++ b/config/transcript.go @@ -0,0 +1,26 @@ +package config + +import ( + "strings" + "time" +) + +// Transcript controls how STT segments are joined into one text field (with embedded newlines). +type Transcript struct { + PauseGapSec float64 `yaml:"pause_gap_sec"` + SpeakerLabel string `yaml:"speaker_label"` +} + +func (t Transcript) WithDefaults() Transcript { + if t.PauseGapSec <= 0 { + t.PauseGapSec = 1.5 + } + if strings.TrimSpace(t.SpeakerLabel) == "" { + t.SpeakerLabel = "Спикер" + } + return t +} + +func (t Transcript) PauseGapDuration() time.Duration { + return time.Duration(t.PauseGapSec * float64(time.Second)) +} diff --git a/config/vad.go b/config/vad.go new file mode 100644 index 0000000..faf40d2 --- /dev/null +++ b/config/vad.go @@ -0,0 +1,88 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +type VAD struct { + Enabled bool `yaml:"enabled"` + Model string `yaml:"model"` + Threshold float64 `yaml:"threshold"` + MinSpeechMs int `yaml:"min_speech_duration_ms"` + MinSilenceMs int `yaml:"min_silence_duration_ms"` + MaxSpeechSec float64 `yaml:"max_speech_duration_s"` + SpeechPadMs int `yaml:"speech_pad_ms"` + SamplesOverlap float64 `yaml:"samples_overlap"` +} + +func DefaultVAD() VAD { + return VAD{ + Threshold: 0.5, + MinSpeechMs: 250, + MinSilenceMs: 100, + MaxSpeechSec: 0, + SpeechPadMs: 30, + SamplesOverlap: 0.1, + } +} + +func (v VAD) WithDefaults() VAD { + d := DefaultVAD() + if v.Threshold <= 0 { + v.Threshold = d.Threshold + } + if v.MinSpeechMs <= 0 { + v.MinSpeechMs = d.MinSpeechMs + } + if v.MinSilenceMs <= 0 { + v.MinSilenceMs = d.MinSilenceMs + } + if v.SpeechPadMs <= 0 { + v.SpeechPadMs = d.SpeechPadMs + } + if v.SamplesOverlap <= 0 { + v.SamplesOverlap = d.SamplesOverlap + } + return v +} + +func (v VAD) ResolveModelPath(modelsDir string) string { + if v.Model == "" { + return "" + } + if filepath.IsAbs(v.Model) { + return v.Model + } + if modelsDir == "" { + return v.Model + } + direct := filepath.Join(modelsDir, v.Model) + if _, err := os.Stat(direct); err == nil { + return direct + } + // VAD weights often live under models_dir/vad/ (not listed in /spr/models). + base := filepath.Base(v.Model) + inVAD := filepath.Join(modelsDir, "vad", base) + if _, err := os.Stat(inVAD); err == nil { + return inVAD + } + return direct +} + +func (v VAD) Validate() error { + if !v.Enabled { + return nil + } + if v.Model == "" { + return fmt.Errorf("vad.model is required when vad.enabled is true") + } + if _, err := os.Stat(v.Model); err != nil { + return fmt.Errorf("vad model %q: %w", v.Model, err) + } + if v.Threshold < 0 || v.Threshold > 1 { + return fmt.Errorf("vad.threshold must be between 0 and 1") + } + return nil +} diff --git a/config/vad_test.go b/config/vad_test.go new file mode 100644 index 0000000..d52179f --- /dev/null +++ b/config/vad_test.go @@ -0,0 +1,65 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestVAD_Validate_disabled(t *testing.T) { + if err := (VAD{}).Validate(); err != nil { + t.Fatal(err) + } +} + +func TestVAD_Validate_requiresModel(t *testing.T) { + err := (VAD{Enabled: true}).Validate() + if err == nil { + t.Fatal("expected error") + } +} + +func TestVAD_Validate_modelExists(t *testing.T) { + dir := t.TempDir() + model := filepath.Join(dir, "ggml-silero.bin") + if err := os.WriteFile(model, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + v := VAD{Enabled: true, Model: model, Threshold: 0.5} + if err := v.Validate(); err != nil { + t.Fatal(err) + } +} + +func TestVAD_ResolveModelPath(t *testing.T) { + v := VAD{Model: "ggml-silero-v6.2.0.bin"} + got := v.ResolveModelPath("/data/models") + want := filepath.Join("/data/models", "ggml-silero-v6.2.0.bin") + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestVAD_ResolveModelPath_vadSubdir(t *testing.T) { + dir := t.TempDir() + vadDir := filepath.Join(dir, "vad") + if err := os.MkdirAll(vadDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(vadDir, "vad.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + v := VAD{Model: "vad.bin"} + got := v.ResolveModelPath(dir) + want := filepath.Join(dir, "vad", "vad.bin") + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestVAD_WithDefaults(t *testing.T) { + v := VAD{Enabled: true}.WithDefaults() + if v.Threshold != 0.5 || v.MinSpeechMs != 250 || v.MinSilenceMs != 100 { + t.Fatalf("unexpected defaults: %+v", v) + } +} diff --git a/config/whisper.go b/config/whisper.go new file mode 100644 index 0000000..624f0ee --- /dev/null +++ b/config/whisper.go @@ -0,0 +1,43 @@ +package config + +import "fmt" + +type Whisper struct { + OutputFormat []string `yaml:"output_format"` + Model string `yaml:"model"` + AudioPath string `yaml:"audio_path"` + Language string `yaml:"language"` + Prompt string `yaml:"prompt"` + OutputFolder string `yaml:"output_folder"` + OutputFilename string `yaml:"output_filename"` + Threads uint `yaml:"threads"` + MaxContext uint `yaml:"max_context"` + BeamSize uint `yaml:"beam_size"` + EntropyThold float64 `yaml:"entropy_thold"` + VAD VAD `yaml:"vad"` + Debug bool `yaml:"debug"` + SpeedUp bool `yaml:"speedup"` + Translate bool `yaml:"translate"` + PrintProgress bool `yaml:"print_progress"` + PrintSegment bool `yaml:"print_segment"` +} + +func (c *Whisper) ValidateModel() error { + if c.Model == "" { + return fmt.Errorf("model is required") + } + return nil +} + +func (c *Whisper) Validate() error { + if err := c.ValidateModel(); err != nil { + return err + } + if c.AudioPath == "" { + return fmt.Errorf("audio path is required") + } + if err := c.VAD.WithDefaults().Validate(); err != nil { + return err + } + return nil +} diff --git a/config/xlm-roberta-model.yaml b/config/xlm-roberta-model.yaml new file mode 100644 index 0000000..f6852e6 --- /dev/null +++ b/config/xlm-roberta-model.yaml @@ -0,0 +1,81 @@ +# Metadata for Salama1429/xlm-roberta_punctuation_fullstop_truecase (ONNX punctuation). +# Install into the model directory: +# cp config/xlm-roberta-model.yaml models/punctuation/xlm-roberta/config.yaml +# or: make install-xlm-punctuation-config + +languages: + - af + - am + - ar + - bg + - bn + - de + - el + - en + - es + - et + - fa + - fi + - fr + - gu + - hi + - hr + - hu + - id + - is + - it + - ja + - kk + - kn + - ko + - ky + - lt + - lv + - mk + - ml + - mr + - nl + - or + - pa + - pl + - ps + - pt + - ro + - ru + - rw + - so + - sr + - sw + - ta + - te + - tr + - uk + - zh + +max_length: 256 + +pre_labels: + - "" + - "¿" + +post_labels: + - "" + - "" + - "." + - "," + - "?" + - "?" + - "," + - "。" + - "、" + - "・" + - "।" + - "؟" + - "،" + - ";" + - "።" + - "፣" + - "፧" + +null_token: "" +acronym_token: "" diff --git a/diarization/diarization.go b/diarization/diarization.go new file mode 100644 index 0000000..a884d4a --- /dev/null +++ b/diarization/diarization.go @@ -0,0 +1,19 @@ +package diarization + +import ( + "context" + + "go-whisper-api/config" + "go-whisper-api/whisper" +) + +// Engine runs offline speaker diarization when built with -tags sherpa. +type Engine interface { + Active() bool + Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) + Close() +} + +func New(cfg config.Diarization) (Engine, error) { + return newEngine(cfg) +} diff --git a/diarization/sherpa.go b/diarization/sherpa.go new file mode 100644 index 0000000..fd495c1 --- /dev/null +++ b/diarization/sherpa.go @@ -0,0 +1,107 @@ +//go:build sherpa + +package diarization + +import ( + "context" + "fmt" + "os" + "sync" + + "go-whisper-api/config" + "go-whisper-api/whisper" + + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" +) + +type sherpaEngine struct { + cfg config.Diarization + sd *sherpa.OfflineSpeakerDiarization + mu sync.Mutex +} + +func newEngine(cfg config.Diarization) (Engine, error) { + cfg = cfg.WithDefaults() + if !cfg.Active() { + return &noopEngine{}, nil + } + if !cfg.ModelsPresent() { + return nil, fmt.Errorf("diarization models missing (run: make download-diarization-models)") + } + conf := &sherpa.OfflineSpeakerDiarizationConfig{ + Segmentation: sherpa.OfflineSpeakerSegmentationModelConfig{ + Pyannote: sherpa.OfflineSpeakerSegmentationPyannoteModelConfig{ + Model: cfg.SegmentationPath(), + }, + NumThreads: cfg.NumThreads, + Debug: 0, + Provider: "cpu", + }, + Embedding: sherpa.SpeakerEmbeddingExtractorConfig{ + Model: cfg.EmbeddingPath(), + NumThreads: cfg.NumThreads, + Debug: 0, + Provider: "cpu", + }, + Clustering: sherpa.FastClusteringConfig{ + NumClusters: cfg.NumClusters, + Threshold: cfg.ClusteringThreshold, + }, + MinDurationOn: cfg.MinDurationOn, + MinDurationOff: cfg.MinDurationOff, + } + sd := sherpa.NewOfflineSpeakerDiarization(conf) + if sd == nil { + return nil, fmt.Errorf("failed to create sherpa speaker diarization") + } + return &sherpaEngine{cfg: cfg, sd: sd}, nil +} + +type noopEngine struct{} + +func (noopEngine) Active() bool { return false } +func (noopEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { return nil, nil } +func (noopEngine) Close() {} + +func (e *sherpaEngine) Active() bool { + return e.sd != nil +} + +func (e *sherpaEngine) Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) { + _ = ctx + if len(samples) == 0 { + return nil, fmt.Errorf("empty audio for diarization") + } + e.mu.Lock() + defer e.mu.Unlock() + if _, err := os.Stat(e.cfg.SegmentationPath()); err != nil { + return nil, err + } + clusters := numClusters + if clusters <= 0 { + clusters = e.cfg.NumClusters + } + e.sd.SetConfig(&sherpa.OfflineSpeakerDiarizationConfig{ + Clustering: sherpa.FastClusteringConfig{ + NumClusters: clusters, + Threshold: e.cfg.ClusteringThreshold, + }, + }) + segments := e.sd.Process(samples) + out := make([]whisper.Turn, len(segments)) + for i, s := range segments { + out[i] = whisper.Turn{ + Start: s.Start, + End: s.End, + Speaker: s.Speaker, + } + } + return out, nil +} + +func (e *sherpaEngine) Close() { + if e.sd != nil { + sherpa.DeleteOfflineSpeakerDiarization(e.sd) + e.sd = nil + } +} diff --git a/diarization/stub.go b/diarization/stub.go new file mode 100644 index 0000000..69383be --- /dev/null +++ b/diarization/stub.go @@ -0,0 +1,29 @@ +//go:build !sherpa + +package diarization + +import ( + "context" + "fmt" + + "go-whisper-api/config" + "go-whisper-api/whisper" +) + +type stubEngine struct { + cfg config.Diarization +} + +func newEngine(cfg config.Diarization) (Engine, error) { + return &stubEngine{cfg: cfg.WithDefaults()}, nil +} + +func (s *stubEngine) Active() bool { + return false +} + +func (s *stubEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { + return nil, fmt.Errorf("speaker diarization requires build with -tags sherpa and models (make build-sherpa, make download-diarization-models)") +} + +func (s *stubEngine) Close() {} diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..651f897 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,55 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG CUDA_VERSION=12.0.0 +# Target the CUDA build image +ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} +# Target the CUDA runtime image +ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} + +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS build +WORKDIR /app +# Unless otherwise specified, we make a fat build. +ARG CUDA_DOCKER_ARCH=all +# Set nvcc architecture +ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} +# Enable cuBLAS +ENV WHISPER_CUBLAS=1 + +#apt-get +RUN apt-get update && \ + apt-get install -y --no-install-recommends build-essential git gcc g++ wget \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +# install golang +RUN wget --progress=dot:giga https://go.dev/dl/go1.22.10.linux-amd64.tar.gz +RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz +ENV PATH ${PATH}:/usr/local/go/bin + +# Ref: https://stackoverflow.com/a/53464012 +ENV CUDA_MAIN_VERSION=12.0 +ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH + +COPY ./ . +RUN make dependency && env && make build && \ + mv bin/go-whisper-api /bin/ && \ + rm -rf bin + +FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} AS runtime +WORKDIR /app + +LABEL maintainer="Bo-Yi Wu " \ + org.label-schema.name="Speech-to-Text" \ + org.label-schema.vendor="Bo-Yi Wu" \ + org.label-schema.schema-version="1.0" + +LABEL org.opencontainers.image.source=https://github.com/appleboy/go-whisper-api +LABEL org.opencontainers.image.description="Speech-to-Text." +LABEL org.opencontainers.image.licenses=MIT + +RUN apt-get update && \ + apt-get install -y --no-install-recommends curl \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY --from=build /bin/go-whisper-api /bin/go-whisper-api +EXPOSE 8080 +ENTRYPOINT ["/bin/go-whisper-api"] \ No newline at end of file diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci new file mode 100644 index 0000000..1e1ff27 --- /dev/null +++ b/docker/Dockerfile.ci @@ -0,0 +1,27 @@ +# CPU image for Gitea CI and hosts without NVIDIA GPU. +# Production GPU image: docker/Dockerfile + +FROM golang:1.23-bookworm AS build +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git cmake pkg-config libsentencepiece-dev \ + && rm -rf /var/lib/apt/lists/* + +COPY . . +RUN make dependency && make build-xlm + +FROM debian:bookworm-slim AS runtime +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=build /app/bin/go-whisper-api /bin/go-whisper-api +COPY --from=build /app/third_party/whisper.cpp/build/src/libwhisper.so* /usr/local/lib/ +COPY --from=build /app/third_party/whisper.cpp/build/ggml/src/libggml*.so* /usr/local/lib/ +ENV LD_LIBRARY_PATH=/usr/local/lib + +EXPOSE 8080 +ENTRYPOINT ["/bin/go-whisper-api"] diff --git a/garbage/filter.go b/garbage/filter.go new file mode 100644 index 0000000..2e33098 --- /dev/null +++ b/garbage/filter.go @@ -0,0 +1,56 @@ +package garbage + +import ( + "regexp" + "strings" +) + +var spaceCollapse = regexp.MustCompile(`\s+`) + +// Word is a timed token for garbage filtering (mirrors whisper.Word JSON shape). +type Word struct { + Word string `json:"word"` + Start int `json:"start"` + Stop int `json:"stop"` +} + +// FilterText removes configured artifact substrings and normalizes whitespace. +func FilterText(text string, patterns []string) string { + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + text = strings.ReplaceAll(text, p, " ") + } + return strings.TrimSpace(spaceCollapse.ReplaceAllString(text, " ")) +} + +// FilterWords drops tokens that match any garbage pattern. +func FilterWords(words []Word, patterns []string) []Word { + if len(words) == 0 { + return words + } + out := make([]Word, 0, len(words)) + for _, w := range words { + if matchesGarbage(w.Word, patterns) { + continue + } + out = append(out, w) + } + return out +} + +func matchesGarbage(word string, patterns []string) bool { + word = strings.TrimSpace(word) + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if word == p || strings.Contains(word, p) { + return true + } + } + return false +} diff --git a/garbage/filter_test.go b/garbage/filter_test.go new file mode 100644 index 0000000..5996bae --- /dev/null +++ b/garbage/filter_test.go @@ -0,0 +1,30 @@ +package garbage + +import "testing" + +func TestFilterText(t *testing.T) { + in := "Привет *выбая* мир *выбая*" + got := FilterText(in, []string{"*выбая*"}) + want := "Привет мир" + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestFilterWords(t *testing.T) { + words := []Word{ + {Word: "Что", Start: 0, Stop: 100}, + {Word: "*выбая*", Start: 100, Stop: 200}, + {Word: "мир", Start: 200, Stop: 300}, + } + got := FilterWords(words, []string{"*выбая*"}) + if len(got) != 2 || got[1].Word != "мир" { + t.Fatalf("got %+v", got) + } +} + +func TestFilterText_emptyPatterns(t *testing.T) { + if got := FilterText("a b", nil); got != "a b" { + t.Fatalf("got %q", got) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b6ab674 --- /dev/null +++ b/go.mod @@ -0,0 +1,48 @@ +module go-whisper-api + +go 1.26 + +require ( + github.com/Eyevinn/mp4ff v0.51.0 + github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27 + github.com/go-audio/audio v1.0.0 + github.com/go-audio/wav v1.1.0 + github.com/google/uuid v1.6.0 + github.com/gopxl/beep v1.4.1 + github.com/joho/godotenv v1.5.1 + github.com/k2-fsa/sherpa-onnx-go v1.13.2 + github.com/mattn/go-isatty v0.0.20 + github.com/olivier-w/climp-aac-decoder v0.1.0 + github.com/pion/opus v0.0.0-20260601214817-71d58474cec8 + github.com/rs/zerolog v1.35.0 + github.com/skrashevich/go-aac v0.1.0 + github.com/urfave/cli/v2 v2.27.7 + github.com/yalue/onnxruntime_go v1.24.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-audio/riff v1.0.0 // indirect + github.com/hajimehoshi/go-mp3 v0.3.4 // indirect + github.com/icza/bitio v1.1.0 // indirect + github.com/jfreymuth/oggvorbis v1.0.5 // indirect + github.com/jfreymuth/vorbis v1.0.2 // indirect + github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 // indirect + github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 // indirect + github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mewkiz/flac v1.0.8 // indirect + github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect + golang.org/x/sys v0.45.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect +) + +replace github.com/ggerganov/whisper.cpp/bindings/go => ./third_party/whisper.cpp/bindings/go diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7ae66e5 --- /dev/null +++ b/go.sum @@ -0,0 +1,127 @@ +github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M= +github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg= +github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= +github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/ebitengine/oto/v3 v3.1.0 h1:9tChG6rizyeR2w3vsygTTTVVJ9QMMyu00m2yBOCch6U= +github.com/ebitengine/oto/v3 v3.1.0/go.mod h1:IK1QTnlfZK2GIB6ziyECm433hAdTaPpOsGMLhEyEGTg= +github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA= +github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= +github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= +github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= +github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= +github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= +github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= +github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= +github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= +github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopxl/beep v1.4.1 h1:WqNs9RsDAhG9M3khMyc1FaVY50dTdxG/6S6a3qsUHqE= +github.com/gopxl/beep v1.4.1/go.mod h1:A1dmiUkuY8kxsvcNJNUBIEcchmiP6eUyCHSxpXl0YO0= +github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= +github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= +github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo= +github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0= +github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A= +github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k= +github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA= +github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ= +github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII= +github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE= +github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/jszwec/csvutil v1.5.1/go.mod h1:Rpu7Uu9giO9subDyMCIQfHVDuLrcaC36UA4YcJjGBkg= +github.com/k2-fsa/sherpa-onnx-go v1.13.2 h1:1orlwwdJcVk37OoV1gi1L6fIdeZHKJcLQjm6S5h/WWs= +github.com/k2-fsa/sherpa-onnx-go v1.13.2/go.mod h1:NyA3OsF/hmcvk4FpbXpBO9Rl/I3dlnrOB9Ny+TnhvIc= +github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 h1:w/RHaU9liD/ovA9Q5iI2lCoVjVEd4M8+uTrf54rjQPY= +github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2/go.mod h1:NXEH2rsBgTdqY59YpPq6CtSBlBAXy/8a9FmpLERU97I= +github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 h1:oIIdSfU3NEMT7oq7yxoH7Rk37kekkbmr/b+7e4er1WE= +github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2/go.mod h1:ZOhUAXC62Unj0ZNfu6zxSFKcW96aXf7P3BsqiUyOBbE= +github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 h1:tkfwXnmJRsChv59ZzLsycphwpvJpSkyd29aMG9kUoEE= +github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2/go.mod h1:5AX7TU8+P/gInjglY1ijtWUM2b8iyR0QX4yEngzMe64= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mewkiz/flac v1.0.8 h1:cophRjvafteDGmqsfXRK28YAX6l8wy19QxTHruEEg1s= +github.com/mewkiz/flac v1.0.8/go.mod h1:l7dt5uFY724eKVkHQtAJAQSkhpC3helU3RDxN0ESAqo= +github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 h1:tnAPMExbRERsyEYkmR1YjhTgDM0iqyiBYf8ojRXxdbA= +github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14/go.mod h1:QYCFBiH5q6XTHEbWhR0uhR3M9qNPoD2CSQzr0g75kE4= +github.com/olivier-w/climp-aac-decoder v0.1.0 h1:jdwYeBlnfQlnm6KA/fG20s22HrlBlVVaaWQTj/tnb3A= +github.com/olivier-w/climp-aac-decoder v0.1.0/go.mod h1:Hpqi8cI4NeN4JEcgKxiAK0zEwx+hO+lWdg9Q3ruHouI= +github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw= +github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0= +github.com/pion/opus v0.0.0-20260601214817-71d58474cec8 h1:THWkLaX+FzbTbHIo1irDpbBMXbNrqV6q2Xs+00p4d+0= +github.com/pion/opus v0.0.0-20260601214817-71d58474cec8/go.mod h1:t5Xog2n682JnawoykACE6nKVmupFvmJvkpM7x6bTv6g= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI= +github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/skrashevich/go-aac v0.1.0 h1:7oHNj1ADmgfjAHvi3wAIFbmbCpQBrcjZEVTLlRtAS1A= +github.com/skrashevich/go-aac v0.1.0/go.mod h1:Mj7r//4LDL4FC0ezORj+MnmQ+nDEkJhTOy2aMC8dzww= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= +github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= +github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= +github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +github.com/yalue/onnxruntime_go v1.24.0 h1:IdgJLxxyotlsUTmL1UnHZgBzXJGgY51LZ4vQ5rZeOXU= +github.com/yalue/onnxruntime_go v1.24.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..14ff581 --- /dev/null +++ b/main.go @@ -0,0 +1,251 @@ +package main + +import ( + "os" + "runtime" + "strconv" + "time" + + "go-whisper-api/api" + "go-whisper-api/config" + + _ "github.com/joho/godotenv/autoload" + "github.com/mattn/go-isatty" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/urfave/cli/v2" +) + +var ( + Version string +) + +func main() { + isTerm := isatty.IsTerminal(os.Stdout.Fd()) + zerolog.SetGlobalLevel(zerolog.InfoLevel) + log.Logger = log.Output( + zerolog.ConsoleWriter{ + Out: os.Stderr, + NoColor: !isTerm, + }, + ) + zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string { + short := file + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + file = short + return file + ":" + strconv.Itoa(line) + } + app := cli.NewApp() + app.Name = "go-whisper-api" + app.Usage = "HTTP API for speech-to-text (SPR + OpenAI-compatible)." + app.Copyright = "Copyright (c) " + strconv.Itoa(time.Now().Year()) + " Bo-Yi Wu" + app.Authors = []*cli.Author{ + { + Name: "Bo-Yi Wu", + Email: "appleboy.tw@gmail.com", + }, + } + app.Version = Version + app.Commands = []*cli.Command{ + { + Name: "serve", + Aliases: []string{"s", "api"}, + Usage: "start HTTP API server", + Flags: serveFlags(), + Action: runServe, + }, + } + app.Flags = append(append([]cli.Flag{configFlag()}, punctuationFlags()...), serveFlags()...) + app.Action = runServe + if err := app.Run(os.Args); err != nil { + log.Fatal().Err(err).Msg("can't run app") + } +} + +func configFlag() cli.Flag { + return &cli.StringFlag{ + Name: "config", + Usage: "path to YAML config file (default: config.yaml if present)", + EnvVars: []string{"CONFIG_PATH", "GO_WHISPER_CONFIG"}, + } +} + +func punctuationFlags() []cli.Flag { + return []cli.Flag{ + &cli.BoolFlag{ + Name: "punctuation-enabled", + Usage: "master switch: enable or disable punctuation (YAML: punctuation.enabled)", + EnvVars: []string{"PUNCTUATION_ENABLED"}, + }, + &cli.BoolFlag{ + Name: "punctuation-default-on", + Usage: "apply punctuation by default when query/flags omitted (YAML: punctuation.default_on)", + }, + &cli.StringFlag{ + Name: "punctuation-engine", + Usage: "punctuation engine: off, heuristic, sherpa, sherpa-online, http, xlm", + EnvVars: []string{"PUNCTUATION_ENGINE"}, + }, + } +} + +func vadFlags() []cli.Flag { + return []cli.Flag{ + &cli.BoolFlag{ + Name: "vad", + Usage: "enable voice activity detection (Silero VAD)", + EnvVars: []string{"API_VAD"}, + }, + &cli.StringFlag{ + Name: "vad-model", + Usage: "path to ggml Silero VAD model (e.g. models/ggml-silero-v6.2.0.bin)", + EnvVars: []string{"API_VAD_MODEL"}, + }, + &cli.Float64Flag{ + Name: "vad-threshold", + Usage: "VAD speech threshold (0.0-1.0)", + EnvVars: []string{"API_VAD_THRESHOLD"}, + }, + &cli.IntFlag{ + Name: "vad-min-speech-ms", + Usage: "minimum speech duration in ms", + EnvVars: []string{"API_VAD_MIN_SPEECH_MS"}, + }, + &cli.IntFlag{ + Name: "vad-min-silence-ms", + Usage: "minimum silence between segments in ms", + EnvVars: []string{"API_VAD_MIN_SILENCE_MS"}, + }, + &cli.Float64Flag{ + Name: "vad-max-speech-sec", + Usage: "maximum speech segment length in seconds (0 = unlimited)", + EnvVars: []string{"API_VAD_MAX_SPEECH_SEC"}, + }, + &cli.IntFlag{ + Name: "vad-speech-pad-ms", + Usage: "padding around speech segments in ms", + EnvVars: []string{"API_VAD_SPEECH_PAD_MS"}, + }, + &cli.Float64Flag{ + Name: "vad-samples-overlap", + Usage: "overlap between VAD segments in seconds", + EnvVars: []string{"API_VAD_SAMPLES_OVERLAP"}, + }, + } +} + +func serveFlags() []cli.Flag { + flags := []cli.Flag{ + &cli.StringFlag{ + Name: "addr", + Usage: "HTTP listen address", + EnvVars: []string{"API_ADDR"}, + Value: ":8080", + }, + &cli.StringFlag{ + Name: "models-dir", + Usage: "directory with ggml *.bin whisper models", + EnvVars: []string{"API_MODELS_DIR"}, + Value: "./models", + }, + &cli.StringFlag{ + Name: "cache-dir", + Usage: "directory for async task cache (waiting/ready)", + EnvVars: []string{"API_CACHE_DIR"}, + Value: "./cache", + }, + &cli.StringFlag{ + Name: "language", + Usage: "default language for speech recognition", + EnvVars: []string{"API_LANGUAGE"}, + Value: "auto", + }, + &cli.UintFlag{ + Name: "threads", + Usage: "number of threads for whisper", + EnvVars: []string{"API_THREADS"}, + Value: uint(runtime.NumCPU()), + }, + &cli.BoolFlag{ + Name: "debug", + Usage: "enable debug mode", + EnvVars: []string{"API_DEBUG"}, + }, + &cli.BoolFlag{ + Name: "speedup", + Usage: "speed up audio by x2", + EnvVars: []string{"API_SPEEDUP"}, + }, + &cli.BoolFlag{ + Name: "translate", + Usage: "translate to english", + EnvVars: []string{"API_TRANSLATE"}, + }, + &cli.StringFlag{ + Name: "prompt", + Usage: "initial prompt", + EnvVars: []string{"API_PROMPT"}, + }, + &cli.UintFlag{ + Name: "max-context", + Usage: "maximum text context tokens", + EnvVars: []string{"API_MAX_CONTEXT"}, + Value: 32, + }, + &cli.UintFlag{ + Name: "beam-size", + Usage: "beam size for beam search", + EnvVars: []string{"API_BEAM_SIZE"}, + Value: 5, + }, + &cli.Float64Flag{ + Name: "entropy-thold", + Usage: "entropy threshold", + EnvVars: []string{"API_ENTROPY_THOLD"}, + Value: 2.4, + }, + &cli.BoolFlag{ + Name: "default-punctuation", + Usage: "enable punctuation on STT by default", + EnvVars: []string{"API_DEFAULT_PUNCTUATION"}, + }, + } + return append(flags, vadFlags()...) +} + +func runServe(c *cli.Context) error { + apiCfg, err := config.APIFromCLI(c) + if err != nil { + return err + } + if apiCfg.Debug { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + log.Logger = log.With().Caller().Logger() + } + vad := apiCfg.VAD.WithDefaults() + if vad.Enabled { + vad.Model = vad.ResolveModelPath(apiCfg.ModelsDir) + if err := vad.Validate(); err != nil { + return err + } + apiCfg.VAD = vad + } + tc, err := config.TranscodeFromCLI(c) + if err != nil { + return err + } + pc, err := config.PunctuationFromCLI(c) + if err != nil { + return err + } + dc, err := config.DiarizationFromCLI(c) + if err != nil { + return err + } + return api.Run(c.Context, apiCfg, tc, pc, dc) +} diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/diarization/.gitkeep b/models/diarization/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/diarization/pyannote-segmentation-3-0/.gitkeep b/models/diarization/pyannote-segmentation-3-0/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/punctuation/.gitkeep b/models/punctuation/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/punctuation/xlm-roberta/.gitkeep b/models/punctuation/xlm-roberta/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/vad/.gitkeep b/models/vad/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/punctuation/heuristic.go b/punctuation/heuristic.go new file mode 100644 index 0000000..441b931 --- /dev/null +++ b/punctuation/heuristic.go @@ -0,0 +1,86 @@ +package punctuation + +import ( + "context" + "regexp" + "strings" + "unicode" + "unicode/utf8" +) + +type Heuristic struct{} + +func (Heuristic) Active() bool { + return true +} + +func (Heuristic) Restore(ctx context.Context, text, language string) (string, error) { + _ = ctx + text = strings.TrimSpace(text) + if text == "" { + return text, nil + } + text = normalizeSpaces(text) + text = capitalizeFirst(text) + lang := strings.ToLower(strings.TrimSpace(language)) + if lang == "ru" || lang == "rus" || lang == "russian" || lang == "auto" { + text = heuristicRU(text) + } else { + text = heuristicEN(text) + } + return ensureTerminalPunct(text), nil +} + +func normalizeSpaces(s string) string { + return strings.Join(strings.Fields(s), " ") +} + +func capitalizeFirst(s string) string { + r, size := utf8.DecodeRuneInString(s) + if r == utf8.RuneError { + return s + } + return string(unicode.ToUpper(r)) + s[size:] +} + +var ( + reQuestionRU = regexp.MustCompile(`(?i)(^|.*\s)(как|что|где|когда|почему|зачем|кто|чей|какой|какая|какое|какие|сколько|зачем|откуда|куда|ли)(\s+[^.?!]+)$`) + reQuestionEN = regexp.MustCompile(`(?i)^(who|what|when|where|why|how|which|whose|whom|is|are|am|was|were|do|does|did|can|could|would|will|shall|should)\b`) +) + +func heuristicRU(s string) string { + if reQuestionRU.MatchString(s) && !strings.HasSuffix(s, "?") { + return s + "?" + } + if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 { + return s + "." + } + return s +} + +func heuristicEN(s string) string { + lower := strings.ToLower(s) + if reQuestionEN.MatchString(lower) && !strings.HasSuffix(s, "?") { + return s + "?" + } + if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 { + return s + "." + } + return s +} + +func hasTerminalPunct(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + r, _ := utf8.DecodeLastRuneInString(s) + return r == '.' || r == '?' || r == '!' || r == '…' +} + +func ensureTerminalPunct(s string) string { + if hasTerminalPunct(s) { + return s + } + return s + "." +} diff --git a/punctuation/heuristic_test.go b/punctuation/heuristic_test.go new file mode 100644 index 0000000..b97fd8c --- /dev/null +++ b/punctuation/heuristic_test.go @@ -0,0 +1,32 @@ +package punctuation + +import ( + "context" + "testing" +) + +func TestHeuristicRU_question(t *testing.T) { + h := Heuristic{} + out, err := h.Restore(context.Background(), "как дела", "ru") + if err != nil { + t.Fatal(err) + } + if !stringsHasSuffix(out, "?") { + t.Fatalf("expected question mark, got %q", out) + } +} + +func TestHeuristicEN_period(t *testing.T) { + h := Heuristic{} + out, err := h.Restore(context.Background(), "hello world", "en") + if err != nil { + t.Fatal(err) + } + if !stringsHasSuffix(out, ".") { + t.Fatalf("expected period, got %q", out) + } +} + +func stringsHasSuffix(s, suffix string) bool { + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} diff --git a/punctuation/internal/spwrap/sp.go b/punctuation/internal/spwrap/sp.go new file mode 100644 index 0000000..16a64b8 --- /dev/null +++ b/punctuation/internal/spwrap/sp.go @@ -0,0 +1,92 @@ +//go:build xlm + +package spwrap + +/* +#cgo CXXFLAGS: -std=c++17 +#cgo LDFLAGS: -lsentencepiece +#include +#include "sp_wrap.h" +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +type Processor struct { + p *C.SPProcessor +} + +func Load(path string) (*Processor, error) { + cpath := C.CString(path) + defer C.free(unsafe.Pointer(cpath)) + var errMsg *C.char + p := C.sp_load(cpath, &errMsg) + if p == nil { + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return nil, fmt.Errorf("sentencepiece: %s", C.GoString(errMsg)) + } + return nil, fmt.Errorf("sentencepiece: failed to load %s", path) + } + return &Processor{p: p}, nil +} + +func (proc *Processor) Close() { + if proc.p != nil { + C.sp_free(proc.p) + proc.p = nil + } +} + +func (proc *Processor) BOSID() int { + return int(C.sp_bos_id(proc.p)) +} + +func (proc *Processor) EOSID() int { + return int(C.sp_eos_id(proc.p)) +} + +func (proc *Processor) PadID() int { + return int(C.sp_pad_id(proc.p)) +} + +func (proc *Processor) EncodeAsIDs(text string) ([]int, error) { + ctext := C.CString(text) + defer C.free(unsafe.Pointer(ctext)) + var ids *C.int + var n C.int + var errMsg *C.char + if C.sp_encode(proc.p, ctext, &ids, &n, &errMsg) == 0 { + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return nil, fmt.Errorf("sentencepiece encode: %s", C.GoString(errMsg)) + } + return nil, fmt.Errorf("sentencepiece encode failed") + } + if ids == nil || n == 0 { + return nil, nil + } + defer C.free(unsafe.Pointer(ids)) + out := make([]int, int(n)) + slice := unsafe.Slice(ids, int(n)) + for i := range out { + out[i] = int(slice[i]) + } + return out, nil +} + +func (proc *Processor) IDToPiece(id int) (string, error) { + var errMsg *C.char + piece := C.sp_id_to_piece(proc.p, C.int(id), &errMsg) + if piece == nil { + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return "", fmt.Errorf("sentencepiece id to piece: %s", C.GoString(errMsg)) + } + return "", fmt.Errorf("sentencepiece id to piece failed") + } + defer C.free(unsafe.Pointer(piece)) + return C.GoString(piece), nil +} diff --git a/punctuation/internal/spwrap/sp_wrap.cc b/punctuation/internal/spwrap/sp_wrap.cc new file mode 100644 index 0000000..f63ecb5 --- /dev/null +++ b/punctuation/internal/spwrap/sp_wrap.cc @@ -0,0 +1,88 @@ +#include "sp_wrap.h" + +#include + +#include +#include +#include +#include + +struct SPProcessor { + sentencepiece::SentencePieceProcessor proc; +}; + +static char *copy_err(const std::string &msg) { + char *out = static_cast(std::malloc(msg.size() + 1)); + if (out != nullptr) { + std::memcpy(out, msg.c_str(), msg.size() + 1); + } + return out; +} + +SPProcessor *sp_load(const char *path, char **err) { + if (err != nullptr) { + *err = nullptr; + } + auto *p = new SPProcessor(); + const auto status = p->proc.Load(path); + if (!status.ok()) { + if (err != nullptr) { + *err = copy_err(status.ToString()); + } + delete p; + return nullptr; + } + return p; +} + +void sp_free(SPProcessor *p) { delete p; } + +int sp_bos_id(const SPProcessor *p) { return p->proc.bos_id(); } + +int sp_eos_id(const SPProcessor *p) { return p->proc.eos_id(); } + +int sp_pad_id(const SPProcessor *p) { return p->proc.pad_id(); } + +int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err) { + if (err != nullptr) { + *err = nullptr; + } + if (out_ids != nullptr) { + *out_ids = nullptr; + } + if (out_len != nullptr) { + *out_len = 0; + } + std::vector ids; + const auto status = p->proc.Encode(text, &ids); + if (!status.ok()) { + if (err != nullptr) { + *err = copy_err(status.ToString()); + } + return 0; + } + if (ids.empty()) { + return 1; + } + int *buf = static_cast(std::malloc(sizeof(int) * ids.size())); + if (buf == nullptr) { + if (err != nullptr) { + *err = copy_err("malloc failed"); + } + return 0; + } + for (size_t i = 0; i < ids.size(); i++) { + buf[i] = ids[i]; + } + *out_ids = buf; + *out_len = static_cast(ids.size()); + return 1; +} + +char *sp_id_to_piece(const SPProcessor *p, int id, char **err) { + if (err != nullptr) { + *err = nullptr; + } + const std::string piece = p->proc.IdToPiece(id); + return copy_err(piece); +} diff --git a/punctuation/internal/spwrap/sp_wrap.h b/punctuation/internal/spwrap/sp_wrap.h new file mode 100644 index 0000000..2695240 --- /dev/null +++ b/punctuation/internal/spwrap/sp_wrap.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct SPProcessor SPProcessor; + +SPProcessor *sp_load(const char *path, char **err); +void sp_free(SPProcessor *p); + +int sp_bos_id(const SPProcessor *p); +int sp_eos_id(const SPProcessor *p); +int sp_pad_id(const SPProcessor *p); + +int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err); +char *sp_id_to_piece(const SPProcessor *p, int id, char **err); + +#ifdef __cplusplus +} +#endif diff --git a/punctuation/ort_env.go b/punctuation/ort_env.go new file mode 100644 index 0000000..ab0f97e --- /dev/null +++ b/punctuation/ort_env.go @@ -0,0 +1,177 @@ +//go:build xlm + +package punctuation + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + ort "github.com/yalue/onnxruntime_go" +) + +var ( + ortOnce sync.Once + ortErr error +) + +func ensureORT() error { + ortOnce.Do(func() { + if p := resolveONNXRuntimeLib(); p != "" { + ort.SetSharedLibraryPath(p) + } + ortErr = ort.InitializeEnvironment() + }) + return ortErr +} + +func resolveONNXRuntimeLib() string { + if p := strings.TrimSpace(os.Getenv("ONNXRUNTIME_SHARED_LIBRARY_PATH")); p != "" { + return p + } + for _, p := range onnxRuntimeCandidates() { + if st, err := os.Stat(p); err == nil && !st.IsDir() { + return p + } + } + return "" +} + +func onnxRuntimeCandidates() []string { + arch := sherpaLibArch() + ver := sherpaLinuxModuleVersion() + var out []string + for _, root := range goModCacheRoots() { + if ver != "" { + out = append(out, filepath.Join(root, + "github.com/k2-fsa/sherpa-onnx-go-linux@"+ver, + "lib", arch, "libonnxruntime.so")) + } + } + if exe, err := os.Executable(); err == nil { + exeDir := filepath.Dir(exe) + out = append(out, + filepath.Join(exeDir, "libonnxruntime.so"), + filepath.Join(exeDir, "lib", "libonnxruntime.so"), + filepath.Join(exeDir, "..", "lib", "libonnxruntime.so"), + ) + if modRoot := findModuleRoot(exeDir); modRoot != "" { + out = append(out, filepath.Join(modRoot, "lib", "libonnxruntime.so")) + } + } + return out +} + +func goModCacheRoots() []string { + var roots []string + seen := map[string]struct{}{} + add := func(p string) { + p = strings.TrimSpace(p) + if p == "" { + return + } + if _, ok := seen[p]; ok { + return + } + seen[p] = struct{}{} + roots = append(roots, p) + } + add(os.Getenv("GOMODCACHE")) + if gopath := os.Getenv("GOPATH"); gopath != "" { + for _, gp := range filepath.SplitList(gopath) { + add(filepath.Join(gp, "pkg", "mod")) + } + } + if home, err := os.UserHomeDir(); err == nil { + add(filepath.Join(home, "go", "pkg", "mod")) + } + return roots +} + +func findModuleRoot(start string) string { + dir := start + for i := 0; i < 8; i++ { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + if cwd, err := os.Getwd(); err == nil { + return findModuleRootFrom(cwd) + } + return "" +} + +func findModuleRootFrom(dir string) string { + for i := 0; i < 8; i++ { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return "" +} + +func sherpaLinuxModuleVersion() string { + for _, dir := range []string{findModuleRootFrom(mustCwd()), ""} { + if dir == "" { + continue + } + if v := readSherpaVersion(filepath.Join(dir, "go.mod")); v != "" { + return v + } + } + if exe, err := os.Executable(); err == nil { + if root := findModuleRoot(filepath.Dir(exe)); root != "" { + return readSherpaVersion(filepath.Join(root, "go.mod")) + } + } + return "" +} + +func readSherpaVersion(goModPath string) string { + data, err := os.ReadFile(goModPath) + if err != nil { + return "" + } + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if !strings.Contains(line, "github.com/k2-fsa/sherpa-onnx-go-linux") { + continue + } + fields := strings.Fields(line) + if len(fields) >= 2 { + return fields[1] // e.g. v1.13.2 — must match pkg/mod path @v1.13.2 + } + } + return "" +} + +func sherpaLibArch() string { + switch runtime.GOARCH { + case "arm64": + return "aarch64-unknown-linux-gnu" + case "arm": + return "arm-unknown-linux-gnueabihf" + default: + return "x86_64-unknown-linux-gnu" + } +} + +func mustCwd() string { + cwd, err := os.Getwd() + if err != nil { + return "" + } + return cwd +} diff --git a/punctuation/punctuation.go b/punctuation/punctuation.go new file mode 100644 index 0000000..3a995fb --- /dev/null +++ b/punctuation/punctuation.go @@ -0,0 +1,154 @@ +package punctuation + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + + "go-whisper-api/config" +) + +type Restorer interface { + Active() bool + Restore(ctx context.Context, text, language string) (string, error) +} + +type Closer interface { + Close() +} + +func New(cfg config.Punctuation) (Restorer, error) { + cfg = cfg.WithDefaults() + if !cfg.Active() { + return Nop{}, nil + } + engine := strings.ToLower(strings.TrimSpace(cfg.Engine)) + switch engine { + case "heuristic": + return Heuristic{}, nil + case "sherpa", "sherpa-offline", "offline": + cfg.Engine = "sherpa" + return newSherpaRestorer(cfg) + case "sherpa-online", "online": + cfg.Engine = "sherpa-online" + return newSherpaRestorer(cfg) + case "xlm", "xlm-roberta", "roberta": + return newXLM(cfg) + case "http": + if cfg.HTTPURL == "" { + return nil, fmt.Errorf("punctuation.http_url is required when engine=http") + } + return HTTP{cfg: cfg}, nil + default: + return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine) + } +} + +type Nop struct{} + +func (Nop) Active() bool { + return false +} + +func (Nop) Restore(ctx context.Context, text, language string) (string, error) { + return text, nil +} + +type HTTP struct { + cfg config.Punctuation + client *http.Client +} + +func (h HTTP) Active() bool { return true } + +func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) { + body, err := json.Marshal(map[string]string{ + "text": text, + "language": language, + }) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + client := h.client + if client == nil { + client = &http.Client{Timeout: h.cfg.Timeout()} + } + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + if resp.StatusCode >= 300 { + return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw))) + } + var out struct { + Text string `json:"text"` + } + if err := json.Unmarshal(raw, &out); err != nil { + return strings.TrimSpace(string(raw)), nil + } + if out.Text == "" { + return text, nil + } + return out.Text, nil +} + +func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) { + if !enabled || r == nil || !r.Active() { + return text, nil + } + text = strings.TrimSpace(text) + if text == "" { + return text, nil + } + return r.Restore(ctx, text, language) +} + +func Close(r Restorer) { + if c, ok := r.(Closer); ok { + c.Close() + } +} + +func AutoSelect(cfg config.Punctuation) (Restorer, error) { + cfg = cfg.WithDefaults() + if !cfg.Active() { + return Nop{}, nil + } + engine := strings.ToLower(cfg.Engine) + if engine == "heuristic" { + return Heuristic{}, nil + } + if engine == "http" { + return New(cfg) + } + if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { + return newXLM(cfg) + } + if _, err := os.Stat(cfg.ModelPath()); err == nil { + cfg.Engine = engine + if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" { + cfg.Engine = "sherpa" + } + return newSherpaRestorer(cfg) + } + if engine == "sherpa" || engine == "sherpa-online" || engine == "online" { + return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath()) + } + return Heuristic{}, nil +} diff --git a/punctuation/punctuation_test.go b/punctuation/punctuation_test.go new file mode 100644 index 0000000..6149fca --- /dev/null +++ b/punctuation/punctuation_test.go @@ -0,0 +1,40 @@ +package punctuation + +import ( + "context" + "testing" + + "go-whisper-api/config" +) + +func TestNop(t *testing.T) { + r, err := New(config.Punctuation{Engine: "off"}) + if err != nil { + t.Fatal(err) + } + out, err := Apply(context.Background(), r, true, "hello world", "en") + if err != nil { + t.Fatal(err) + } + if out != "hello world" { + t.Fatalf("got %q", out) + } +} + +func TestApply_disabled(t *testing.T) { + r := Heuristic{} + out, err := Apply(context.Background(), r, false, "hello", "en") + if err != nil || out != "hello" { + t.Fatalf("got %q err=%v", out, err) + } +} + +func TestNew_heuristic(t *testing.T) { + r, err := New(config.Punctuation{Enabled: true, Engine: "heuristic"}) + if err != nil { + t.Fatal(err) + } + if !r.Active() { + t.Fatal("expected active") + } +} diff --git a/punctuation/sherpa.go b/punctuation/sherpa.go new file mode 100644 index 0000000..ab5871b --- /dev/null +++ b/punctuation/sherpa.go @@ -0,0 +1,107 @@ +//go:build sherpa + +package punctuation + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + + "go-whisper-api/config" + + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" +) + +type Sherpa struct { + offline *sherpa.OfflinePunctuation + online *sherpa.OnlinePunctuation + mu sync.Mutex +} + +func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) { + return newSherpa(cfg) +} + +func newSherpa(cfg config.Punctuation) (*Sherpa, error) { + cfg = cfg.WithDefaults() + engine := strings.ToLower(cfg.Engine) + s := &Sherpa{} + switch engine { + case "sherpa-online", "online": + modelPath := cfg.ModelPath() + vocabPath := cfg.BpeVocabPath() + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err) + } + if _, err := os.Stat(vocabPath); err != nil { + return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err) + } + conf := sherpa.OnlinePunctuationConfig{} + conf.Model.CnnBilstm = modelPath + conf.Model.BpeVocab = vocabPath + conf.Model.NumThreads = cfg.NumThreads + conf.Model.Provider = "cpu" + s.online = sherpa.NewOnlinePunctuation(&conf) + if s.online == nil { + return nil, fmt.Errorf("failed to create sherpa online punctuation") + } + default: + modelPath := cfg.ModelPath() + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err) + } + conf := sherpa.OfflinePunctuationConfig{} + conf.Model.CtTransformer = modelPath + conf.Model.NumThreads = cfg.NumThreads + conf.Model.Provider = "cpu" + s.offline = sherpa.NewOfflinePunctuation(&conf) + if s.offline == nil { + return nil, fmt.Errorf("failed to create sherpa offline punctuation") + } + } + return s, nil +} + +func (s *Sherpa) Active() bool { + return true +} + +func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) { + _ = ctx + _ = language + text = strings.TrimSpace(text) + if text == "" { + return text, nil + } + s.mu.Lock() + defer s.mu.Unlock() + var out string + switch { + case s.offline != nil: + out = s.offline.AddPunct(text) + case s.online != nil: + out = s.online.AddPunct(text) + default: + return text, fmt.Errorf("sherpa punctuation not initialized") + } + out = strings.TrimSpace(out) + if out == "" { + return text, nil + } + return out, nil +} + +func (s *Sherpa) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if s.offline != nil { + sherpa.DeleteOfflinePunc(s.offline) + s.offline = nil + } + if s.online != nil { + sherpa.DeleteOnlinePunctuation(s.online) + s.online = nil + } +} diff --git a/punctuation/sherpa_stub.go b/punctuation/sherpa_stub.go new file mode 100644 index 0000000..cc4f44a --- /dev/null +++ b/punctuation/sherpa_stub.go @@ -0,0 +1,13 @@ +//go:build !sherpa + +package punctuation + +import ( + "fmt" + + "go-whisper-api/config" +) + +func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) { + return nil, fmt.Errorf("punctuation engine %q requires build tag sherpa (use: make build)", cfg.Engine) +} diff --git a/punctuation/xlm.go b/punctuation/xlm.go new file mode 100644 index 0000000..0bb4104 --- /dev/null +++ b/punctuation/xlm.go @@ -0,0 +1,319 @@ +//go:build xlm + +package punctuation + +import ( + "context" + "fmt" + "os" + "strings" + "unicode" + + "go-whisper-api/config" + "go-whisper-api/punctuation/internal/spwrap" + + ort "github.com/yalue/onnxruntime_go" +) + +type XLM struct { + cfg config.Punctuation + modelCfg xlmModelConfig + sp *spwrap.Processor + session *ort.DynamicAdvancedSession + inputName string + outputNames []string + joinSBD bool +} + +func newXLM(cfg config.Punctuation) (*XLM, error) { + if err := ensureORT(); err != nil { + return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err) + } + onnxPath := cfg.ModelPath() + if _, err := os.Stat(onnxPath); err != nil { + return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath) + } + spPath := cfg.SPModelPath() + if _, err := os.Stat(spPath); err != nil { + return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath) + } + cfgPath := cfg.XLMConfigPath() + modelCfg, err := loadXLMConfig(cfgPath) + if err != nil { + return nil, err + } + sp, err := spwrap.Load(spPath) + if err != nil { + return nil, err + } + inputs, outputs, err := ort.GetInputOutputInfo(onnxPath) + if err != nil { + sp.Close() + return nil, fmt.Errorf("xlm model io: %w", err) + } + if len(inputs) == 0 || len(outputs) < 4 { + sp.Close() + return nil, fmt.Errorf("xlm model: unexpected inputs/outputs") + } + inNames := make([]string, len(inputs)) + for i, in := range inputs { + inNames[i] = in.Name + } + outNames := make([]string, len(outputs)) + for i, out := range outputs { + outNames[i] = out.Name + } + session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil) + if err != nil { + sp.Close() + return nil, fmt.Errorf("xlm onnx session: %w", err) + } + return &XLM{ + cfg: cfg, + modelCfg: modelCfg, + sp: sp, + session: session, + inputName: inNames[0], + outputNames: outNames, + joinSBD: cfg.XLMJoinSentences(), + }, nil +} + +func (x *XLM) Active() bool { + return true +} + +func (x *XLM) Close() { + if x.session != nil { + _ = x.session.Destroy() + x.session = nil + } + if x.sp != nil { + x.sp.Close() + x.sp = nil + } +} + +func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + text = strings.TrimSpace(normalizeXLMSpaces(text)) + if text == "" { + return text, nil + } + + ids, err := x.sp.EncodeAsIDs(text) + if err != nil { + return "", err + } + full := make([]int, 0, len(ids)+2) + full = append(full, x.sp.BOSID()) + full = append(full, ids...) + full = append(full, x.sp.EOSID()) + maxLen := x.modelCfg.MaxLength + if maxLen <= 2 { + maxLen = 256 + } + if len(full) <= maxLen { + return x.inferIDs(full) + } + var parts []string + content := full[1 : len(full)-1] + step := maxLen - 2 + for start := 0; start < len(content); { + end := start + step + if end > len(content) { + end = len(content) + } + chunk := make([]int, 0, end-start+2) + chunk = append(chunk, x.sp.BOSID()) + chunk = append(chunk, content[start:end]...) + chunk = append(chunk, x.sp.EOSID()) + out, err := x.inferIDs(chunk) + if err != nil { + return "", err + } + if out != "" { + parts = append(parts, out) + } + if end >= len(content) { + break + } + start = end + } + return strings.TrimSpace(strings.Join(parts, " ")), nil +} + +func (x *XLM) inferIDs(inputIDs []int) (string, error) { + data := make([]int64, len(inputIDs)) + for i, id := range inputIDs { + data[i] = int64(id) + } + inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data) + if err != nil { + return "", err + } + defer inputTensor.Destroy() + outputs := make([]ort.Value, len(x.outputNames)) + if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil { + return "", err + } + defer destroyValues(outputs) + pre, err := int64Row(outputs[0], len(inputIDs)) + if err != nil { + return "", err + } + post, err := int64Row(outputs[1], len(inputIDs)) + if err != nil { + return "", err + } + cap, err := boolMatrix(outputs[2], len(inputIDs)) + if err != nil { + return "", err + } + sbd, err := boolRow(outputs[3], len(inputIDs)) + if err != nil { + return "", err + } + return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD) +} + +func destroyValues(vals []ort.Value) { + for _, v := range vals { + if v != nil { + _ = v.Destroy() + } + } +} + +func int64Row(v ort.Value, wantLen int) ([]int64, error) { + switch t := v.(type) { + case *ort.Tensor[int64]: + d := t.GetData() + if len(d) == wantLen { + return d, nil + } + if len(d) > wantLen { + return d[len(d)-wantLen:], nil + } + return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen) + case *ort.Tensor[int32]: + d := t.GetData() + if len(d) > wantLen { + d = d[len(d)-wantLen:] + } + out := make([]int64, wantLen) + for i := 0; i < wantLen && i < len(d); i++ { + out[i] = int64(d[i]) + } + return out, nil + default: + return nil, fmt.Errorf("unexpected int output type %T", v) + } +} + +func boolRow(v ort.Value, wantLen int) ([]bool, error) { + switch t := v.(type) { + case *ort.Tensor[bool]: + d := t.GetData() + if len(d) == wantLen { + return d, nil + } + if len(d) > wantLen { + return d[len(d)-wantLen:], nil + } + return nil, fmt.Errorf("bool output short") + case *ort.Tensor[float32]: + d := t.GetData() + out := make([]bool, wantLen) + for i := 0; i < wantLen && i < len(d); i++ { + out[i] = d[i] > 0.5 + } + return out, nil + default: + return nil, fmt.Errorf("unexpected bool output type %T", v) + } +} + +func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) { + switch t := v.(type) { + case *ort.Tensor[bool]: + shape := t.GetShape() + d := t.GetData() + if len(shape) == 3 { + _, sl, width := shape[0], shape[1], shape[2] + out := make([][]bool, sl) + for i := 0; i < int(sl); i++ { + row := make([]bool, width) + base := int(i) * int(width) + copy(row, d[base:base+int(width)]) + out[i] = row + } + return out, nil + } + width := len(d) / seqLen + if width < 1 { + width = 1 + } + out := make([][]bool, seqLen) + for i := 0; i < seqLen; i++ { + row := make([]bool, width) + base := i * width + if base+width <= len(d) { + copy(row, d[base:base+width]) + } + out[i] = row + } + return out, nil + case *ort.Tensor[float32]: + shape := t.GetShape() + d := t.GetData() + if len(shape) == 3 { + _, sl, width := shape[0], shape[1], shape[2] + out := make([][]bool, sl) + for i := 0; i < int(sl); i++ { + row := make([]bool, width) + base := int(i) * int(width) + for j := 0; j < int(width); j++ { + row[j] = d[base+j] > 0.5 + } + out[i] = row + } + return out, nil + } + width := len(d) / seqLen + if width < 1 { + width = 1 + } + out := make([][]bool, seqLen) + for i := 0; i < seqLen; i++ { + row := make([]bool, width) + base := i * width + for j := 0; j < width && base+j < len(d); j++ { + row[j] = d[base+j] > 0.5 + } + out[i] = row + } + return out, nil + default: + return nil, fmt.Errorf("unexpected cap output type %T", v) + } +} + +func normalizeXLMSpaces(s string) string { + var b strings.Builder + prevSpace := false + for _, r := range s { + if unicode.IsSpace(r) { + if !prevSpace { + b.WriteRune(' ') + prevSpace = true + } + continue + } + prevSpace = false + b.WriteRune(r) + } + return b.String() +} diff --git a/punctuation/xlm_config.go b/punctuation/xlm_config.go new file mode 100644 index 0000000..59e1fe1 --- /dev/null +++ b/punctuation/xlm_config.go @@ -0,0 +1,43 @@ +package punctuation + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +type xlmModelConfig struct { + Languages []string `yaml:"languages"` + MaxLength int `yaml:"max_length"` + PreLabels []string `yaml:"pre_labels"` + PostLabels []string `yaml:"post_labels"` + NullToken string `yaml:"null_token"` + Acronym string `yaml:"acronym_token"` +} + +func loadXLMConfig(path string) (xlmModelConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return xlmModelConfig{}, err + } + var cfg xlmModelConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return xlmModelConfig{}, fmt.Errorf("parse xlm config %s: %w", path, err) + } + if cfg.MaxLength <= 0 { + cfg.MaxLength = 256 + } + if cfg.NullToken == "" { + cfg.NullToken = "" + } + if cfg.Acronym == "" { + cfg.Acronym = "" + } + return cfg, nil +} + +func defaultXLMConfigPath(modelDir string) string { + return filepath.Join(modelDir, "config.yaml") +} diff --git a/punctuation/xlm_decode.go b/punctuation/xlm_decode.go new file mode 100644 index 0000000..d0b17ea --- /dev/null +++ b/punctuation/xlm_decode.go @@ -0,0 +1,97 @@ +//go:build xlm + +package punctuation + +import ( + "strings" + + "go-whisper-api/punctuation/internal/spwrap" +) + +func decodeXLMSegment( + sp *spwrap.Processor, + cfg xlmModelConfig, + inputIDs []int, + prePred, postPred []int64, + capPred [][]bool, + sbdPred []bool, + joinSentences bool, +) (string, error) { + var outputTexts []string + current := make([]string, 0, len(inputIDs)*4) + for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ { + piece, err := sp.IDToPiece(inputIDs[tokenIdx]) + if err != nil { + return "", err + } + if strings.HasPrefix(piece, "▁") && len(current) > 0 { + current = append(current, " ") + } + preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx) + postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx) + if preLabel != cfg.NullToken { + current = append(current, preLabel) + } + charStart := 0 + if strings.HasPrefix(piece, "▁") { + charStart = 1 + } + runes := []rune(piece) + for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ { + ch := string(runes[tokenCharIdx]) + if capAt(capPred, tokenIdx, tokenCharIdx) { + ch = strings.ToUpper(ch) + } + current = append(current, ch) + if postLabel == cfg.Acronym { + current = append(current, ".") + } + } + if postLabel != cfg.NullToken && postLabel != cfg.Acronym { + current = append(current, postLabel) + } + if sbdAt(sbdPred, tokenIdx) { + outputTexts = append(outputTexts, strings.Join(current, "")) + current = current[:0] + } + } + if len(current) > 0 { + outputTexts = append(outputTexts, strings.Join(current, "")) + } + if len(outputTexts) == 0 { + return "", nil + } + if joinSentences { + return strings.Join(outputTexts, " "), nil + } + return outputTexts[0], nil +} + +func labelAt(labels []string, preds []int64, idx int) string { + if idx < 0 || idx >= len(preds) { + return labels[0] + } + pi := int(preds[idx]) + if pi < 0 || pi >= len(labels) { + return labels[0] + } + return labels[pi] +} + +func capAt(capPred [][]bool, tokenIdx, charIdx int) bool { + if tokenIdx < 0 || tokenIdx >= len(capPred) { + return false + } + row := capPred[tokenIdx] + if charIdx < 0 || charIdx >= len(row) { + return false + } + return row[charIdx] +} + +func sbdAt(sbd []bool, tokenIdx int) bool { + if tokenIdx < 0 || tokenIdx >= len(sbd) { + return false + } + return sbd[tokenIdx] +} diff --git a/punctuation/xlm_stub.go b/punctuation/xlm_stub.go new file mode 100644 index 0000000..d12a476 --- /dev/null +++ b/punctuation/xlm_stub.go @@ -0,0 +1,13 @@ +//go:build !xlm + +package punctuation + +import ( + "fmt" + + "go-whisper-api/config" +) + +func newXLM(cfg config.Punctuation) (Restorer, error) { + return nil, fmt.Errorf("punctuation engine %q requires build tag xlm (run: make build-xlm)", cfg.Engine) +} diff --git a/transcode/aac_decode.go b/transcode/aac_decode.go new file mode 100644 index 0000000..e06ce9d --- /dev/null +++ b/transcode/aac_decode.go @@ -0,0 +1,132 @@ +package transcode + +import ( + "fmt" + "io" + "math" + "os" + "path/filepath" + "strings" + + "github.com/olivier-w/climp-aac-decoder/aacfile" +) + +func decodeAACPath(path, ext string) ([]float64, int, int, error) { + f, err := os.Open(path) + if err != nil { + return nil, 0, 0, err + } + defer f.Close() + st, err := f.Stat() + if err != nil { + return nil, 0, 0, err + } + // aacfile picks the parser from the *name* extension, not file content (.mp4 → .m4a). + name := aacOpenName(path, ext) + size := st.Size() + r, err := aacfile.Open(f, size, name) + if err != nil && isMP4SampleDeltaError(err) { + return decodeMP4AACRelaxed(f, size) + } + if err != nil { + return nil, 0, 0, err + } + defer r.Close() + sr := r.SampleRate() + ch := r.ChannelCount() + pcm, err := io.ReadAll(r) + if err != nil { + return nil, 0, 0, fmt.Errorf("read aac pcm: %w", err) + } + samples := pcm16LEToFloat(pcm, ch) + if ch > 1 { + samples = interleavedToMono(samples, ch) + ch = 1 + } + return samples, sr, ch, nil +} + +// aacContainerExt maps file extensions to a container name understood by aacfile. +func aacContainerExt(ext string) string { + switch strings.ToLower(ext) { + case ".mp4", ".m4v", ".mov", ".3gp", ".3g2": + return ".m4a" + case ".aac", ".m4a", ".m4b": + return ext + default: + return ".m4a" + } +} + +func aacOpenName(path, ext string) string { + containerExt := aacContainerExt(ext) + if containerExt == "" { + containerExt = ".m4a" + } + base := filepath.Base(path) + if e := strings.ToLower(filepath.Ext(base)); e == containerExt { + return base + } + stem := strings.TrimSuffix(base, filepath.Ext(base)) + if stem == "" || stem == base { + stem = "audio" + } + return stem + containerExt +} + +func pcm16LEToFloat(pcm []byte, channels int) []float64 { + if channels <= 0 { + channels = 1 + } + frameBytes := 2 * channels + nFrames := len(pcm) / frameBytes + out := make([]float64, nFrames*channels) + for i := 0; i < nFrames*channels; i++ { + off := i * 2 + if off+1 >= len(pcm) { + break + } + v := int16(pcm[off]) | int16(pcm[off+1])<<8 + out[i] = float64(v) / 32768.0 + } + return out +} + +func interleavedToMono(samples []float64, channels int) []float64 { + if channels <= 1 { + return samples + } + nFrames := len(samples) / channels + out := make([]float64, nFrames) + for i := 0; i < nFrames; i++ { + var sum float64 + for c := 0; c < channels; c++ { + sum += samples[i*channels+c] + } + out[i] = sum / float64(channels) + } + return out +} + +func resampleLinear(samples []float64, fromRate, toRate int) []float64 { + if fromRate <= 0 || toRate <= 0 || fromRate == toRate || len(samples) == 0 { + return samples + } + ratio := float64(fromRate) / float64(toRate) + outLen := int(math.Round(float64(len(samples)) / ratio)) + if outLen < 1 { + outLen = 1 + } + out := make([]float64, outLen) + for i := 0; i < outLen; i++ { + src := float64(i) * ratio + j := int(src) + if j >= len(samples)-1 { + out[i] = samples[len(samples)-1] + continue + } + frac := src - float64(j) + out[i] = samples[j]*(1-frac) + samples[j+1]*frac + } + return out +} diff --git a/transcode/aac_decode_test.go b/transcode/aac_decode_test.go new file mode 100644 index 0000000..04f7167 --- /dev/null +++ b/transcode/aac_decode_test.go @@ -0,0 +1,31 @@ +package transcode + +import "testing" + +func TestAacOpenName_mp4(t *testing.T) { + got := aacOpenName("/tmp/cache/input.mp4", ".mp4") + if got != "input.m4a" { + t.Fatalf("got %q want input.m4a", got) + } +} + +func TestAacOpenName_noExt(t *testing.T) { + got := aacOpenName("/tmp/input", ".m4a") + if got != "audio.m4a" { + t.Fatalf("got %q", got) + } +} + +func TestAacContainerExt(t *testing.T) { + cases := map[string]string{ + ".mp4": ".m4a", + ".mov": ".m4a", + ".aac": ".aac", + ".m4a": ".m4a", + } + for in, want := range cases { + if got := aacContainerExt(in); got != want { + t.Fatalf("%s: got %q want %q", in, got, want) + } + } +} diff --git a/transcode/decode.go b/transcode/decode.go new file mode 100644 index 0000000..c6fc875 --- /dev/null +++ b/transcode/decode.go @@ -0,0 +1,127 @@ +package transcode + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/gopxl/beep" + "github.com/gopxl/beep/flac" + "github.com/gopxl/beep/mp3" + beepwav "github.com/gopxl/beep/wav" +) + +var probeFormats = []string{ + ".wav", ".wave", ".mp3", ".flac", ".ogg", ".opus", + ".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".3gp", ".3g2", ".aac", +} + +func supportedFormatsMessage() string { + return strings.Join(probeFormats, ", ") +} + +func resolveFormat(path string) (string, error) { + ext := strings.ToLower(filepath.Ext(path)) + if ext != "" { + return ext, nil + } + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + ext = sniffFormat(f) + if ext == "" { + return "", fmt.Errorf("could not detect audio format (supported: %s)", supportedFormatsMessage()) + } + return ext, nil +} + +func openDecoder(path string) (beep.Streamer, beep.Format, io.Closer, error) { + ext, err := resolveFormat(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + streamer, format, closer, err := decodeByExt(path, ext) + if err == nil { + return streamer, format, closer, nil + } + for _, try := range probeFormats { + if try == ext { + continue + } + streamer, format, closer, tryErr := decodeByExt(path, try) + if tryErr == nil { + return streamer, format, closer, nil + } + } + return nil, beep.Format{}, nil, fmt.Errorf("unsupported audio format %q (supported: %s): %w", ext, supportedFormatsMessage(), err) +} + +func decodeByExt(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) { + switch ext { + case ".wav", ".wave": + return decodeBeepFile(path, ext) + case ".mp3": + return decodeBeepFile(path, ext) + case ".flac": + return decodeBeepFile(path, ext) + case ".ogg", ".opus": + return decodeOggFile(path) + case ".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".aac": + return decodeAACAsStreamer(path, ext) + case ".webm": + return nil, beep.Format{}, nil, fmt.Errorf("webm is not supported yet") + default: + return nil, beep.Format{}, nil, fmt.Errorf("unsupported extension %q", ext) + } +} + +func decodeBeepFile(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) { + f, err := os.Open(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + var ( + streamer beep.StreamSeekCloser + format beep.Format + decErr error + ) + switch ext { + case ".wav", ".wave": + streamer, format, decErr = beepwav.Decode(f) + case ".mp3": + streamer, format, decErr = mp3.Decode(f) + case ".flac": + streamer, format, decErr = flac.Decode(f) + default: + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("internal: beep decode for %q", ext) + } + if decErr != nil { + f.Close() + return nil, beep.Format{}, nil, decErr + } + return streamer, format, f, nil +} + +func decodeAACAsStreamer(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) { + samples, sr, ch, err := decodeAACPath(path, ext) + if err != nil { + return nil, beep.Format{}, nil, err + } + if ch <= 0 { + ch = 1 + } + return newSamplesStreamer(samples, sr), beep.Format{ + SampleRate: beep.SampleRate(sr), + NumChannels: ch, + Precision: 2, + }, noopCloser{}, nil +} + +type noopCloser struct{} + +func (noopCloser) Close() error { return nil } diff --git a/transcode/engine.go b/transcode/engine.go new file mode 100644 index 0000000..df9adad --- /dev/null +++ b/transcode/engine.go @@ -0,0 +1,119 @@ +package transcode + +import ( + "context" + "fmt" + "os" +) + +// Engine converts input audio to PCM WAV using pure Go decoders (no ffmpeg). +type Engine struct{} + +// NewEngine creates a transcoder. The ffmpegPath argument is ignored (kept for config compatibility). +func NewEngine(_ string) *Engine { + return &Engine{} +} + +func (e *Engine) Available() error { + return nil +} + +func (e *Engine) Transcode(ctx context.Context, src, dst string, opt Options) error { + if err := opt.Validate(); err != nil { + return err + } + spec, err := ResolveFormat(opt.Format) + if err != nil { + return err + } + dst, err = OutputPath(dst, spec.ID) + if err != nil { + return err + } + streamer, format, closer, err := openDecoder(src) + if err != nil { + return err + } + defer closer.Close() + s, format := buildPipeline(streamer, format, opt) + samples, err := drainSamples(ctx, s) + if err != nil { + return err + } + ch := format.NumChannels + if opt.Channels > 0 { + ch = opt.Channels + } + sr := int(format.SampleRate) + if opt.SampleRate > 0 { + sr = opt.SampleRate + } + if err := writePCM16WAV(dst, sr, ch, samples); err != nil { + return fmt.Errorf("write wav: %w", err) + } + return nil +} + +func Transcode(ctx context.Context, src, dst string, opt Options) error { + return NewEngine("").Transcode(ctx, src, dst, opt) +} + +func ToWhisperWAV(ctx context.Context, src, dst string) error { + return Transcode(ctx, src, dst, WhisperOptions()) +} + +// SupportedInputFormats lists file extensions decoded without external tools. +func SupportedInputFormats() []string { + return append([]string(nil), probeFormats...) +} + +func (e *Engine) Probe(ctx context.Context, path string) (MediaInfo, error) { + _ = ctx + ext, err := resolveFormat(path) + if err != nil { + return MediaInfo{}, err + } + streamer, format, closer, err := openDecoder(path) + if err != nil { + return MediaInfo{}, err + } + defer closer.Close() + info := MediaInfo{ + Format: ext, + Streams: []StreamInfo{{ + Index: 0, + Codec: extTrim(ext), + Type: "audio", + SampleRate: int(format.SampleRate), + Channels: format.NumChannels, + }}, + } + if st, err := os.Stat(path); err == nil { + info.BitRate = st.Size() * 8 + } + _ = streamer + return info, nil +} + +func extTrim(ext string) string { + if len(ext) > 0 && ext[0] == '.' { + return ext[1:] + } + return ext +} + +// MediaInfo describes decoded input (for optional diagnostics). +type MediaInfo struct { + Format string `json:"format"` + Duration float64 `json:"duration_seconds"` + BitRate int64 `json:"bit_rate"` + Streams []StreamInfo `json:"streams"` +} + +type StreamInfo struct { + Index int `json:"index"` + Codec string `json:"codec"` + Type string `json:"type"` + SampleRate int `json:"sample_rate,omitempty"` + Channels int `json:"channels,omitempty"` +} diff --git a/transcode/engine_test.go b/transcode/engine_test.go new file mode 100644 index 0000000..641056c --- /dev/null +++ b/transcode/engine_test.go @@ -0,0 +1,85 @@ +package transcode + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/go-audio/wav" +) + +func samplePath(t *testing.T, name string) string { + t.Helper() + p := filepath.Join("..", "third_party", "whisper.cpp", "samples", name) + if _, err := os.Stat(p); err != nil { + t.Skip("sample not found:", p) + } + return p +} + +func TestToWhisperWAV_mp3(t *testing.T) { + src := samplePath(t, "jfk.mp3") + dst := filepath.Join(t.TempDir(), "out.wav") + if err := ToWhisperWAV(context.Background(), src, dst); err != nil { + t.Fatal(err) + } + assertWhisperWAV(t, dst) +} + +func TestToWhisperWAV_wav(t *testing.T) { + src := samplePath(t, "jfk.wav") + dst := filepath.Join(t.TempDir(), "out.wav") + if err := ToWhisperWAV(context.Background(), src, dst); err != nil { + t.Fatal(err) + } + assertWhisperWAV(t, dst) +} + +func TestResolveFormat_noExtension_mp3(t *testing.T) { + src := samplePath(t, "jfk.mp3") + data, err := os.ReadFile(src) + if err != nil { + t.Fatal(err) + } + path := filepath.Join(t.TempDir(), "upload") + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatal(err) + } + ext, err := resolveFormat(path) + if err != nil || ext != ".mp3" { + t.Fatalf("ext=%q err=%v", ext, err) + } +} + +func TestEngine_Available(t *testing.T) { + if err := NewEngine("").Available(); err != nil { + t.Fatal(err) + } +} + +func assertWhisperWAV(t *testing.T, path string) { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + dec := wav.NewDecoder(f) + if !dec.IsValidFile() { + t.Fatal("invalid wav") + } + buf, err := dec.FullPCMBuffer() + if err != nil { + t.Fatal(err) + } + if dec.SampleRate != 16000 { + t.Fatalf("sample rate %d", dec.SampleRate) + } + if dec.NumChans != 1 { + t.Fatalf("channels %d", dec.NumChans) + } + if len(buf.Data) == 0 { + t.Fatal("empty audio") + } +} diff --git a/transcode/format.go b/transcode/format.go new file mode 100644 index 0000000..fdc34bb --- /dev/null +++ b/transcode/format.go @@ -0,0 +1,47 @@ +package transcode + +import ( + "fmt" + "path/filepath" + "strings" +) + +const ( + FormatWAV = "wav" + FormatPCM = "pcm" +) + +type FormatSpec struct { + ID string + Extension string + Codec string +} + +var formats = map[string]FormatSpec{ + FormatWAV: {ID: FormatWAV, Extension: ".wav", Codec: "pcm_s16le"}, + FormatPCM: {ID: FormatPCM, Extension: ".wav", Codec: "pcm_s16le"}, +} + +func ResolveFormat(name string) (FormatSpec, error) { + name = strings.ToLower(strings.TrimSpace(name)) + if name == "" { + name = FormatWAV + } + spec, ok := formats[name] + if !ok { + return FormatSpec{}, fmt.Errorf("unsupported format %q (supported: wav)", name) + } + return spec, nil +} + +func OutputPath(dst, format string) (string, error) { + spec, err := ResolveFormat(format) + if err != nil { + return "", err + } + ext := filepath.Ext(dst) + if ext == "" { + return dst + spec.Extension, nil + } + return dst, nil +} diff --git a/transcode/mp4_aac_decode.go b/transcode/mp4_aac_decode.go new file mode 100644 index 0000000..ea39dfb --- /dev/null +++ b/transcode/mp4_aac_decode.go @@ -0,0 +1,288 @@ +package transcode + +import ( + "errors" + "fmt" + "io" + + "github.com/Eyevinn/mp4ff/mp4" + "github.com/olivier-w/climp-aac-decoder/aacfile" + aacdec "github.com/skrashevich/go-aac/pkg/decoder" +) + +type mp4AACSample struct { + offset int64 + size int +} + +func isMP4SampleDeltaError(err error) bool { + var uf *aacfile.UnsupportedFeatureError + if !errors.As(err, &uf) { + return false + } + return uf.Feature == "MP4 sample delta" || uf.Feature == "MP4 sample delta layout" +} + +// decodeMP4AACRelaxed demuxes MP4/M4A with mp4ff (ignoring stts sample deltas) and +// decodes raw AAC frames with go-aac. Used when climp-aac-decoder rejects stts +// entries whose delta is not exactly 1024 (common in ffmpeg/phone muxers). +func decodeMP4AACRelaxed(r io.ReaderAt, size int64) ([]float64, int, int, error) { + asc, samples, leading, err := demuxMP4AAC(r, size) + if err != nil { + return nil, 0, 0, err + } + dec := aacdec.New() + if err := dec.SetASC(asc); err != nil { + return nil, 0, 0, fmt.Errorf("aac config: %w", err) + } + ch := dec.Config.ChanConfig + if ch < 1 { + return nil, 0, 0, fmt.Errorf("aac config: invalid channel count %d", ch) + } + sr := dec.Config.SampleRate + if sr <= 0 { + return nil, 0, 0, fmt.Errorf("aac config: invalid sample rate %d", sr) + } + + maxSize := 0 + for _, s := range samples { + if s.size > maxSize { + maxSize = s.size + } + } + buf := make([]byte, maxSize) + var pcm []float32 + for i, s := range samples { + if cap(buf) < s.size { + buf = make([]byte, s.size) + } + frame := buf[:s.size] + if _, err := r.ReadAt(frame, s.offset); err != nil { + return nil, 0, 0, fmt.Errorf("read mp4 aac sample %d: %w", i, err) + } + out, err := dec.DecodeFrame(frame) + if err != nil { + return nil, 0, 0, fmt.Errorf("decode mp4 aac sample %d: %w", i, err) + } + pcm = append(pcm, out...) + } + + skipSamples := leading * ch + if skipSamples > len(pcm) { + skipSamples = len(pcm) + } + pcm = pcm[skipSamples:] + + samplesF64 := make([]float64, len(pcm)) + for i, v := range pcm { + samplesF64[i] = float64(v) + } + if ch > 1 { + samplesF64 = float32InterleavedToMono(samplesF64, ch) + ch = 1 + } + return samplesF64, sr, ch, nil +} + +func float32InterleavedToMono(samples []float64, channels int) []float64 { + if channels <= 1 { + return samples + } + nFrames := len(samples) / channels + out := make([]float64, nFrames) + for i := 0; i < nFrames; i++ { + var sum float64 + for c := 0; c < channels; c++ { + sum += samples[i*channels+c] + } + out[i] = sum / float64(channels) + } + return out +} + +func demuxMP4AAC(r io.ReaderAt, size int64) (asc []byte, samples []mp4AACSample, leading int, err error) { + file, err := mp4.DecodeFile(io.NewSectionReader(r, 0, size), mp4.WithDecodeMode(mp4.DecModeLazyMdat)) + if err != nil { + return nil, nil, 0, fmt.Errorf("mp4 decode: %w", err) + } + if file.IsFragmented() { + return nil, nil, 0, fmt.Errorf("fragmented mp4 is not supported") + } + if file.Moov == nil { + return nil, nil, 0, fmt.Errorf("mp4: missing moov") + } + + var audioTracks []*mp4.TrakBox + for _, trak := range file.Moov.Traks { + if trak != nil && trak.Mdia != nil && trak.Mdia.Hdlr != nil && trak.Mdia.Hdlr.HandlerType == "soun" { + audioTracks = append(audioTracks, trak) + } + } + if len(audioTracks) != 1 { + return nil, nil, 0, fmt.Errorf("mp4: expected one audio track, found %d", len(audioTracks)) + } + trak := audioTracks[0] + if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil || trak.Mdia.Minf.Stbl.Stsd == nil { + return nil, nil, 0, fmt.Errorf("mp4: incomplete audio track") + } + stsd := trak.Mdia.Minf.Stbl.Stsd + if len(stsd.Children) != 1 { + return nil, nil, 0, fmt.Errorf("mp4: multiple sample descriptions") + } + if stsd.Enca != nil { + return nil, nil, 0, fmt.Errorf("mp4: encrypted audio") + } + sampleEntry := stsd.Mp4a + if sampleEntry == nil { + return nil, nil, 0, fmt.Errorf("mp4: unsupported audio sample entry %s", stsd.Children[0].Type()) + } + if sampleEntry.Sinf != nil { + return nil, nil, 0, fmt.Errorf("mp4: encrypted audio") + } + if sampleEntry.Esds == nil || + sampleEntry.Esds.DecConfigDescriptor == nil || + sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo == nil || + len(sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig) == 0 { + return nil, nil, 0, fmt.Errorf("mp4: missing AudioSpecificConfig") + } + asc = append([]byte(nil), sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig...) + + leading, _ = mp4LeadingTrimRelaxed(trak) + + samples, err = buildMP4AACSamples(trak, size) + if err != nil { + return nil, nil, 0, err + } + if len(samples) == 0 { + return nil, nil, 0, fmt.Errorf("mp4: no audio samples") + } + return asc, samples, leading, nil +} + +func mp4LeadingTrimRelaxed(trak *mp4.TrakBox) (int, error) { + if trak.Edts == nil || len(trak.Edts.Elst) == 0 { + return 0, nil + } + if len(trak.Edts.Elst) != 1 || len(trak.Edts.Elst[0].Entries) != 1 { + return 0, nil + } + entry := trak.Edts.Elst[0].Entries[0] + if entry.MediaRateInteger != 1 || entry.MediaRateFraction != 0 { + return 0, nil + } + if entry.MediaTime < 0 { + return 0, nil + } + return int(entry.MediaTime), nil +} + +func buildMP4AACSamples(trak *mp4.TrakBox, size int64) ([]mp4AACSample, error) { + if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil { + return nil, fmt.Errorf("mp4: incomplete sample table") + } + stbl := trak.Mdia.Minf.Stbl + if stbl.Stsc == nil || stbl.Stsz == nil { + return nil, fmt.Errorf("mp4: incomplete sample table") + } + if stbl.Stco == nil && stbl.Co64 == nil { + return nil, fmt.Errorf("mp4: missing chunk offsets") + } + if len(stbl.Stsc.Entries) == 0 { + return nil, fmt.Errorf("mp4: empty chunk map") + } + + totalSamples := int(trak.GetNrSamples()) + if totalSamples <= 0 { + return nil, fmt.Errorf("mp4: empty sample table") + } + + sampleSizes, err := mp4AACSampleSizes(stbl.Stsz, totalSamples) + if err != nil { + return nil, err + } + chunkOffsets, err := mp4AACChunkOffsets(stbl) + if err != nil { + return nil, err + } + + out := make([]mp4AACSample, 0, totalSamples) + sampleIndex := 0 + entryIndex := 0 + entry := stbl.Stsc.Entries[entryIndex] + + for chunkIndex := 0; chunkIndex < len(chunkOffsets) && sampleIndex < totalSamples; chunkIndex++ { + chunkNr := uint32(chunkIndex + 1) + for entryIndex+1 < len(stbl.Stsc.Entries) && chunkNr >= stbl.Stsc.Entries[entryIndex+1].FirstChunk { + entryIndex++ + entry = stbl.Stsc.Entries[entryIndex] + } + if entry.SamplesPerChunk == 0 { + return nil, fmt.Errorf("mp4: zero samples per chunk") + } + + offset := chunkOffsets[chunkIndex] + samplesPerChunk := int(entry.SamplesPerChunk) + for i := 0; i < samplesPerChunk && sampleIndex < totalSamples; i++ { + sampleSize := sampleSizes[sampleIndex] + end := offset + int64(sampleSize) + if offset < 0 || end < offset || end > size { + return nil, fmt.Errorf("mp4: invalid sample bounds at sample %d", sampleIndex+1) + } + out = append(out, mp4AACSample{offset: offset, size: sampleSize}) + offset = end + sampleIndex++ + } + } + if sampleIndex != totalSamples { + return nil, fmt.Errorf("mp4: sample table mismatch") + } + return out, nil +} + +func mp4AACSampleSizes(stsz *mp4.StszBox, totalSamples int) ([]int, error) { + if stsz == nil { + return nil, fmt.Errorf("mp4: missing sample sizes") + } + if int(stsz.GetNrSamples()) != totalSamples { + return nil, fmt.Errorf("mp4: sample size count mismatch") + } + sizes := make([]int, totalSamples) + if stsz.SampleUniformSize != 0 { + sz := int(stsz.SampleUniformSize) + for i := range sizes { + sizes[i] = sz + } + return sizes, nil + } + if len(stsz.SampleSize) != totalSamples { + return nil, fmt.Errorf("mp4: sample size table mismatch") + } + for i, sz := range stsz.SampleSize { + sizes[i] = int(sz) + } + return sizes, nil +} + +func mp4AACChunkOffsets(stbl *mp4.StblBox) ([]int64, error) { + switch { + case stbl == nil: + return nil, fmt.Errorf("mp4: incomplete sample table") + case stbl.Stco != nil: + offsets := make([]int64, len(stbl.Stco.ChunkOffset)) + for i, off := range stbl.Stco.ChunkOffset { + offsets[i] = int64(off) + } + return offsets, nil + case stbl.Co64 != nil: + offsets := make([]int64, len(stbl.Co64.ChunkOffset)) + for i, off := range stbl.Co64.ChunkOffset { + if off > uint64(^uint64(0)>>1) { + return nil, fmt.Errorf("mp4: invalid chunk offset") + } + offsets[i] = int64(off) + } + return offsets, nil + default: + return nil, fmt.Errorf("mp4: missing chunk offsets") + } +} diff --git a/transcode/ogg_decode.go b/transcode/ogg_decode.go new file mode 100644 index 0000000..a60cbe1 --- /dev/null +++ b/transcode/ogg_decode.go @@ -0,0 +1,154 @@ +package transcode + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/gopxl/beep" + "github.com/gopxl/beep/vorbis" + "github.com/pion/opus" + "github.com/pion/opus/pkg/oggreader" +) + +func decodeOggFile(path string) (beep.Streamer, beep.Format, io.Closer, error) { + switch sniffOggCodec(path) { + case "opus": + return decodeOggOpus(path) + case "vorbis": + return decodeOggVorbis(path) + default: + streamer, format, closer, err := decodeOggVorbis(path) + if err == nil { + return streamer, format, closer, nil + } + if isVorbisInvalidHeader(err) { + return decodeOggOpus(path) + } + return nil, beep.Format{}, nil, err + } +} + +func sniffOggCodec(path string) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + buf := make([]byte, 8192) + n, _ := io.ReadFull(f, buf) + buf = buf[:n] + if len(buf) < 4 || !bytes.HasPrefix(buf, []byte("OggS")) { + return "" + } + if bytes.Contains(buf, []byte("OpusHead")) { + return "opus" + } + // Vorbis ID packet: 0x01 + "vorbis" + if bytes.Contains(buf, []byte{0x01, 'v', 'o', 'r', 'b', 'i', 's'}) { + return "vorbis" + } + return "" +} + +func isVorbisInvalidHeader(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "invalid header") || strings.Contains(msg, "vorbis:") +} + +func decodeOggVorbis(path string) (beep.Streamer, beep.Format, io.Closer, error) { + f, err := os.Open(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + streamer, format, decErr := vorbis.Decode(f) + if decErr != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/vorbis: %w", decErr) + } + return streamer, format, f, nil +} + +func decodeOggOpus(path string) (beep.Streamer, beep.Format, io.Closer, error) { + f, err := os.Open(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + ogg, header, err := oggreader.NewWith(f) + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: %w", err) + } + + sr := int(header.SampleRate) + if sr <= 0 { + sr = 48000 + } + ch := int(header.Channels) + if ch <= 0 { + ch = 1 + } + + dec, err := opus.NewDecoderWithOutput(sr, ch) + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decoder: %w", err) + } + + const maxFrameSamples = 5760 + pcmBuf := make([]float32, maxFrameSamples*ch) + var samples []float64 + + for { + pkt, _, err := ogg.ParseNextPacket() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus read: %w", err) + } + if len(pkt) == 0 || bytes.HasPrefix(pkt, []byte("OpusHead")) || bytes.HasPrefix(pkt, []byte("OpusTags")) { + continue + } + + n, err := dec.DecodeToFloat32(pkt, pcmBuf) + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decode: %w", err) + } + if n <= 0 { + continue + } + total := n * ch + if total > len(pcmBuf) { + total = len(pcmBuf) + } + for i := 0; i < total; i++ { + samples = append(samples, float64(pcmBuf[i])) + } + } + + if len(samples) == 0 { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: no audio samples") + } + + outCh := ch + if outCh > 1 { + samples = interleavedToMono(samples, outCh) + outCh = 1 + } + + return newSamplesStreamer(samples, sr), beep.Format{ + SampleRate: beep.SampleRate(sr), + NumChannels: outCh, + Precision: 2, + }, f, nil +} diff --git a/transcode/ogg_decode_test.go b/transcode/ogg_decode_test.go new file mode 100644 index 0000000..a0c2b1b --- /dev/null +++ b/transcode/ogg_decode_test.go @@ -0,0 +1,52 @@ +package transcode + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestSniffOggCodec_opus(t *testing.T) { + path := pionTinyOggPath(t) + if got := sniffOggCodec(path); got != "opus" { + t.Fatalf("sniffOggCodec() = %q want opus", got) + } +} + +func TestDecodeOggOpus_pionTiny(t *testing.T) { + path := pionTinyOggPath(t) + streamer, format, closer, err := decodeOggOpus(path) + if err != nil { + t.Fatal(err) + } + defer closer.Close() + if format.SampleRate == 0 { + t.Fatal("zero sample rate") + } + buf := make([][2]float64, 4096) + n, ok := streamer.Stream(buf) + if !ok || n == 0 { + t.Fatal("expected pcm samples") + } +} + +func pionTinyOggPath(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("runtime.Caller failed") + } + modRoot := filepath.Clean(filepath.Join(filepath.Dir(file), "..")) + cache := os.Getenv("GOMODCACHE") + if cache == "" { + home, _ := os.UserHomeDir() + cache = filepath.Join(home, "go", "pkg", "mod") + } + path := filepath.Join(cache, "github.com/pion/opus@v0.0.0-20260601214817-71d58474cec8/testdata/tiny.ogg") + if _, err := os.Stat(path); err != nil { + _ = modRoot + t.Skipf("pion testdata not in module cache: %v", err) + } + return path +} diff --git a/transcode/options.go b/transcode/options.go new file mode 100644 index 0000000..a1025ce --- /dev/null +++ b/transcode/options.go @@ -0,0 +1,43 @@ +package transcode + +import "fmt" + +type Options struct { + Format string + SampleRate int + Channels int + Codec string +} + +func WhisperOptions() Options { + return Options{ + Format: FormatWAV, + SampleRate: 16000, + Channels: 1, + Codec: "pcm_s16le", + } +} + +func (o *Options) ApplyDefaults() error { + spec, err := ResolveFormat(o.Format) + if err != nil { + return err + } + if o.Codec == "" { + o.Codec = spec.Codec + } + return nil +} + +func (o *Options) Validate() error { + if err := o.ApplyDefaults(); err != nil { + return err + } + if o.SampleRate < 0 || o.Channels < 0 { + return fmt.Errorf("sample_rate and channels must be >= 0") + } + if o.Channels > 8 { + return fmt.Errorf("channels must be <= 8") + } + return nil +} diff --git a/transcode/options_test.go b/transcode/options_test.go new file mode 100644 index 0000000..637c962 --- /dev/null +++ b/transcode/options_test.go @@ -0,0 +1,20 @@ +package transcode + +import "testing" + +func TestWhisperOptions(t *testing.T) { + o := WhisperOptions() + if err := o.Validate(); err != nil { + t.Fatal(err) + } + if o.SampleRate != 16000 || o.Channels != 1 { + t.Fatalf("unexpected whisper opts: %+v", o) + } +} + +func TestResolveFormat_unknown(t *testing.T) { + _, err := ResolveFormat("xyz") + if err == nil { + t.Fatal("expected error") + } +} diff --git a/transcode/samples_stream.go b/transcode/samples_stream.go new file mode 100644 index 0000000..9a7a2c8 --- /dev/null +++ b/transcode/samples_stream.go @@ -0,0 +1,38 @@ +package transcode + +import "github.com/gopxl/beep" + +type samplesStreamer struct { + samples []float64 + pos int + sampleRate beep.SampleRate +} + +func newSamplesStreamer(samples []float64, sampleRate int) *samplesStreamer { + return &samplesStreamer{ + samples: samples, + sampleRate: beep.SampleRate(sampleRate), + } +} + +func (s *samplesStreamer) Stream(buf [][2]float64) (int, bool) { + if s.pos >= len(s.samples) { + return 0, false + } + n := 0 + for i := range buf { + if s.pos >= len(s.samples) { + return n, n > 0 + } + v := s.samples[s.pos] + buf[i][0] = v + buf[i][1] = v + s.pos++ + n++ + } + return n, true +} + +func (s *samplesStreamer) Err() error { + return nil +} diff --git a/transcode/sniff.go b/transcode/sniff.go new file mode 100644 index 0000000..d16e2c4 --- /dev/null +++ b/transcode/sniff.go @@ -0,0 +1,44 @@ +package transcode + +import ( + "bytes" + "io" +) + +// sniffFormat detects container/codec from file header when the path has no extension. +func sniffFormat(r io.Reader) string { + head := make([]byte, 32) + n, _ := io.ReadFull(r, head) + head = head[:n] + if len(head) < 4 { + return "" + } + if bytes.HasPrefix(head, []byte("RIFF")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("WAVE")) { + return ".wav" + } + if bytes.HasPrefix(head, []byte("ID3")) { + return ".mp3" + } + if len(head) >= 2 && head[0] == 0xFF && (head[1]&0xE0) == 0xE0 { + return ".mp3" + } + if bytes.HasPrefix(head, []byte("fLaC")) { + return ".flac" + } + if bytes.HasPrefix(head, []byte("OggS")) { + return ".ogg" + } + if len(head) >= 8 && bytes.Equal(head[4:8], []byte("ftyp")) { + return ".m4a" + } + if bytes.HasPrefix(head, []byte{0xFF, 0xF1}) || bytes.HasPrefix(head, []byte{0xFF, 0xF9}) { + return ".aac" + } + if bytes.HasPrefix(head, []byte("FORM")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("AIFF")) { + return ".aiff" + } + if bytes.HasPrefix(head, []byte{0x1A, 0x45, 0xDF, 0xA3}) { + return ".webm" + } + return "" +} diff --git a/transcode/stream.go b/transcode/stream.go new file mode 100644 index 0000000..870e95e --- /dev/null +++ b/transcode/stream.go @@ -0,0 +1,42 @@ +package transcode + +import ( + "context" + + "github.com/gopxl/beep" + "github.com/gopxl/beep/effects" +) + +func buildPipeline(streamer beep.Streamer, format beep.Format, opt Options) (beep.Streamer, beep.Format) { + out := streamer + if opt.SampleRate > 0 && format.SampleRate != beep.SampleRate(opt.SampleRate) { + out = beep.Resample(4, format.SampleRate, beep.SampleRate(opt.SampleRate), out) + format.SampleRate = beep.SampleRate(opt.SampleRate) + } + if opt.Channels == 1 { + out = effects.Mono(out) + format.NumChannels = 1 + } + return out, format +} + +func drainSamples(ctx context.Context, s beep.Streamer) ([]float64, error) { + buf := make([][2]float64, 4096) + var samples []float64 + for { + if err := ctx.Err(); err != nil { + return nil, err + } + n, ok := s.Stream(buf) + if !ok { + if err := s.Err(); err != nil { + return nil, err + } + break + } + for i := 0; i < n; i++ { + samples = append(samples, buf[i][0]) + } + } + return samples, nil +} diff --git a/transcode/wav_out.go b/transcode/wav_out.go new file mode 100644 index 0000000..ffe0021 --- /dev/null +++ b/transcode/wav_out.go @@ -0,0 +1,53 @@ +package transcode + +import ( + "os" + + "github.com/go-audio/audio" + "github.com/go-audio/wav" +) + +func writePCM16WAV(path string, sampleRate int, channels int, samples []float64) error { + if channels <= 0 { + channels = 1 + } + if err := os.MkdirAll(dirOf(path), 0o755); err != nil && dirOf(path) != "." { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + enc := wav.NewEncoder(f, sampleRate, 16, channels, 1) + data := make([]int, len(samples)) + for i, s := range samples { + data[i] = floatToInt16(s) + } + if err := enc.Write(&audio.IntBuffer{ + Format: &audio.Format{SampleRate: sampleRate, NumChannels: channels}, + Data: data, + }); err != nil { + return err + } + return enc.Close() +} + +func floatToInt16(f float64) int { + if f > 1 { + f = 1 + } + if f < -1 { + f = -1 + } + return int(f * 32767) +} + +func dirOf(path string) string { + for i := len(path) - 1; i >= 0; i-- { + if path[i] == '/' || path[i] == '\\' { + return path[:i] + } + } + return "." +} diff --git a/whisper/audio.go b/whisper/audio.go new file mode 100644 index 0000000..1cc9e50 --- /dev/null +++ b/whisper/audio.go @@ -0,0 +1,11 @@ +package whisper + +import ( + "context" + + "go-whisper-api/transcode" +) + +func AudioToWav(src, dst string) error { + return transcode.ToWhisperWAV(context.Background(), src, dst) +} diff --git a/whisper/audio_load.go b/whisper/audio_load.go new file mode 100644 index 0000000..be2aa4f --- /dev/null +++ b/whisper/audio_load.go @@ -0,0 +1,60 @@ +package whisper + +import ( + "fmt" + "os" + "path/filepath" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-audio/wav" +) + +// LoadPCM16Mono reads 16 kHz mono WAV into float32 samples (for diarization). +func LoadPCM16Mono(path string) ([]float32, error) { + return loadPCM16Mono(path) +} + +func loadPCM16Mono(path string) ([]float32, error) { + fh, err := os.Open(path) + if err != nil { + return nil, err + } + defer fh.Close() + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + if err != nil { + return nil, err + } + if dec.SampleRate != whisper.SampleRate { + return nil, fmt.Errorf("unsupported sample rate: %d", dec.SampleRate) + } + if dec.NumChans != 1 { + return nil, fmt.Errorf("unsupported number of channels: %d", dec.NumChans) + } + return buf.AsFloat32Buffer().Data, nil +} + +func prepareAudioPCM(sourcePath string) (data []float32, cleanup func(), err error) { + cleanup = func() {} + if data, err = loadPCM16Mono(sourcePath); err == nil { + return data, cleanup, nil + } + dir, err := config.MkdirTemp("go-whisper-api-whisper-*") + if err != nil { + return nil, nil, err + } + cleanup = func() { os.RemoveAll(dir) } + converted := filepath.Join(dir, "converted.wav") + if err := AudioToWav(sourcePath, converted); err != nil { + cleanup() + return nil, nil, err + } + data, err = loadPCM16Mono(converted) + if err != nil { + cleanup() + return nil, nil, err + } + return data, cleanup, nil +} diff --git a/whisper/format.go b/whisper/format.go new file mode 100644 index 0000000..4dfd7f7 --- /dev/null +++ b/whisper/format.go @@ -0,0 +1,154 @@ +package whisper + +import ( + "fmt" + "strings" + "time" + + wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +// Turn is a speaker-active time range from diarization (seconds). +type Turn struct { + Start float32 + End float32 + Speaker int +} + +// FormatOptions controls joining Whisper segments into one string with newlines. +type FormatOptions struct { + PauseGap time.Duration + SpeakerLabel string + UseSpeakers bool +} + +// FormatSegments joins segment texts with \n on long pauses and optional speaker labels. +func FormatSegments(segments []wpkg.Segment, turns []Turn, opts FormatOptions) string { + if len(segments) == 0 { + return "" + } + if opts.PauseGap <= 0 { + opts.PauseGap = 1500 * time.Millisecond + } + label := strings.TrimSpace(opts.SpeakerLabel) + if label == "" { + label = "Спикер" + } + + lines := make([]segmentLine, len(segments)) + for i, seg := range segments { + lines[i] = segmentLine{ + Text: strings.TrimSpace(seg.Text), + Start: seg.Start, + End: seg.End, + Speaker: -1, + } + } + if opts.UseSpeakers && len(turns) > 0 { + assignSpeakers(lines, turns) + } + + var b strings.Builder + prevSpeaker := -2 + prevIdx := -1 + for i, line := range lines { + if line.Text == "" { + continue + } + if b.Len() > 0 && prevIdx >= 0 { + speakerBreak := opts.UseSpeakers && line.Speaker >= 0 && line.Speaker != prevSpeaker + pauseBreak := line.Start-lines[prevIdx].End >= opts.PauseGap + switch { + case speakerBreak: + b.WriteString("\n\n") + fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1) + case pauseBreak: + b.WriteString("\n") + default: + if !strings.HasSuffix(b.String(), " ") && !strings.HasSuffix(b.String(), "\n") { + b.WriteByte(' ') + } + } + } else if opts.UseSpeakers && line.Speaker >= 0 { + fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1) + } + b.WriteString(line.Text) + if line.Speaker >= 0 { + prevSpeaker = line.Speaker + } + prevIdx = i + } + return strings.TrimSpace(b.String()) +} + +type segmentLine struct { + Text string + Start time.Duration + End time.Duration + Speaker int +} + +func assignSpeakers(lines []segmentLine, turns []Turn) { + for i := range lines { + mid := lines[i].Start + (lines[i].End-lines[i].Start)/2 + lines[i].Speaker = speakerAt(mid, turns) + } +} + +func speakerAt(t time.Duration, turns []Turn) int { + sec := float32(t.Seconds()) + bestSpeaker := -1 + bestOverlap := float32(0) + for _, tr := range turns { + if sec >= tr.Start && sec < tr.End { + return tr.Speaker + } + overlap := intervalOverlap(sec, sec, tr.Start, tr.End) + if overlap > bestOverlap { + bestOverlap = overlap + bestSpeaker = tr.Speaker + } + } + return bestSpeaker +} + +func intervalOverlap(a0, a1, b0, b1 float32) float32 { + start := max32(a0, b0) + end := min32(a1, b1) + if end <= start { + return 0 + } + return end - start +} + +func max32(a, b float32) float32 { + if a > b { + return a + } + return b +} + +func min32(a, b float32) float32 { + if a < b { + return a + } + return b +} + +// PunctuateSegments runs punctuation on each segment separately (preserves line breaks). +func PunctuateSegments(segments []wpkg.Segment, restore func(text string) (string, error)) ([]wpkg.Segment, error) { + out := make([]wpkg.Segment, len(segments)) + copy(out, segments) + for i := range out { + t := strings.TrimSpace(out[i].Text) + if t == "" { + continue + } + p, err := restore(t) + if err != nil { + return nil, err + } + out[i].Text = " " + strings.TrimSpace(p) + } + return out, nil +} diff --git a/whisper/format_test.go b/whisper/format_test.go new file mode 100644 index 0000000..7490313 --- /dev/null +++ b/whisper/format_test.go @@ -0,0 +1,40 @@ +package whisper + +import ( + "testing" + "time" + + wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func TestFormatSegments_pauseAndSpeaker(t *testing.T) { + segments := []wpkg.Segment{ + {Text: " привет", Start: 0, End: 2 * time.Second}, + {Text: " мир", Start: 4 * time.Second, End: 5 * time.Second}, + {Text: " ответ", Start: 6 * time.Second, End: 8 * time.Second}, + } + turns := []Turn{ + {Start: 0, End: 5.5, Speaker: 0}, + {Start: 5.5, End: 10, Speaker: 1}, + } + got := FormatSegments(segments, turns, FormatOptions{ + PauseGap: 1500 * time.Millisecond, + SpeakerLabel: "Спикер", + UseSpeakers: true, + }) + want := "Спикер 1: привет\nмир\n\nСпикер 2: ответ" + if got != want { + t.Fatalf("got %q\nwant %q", got, want) + } +} + +func TestFormatSegments_noSpeakers(t *testing.T) { + segments := []wpkg.Segment{ + {Text: "a", Start: 0, End: time.Second}, + {Text: "b", Start: 3 * time.Second, End: 4 * time.Second}, + } + got := FormatSegments(segments, nil, FormatOptions{PauseGap: time.Second, UseSpeakers: false}) + if got != "a\nb" { + t.Fatalf("got %q", got) + } +} diff --git a/whisper/helper.go b/whisper/helper.go new file mode 100644 index 0000000..8f98e6a --- /dev/null +++ b/whisper/helper.go @@ -0,0 +1,15 @@ +package whisper + +import ( + "fmt" + "time" +) + +func srtTimestamp(t time.Duration) string { + return fmt.Sprintf("%02d:%02d:%02d,%03d", + t/time.Hour, + (t%time.Hour)/time.Minute, + (t%time.Minute)/time.Second, + (t%time.Second)/time.Millisecond, + ) +} diff --git a/whisper/helper_test.go b/whisper/helper_test.go new file mode 100644 index 0000000..3f6b29d --- /dev/null +++ b/whisper/helper_test.go @@ -0,0 +1,39 @@ +package whisper + +import ( + "testing" + "time" +) + +func TestSrtTimestamp(t *testing.T) { + type args struct { + t time.Duration + } + tests := []struct { + name string + args args + want string + }{ + { + name: "test 1", + args: args{ + t: time.Duration(1*time.Hour + 2*time.Minute + 3*time.Second + 4*time.Millisecond), + }, + want: "01:02:03,004", + }, + { + name: "test 2", + args: args{ + t: time.Duration(10*time.Hour + 20*time.Minute + 30*time.Second + 40*time.Millisecond), + }, + want: "10:20:30,040", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := srtTimestamp(tt.args.t); got != tt.want { + t.Errorf("srtTimestamp() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/whisper/model_pool.go b/whisper/model_pool.go new file mode 100644 index 0000000..9a96da4 --- /dev/null +++ b/whisper/model_pool.go @@ -0,0 +1,89 @@ +package whisper + +import ( + "sync" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +// Model is a loaded whisper.cpp weights file (re-export for API callers). +type Model = whisper.Model + +// ModelPool keeps whisper models loaded in memory and serializes inference per model path. +type ModelPool struct { + mu sync.Mutex + entries map[string]*pooledModel +} + +type pooledModel struct { + model whisper.Model + mu sync.Mutex +} + +func NewModelPool() *ModelPool { + return &ModelPool{entries: make(map[string]*pooledModel)} +} + +func (p *ModelPool) WithModel(path string, fn func(whisper.Model) error) error { + e, err := p.entry(path) + if err != nil { + return err + } + e.mu.Lock() + defer e.mu.Unlock() + return fn(e.model) +} + +func (p *ModelPool) entry(path string) (*pooledModel, error) { + p.mu.Lock() + defer p.mu.Unlock() + if e, ok := p.entries[path]; ok { + return e, nil + } + m, err := whisper.New(path) + if err != nil { + return nil, err + } + e := &pooledModel{model: m} + p.entries[path] = e + return e, nil +} + +func (p *ModelPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + for _, e := range p.entries { + _ = e.model.Close() + } + p.entries = make(map[string]*pooledModel) +} + +var defaultPool = NewModelPool() + +func DefaultPool() *ModelPool { + return defaultPool +} + +// Transcribe runs speech recognition using a cached model. +func Transcribe(cfg *config.Whisper) (TranscriptResult, error) { + return TranscribeWithPool(defaultPool, cfg, RunOptions{}) +} + +func TranscribeWithPool(pool *ModelPool, cfg *config.Whisper, opts RunOptions) (TranscriptResult, error) { + if pool == nil { + pool = defaultPool + } + if err := cfg.Validate(); err != nil { + return TranscriptResult{}, err + } + eng := &Engine{cfg: cfg, runOpts: opts} + err := pool.WithModel(cfg.Model, func(m whisper.Model) error { + return eng.transcribeWithModel(m) + }) + if err != nil { + return TranscriptResult{}, err + } + return eng.Result(), nil +} diff --git a/whisper/options.go b/whisper/options.go new file mode 100644 index 0000000..3da6037 --- /dev/null +++ b/whisper/options.go @@ -0,0 +1,8 @@ +package whisper + +// RunOptions affects transcription output formatting and optional diarization hints. +type RunOptions struct { + Format FormatOptions + Turns []Turn + PunctuateRestore func(text string) (string, error) +} diff --git a/whisper/vad.go b/whisper/vad.go new file mode 100644 index 0000000..39bff45 --- /dev/null +++ b/whisper/vad.go @@ -0,0 +1,43 @@ +package whisper + +import ( + "fmt" + "math" + + "go-whisper-api/config" + + pkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func ApplyVAD(ctx pkg.Context, vad config.VAD) { + if !vad.Enabled { + return + } + vad = vad.WithDefaults() + ctx.SetVAD(true) + ctx.SetVADModelPath(vad.Model) + ctx.SetVADThreshold(float32(vad.Threshold)) + ctx.SetVADMinSpeechMs(vad.MinSpeechMs) + ctx.SetVADMinSilenceMs(vad.MinSilenceMs) + if vad.MaxSpeechSec > 0 { + ctx.SetVADMaxSpeechSec(float32(vad.MaxSpeechSec)) + } else { + ctx.SetVADMaxSpeechSec(float32(math.MaxFloat32)) + } + ctx.SetVADSpeechPadMs(vad.SpeechPadMs) + ctx.SetVADSamplesOverlap(float32(vad.SamplesOverlap)) +} + +func prepareVAD(vad *config.VAD, modelsDir string) error { + if vad == nil || !vad.Enabled { + return nil + } + if modelsDir != "" { + vad.Model = vad.ResolveModelPath(modelsDir) + } + *vad = vad.WithDefaults() + if err := vad.Validate(); err != nil { + return fmt.Errorf("vad: %w", err) + } + return nil +} diff --git a/whisper/whisper.go b/whisper/whisper.go new file mode 100644 index 0000000..24c7b9f --- /dev/null +++ b/whisper/whisper.go @@ -0,0 +1,245 @@ +package whisper + +import ( + "fmt" + "os" + "path" + "path/filepath" + "strings" + "time" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/rs/zerolog/log" +) + +type OutputFormat string + +func (f OutputFormat) String() string { + return string(f) +} + +var ( + FormatTxt OutputFormat = "txt" + FormatSrt OutputFormat = "srt" + FormatCSV OutputFormat = "csv" +) + +type Engine struct { + cfg *config.Whisper + ctx whisper.Context + model whisper.Model + segments []whisper.Segment + progress int + runOpts RunOptions +} + +func (e *Engine) Transcript() error { + return defaultPool.WithModel(e.cfg.Model, func(m whisper.Model) error { + return e.transcribeWithModel(m) + }) +} + +func (e *Engine) transcribeWithModel(model whisper.Model) error { + data, cleanup, err := prepareAudioPCM(e.cfg.AudioPath) + if err != nil { + return err + } + defer cleanup() + e.model = model + e.ctx, err = e.model.NewContext() + if err != nil { + return err + } + e.ctx.SetThreads(e.cfg.Threads) + if e.cfg.SpeedUp { + e.ctx.SetAudioCtx(750) + } + e.ctx.SetTranslate(e.cfg.Translate) + if e.cfg.Prompt != "" { + e.ctx.SetInitialPrompt(e.cfg.Prompt) + } + e.ctx.SetMaxContext(int(e.cfg.MaxContext)) + if e.cfg.Debug { + log.Info().Msgf("%s", e.ctx.SystemInfo()) + } + if e.cfg.Language != "" { + _ = e.ctx.SetLanguage(e.cfg.Language) + } + if e.cfg.BeamSize > 0 { + e.ctx.SetBeamSize(int(e.cfg.BeamSize)) + } + if e.cfg.EntropyThold > 0 { + e.ctx.SetEntropyThold(float32(e.cfg.EntropyThold)) + } + if err := prepareVAD(&e.cfg.VAD, ""); err != nil { + return err + } + ApplyVAD(e.ctx, e.cfg.VAD) + log.Debug().Msg("start transcribe process") + e.ctx.ResetTimings() + if err := e.ctx.Process(data, e.cbEncoderBegin(), e.cbSegment(), e.cbProgress()); err != nil { + return err + } + if e.cfg.Debug { + e.ctx.PrintTimings() + } + return nil +} + +func (e *Engine) cbEncoderBegin() func() bool { + return func() bool { return true } +} + +func (e *Engine) cbSegment() func(segment whisper.Segment) { + return func(segment whisper.Segment) { + e.segments = append(e.segments, segment) + if !e.cfg.PrintSegment { + return + } + log.Info().Msgf( + "[%6s -> %6s] %s", + segment.Start.Truncate(time.Millisecond), + segment.End.Truncate(time.Millisecond), + segment.Text, + ) + } +} + +func (e *Engine) cbProgress() func(progress int) { + return func(progress int) { + if progress > 100 { + progress = 100 + } + if e.progress == progress { + return + } + e.progress = progress + if e.cfg.PrintProgress { + log.Info().Msgf("current progress: %d%%", progress) + } + } +} + +func (e *Engine) getOutputPath(format string) string { + ext := filepath.Ext(e.cfg.AudioPath) + filename := filepath.Base(e.cfg.AudioPath) + if e.cfg.OutputFilename != "" { + filename = e.cfg.OutputFilename + } + folder := filepath.Dir(e.cfg.AudioPath) + if e.cfg.OutputFolder != "" { + folder = e.cfg.OutputFolder + } + return path.Join(folder, strings.TrimSuffix(filename, ext)+"."+format) +} + +func (e *Engine) Save(format string) error { + outputPath := e.getOutputPath(format) + log.Info().Str("output-path", outputPath).Str("output-format", format).Msg("save text to file") + text := "" + switch OutputFormat(format) { + case FormatSrt: + for i, segment := range e.segments { + text += fmt.Sprintf("%d\n", i+1) + text += fmt.Sprintf("%s --> %s\n", srtTimestamp(segment.Start), srtTimestamp(segment.End)) + text += segment.Text + "\n\n" + + } + case FormatTxt: + for _, segment := range e.segments { + text += segment.Text + } + case FormatCSV: + text = "start,end,text\n" + for _, segment := range e.segments { + text += fmt.Sprintf("%s,%s,\"%s\"\n", segment.Start, segment.End, segment.Text) + } + } + if err := os.WriteFile(outputPath, []byte(text), 0o644); err != nil { + return err + } + return nil +} + +type Word struct { + Word string `json:"word"` + Start int `json:"start"` + Stop int `json:"stop"` +} + +type TranscriptResult struct { + Text string `json:"text"` + Words []Word `json:"words,omitempty"` +} + +func (e *Engine) SetTranscriptText(text string) { + if len(e.segments) == 0 { + e.segments = []whisper.Segment{{Text: text}} + return + } + start := e.segments[0].Start + end := e.segments[len(e.segments)-1].End + e.segments = []whisper.Segment{{Text: text, Start: start, End: end}} +} + +func (e *Engine) Result() TranscriptResult { + segments := e.segments + if e.runOpts.PunctuateRestore != nil { + updated, err := PunctuateSegments(segments, e.runOpts.PunctuateRestore) + if err == nil { + segments = updated + } + } + text := FormatSegments(segments, e.runOpts.Turns, e.runOpts.Format) + var words []Word + for _, segment := range segments { + words = append(words, segmentWords(segment)...) + } + return TranscriptResult{ + Text: text, + Words: words, + } +} + +func segmentWords(segment whisper.Segment) []Word { + parts := strings.Fields(strings.TrimSpace(segment.Text)) + if len(parts) == 0 { + return nil + } + startMs := int(segment.Start / time.Millisecond) + endMs := int(segment.End / time.Millisecond) + if endMs < startMs { + endMs = startMs + } + span := endMs - startMs + if span <= 0 { + span = 1 + } + step := span / len(parts) + if step < 1 { + step = 1 + } + out := make([]Word, 0, len(parts)) + for i, part := range parts { + wStart := startMs + i*step + wStop := wStart + step + if i == len(parts)-1 { + wStop = endMs + } + out = append(out, Word{ + Word: part, + Start: wStart, + Stop: wStop, + }) + } + return out +} + +func (e *Engine) Close() error { + // Models are owned by ModelPool; do not close shared weights here. + e.ctx = nil + e.model = nil + return nil +} diff --git a/whisper/whisper_test.go b/whisper/whisper_test.go new file mode 100644 index 0000000..a1078fb --- /dev/null +++ b/whisper/whisper_test.go @@ -0,0 +1,66 @@ +package whisper + +import ( + "testing" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func TestEngine_getOutputPath(t *testing.T) { + type fields struct { + cfg *config.Whisper + ctx whisper.Context + model whisper.Model + segments []whisper.Segment + } + type args struct { + format string + } + tests := []struct { + name string + fields fields + args args + want string + }{ + { + name: "change wav to txt", + fields: fields{ + cfg: &config.Whisper{ + AudioPath: "/test/1234/foo.wav", + }, + }, + args: args{ + format: "txt", + }, + want: "/test/1234/foo.txt", + }, + { + name: "change output folder", + fields: fields{ + cfg: &config.Whisper{ + AudioPath: "/test/1234/foo.wav", + OutputFolder: "/foo/bar", + }, + }, + args: args{ + format: "txt", + }, + want: "/foo/bar/foo.txt", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Engine{ + cfg: tt.fields.cfg, + ctx: tt.fields.ctx, + model: tt.fields.model, + segments: tt.fields.segments, + } + if got := e.getOutputPath(tt.args.format); got != tt.want { + t.Errorf("Engine.getOutputPath() = %v, want %v", got, tt.want) + } + }) + } +}