commit b5c083e06f523460c47a02776567e475bc90180e Author: admin Date: Thu Jun 4 18:10:52 2026 +0700 first commit diff --git a/.gitea/FUNDING.yml b/.gitea/FUNDING.yml new file mode 100644 index 0000000..df9ae63 --- /dev/null +++ b/.gitea/FUNDING.yml @@ -0,0 +1,13 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry +custom: ['https://www.paypal.me/appleboy46'] diff --git a/.gitea/dependabot.yml b/.gitea/dependabot.yml new file mode 100644 index 0000000..632e8eb --- /dev/null +++ b/.gitea/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly diff --git a/.gitea/workflows/README.md b/.gitea/workflows/README.md new file mode 100644 index 0000000..bb54ab0 --- /dev/null +++ b/.gitea/workflows/README.md @@ -0,0 +1,27 @@ +# Gitea Actions + +Workflows use [Gitea Actions](https://docs.gitea.com/usage/actions/overview) (compatible with GitHub Actions syntax). + +## Secrets + +| Secret | Used by | Purpose | +|--------|---------|---------| +| `GITEA_TOKEN` | `docker.yml` | Push images to the instance container registry | + +Create a token with **write:package** (and **read:repository** for checkout if needed). + +## Images + +- **CI / CPU**: `docker/Dockerfile.ci` — built on every push to `main` and tags `v*` +- **GPU**: `docker/Dockerfile` — build manually on NVIDIA hosts + +Registry URL: `${{ gitea.server_url }}/${{ gitea.repository }}` + +## Local parity + +```sh +sudo apt-get install -y build-essential cmake git +go test ./config/... ./punctuation/... ./transcode/... +make dependency && make test +docker build -f docker/Dockerfile.ci -t go-whisper-api:ci . +``` diff --git a/.gitea/workflows/codeql.yml b/.gitea/workflows/codeql.yml new file mode 100644 index 0000000..c2ffa10 --- /dev/null +++ b/.gitea/workflows/codeql.yml @@ -0,0 +1,32 @@ +# CodeQL requires GitHub.com or a Gitea instance with the codeql-action available. +# Disable this workflow on self-hosted Gitea if the action cannot be resolved. + +name: CodeQL + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + - cron: "41 23 * * 6" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + strategy: + fail-fast: false + matrix: + language: ["go"] + steps: + - uses: actions/checkout@v4 + - uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + continue-on-error: true + - uses: github/codeql-action/analyze@v3 + continue-on-error: true diff --git a/.gitea/workflows/docker.yml b/.gitea/workflows/docker.yml new file mode 100644 index 0000000..2a70dd3 --- /dev/null +++ b/.gitea/workflows/docker.yml @@ -0,0 +1,60 @@ +name: Docker Image + +on: + push: + branches: + - main + tags: + - "v*" + pull_request: + branches: + - main + +env: + # Gitea container registry: // + IMAGE: ${{ gitea.server_url }}/${{ gitea.repository }} + +jobs: + build-docker: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Gitea Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ gitea.server_url }} + username: ${{ gitea.actor }} + password: ${{ secrets.GITEA_TOKEN }} + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: | + ${{ env.IMAGE }} + tags: | + type=raw,value=latest,enable={{is_default_branch}} + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + + - name: Build and push (CPU) + uses: docker/build-push-action@v6 + with: + context: . + file: docker/Dockerfile.ci + platforms: linux/amd64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitea/workflows/goreleaser.yml b/.gitea/workflows/goreleaser.yml new file mode 100644 index 0000000..ff3c09e --- /dev/null +++ b/.gitea/workflows/goreleaser.yml @@ -0,0 +1,43 @@ +name: Release + +on: + push: + tags: + - "v*" + +permissions: + contents: write + packages: write + +jobs: + goreleaser: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Install build dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends build-essential cmake git + + - name: Build whisper dependency + run: make dependency + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: latest + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }} diff --git a/.gitea/workflows/lint.yml b/.gitea/workflows/lint.yml new file mode 100644 index 0000000..6eda352 --- /dev/null +++ b/.gitea/workflows/lint.yml @@ -0,0 +1,62 @@ +name: Lint and Testing + +on: + push: + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Hadolint Dockerfile (CI) + uses: hadolint/hadolint-action@v3.1.0 + with: + dockerfile: docker/Dockerfile.ci + + - name: Hadolint Dockerfile (GPU) + uses: hadolint/hadolint-action@v3.1.0 + with: + dockerfile: docker/Dockerfile + continue-on-error: true + + test: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Install build dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + build-essential cmake git libsentencepiece-dev pkg-config + + - name: Unit tests (no CGO whisper) + run: go test ./config/... ./punctuation/... ./transcode/... -count=1 + + - name: Build with xlm punctuation and run full tests + run: make dependency && make test TAGS=xlm + + golangci: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + args: --timeout=10m + continue-on-error: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..08d335d --- /dev/null +++ b/.gitignore @@ -0,0 +1,44 @@ +# Go build artifacts +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +go.work + +# Binaries and native libs (local build) +/bin/ +bin/ +/lib/ +lib/ + +# Runtime data (never commit) +config.yaml +.env +.env.* +cache/ + +# Packaged releases +dist/ +*.tar.gz + +# Downloaded models and ONNX weights +models/**/*.bin +models/**/*.onnx +models/**/*.model +models/*.bin +models/*.onnx +models/*.model + +# whisper.cpp submodule and its build tree +third_party/ + +# IDE and OS +.idea/ +.vscode/ +*.swp +*~ +.DS_Store +Thumbs.db diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..70e82c9 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,3 @@ +run: + skip-dirs: + - third_party/whisper.cpp diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..cbed619 --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,3 @@ +builds: +- skip: true + diff --git a/.hadolint.yaml b/.hadolint.yaml new file mode 100644 index 0000000..502b578 --- /dev/null +++ b/.hadolint.yaml @@ -0,0 +1,3 @@ +ignored: + - DL3018 + - DL3008 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2f70a03 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Bo-Yi Wu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8aa83ef --- /dev/null +++ b/Makefile @@ -0,0 +1,280 @@ +EXECUTABLE := go-whisper-api +GO ?= go +GOFILES := $(shell find . -name "*.go" -type f) +HAS_GO = $(shell hash $(GO) > /dev/null 2>&1 && echo "GO" || echo "NOGO" ) + +WHISPER_CPP := $(abspath third_party/whisper.cpp) +WHISPER_BUILD := $(WHISPER_CPP)/build +WHISPER_VENDOR := $(WHISPER_CPP)/bindings/go +WHISPER_LIBDIR := $(WHISPER_BUILD)/src:$(WHISPER_BUILD)/ggml/src + +RUNTIME_LIB_DIR := $(abspath lib) +# $ORIGIN/lib — binary next to lib/ (e.g. ./go-whisper-api + ./lib/) +# $ORIGIN/../lib — binary in bin/ (e.g. bin/go-whisper-api + lib/) +RUNTIME_RPATH := -Wl,-rpath,$$ORIGIN/lib:$$ORIGIN/../lib + +ifneq ($(shell uname), Darwin) + EXTLDFLAGS = -extldflags "$(RUNTIME_RPATH)" +else + EXTLDFLAGS = +endif + +ifeq ($(HAS_GO), GO) + GOPATH ?= $(shell $(GO) env GOPATH) + export PATH := $(GOPATH)/bin:$(PATH) + + CGO_EXTRA_CFLAGS := -DSQLITE_MAX_VARIABLE_NUMBER=32766 + CGO_CFLAGS ?= $(shell $(GO) env CGO_CFLAGS) $(CGO_EXTRA_CFLAGS) +endif + +ifeq ($(OS), Windows_NT) + GOFLAGS := -v -buildmode=exe + EXECUTABLE ?= $(EXECUTABLE).exe +else ifeq ($(OS), Windows) + GOFLAGS := -v -buildmode=exe + EXECUTABLE ?= $(EXECUTABLE).exe +else + GOFLAGS := -v + EXECUTABLE ?= $(EXECUTABLE) +endif + +ifneq ($(DRONE_TAG),) + VERSION ?= $(DRONE_TAG) +else + VERSION ?= $(shell git describe --tags --always || git rev-parse --short HEAD) +endif + +TAGS ?= +UNAME_M := $(shell uname -m) +ifeq ($(UNAME_M),x86_64) +SHERPA_LIBARCH := x86_64-unknown-linux-gnu +endif +ifeq ($(UNAME_M),aarch64) +SHERPA_LIBARCH := aarch64-unknown-linux-gnu +endif +SHERPA_LINUX_VER := $(shell awk '/sherpa-onnx-go-linux/ {print $$2; exit}' go.mod) +SHERPA_LIBDIR := $(GOPATH)/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH) +ifneq ($(shell uname), Darwin) +EXTLDFLAGS_SHERPA = -extldflags "$(RUNTIME_RPATH)" +EXTLDFLAGS_XLM = -extldflags "$(RUNTIME_RPATH)" +else +EXTLDFLAGS_SHERPA = +EXTLDFLAGS_XLM = +endif +GOLDFLAGS ?= -X 'main.Version=$(VERSION)' + +INCLUDE_PATH := $(WHISPER_CPP)/include:$(WHISPER_CPP)/ggml/include:$(WHISPER_VENDOR):$(INCLUDE_PATH) +LIBRARY_PATH := $(WHISPER_LIBDIR):$(LIBRARY_PATH) +export LD_LIBRARY_PATH := $(WHISPER_LIBDIR):$(LD_LIBRARY_PATH) + +ifdef WHISPER_CUBLAS + CGO_CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + CGO_CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + EXTLDFLAGS = -extldflags "-lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib" + +build: $(EXECUTABLE) + +$(EXECUTABLE): $(GOFILES) + CGO_CXXFLAGS=${CGO_CXXFLAGS} CGO_CFLAGS=${CGO_CFLAGS} C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@ +endif + +MODEL_URL ?= https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin +MODEL_PATH ?= models/ggml-tiny.en.bin +VAD_MODEL ?= silero-v6.2.0 +VAD_MODEL_PATH ?= models/ggml-silero-v6.2.0.bin + +all: build + +PUNCT_MODEL_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8.tar.bz2 +PUNCT_MODEL_DIR ?= models/punctuation/ct-transformer-zh-en-int8 + +XLM_PUNCT_DIR ?= models/punctuation/xlm-roberta +XLM_HF_REPO ?= Salama1429/xlm-roberta_punctuation_fullstop_truecase + +XLM_MODEL_CONFIG_SRC ?= config/xlm-roberta-model.yaml +ORT_LIB_SRC ?= $(shell $(GO) env GOMODCACHE 2>/dev/null)/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH)/libonnxruntime.so + +# Copy runtime .so into ./lib/. Binary rpath: $ORIGIN/lib or $ORIGIN/../lib (see RUNTIME_RPATH). +# Use cp -n where possible: existing root-owned libs in lib/ must not break the build. +install-runtime-libs: dependency + @mkdir -p "$(RUNTIME_LIB_DIR)" + @cp -an "$(WHISPER_BUILD)/src"/libwhisper.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true + @cp -an "$(WHISPER_BUILD)/ggml/src"/libggml*.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true + @echo "Whisper/ggml libs ready in $(RUNTIME_LIB_DIR)/" + +install-ort-lib: + @mkdir -p "$(RUNTIME_LIB_DIR)" + @if [ ! -f "$(ORT_LIB_SRC)" ]; then echo "missing $(ORT_LIB_SRC); run: go mod download"; exit 1; fi + @dest="$(RUNTIME_LIB_DIR)/libonnxruntime.so"; \ + if [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \ + echo "libonnxruntime.so already up to date in $(RUNTIME_LIB_DIR)/"; \ + elif cp -f "$(ORT_LIB_SRC)" "$$dest" 2>/dev/null; then \ + echo "Installed $$dest"; \ + elif [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \ + echo "libonnxruntime.so present in $(RUNTIME_LIB_DIR)/ (unchanged, not writable)"; \ + else \ + echo "cannot install libonnxruntime.so to $$dest"; \ + echo "fix: sudo chown -R $$USER:$$(id -gn) $(RUNTIME_LIB_DIR)"; \ + exit 1; \ + fi + +# XLM punctuation links -lsentencepiece; bundle .so for hosts without libsentencepiece0 package. +SP_LIB_DIRS := /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu /usr/lib64 /usr/lib +install-sp-lib: + @mkdir -p "$(RUNTIME_LIB_DIR)" + @found=0; \ + for d in $(SP_LIB_DIRS); do \ + if [ -e "$$d/libsentencepiece.so.0" ] || [ -L "$$d/libsentencepiece.so.0" ]; then \ + cp -an "$$d"/libsentencepiece.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true; \ + found=1; \ + break; \ + fi; \ + done; \ + if [ "$$found" = "0" ]; then \ + echo "libsentencepiece.so.0 not found; install: sudo apt-get install libsentencepiece0"; \ + exit 1; \ + fi + @test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0 after install-sp-lib"; exit 1) + @echo "Sentencepiece libs ready in $(RUNTIME_LIB_DIR)/" + +# If lib/*.so were created as root (e.g. manual cp with sudo), reclaim ownership for builds. +fix-lib-perms: + @if [ -d "$(RUNTIME_LIB_DIR)" ]; then \ + chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)" 2>/dev/null || \ + sudo chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)"; \ + echo "Ownership of $(RUNTIME_LIB_DIR)/ updated"; \ + fi + +install-runtime-libs-xlm: install-runtime-libs install-ort-lib install-sp-lib + +# Fail fast before deploy if ./lib is incomplete (lib/ is not in git: *.so is gitignored). +verify-runtime-libs-xlm: + @test -f bin/$(EXECUTABLE) || (echo "missing bin/$(EXECUTABLE); run: make build-xlm"; exit 1) + @test -f "$(RUNTIME_LIB_DIR)/libonnxruntime.so" || (echo "missing $(RUNTIME_LIB_DIR)/libonnxruntime.so; run: make install-runtime-libs-xlm"; exit 1) + @test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0; run: make install-sp-lib"; exit 1) + @test -e "$(RUNTIME_LIB_DIR)/libwhisper.so.1" || (echo "missing $(RUNTIME_LIB_DIR)/libwhisper.so.1; run: make install-runtime-libs"; exit 1) + @echo "Runtime libs OK in $(RUNTIME_LIB_DIR)/" + +RUNTIME_TARBALL := dist/go-whisper-api-runtime-$(shell uname -m).tar.gz +package-runtime-xlm: verify-runtime-libs-xlm + @mkdir -p dist + tar -czf "$(RUNTIME_TARBALL)" bin/$(EXECUTABLE) lib + @echo "Created $(RUNTIME_TARBALL) — on prod: tar -xzf ... -C /opt/go-whisper-api (keeps bin/ and lib/)" + +# Copy bundled label config (needs write access to $(XLM_PUNCT_DIR); fix with: sudo chown -R $$USER models/punctuation) +install-xlm-punctuation-config: + @mkdir -p $(XLM_PUNCT_DIR) + @cp "$(XLM_MODEL_CONFIG_SRC)" "$(XLM_PUNCT_DIR)/config.yaml" + @echo "Installed $(XLM_PUNCT_DIR)/config.yaml" + +download-xlm-punctuation-model: install-xlm-punctuation-config + @mkdir -p $(XLM_PUNCT_DIR) + @for f in model.onnx sp.model; do \ + if [ ! -f "$(XLM_PUNCT_DIR)/$$f" ]; then \ + echo "Downloading $$f from $(XLM_HF_REPO)..."; \ + curl -fL "https://huggingface.co/$(XLM_HF_REPO)/resolve/main/$$f" -o "$(XLM_PUNCT_DIR)/$$f"; \ + else \ + echo "Already have $(XLM_PUNCT_DIR)/$$f"; \ + fi; \ + done + +DIAR_SEG_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +DIAR_EMB_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx +DIAR_DIR ?= models/diarization + +download-diarization-models: + @mkdir -p $(DIAR_DIR) + @if [ ! -f "$(DIAR_DIR)/pyannote-segmentation-3-0/model.onnx" ]; then \ + echo "Downloading speaker segmentation model..."; \ + curl -fL "$(DIAR_SEG_URL)" -o /tmp/diar-seg.tar.bz2; \ + tar -xjf /tmp/diar-seg.tar.bz2 -C $(DIAR_DIR); \ + rm -f /tmp/diar-seg.tar.bz2; \ + else \ + echo "Segmentation model present"; \ + fi + @if [ ! -f "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" ]; then \ + echo "Downloading speaker embedding model..."; \ + curl -fL "$(DIAR_EMB_URL)" -o "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; \ + else \ + echo "Embedding model present"; \ + fi + +download-punctuation-model: + @mkdir -p models/punctuation + @if [ ! -f "$(PUNCT_MODEL_DIR)/model.int8.onnx" ]; then \ + echo "Downloading punctuation model..."; \ + curl -fL "$(PUNCT_MODEL_URL)" -o /tmp/punct-model.tar.bz2; \ + tar -xjf /tmp/punct-model.tar.bz2 -C models/punctuation; \ + rm -f /tmp/punct-model.tar.bz2; \ + if [ -d models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 ]; then \ + mv models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 "$(PUNCT_MODEL_DIR)"; \ + fi; \ + else \ + echo "Punctuation model already exists: $(PUNCT_MODEL_DIR)/model.int8.onnx"; \ + fi + +download-model: + @mkdir -p models + @if [ ! -f "$(MODEL_PATH)" ]; then \ + echo "Downloading $(MODEL_PATH)..."; \ + curl -fL "$(MODEL_URL)" -o "$(MODEL_PATH)"; \ + else \ + echo "Model already exists: $(MODEL_PATH)"; \ + fi + +download-vad-model: + @mkdir -p models + @if [ ! -f "$(VAD_MODEL_PATH)" ]; then \ + echo "Downloading VAD model $(VAD_MODEL) to models/..."; \ + ./third_party/whisper.cpp/models/download-vad-model.sh $(VAD_MODEL) models; \ + else \ + echo "VAD model already exists: $(VAD_MODEL_PATH)"; \ + fi + +clone: + @[ -d third_party/whisper.cpp ] || git clone https://github.com/appleboy/whisper.cpp.git third_party/whisper.cpp + +dependency: clone + @echo Build whisper + @if [ ! -f "$(WHISPER_BUILD)/src/libwhisper.so" ] && [ ! -f "$(WHISPER_BUILD)/src/libwhisper.a" ]; then \ + cmake -S "$(WHISPER_CPP)" -B "$(WHISPER_BUILD)" -DCMAKE_BUILD_TYPE=Release && \ + cmake --build "$(WHISPER_BUILD)" --config Release -j$$(nproc 2>/dev/null || echo 4); \ + else \ + echo "whisper library already built in $(WHISPER_BUILD)"; \ + fi + +test: + @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) test -v -cover -coverprofile coverage.txt ./... && echo "\n==>\033[32m Ok\033[m\n" || exit 1 + +install: $(GOFILES) + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) install -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' + +build: install-runtime-libs $(EXECUTABLE) + +# Build with sherpa-onnx (punctuation + speaker diarization) +build-sherpa: + @$(MAKE) build TAGS=sherpa + +# XLM-RoBERTa punctuation (47 languages); requires libsentencepiece-dev +build-xlm: + @$(MAKE) install-runtime-libs-xlm + @$(MAKE) build TAGS=xlm + +$(EXECUTABLE): $(GOFILES) +ifneq (,$(findstring xlm,$(TAGS))) + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_XLM) -s -w $(GOLDFLAGS)' -o bin/$@ +else ifneq (,$(findstring sherpa,$(TAGS))) + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_SHERPA) -s -w $(GOLDFLAGS)' -o bin/$@ +else + C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@ +endif + +clean: + $(GO) clean -x -i ./... + rm -rf coverage.txt $(EXECUTABLE) $(DIST) bin/$(EXECUTABLE) + +clean-whisper: + rm -rf "$(WHISPER_BUILD)" + +version: + @echo $(VERSION) diff --git a/README.md b/README.md new file mode 100644 index 0000000..5df4305 --- /dev/null +++ b/README.md @@ -0,0 +1,129 @@ +# go-whisper-api + +HTTP-сервер распознавания речи на базе [whisper.cpp](https://github.com/ggerganov/whisper.cpp): совместимость с **SPR** (`/spr/*`) и **OpenAI Whisper API** (`/v1/*`) для Open WebUI и других клиентов. + +## Модели Whisper (ggml) + +См. [каталог моделей на Hugging Face](https://huggingface.co/ggerganov/whisper.cpp/tree/main). + +```sh +make download-model +# или вручную: +curl -LJ https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin \ + --output models/ggml-small.bin +``` + +Веса STT: `*.bin` в корне `api.models_dir` (список — `GET /spr/models`). Подкаталоги `models/vad/`, `models/punctuation/` и т.п. в список моделей не входят. + +## Конфигурация + +Шаблон в репозитории — `config.yaml.example`. Рабочий файл создайте локально (он в `.gitignore`): + +```sh +cp config.yaml.example config.yaml +``` + +Если `--config` не указан и в текущем каталоге есть `config.yaml`, он подгружается автоматически. + +Промежуточные файлы (загрузки, конвертация) — только в `/tmp`, удаляются после обработки. Форматы **wav, mp3, flac, ogg, m4a, mp4, aac** декодируются на чистом Go (без ffmpeg). + +## Запуск + +```sh +go-whisper-api serve --config config.yaml +# или с флагами (перекрывают YAML): +go-whisper-api --models-dir ./models --addr :8080 +``` + +Docker: + +```sh +docker run -p 8080:8080 \ + -v $PWD/models:/app/models \ + -v $PWD/config.yaml:/app/config.yaml \ + ghcr.io/appleboy/go-whisper-api:latest \ + serve --config /app/config.yaml +``` + +Swagger UI: [http://localhost:8080/](http://localhost:8080/) + +## VAD (детекция речи) + +```sh +make download-vad-model +``` + +```yaml +api: + models_dir: ./models + vad: + enabled: true + model: ggml-silero-v6.2.0.bin +``` + +Полезно на длинных записях с тишиной или музыкой в начале/конце. + +## Пунктуация + +| Уровень | Параметр | Эффект | +|---------|----------|--------| +| Мастер | `punctuation.enabled: false` | Пунктуация отключена | +| API по умолчанию | `api.default_punctuation: true` | Если нет `?punctuation=` | +| На запрос | `?punctuation=1` / `?punctuation=0` | Только при `enabled: true` | + +Движки: `heuristic`, `xlm`, `sherpa`, `sherpa-online`, `http`, `off`. + +```sh +curl -X POST "http://localhost:8080/spr/stt/ggml-small?punctuation=1" -F "wav=@audio.wav" +``` + +Сборка XLM: `make download-xlm-punctuation-model && make build-xlm`. Продакшен: `make package-runtime-xlm` — в git не коммитятся `lib/*.so` (см. `.gitignore`). + +Фильтр артефактов: `api.garbage` (по умолчанию `*выбая*`), отключить: `garbage: []`. + +## HTTP API + +- **SPR** (`/spr/*`) — очередь, waveform, импорт/экспорт моделей +- **OpenAI** (`/v1/*`) — синхронный STT + +Пример SPR: + +```sh +curl -X POST "http://localhost:8080/spr/stt/ggml-small?async=0" -F "wav=@audio.wav" +``` + +По умолчанию STT **асинхронный**: `POST /spr/stt/{id}` → `taskID`, опрос `GET /spr/queue/{taskID}`, результат `GET /spr/result/{taskID}`. Кэш: `./cache/waiting/` и `./cache/ready/`. + +Язык по умолчанию — **`ru`** (`?language=en`, `?language=auto`). + +| Endpoint | Метод | Описание | +|----------|--------|-------------| +| `/spr/models` | GET | Список моделей | +| `/spr/stt/{id}` | POST | Транскрипция | +| `/spr/result/{taskID}` | GET | Результат | +| `/spr/queue` | GET | Все задачи | +| `/spr/queue/{taskID}` | GET / DELETE | Статус / удаление | +| `/v1/audio/transcriptions` | POST | OpenAI-совместимая транскрипция | +| `/v1/models` | GET | Список моделей | + +## Open WebUI + +| Параметр | Значение | +|----------|----------| +| Engine | `OpenAI` | +| API Base URL | `http://:6183/v1` | +| API Key | любая непустая строка | +| STT Model | `whisper-1` (см. `api.default_model` в config) | + +```yaml +api: + default_model: ggml-large-v3-turbo +``` + +## CI/CD + +Workflows в `.gitea/workflows/` (`lint.yml`, `docker.yml`). Секрет `GITEA_TOKEN` с `write:package` для push образов. + +```sh +docker build -f docker/Dockerfile.ci -t go-whisper-api . +``` diff --git a/api/cache.go b/api/cache.go new file mode 100644 index 0000000..1c435e0 --- /dev/null +++ b/api/cache.go @@ -0,0 +1,362 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "go-whisper-api/whisper" +) + +const ( + cacheWaiting = "waiting" + cacheReady = "ready" + fileParams = "params.conf" + fileAudio = "audio.wav" + fileAudioJSON = "audio.json" +) + +type TaskParams struct { + ID string `json:"id"` + Created string `json:"created"` + Processed string `json:"processed,omitempty"` + Status string `json:"status"` + Model string `json:"model"` + Language string `json:"language,omitempty"` + Punctuation bool `json:"punctuation,omitempty"` + Speakers bool `json:"speakers,omitempty"` + NumClusters int `json:"num_clusters,omitempty"` + Text string `json:"text,omitempty"` + Words []whisper.Word `json:"words,omitempty"` + Error string `json:"error,omitempty"` +} + +type AudioJSON struct { + Waveform []float64 `json:"waveform"` + Buckets int `json:"buckets"` +} + +type DiskCache struct { + root string +} + +func (c *DiskCache) Root() string { + return c.root +} + +func resolveCacheRoot(root string) (string, error) { + if root == "" { + root = "./cache" + } + abs, err := filepath.Abs(root) + if err != nil { + return "", fmt.Errorf("cache dir: %w", err) + } + return filepath.Clean(abs), nil +} + +func NewDiskCache(root string) (*DiskCache, error) { + abs, err := resolveCacheRoot(root) + if err != nil { + return nil, err + } + c := &DiskCache{root: abs} + for _, sub := range []string{cacheWaiting, cacheReady} { + if err := os.MkdirAll(filepath.Join(abs, sub), 0o755); err != nil { + return nil, err + } + } + return c, nil +} + +func (c *DiskCache) waitingDir(id string) string { + return filepath.Join(c.root, cacheWaiting, id) +} + +func (c *DiskCache) readyDir(id string) string { + return filepath.Join(c.root, cacheReady, id) +} + +func (c *DiskCache) locate(id string) (dir, phase string, ok bool) { + if id == "" { + return "", "", false + } + ready := c.readyDir(id) + if st, err := os.Stat(ready); err == nil && st.IsDir() { + return ready, cacheReady, true + } + waiting := c.waitingDir(id) + if st, err := os.Stat(waiting); err == nil && st.IsDir() { + return waiting, cacheWaiting, true + } + return "", "", false +} + +func (c *DiskCache) Enqueue(id string, params TaskParams, audioWavPath string) error { + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + params.ID = id + if err := c.writeParams(dir, params); err != nil { + _ = os.RemoveAll(dir) + return err + } + dst := filepath.Join(dir, fileAudio) + if err := os.Rename(audioWavPath, dst); err != nil { + if err2 := copyFile(audioWavPath, dst); err2 != nil { + _ = os.RemoveAll(dir) + return fmt.Errorf("move audio to %s: %w", dst, err2) + } + _ = os.Remove(audioWavPath) + } + return nil +} + +func (c *DiskCache) writeParams(dir string, params TaskParams) error { + data, err := json.Marshal(params) + if err != nil { + return err + } + return os.WriteFile(filepath.Join(dir, fileParams), data, 0o644) +} + +func (c *DiskCache) LoadParams(id string) (TaskParams, string, error) { + dir, phase, ok := c.locate(id) + if !ok { + return TaskParams{}, "", fmt.Errorf("task not found") + } + params, err := c.readParams(dir) + if err != nil { + return TaskParams{}, "", err + } + return params, phase, nil +} + +func (c *DiskCache) readParams(dir string) (TaskParams, error) { + path := filepath.Join(dir, fileParams) + data, err := os.ReadFile(path) + if err != nil { + return TaskParams{}, fmt.Errorf("read %s: %w", path, err) + } + var params TaskParams + if err := json.Unmarshal(data, ¶ms); err != nil { + return TaskParams{}, err + } + return params, nil +} + +func (c *DiskCache) List() (map[string]map[string]string, error) { + out := make(map[string]map[string]string) + for _, phase := range []string{cacheWaiting, cacheReady} { + base := filepath.Join(c.root, phase) + entries, err := os.ReadDir(base) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + for _, e := range entries { + if !e.IsDir() { + continue + } + id := e.Name() + params, err := c.readParams(filepath.Join(base, id)) + if err != nil { + continue + } + out[id] = map[string]string{ + "created": params.Created, + "status": params.Status, + } + } + } + return out, nil +} + +func (c *DiskCache) NextWaiting() (string, bool, error) { + base := filepath.Join(c.root, cacheWaiting) + entries, err := os.ReadDir(base) + if err != nil { + if os.IsNotExist(err) { + return "", false, nil + } + return "", false, err + } + type item struct { + id string + created time.Time + } + var pending []item + for _, e := range entries { + if !e.IsDir() { + continue + } + params, err := c.readParams(filepath.Join(base, e.Name())) + if err != nil || params.Status != string(statusWaiting) { + continue + } + t, _ := time.ParseInLocation("2006-01-02 15:04:05", params.Created, time.Local) + pending = append(pending, item{id: e.Name(), created: t}) + } + if len(pending) == 0 { + return "", false, nil + } + sort.Slice(pending, func(i, j int) bool { + return pending[i].created.Before(pending[j].created) + }) + return pending[0].id, true, nil +} + +func (c *DiskCache) SetStatus(id string, status taskStatus, mutate func(*TaskParams)) error { + dir := c.waitingDir(id) + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("task %s not in waiting", id) + } + params, err := c.readParams(dir) + if err != nil { + return err + } + params.Status = string(status) + if mutate != nil { + mutate(¶ms) + } + return c.writeParams(dir, params) +} + +func (c *DiskCache) FinishWaiting(id string, result whisper.TranscriptResult, errMsg string, waveform []float64) error { + dir := c.waitingDir(id) + params, err := c.readParams(dir) + if err != nil { + return err + } + params.Processed = time.Now().Format("2006-01-02 15:04:05") + if errMsg != "" { + params.Status = string(statusError) + params.Error = errMsg + } else { + params.Status = string(statusReady) + params.Text = result.Text + params.Words = result.Words + } + if err := c.writeParams(dir, params); err != nil { + return fmt.Errorf("update %s: %w", filepath.Join(dir, fileParams), err) + } + if len(waveform) > 0 { + aj := AudioJSON{Waveform: waveform, Buckets: len(waveform)} + data, err := json.Marshal(aj) + if err != nil { + return err + } + if err := os.WriteFile(filepath.Join(dir, fileAudioJSON), data, 0o644); err != nil { + return err + } + } + return nil +} + +func (c *DiskCache) PromoteToReady(id string) error { + if _, phase, ok := c.locate(id); ok && phase == cacheReady { + return nil + } + src := c.waitingDir(id) + dst := c.readyDir(id) + if _, err := os.Stat(src); err != nil { + return fmt.Errorf("task %s not in waiting", id) + } + if _, err := os.Stat(dst); err == nil { + return os.RemoveAll(src) + } + return os.Rename(src, dst) +} + +func (c *DiskCache) Delete(id string) bool { + dir, _, ok := c.locate(id) + if !ok { + return false + } + _ = os.RemoveAll(dir) + return true +} + +func (c *DiskCache) AudioPath(id string) (string, bool) { + dir, _, ok := c.locate(id) + if !ok { + return "", false + } + p := filepath.Join(dir, fileAudio) + if _, err := os.Stat(p); err != nil { + return "", false + } + return p, true +} + +func (c *DiskCache) Waveform(id string) ([]float64, error) { + dir, _, ok := c.locate(id) + if !ok { + return nil, fmt.Errorf("task not found") + } + data, err := os.ReadFile(filepath.Join(dir, fileAudioJSON)) + if err == nil { + var aj AudioJSON + if json.Unmarshal(data, &aj) == nil && len(aj.Waveform) > 0 { + return aj.Waveform, nil + } + } + return waveformFromWav(filepath.Join(dir, fileAudio), 512) +} + +func (c *DiskCache) RecoverInterrupted() error { + base := filepath.Join(c.root, cacheWaiting) + entries, err := os.ReadDir(base) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + for _, e := range entries { + if !e.IsDir() { + continue + } + id := e.Name() + dir := filepath.Join(base, id) + params, err := c.readParams(dir) + if err != nil { + continue + } + switch params.Status { + case string(statusProcessing): + params.Status = string(statusWaiting) + _ = c.writeParams(dir, params) + case string(statusReady), string(statusError): + _ = c.PromoteToReady(id) + } + } + return nil +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + _, err = io.Copy(out, in) + return err +} + +func isValidTaskID(id string) bool { + return id != "" && !strings.Contains(id, "..") && !strings.ContainsAny(id, `/\`) +} diff --git a/api/cache_test.go b/api/cache_test.go new file mode 100644 index 0000000..2860c81 --- /dev/null +++ b/api/cache_test.go @@ -0,0 +1,127 @@ +package api + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiskCache_params_promote_list(t *testing.T) { + root := t.TempDir() + c, err := NewDiskCache(root) + if err != nil { + t.Fatal(err) + } + + id := "319d72c7-301d-44fd-935f-3526dfb70f9f" + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + params := TaskParams{ + ID: id, + Created: "2026-03-31 21:37:46", + Status: string(statusReady), + Model: "ggml-small", + Text: "test", + } + if err := c.writeParams(dir, params); err != nil { + t.Fatal(err) + } + + list, err := c.List() + if err != nil { + t.Fatal(err) + } + if list[id]["status"] != string(statusReady) { + t.Fatalf("list: %v", list[id]) + } + + if err := c.PromoteToReady(id); err != nil { + t.Fatal(err) + } + _, phase, err := c.LoadParams(id) + if err != nil || phase != cacheReady { + t.Fatalf("phase=%s err=%v", phase, err) + } + if _, err := os.Stat(filepath.Join(c.readyDir(id), fileParams)); err != nil { + t.Fatal(err) + } +} + +func TestIsValidTaskID(t *testing.T) { + if !isValidTaskID("319d72c7-301d-44fd-935f-3526dfb70f9f") { + t.Fatal("uuid should be valid") + } + if isValidTaskID("../etc") { + t.Fatal("path traversal must be rejected") + } +} + +func TestResolveCacheRoot_absolute(t *testing.T) { + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + abs, err := resolveCacheRoot("./cache") + if err != nil { + t.Fatal(err) + } + want := filepath.Join(cwd, "cache") + if abs != want { + t.Fatalf("got %q want %q", abs, want) + } +} + +func TestDiskCache_RecoverInterrupted(t *testing.T) { + root := t.TempDir() + c, err := NewDiskCache(root) + if err != nil { + t.Fatal(err) + } + id := "task-1" + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := c.writeParams(dir, TaskParams{ + ID: id, Created: "2026-01-01 00:00:00", Status: string(statusProcessing), Model: "m", + }); err != nil { + t.Fatal(err) + } + if err := c.RecoverInterrupted(); err != nil { + t.Fatal(err) + } + p, _, err := c.LoadParams(id) + if err != nil { + t.Fatal(err) + } + if p.Status != string(statusWaiting) { + t.Fatalf("got %q", p.Status) + } +} + +func TestDiskCache_RecoverInterrupted_promotesReady(t *testing.T) { + root := t.TempDir() + c, err := NewDiskCache(root) + if err != nil { + t.Fatal(err) + } + id := "done-task" + dir := c.waitingDir(id) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := c.writeParams(dir, TaskParams{ + ID: id, Created: "2026-01-01 00:00:00", Status: string(statusReady), Model: "m", Text: "hi", + }); err != nil { + t.Fatal(err) + } + if err := c.RecoverInterrupted(); err != nil { + t.Fatal(err) + } + _, phase, err := c.LoadParams(id) + if err != nil || phase != cacheReady { + t.Fatalf("phase=%s err=%v", phase, err) + } +} diff --git a/api/garbage.go b/api/garbage.go new file mode 100644 index 0000000..7ec8330 --- /dev/null +++ b/api/garbage.go @@ -0,0 +1,26 @@ +package api + +import ( + "go-whisper-api/garbage" + "go-whisper-api/whisper" +) + +func applyGarbage(r whisper.TranscriptResult, patterns []string) whisper.TranscriptResult { + if len(patterns) == 0 { + return r + } + r.Text = garbage.FilterText(r.Text, patterns) + if len(r.Words) == 0 { + return r + } + gw := make([]garbage.Word, len(r.Words)) + for i, w := range r.Words { + gw[i] = garbage.Word{Word: w.Word, Start: w.Start, Stop: w.Stop} + } + gw = garbage.FilterWords(gw, patterns) + r.Words = make([]whisper.Word, len(gw)) + for i, w := range gw { + r.Words[i] = whisper.Word{Word: w.Word, Start: w.Start, Stop: w.Stop} + } + return r +} diff --git a/api/models.go b/api/models.go new file mode 100644 index 0000000..7345254 --- /dev/null +++ b/api/models.go @@ -0,0 +1,181 @@ +package api + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" +) + +// Subdirectories under models_dir that hold non-Whisper assets (VAD, punctuation, etc.). +var reservedModelSubdirs = map[string]struct{}{ + "vad": {}, + "punctuation": {}, +} + +// Top-level .bin files that are not Whisper STT models. +var excludedWhisperModelFiles = map[string]struct{}{ + "vad.bin": {}, +} + +type Registry struct { + dir string + mu sync.RWMutex +} + +func NewRegistry(dir string) *Registry { + return &Registry{dir: dir} +} + +func isReservedModelSubdir(name string) bool { + _, ok := reservedModelSubdirs[strings.ToLower(name)] + return ok +} + +func isWhisperModelFile(name string) bool { + if !strings.HasSuffix(strings.ToLower(name), ".bin") { + return false + } + _, excluded := excludedWhisperModelFiles[strings.ToLower(name)] + return !excluded +} + +func (r *Registry) List() ([]string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + entries, err := os.ReadDir(r.dir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var models []string + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !isWhisperModelFile(name) { + continue + } + models = append(models, strings.TrimSuffix(name, filepath.Ext(name))) + } + return models, nil +} + +func (r *Registry) Path(id string) (string, error) { + if id == "" { + return "", fmt.Errorf("model id is required") + } + if strings.Contains(id, "/") || strings.Contains(id, "..") { + return "", fmt.Errorf("invalid model id") + } + if isReservedModelSubdir(id) { + return "", fmt.Errorf("model %q not found", id) + } + r.mu.RLock() + defer r.mu.RUnlock() + candidates := []string{ + filepath.Join(r.dir, id+".bin"), + filepath.Join(r.dir, id), + filepath.Join(r.dir, "ggml-"+id+".bin"), + } + for _, p := range candidates { + if st, err := os.Stat(p); err == nil && !st.IsDir() && isWhisperModelFile(filepath.Base(p)) { + return p, nil + } + } + return "", fmt.Errorf("model %q not found", id) +} + +func (r *Registry) Delete(id string) error { + p, err := r.Path(id) + if err != nil { + return err + } + r.mu.Lock() + defer r.mu.Unlock() + return os.Remove(p) +} + +func (r *Registry) Import(id string, src io.Reader) error { + if id == "" { + return fmt.Errorf("model id is required") + } + if strings.Contains(id, "/") || strings.Contains(id, "..") { + return fmt.Errorf("invalid model id") + } + if isReservedModelSubdir(id) { + return fmt.Errorf("invalid model id") + } + r.mu.Lock() + defer r.mu.Unlock() + if err := os.MkdirAll(r.dir, 0o755); err != nil { + return err + } + dst := filepath.Join(r.dir, id+".bin") + f, err := os.Create(dst) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(f, src); err != nil { + os.Remove(dst) + return err + } + return nil +} + +func (r *Registry) Open(id string) (*os.File, error) { + p, err := r.Path(id) + if err != nil { + return nil, err + } + return os.Open(p) +} + +// Resolve maps an OpenAI-style model name to a local whisper model id. +func (r *Registry) Resolve(id, defaultModel string) (string, error) { + id = strings.TrimSpace(id) + if id == "" { + id = strings.TrimSpace(defaultModel) + } + if id == "" { + models, err := r.List() + if err != nil { + return "", err + } + if len(models) == 0 { + return "", fmt.Errorf("model is required") + } + return models[0], nil + } + if id == "whisper-1" { + if dm := strings.TrimSpace(defaultModel); dm != "" { + if _, err := r.Path(dm); err == nil { + return dm, nil + } + } + models, err := r.List() + if err != nil { + return "", err + } + if len(models) == 0 { + return "", fmt.Errorf("no whisper models installed") + } + return models[0], nil + } + if _, err := r.Path(id); err == nil { + return id, nil + } + if alt := strings.TrimPrefix(id, "whisper-"); alt != id { + if _, err := r.Path(alt); err == nil { + return alt, nil + } + } + return "", fmt.Errorf("model %q not found", id) +} diff --git a/api/models_test.go b/api/models_test.go new file mode 100644 index 0000000..21a65ad --- /dev/null +++ b/api/models_test.go @@ -0,0 +1,55 @@ +package api + +import ( + "os" + "path/filepath" + "testing" +) + +func TestRegistry_List_excludesAuxiliaryDirs(t *testing.T) { + root := t.TempDir() + for _, name := range []string{"common.bin", "vad/vad.bin", "punctuation/model.onnx"} { + p := filepath.Join(root, name) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(p, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + } + r := NewRegistry(root) + list, err := r.List() + if err != nil { + t.Fatal(err) + } + if len(list) != 1 || list[0] != "common" { + t.Fatalf("list=%v want [common]", list) + } +} + +func TestRegistry_List_excludesVADAtRoot(t *testing.T) { + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "vad.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + list, err := NewRegistry(root).List() + if err != nil { + t.Fatal(err) + } + if len(list) != 1 || list[0] != "ggml-small" { + t.Fatalf("list=%v", list) + } +} + +func TestRegistry_Path_rejectsReservedIDs(t *testing.T) { + root := t.TempDir() + r := NewRegistry(root) + for _, id := range []string{"vad", "punctuation"} { + if _, err := r.Path(id); err == nil { + t.Fatalf("expected error for %q", id) + } + } +} diff --git a/api/openai.go b/api/openai.go new file mode 100644 index 0000000..f1658f0 --- /dev/null +++ b/api/openai.go @@ -0,0 +1,121 @@ +package api + +import ( + "encoding/json" + "net/http" + "strings" + "time" +) + +func (s *Server) handleOpenAITranscriptions(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + if err := r.ParseMultipartForm(128 << 20); err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + modelID, err := s.models.Resolve(r.FormValue("model"), s.cfg.DefaultModel) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + modelPath, err := s.models.Path(modelID) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + audioPath, cleanup, err := s.saveUploadedOpenAI(r) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + defer cleanup() + + stt := s.parseOpenAISTTOptions(r) + result, err := s.transcribe(r.Context(), modelPath, audioPath, stt) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, err.Error()) + return + } + + switch strings.ToLower(strings.TrimSpace(r.FormValue("response_format"))) { + case "text": + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(result.Text)) + default: + writeJSON(w, http.StatusOK, map[string]string{"text": result.Text}) + } +} + +func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + ids, err := s.models.List() + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, err.Error()) + return + } + now := time.Now().Unix() + data := make([]map[string]any, 0, len(ids)) + for _, id := range ids { + data = append(data, map[string]any{ + "id": id, + "object": "model", + "created": now, + "owned_by": "go-whisper-api", + }) + } + writeJSON(w, http.StatusOK, map[string]any{ + "object": "list", + "data": data, + }) +} + +func (s *Server) parseOpenAISTTOptions(r *http.Request) sttOptions { + lang := strings.TrimSpace(r.FormValue("language")) + if lang == "" { + lang = s.cfg.Language + } + return sttOptions{ + language: lang, + punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation), + } +} + +func (s *Server) saveUploadedOpenAI(r *http.Request) (path string, cleanup func(), err error) { + if r.MultipartForm == nil { + if err := r.ParseMultipartForm(128 << 20); err != nil { + return "", nil, err + } + } + return saveUploadedRawFields(r, []string{"file", "audio", "wav"}) +} + +func writeOpenAIError(w http.ResponseWriter, code int, msg string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": msg, + "type": openAIErrorType(code), + }, + }) +} + +func openAIErrorType(code int) string { + switch code { + case http.StatusBadRequest: + return "invalid_request_error" + case http.StatusUnauthorized: + return "authentication_error" + case http.StatusNotFound: + return "not_found_error" + default: + return "server_error" + } +} diff --git a/api/openai_test.go b/api/openai_test.go new file mode 100644 index 0000000..cd81175 --- /dev/null +++ b/api/openai_test.go @@ -0,0 +1,90 @@ +package api + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "go-whisper-api/config" +) + +func TestRegistryResolve(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "ggml-large-v3-turbo.bin"), []byte("y"), 0o644); err != nil { + t.Fatal(err) + } + reg := NewRegistry(dir) + + id, err := reg.Resolve("whisper-1", "ggml-large-v3-turbo") + if err != nil || id != "ggml-large-v3-turbo" { + t.Fatalf("whisper-1: id=%q err=%v", id, err) + } + id, err = reg.Resolve("whisper-large-v3-turbo", "") + if err != nil || id != "large-v3-turbo" { + t.Fatalf("whisper-large-v3-turbo: id=%q err=%v", id, err) + } + id, err = reg.Resolve("ggml-small", "") + if err != nil || id != "ggml-small" { + t.Fatalf("ggml-small: id=%q err=%v", id, err) + } + id, err = reg.Resolve("", "ggml-small") + if err != nil || id != "ggml-small" { + t.Fatalf("empty with default: id=%q err=%v", id, err) + } +} + +func TestHandleOpenAITranscriptionsMissingFile(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + srv := &Server{ + cfg: config.API{ModelsDir: dir, Language: "ru"}, + models: NewRegistry(dir), + } + body := &bytes.Buffer{} + w := multipart.NewWriter(body) + _ = w.WriteField("model", "ggml-small") + w.Close() + + req := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", body) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + srv.handleOpenAITranscriptions(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "file") { + t.Fatalf("expected file error, got %s", rec.Body.String()) + } +} + +func TestHandleOpenAIModels(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + srv := &Server{ + cfg: config.API{ModelsDir: dir, Language: "ru"}, + models: NewRegistry(dir), + } + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + rec := httptest.NewRecorder() + srv.handleOpenAIModels(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "ggml-small") { + t.Fatalf("expected model list, got %s", rec.Body.String()) + } +} diff --git a/api/queue_worker.go b/api/queue_worker.go new file mode 100644 index 0000000..a93737f --- /dev/null +++ b/api/queue_worker.go @@ -0,0 +1,98 @@ +package api + +import ( + "context" + "time" + + "go-whisper-api/whisper" + + "github.com/rs/zerolog/log" +) + +func (s *Server) StartWorker(ctx context.Context) { + go s.queueWorker(ctx) +} + +func (s *Server) queueWorker(ctx context.Context) { + const idlePoll = 2 * time.Second + for { + if ctx.Err() != nil { + return + } + id, ok, err := s.cache.NextWaiting() + if err != nil { + log.Error().Err(err).Msg("cache queue scan") + if !sleepOrWake(ctx, s.queueWake, time.Second) { + return + } + continue + } + if ok { + s.processCacheTask(ctx, id) + continue + } + if !sleepOrWake(ctx, s.queueWake, idlePoll) { + return + } + } +} + +func sleepOrWake(ctx context.Context, wake <-chan struct{}, d time.Duration) bool { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-wake: + return true + case <-timer.C: + return true + } +} + +func (s *Server) processCacheTask(ctx context.Context, id string) { + params, _, err := s.cache.LoadParams(id) + if err != nil { + return + } + modelPath, err := s.models.Path(params.Model) + if err != nil { + s.completeCacheTask(id, whisper.TranscriptResult{}, err.Error()) + return + } + if err := s.cache.SetStatus(id, statusProcessing, nil); err != nil { + log.Error().Err(err).Str("task", id).Msg("set processing") + return + } + audioPath, ok := s.cache.AudioPath(id) + if !ok { + s.completeCacheTask(id, whisper.TranscriptResult{}, "audio file missing") + return + } + stt := sttOptions{ + language: params.Language, + punctuate: params.Punctuation, + speakers: params.Speakers, + numClusters: params.NumClusters, + } + if stt.language == "" { + stt.language = s.cfg.Language + } + result, err := s.transcribe(ctx, modelPath, audioPath, stt) + if err != nil { + s.completeCacheTask(id, result, err.Error()) + log.Error().Err(err).Str("task", id).Msg("async transcribe") + return + } + s.completeCacheTask(id, result, "") +} + +func (s *Server) completeCacheTask(id string, result whisper.TranscriptResult, errMsg string) { + if err := s.cache.FinishWaiting(id, result, errMsg, nil); err != nil { + log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("finish task") + return + } + if err := s.cache.PromoteToReady(id); err != nil { + log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("promote to ready") + } +} diff --git a/api/result.go b/api/result.go new file mode 100644 index 0000000..180f9ba --- /dev/null +++ b/api/result.go @@ -0,0 +1,28 @@ +package api + +import "go-whisper-api/whisper" + +// sprResultReady builds GET /spr/result/{taskID} body for completed tasks (SPR-compatible). +func sprResultReady(params TaskParams) map[string]any { + words := params.Words + if words == nil { + words = []whisper.Word{} + } + return map[string]any{ + "model": params.Model, + "text": params.Text, + "words": words, + "toxicity": map[string]float64{ + "insult": 0, + "obscenity": 0, + "threat": 0, + "politeness": 0, + }, + "emotion": map[string]any{}, + "voice_analysis": map[string]any{}, + "status": "ready", + "taskID": params.ID, + "created": params.Created, + "processed": params.Processed, + } +} diff --git a/api/result_test.go b/api/result_test.go new file mode 100644 index 0000000..7b6adc9 --- /dev/null +++ b/api/result_test.go @@ -0,0 +1,26 @@ +package api + +import "testing" + +func TestSprResultReady(t *testing.T) { + body := sprResultReady(TaskParams{ + ID: "701b3e84-5815-4baf-97a2-933ff820f16d", + Created: "2026-06-03 10:06:35", + Processed: "2026-06-03 10:07:48", + Status: "ready", + Model: "common", + Text: "hello", + }) + for _, key := range []string{"model", "text", "words", "toxicity", "emotion", "voice_analysis", "status", "taskID", "created", "processed"} { + if body[key] == nil { + t.Fatalf("missing %q", key) + } + } + if body["status"] != "ready" { + t.Fatalf("status=%v", body["status"]) + } + tox, ok := body["toxicity"].(map[string]float64) + if !ok || len(tox) != 4 { + t.Fatalf("toxicity=%v", body["toxicity"]) + } +} diff --git a/api/server.go b/api/server.go new file mode 100644 index 0000000..cbcf66d --- /dev/null +++ b/api/server.go @@ -0,0 +1,640 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + "go-whisper-api/config" + "go-whisper-api/diarization" + "go-whisper-api/punctuation" + "go-whisper-api/transcode" + "go-whisper-api/whisper" + + "github.com/google/uuid" + "github.com/rs/zerolog/log" +) + +type Server struct { + cfg config.API + punctCfg config.Punctuation + diarCfg config.Diarization + transcode *transcode.Engine + modelPool *whisper.ModelPool + punct punctuation.Restorer + diarizer diarization.Engine + models *Registry + cache *DiskCache + mux *http.ServeMux + queueWake chan struct{} +} + +func NewServer(cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) (*Server, error) { + cfg = cfg.WithDefaults() + tc = tc.WithDefaults() + pc = pc.WithDefaults() + dc = dc.WithDefaults() + restorer, err := punctuation.New(pc) + if err != nil { + return nil, err + } + if pc.Active() && !restorer.Active() { + return nil, fmt.Errorf("punctuation is enabled but engine %q is not available", pc.Engine) + } + if cfg.ModelsDir == "" { + cfg.ModelsDir = "./models" + } + if cfg.Addr == "" { + cfg.Addr = ":8080" + } + if cfg.Threads == 0 { + cfg.Threads = uint(runtime.NumCPU()) + } + cfg = cfg.WithDefaults() + if cfg.MaxContext == 0 { + cfg.MaxContext = 32 + } + if cfg.BeamSize == 0 { + cfg.BeamSize = 5 + } + if cfg.EntropyThold == 0 { + cfg.EntropyThold = 2.4 + } + if err := os.MkdirAll(cfg.ModelsDir, 0o755); err != nil { + return nil, err + } + cacheDir := cfg.CacheDir + if cacheDir == "" { + cacheDir = "./cache" + } + cache, err := NewDiskCache(cacheDir) + if err != nil { + return nil, err + } + cfg.CacheDir = cache.Root() + if err := cache.RecoverInterrupted(); err != nil { + return nil, err + } + diar, err := diarization.New(dc) + if err != nil { + return nil, err + } + s := &Server{ + cfg: cfg, + punctCfg: pc, + diarCfg: dc, + transcode: transcode.NewEngine(tc.FFmpegPath), + modelPool: whisper.NewModelPool(), + punct: restorer, + diarizer: diar, + models: NewRegistry(cfg.ModelsDir), + cache: cache, + mux: http.NewServeMux(), + queueWake: make(chan struct{}, 1), + } + s.routes() + go s.warmModels() + return s, nil +} + +func (s *Server) warmModels() { + ids, err := s.models.List() + if err != nil { + return + } + for _, id := range ids { + path, err := s.models.Path(id) + if err != nil { + continue + } + if err := s.modelPool.WithModel(path, func(whisper.Model) error { return nil }); err != nil { + log.Warn().Err(err).Str("model", id).Msg("preload whisper model") + } else { + log.Info().Str("model", id).Msg("whisper model loaded") + } + } +} + +func (s *Server) routes() { + s.mux.HandleFunc("/", s.handleSwaggerUI) + s.mux.HandleFunc("/swagger.json", s.handleSwaggerJSON) + s.mux.HandleFunc("/spr/models", s.handleModels) + s.mux.HandleFunc("/spr/hostname", s.handleHostname) + s.mux.HandleFunc("/spr/queue", s.handleQueue) + s.mux.HandleFunc("/spr/stt/", s.handleSTT) + s.mux.HandleFunc("/spr/result/", s.handleResult) + s.mux.HandleFunc("/spr/queue/", s.handleQueueItem) + s.mux.HandleFunc("/spr/audio/", s.handleAudio) + s.mux.HandleFunc("/spr/waveform/", s.handleWaveform) + s.mux.HandleFunc("/spr/delete/", s.handleDeleteModel) + s.mux.HandleFunc("/spr/export/", s.handleExportModel) + s.mux.HandleFunc("/spr/import/", s.handleImportModel) + s.mux.HandleFunc("/v1/audio/transcriptions", s.handleOpenAITranscriptions) + s.mux.HandleFunc("/v1/audio/transcriptions/", s.handleOpenAITranscriptions) + s.mux.HandleFunc("/v1/models", s.handleOpenAIModels) +} + +func (s *Server) ListenAndServe() error { + log.Info(). + Str("addr", s.cfg.Addr). + Str("models", s.cfg.ModelsDir). + Str("cache", s.cache.Root()). + Msg("starting API server") + return http.ListenAndServe(s.cfg.Addr, s.mux) +} + +func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + models, err := s.models.List() + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"models": models}) +} + +func (s *Server) handleHostname(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + host, _ := os.Hostname() + cwd, _ := os.Getwd() + writeJSON(w, http.StatusOK, map[string]any{ + "error": 0, + "message": "Success", + "hostname": host, + "version": "go-whisper-api", + "cwd": cwd, + "models": s.cfg.ModelsDir, + "cache": s.cache.Root(), + }) +} + +func (s *Server) handleQueue(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + list, err := s.cache.List() + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, list) +} + +func (s *Server) handleQueueItem(w http.ResponseWriter, r *http.Request) { + id := strings.TrimPrefix(r.URL.Path, "/spr/queue/") + if id == "" || !isValidTaskID(id) { + writeError(w, http.StatusBadRequest, "task id required") + return + } + switch r.Method { + case http.MethodGet: + s.handleQueueGet(w, id) + case http.MethodDelete: + if !s.cache.Delete(id) { + writeAPIError(w, http.StatusNotFound, "TaskNotFound") + return + } + writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"}) + default: + methodNotAllowed(w) + } +} + +func (s *Server) handleQueueGet(w http.ResponseWriter, id string) { + params, phase, err := s.cache.LoadParams(id) + if err != nil { + writeAPIError(w, http.StatusNotFound, "task not found") + return + } + switch params.Status { + case string(statusReady): + if phase == cacheWaiting { + if err := s.cache.PromoteToReady(id); err != nil { + writeAPIError(w, http.StatusInternalServerError, err.Error()) + return + } + } + writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"}) + case string(statusError): + msg := params.Error + if msg == "" { + msg = "transcription failed" + } + writeAPIError(w, http.StatusNotFound, msg) + default: + writeJSON(w, http.StatusOK, map[string]any{ + "error": 0, + "message": params.Status, + "status": params.Status, + }) + } +} + +func (s *Server) handleSTT(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + methodNotAllowed(w) + return + } + modelID := strings.TrimPrefix(r.URL.Path, "/spr/stt/") + if modelID == "" { + writeError(w, http.StatusBadRequest, "model id required") + return + } + modelPath, err := s.models.Path(modelID) + if err != nil { + writeAPIError(w, http.StatusNotFound, err.Error()) + return + } + audioPath, cleanup, err := s.saveUploadedWav(r) + if err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + stt, err := s.parseSTTOptions(r) + if err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + if queryAsync(r, s.cfg.DefaultAsync) { + taskID, err := s.enqueueAsync(r, modelID, audioPath, stt) + cleanup() + if err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]string{"taskID": taskID}) + return + } + defer cleanup() + result, err := s.transcribe(r.Context(), modelPath, audioPath, stt) + if err != nil { + writeAPIError(w, http.StatusMethodNotAllowed, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "model": modelID, + "text": result.Text, + "words": result.Words, + }) +} + +func (s *Server) handleResult(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/result/") + if !isValidTaskID(id) { + writeAPIError(w, http.StatusBadRequest, "task id required") + return + } + params, _, err := s.cache.LoadParams(id) + if err != nil { + writeAPIError(w, http.StatusNotFound, "TaskNotFound") + return + } + switch params.Status { + case string(statusWaiting), string(statusProcessing): + writeJSON(w, http.StatusOK, map[string]string{"status": params.Status}) + case string(statusError): + msg := params.Error + if msg == "" { + msg = "TaskNotFound" + } + writeAPIError(w, http.StatusNotFound, msg) + case string(statusReady): + writeJSON(w, http.StatusOK, sprResultReady(params)) + default: + writeJSON(w, http.StatusOK, map[string]string{"status": params.Status}) + } +} + +func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/audio/") + path, ok := s.cache.AudioPath(id) + if !ok { + writeError(w, http.StatusNotFound, "task not found") + return + } + http.ServeFile(w, r, path) +} + +func (s *Server) handleWaveform(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/waveform/") + if !isValidTaskID(id) { + writeAPIError(w, http.StatusBadRequest, "task id required") + return + } + wf, err := s.cache.Waveform(id) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]any{"error": 0, "waveform": wf}) +} + +func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/delete/") + if err := s.models.Delete(id); err != nil { + writeAPIError(w, http.StatusNotFound, err.Error()) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *Server) handleExportModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/export/") + f, err := s.models.Open(id) + if err != nil { + writeAPIError(w, http.StatusNotFound, err.Error()) + return + } + defer f.Close() + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q.bin", id)) + w.Header().Set("Content-Type", "application/octet-stream") + io.Copy(w, f) +} + +func (s *Server) handleImportModel(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + methodNotAllowed(w) + return + } + id := strings.TrimPrefix(r.URL.Path, "/spr/import/") + if err := r.ParseMultipartForm(32 << 20); err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + file, header, err := r.FormFile("zip-model") + if err != nil { + file, header, err = r.FormFile("model") + } + if err != nil { + writeAPIError(w, http.StatusBadRequest, "model file required") + return + } + defer file.Close() + src := io.Reader(file) + if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") { + writeAPIError(w, http.StatusBadRequest, "zip import is not supported; upload .bin model file as zip-model field") + return + } + if err := s.models.Import(id, src); err != nil { + writeAPIError(w, http.StatusBadRequest, err.Error()) + return + } + w.WriteHeader(http.StatusOK) +} + +func (s *Server) enqueueAsync(r *http.Request, modelID, audioWavPath string, stt sttOptions) (string, error) { + id := uuid.New().String() + params := TaskParams{ + ID: id, + Created: time.Now().Format("2006-01-02 15:04:05"), + Status: string(statusWaiting), + Model: modelID, + Language: stt.language, + Punctuation: stt.punctuate, + Speakers: stt.speakers, + NumClusters: stt.numClusters, + } + if err := s.cache.Enqueue(id, params, audioWavPath); err != nil { + return "", err + } + s.notifyQueue() + log.Info(). + Str("task", id). + Str("model", modelID). + Str("cache", s.cache.Root()). + Msg("enqueued async task") + return id, nil +} + +func (s *Server) transcribe(ctx context.Context, modelPath, audioPath string, stt sttOptions) (whisper.TranscriptResult, error) { + turns, err := s.runDiarization(ctx, audioPath, stt) + if err != nil { + return whisper.TranscriptResult{}, err + } + vad := s.cfg.VAD + if vad.Enabled { + vad.Model = vad.ResolveModelPath(s.cfg.ModelsDir) + } + cfg := &config.Whisper{ + Model: modelPath, + AudioPath: audioPath, + Threads: s.cfg.Threads, + Language: stt.language, + Debug: s.cfg.Debug, + SpeedUp: s.cfg.SpeedUp, + Translate: s.cfg.Translate, + Prompt: s.cfg.Prompt, + MaxContext: s.cfg.MaxContext, + BeamSize: s.cfg.BeamSize, + EntropyThold: s.cfg.EntropyThold, + VAD: vad, + PrintProgress: false, + PrintSegment: false, + } + runOpts := s.whisperRunOpts(stt, turns) + if stt.punctuate && s.punct.Active() { + runOpts.PunctuateRestore = func(text string) (string, error) { + return punctuation.Apply(ctx, s.punct, true, text, stt.language) + } + } + result, err := whisper.TranscribeWithPool(s.modelPool, cfg, runOpts) + if err != nil { + return whisper.TranscriptResult{}, err + } + return applyGarbage(result, s.cfg.GarbagePatterns()), nil +} + +func (s *Server) runDiarization(ctx context.Context, audioPath string, stt sttOptions) ([]whisper.Turn, error) { + if !stt.speakers { + return nil, nil + } + samples, err := whisper.LoadPCM16Mono(audioPath) + if err != nil { + return nil, fmt.Errorf("diarization audio: %w", err) + } + return s.diarizer.Process(ctx, samples, stt.numClusters) +} + +func saveUploadedRaw(r *http.Request) (path string, cleanup func(), err error) { + return saveUploadedRawFields(r, []string{"audio", "wav", "file"}) +} + +func saveUploadedRawFields(r *http.Request, fieldNames []string) (path string, cleanup func(), err error) { + if r.MultipartForm == nil { + if err := r.ParseMultipartForm(128 << 20); err != nil { + return "", nil, err + } + } + var ( + file multipart.File + header *multipart.FileHeader + found bool + ) + for _, name := range fieldNames { + file, header, err = r.FormFile(name) + if err == nil { + found = true + break + } + } + if !found { + return "", nil, fmt.Errorf("audio file required (form field: %s)", strings.Join(fieldNames, ", ")) + } + defer file.Close() + dir, err := config.MkdirTemp("go-whisper-api-upload-*") + if err != nil { + return "", nil, err + } + cleanup = func() { os.RemoveAll(dir) } + base := "input" + if header != nil { + if ext := filepath.Ext(header.Filename); ext != "" { + base += ext + } + } + raw := filepath.Join(dir, base) + out, err := os.Create(raw) + if err != nil { + cleanup() + return "", nil, err + } + if _, err := io.Copy(out, file); err != nil { + out.Close() + cleanup() + return "", nil, err + } + out.Close() + return raw, cleanup, nil +} + +func (s *Server) saveUploadedWav(r *http.Request) (path string, cleanup func(), err error) { + raw, cleanup, err := saveUploadedRaw(r) + if err != nil { + return "", nil, err + } + dst := filepath.Join(filepath.Dir(raw), "audio.wav") + if err := s.transcode.Transcode(r.Context(), raw, dst, transcode.WhisperOptions()); err != nil { + cleanup() + return "", nil, err + } + return dst, cleanup, nil +} + +func queryBoolDefault(r *http.Request, key string, def bool) bool { + v := r.URL.Query().Get(key) + if v == "" { + return def + } + b, err := strconv.ParseBool(v) + if err != nil { + return def + } + return b +} + +func queryAsync(r *http.Request, defaultAsync bool) bool { + v := r.URL.Query().Get("async") + if v == "" { + return defaultAsync + } + n, err := strconv.Atoi(v) + if err != nil { + return defaultAsync + } + return n == 1 +} + +func queryInt(r *http.Request, key string, def int) int { + v := r.URL.Query().Get(key) + if v == "" { + return def + } + n, err := strconv.Atoi(v) + if err != nil { + return def + } + return n +} + +func writeJSON(w http.ResponseWriter, code int, v any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, code int, msg string) { + http.Error(w, msg, code) +} + +func writeAPIError(w http.ResponseWriter, code int, msg string) { + writeJSON(w, code, map[string]any{"error": 1, "message": msg}) +} + +func methodNotAllowed(w http.ResponseWriter) { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) +} + +func (s *Server) notifyQueue() { + select { + case s.queueWake <- struct{}{}: + default: + } +} + +func Run(ctx context.Context, cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) error { + srv, err := NewServer(cfg, tc, pc, dc) + if err != nil { + return err + } + defer srv.modelPool.Close() + srv.StartWorker(ctx) + hs := &http.Server{Addr: cfg.Addr, Handler: srv.mux} + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = hs.Shutdown(shutdownCtx) + }() + if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + punctuation.Close(srv.punct) + srv.diarizer.Close() + return nil +} diff --git a/api/swagger-ui.html b/api/swagger-ui.html new file mode 100644 index 0000000..1bb5e40 --- /dev/null +++ b/api/swagger-ui.html @@ -0,0 +1,23 @@ + + + + + go-whisper-api + + + +
+ + + + diff --git a/api/swagger.go b/api/swagger.go new file mode 100644 index 0000000..fafeae1 --- /dev/null +++ b/api/swagger.go @@ -0,0 +1,34 @@ +package api + +import ( + _ "embed" + "net/http" +) + +//go:embed swagger.json +var swaggerSpec []byte + +//go:embed swagger-ui.html +var swaggerUI []byte + +func (s *Server) handleSwaggerJSON(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Write(swaggerSpec) +} + +func (s *Server) handleSwaggerUI(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + if r.Method != http.MethodGet { + methodNotAllowed(w) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Write(swaggerUI) +} diff --git a/api/swagger.json b/api/swagger.json new file mode 100644 index 0000000..188ddf7 --- /dev/null +++ b/api/swagger.json @@ -0,0 +1,457 @@ +{ + "swagger": "2.0", + "basePath": "/", + "paths": { + "/spr/audio/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_audio_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/delete/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "delete": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "delete_model_delete", + "tags": [ + "spr" + ] + } + }, + "/spr/export/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_model_export", + "tags": [ + "spr" + ] + } + }, + "/spr/hostname": { + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_hostname_class", + "tags": [ + "spr" + ] + } + }, + "/spr/import/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string" + } + ], + "post": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "post_model_import", + "parameters": [ + { + "name": "zip-model", + "in": "formData", + "type": "file", + "required": true, + "description": "prepared model zip file" + } + ], + "consumes": [ + "multipart/form-data" + ], + "tags": [ + "spr" + ] + } + }, + "/spr/models": { + "get": { + "responses": { + "200": { + "description": "Success", + "schema": { + "$ref": "#/definitions/modelList" + } + } + }, + "operationId": "get_model_list", + "tags": [ + "spr" + ] + } + }, + "/spr/queue": { + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_queue_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/queue/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "delete": { + "responses": { + "200": { + "description": "Success", + "schema": { + "type": "object", + "properties": { + "error": { "type": "integer" }, + "message": { "type": "string" } + } + } + }, + "404": { + "description": "TaskNotFound" + } + }, + "operationId": "delete_queue_del_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/result/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "404": { + "description": "Not found", + "schema": { + "$ref": "#/definitions/error" + } + }, + "200": { + "description": "Success", + "schema": { + "$ref": "#/definitions/resultTTS" + } + } + }, + "operationId": "get_result_stt", + "tags": [ + "spr" + ] + } + }, + "/spr/stt/{id}": { + "post": { + "responses": { + "405": { + "description": "Error", + "schema": { + "$ref": "#/definitions/error" + } + }, + "404": { + "description": "Not found", + "schema": { + "$ref": "#/definitions/error" + } + }, + "200": { + "description": "Success", + "schema": { + "$ref": "#/definitions/modelTTS" + } + } + }, + "operationId": "post_model_test", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "string", + "description": "NN Model ID" + }, + { + "name": "wav", + "in": "formData", + "type": "file", + "description": "file to recognize" + }, + { + "name": "async", + "in": "query", + "type": "integer", + "description": "async mode (default 1: enqueue to cache/waiting; 0: sync text response)", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "speakers", + "in": "query", + "type": "integer", + "description": "find speakers", + "default": 0, + "enum": [ + 0, + 1 + ] + }, + { + "name": "speaker_counter", + "in": "query", + "type": "integer", + "description": "number of speakers. 0 for autodetect. -1 disable speaker detection.", + "default": 0 + }, + { + "name": "normalization", + "in": "query", + "type": "integer", + "description": "normalize text", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "punctuation", + "in": "query", + "type": "integer", + "description": "punctuate text", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "toxicity", + "in": "query", + "type": "integer", + "description": "toxicity analyzer", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "emotion", + "in": "query", + "type": "integer", + "description": "emotion analyzer", + "default": 0, + "enum": [ + 0, + 1 + ] + }, + { + "name": "voice_analyzer", + "in": "query", + "type": "integer", + "description": "voice analyzer", + "default": 1, + "enum": [ + 0, + 1 + ] + }, + { + "name": "vad", + "in": "query", + "type": "string", + "description": "VAD type", + "default": "webrtc", + "enum": [ + "webrtc" + ] + }, + { + "name": "classifiers", + "in": "query", + "type": "string", + "description": "JSON with classification models to process each sentence" + }, + { + "name": "webhook", + "in": "query", + "type": "string", + "description": "webhook url to send stt async result" + } + ], + "consumes": [ + "multipart/form-data" + ], + "tags": [ + "spr" + ] + } + }, + "/spr/waveform/{taskID}": { + "parameters": [ + { + "name": "taskID", + "in": "path", + "required": true, + "type": "string" + } + ], + "get": { + "responses": { + "200": { + "description": "Success" + } + }, + "operationId": "get_audioarray_stt", + "tags": [ + "spr" + ] + } + } + }, + "info": { + "title": "go-whisper-api", + "version": "5.008 release", + "description": "Whisper speech-to-text API (SPR-compatible)" + }, + "produces": [ + "application/json" + ], + "consumes": [ + "application/json" + ], + "tags": [ + { + "name": "spr", + "description": "NN Model operations" + } + ], + "definitions": { + "modelList": { + "properties": { + "models": { + "type": "array", + "items": { + "type": "string", + "description": "NN Model ID" + } + } + }, + "type": "object" + }, + "error": { + "properties": { + "error": { + "type": "integer", + "description": "Error flag" + }, + "message": { + "type": "string", + "description": "Error description" + } + }, + "type": "object" + }, + "modelTTS": { + "required": [ + "text" + ], + "properties": { + "text": { + "type": "string", + "description": "Recognized text" + } + }, + "type": "object" + }, + "resultTTS": { + "properties": { + "model": { "type": "string" }, + "text": { "type": "string" }, + "words": { "type": "array" }, + "toxicity": { "type": "object" }, + "emotion": { "type": "object" }, + "voice_analysis": { "type": "object" }, + "status": { "type": "string" }, + "taskID": { "type": "string" }, + "created": { "type": "string" }, + "processed": { "type": "string" } + }, + "type": "object" + } + }, + "responses": { + "ParseError": { + "description": "When a mask can't be parsed" + }, + "MaskError": { + "description": "When any error occurs on mask" + } + } +} \ No newline at end of file diff --git a/api/tasks.go b/api/tasks.go new file mode 100644 index 0000000..7f2a8f8 --- /dev/null +++ b/api/tasks.go @@ -0,0 +1,10 @@ +package api + +type taskStatus string + +const ( + statusWaiting taskStatus = "waiting" + statusProcessing taskStatus = "processing" + statusReady taskStatus = "ready" + statusError taskStatus = "error" +) diff --git a/api/transcribe_opts.go b/api/transcribe_opts.go new file mode 100644 index 0000000..cc923c9 --- /dev/null +++ b/api/transcribe_opts.go @@ -0,0 +1,77 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + + "go-whisper-api/config" + "go-whisper-api/whisper" +) + +type sttOptions struct { + language string + punctuate bool + speakers bool + numClusters int +} + +func (s *Server) parseSTTOptions(r *http.Request) (sttOptions, error) { + opts := sttOptions{ + language: resolveLanguage(r, s.cfg.Language), + punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation), + } + sp, clusters, err := querySpeakers(r, s.cfg.DefaultSpeakers, s.diarCfg, s.diarizer.Active()) + if err != nil { + return opts, err + } + opts.speakers = sp + opts.numClusters = clusters + return opts, nil +} + +func resolveLanguage(r *http.Request, defaultLang string) string { + if v := strings.TrimSpace(r.URL.Query().Get("language")); v != "" { + return v + } + return strings.TrimSpace(defaultLang) +} + +func querySpeakers(r *http.Request, defaultOn bool, dc config.Diarization, diarizerActive bool) (enabled bool, numClusters int, err error) { + counter := queryInt(r, "speaker_counter", -999) + if counter == -1 { + return false, 0, nil + } + speakersQ := r.URL.Query().Get("speakers") + enabled = defaultOn + if speakersQ != "" { + enabled = queryInt(r, "speakers", 0) == 1 + } + if !enabled { + return false, 0, nil + } + if !dc.Active() { + return false, 0, fmt.Errorf("speaker diarization is disabled in config (diarization.enabled: true)") + } + if !diarizerActive { + return false, 0, fmt.Errorf("speaker diarization requires server built with -tags sherpa (make build-sherpa) and models (make download-diarization-models)") + } + if counter > 0 { + numClusters = counter + } else if dc.NumClusters > 0 { + numClusters = dc.NumClusters + } + return true, numClusters, nil +} + +func (s *Server) whisperRunOpts(stt sttOptions, turns []whisper.Turn) whisper.RunOptions { + t := s.cfg.Transcript.WithDefaults() + return whisper.RunOptions{ + Turns: turns, + Format: whisper.FormatOptions{ + PauseGap: t.PauseGapDuration(), + SpeakerLabel: t.SpeakerLabel, + UseSpeakers: stt.speakers && len(turns) > 0, + }, + } +} diff --git a/api/waveform.go b/api/waveform.go new file mode 100644 index 0000000..1b6f07c --- /dev/null +++ b/api/waveform.go @@ -0,0 +1,48 @@ +package api + +import ( + "math" + "os" + + "github.com/go-audio/wav" +) + +func waveformFromWav(path string, buckets int) ([]float64, error) { + if buckets <= 0 { + buckets = 256 + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + dec := wav.NewDecoder(f) + buf, err := dec.FullPCMBuffer() + if err != nil { + return nil, err + } + samples := buf.AsFloat32Buffer().Data + if len(samples) == 0 { + return make([]float64, buckets), nil + } + chunk := len(samples) / buckets + if chunk < 1 { + chunk = 1 + } + out := make([]float64, 0, buckets) + for i := 0; i < len(samples) && len(out) < buckets; i += chunk { + end := i + chunk + if end > len(samples) { + end = len(samples) + } + peak := 0.0 + for _, s := range samples[i:end] { + v := math.Abs(float64(s)) + if v > peak { + peak = v + } + } + out = append(out, math.Round(peak*1000)/1000) + } + return out, nil +} diff --git a/config.yaml.example b/config.yaml.example new file mode 100644 index 0000000..2a3fd24 --- /dev/null +++ b/config.yaml.example @@ -0,0 +1,55 @@ +api: + addr: "0.0.0.0:6183" + models_dir: "./models" + cache_dir: "./cache" + # Open WebUI / OpenAI STT: model id when client sends whisper-1 (e.g. ggml-large-v3-turbo) + default_model: ggml-large-v3-turbo + threads: 16 + language: ru + transcript: + pause_gap_sec: 1.5 + speaker_label: "Спикер" + default_speakers: false + debug: false + speedup: false + translate: false + prompt: "" + max_context: 32 + beam_size: 5 + entropy_thold: 2.4 + vad: + enabled: false + model: vad/vad.bin + threshold: 0.5 + min_speech_duration_ms: 250 + min_silence_duration_ms: 100 + speech_pad_ms: 30 + samples_overlap: 0.1 + default_punctuation: true + default_async: true + garbage: + - "*выбая*" + +# transcode: pure Go (wav, mp3, flac, ogg, m4a, mp4, aac) — no ffmpeg required +transcode: {} + +diarization: + enabled: false + model_dir: ./models/diarization + segmentation_model: pyannote-segmentation-3-0/model.onnx + embedding_model: 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + num_threads: 2 + num_clusters: 0 + clustering_threshold: 0.5 + +punctuation: + enabled: true + default_on: true + # off | heuristic | xlm | sherpa | sherpa-online | http + engine: xlm + model_dir: ./models/punctuation/xlm-roberta + model_file: model.onnx + sp_model: sp.model + config_file: config.yaml + apply_sbd: true + num_threads: 2 diff --git a/config/api.go b/config/api.go new file mode 100644 index 0000000..8bb25b0 --- /dev/null +++ b/config/api.go @@ -0,0 +1,42 @@ +package config + +import "strings" + +type API struct { + Addr string `yaml:"addr"` + ModelsDir string `yaml:"models_dir"` + CacheDir string `yaml:"cache_dir"` + // DefaultModel: whisper model id for OpenAI /v1/audio/transcriptions (maps whisper-1). + DefaultModel string `yaml:"default_model"` + Language string `yaml:"language"` + Transcript Transcript `yaml:"transcript"` + DefaultSpeakers bool `yaml:"default_speakers"` + Prompt string `yaml:"prompt"` + Threads uint `yaml:"threads"` + MaxContext uint `yaml:"max_context"` + BeamSize uint `yaml:"beam_size"` + EntropyThold float64 `yaml:"entropy_thold"` + VAD VAD `yaml:"vad"` + Debug bool `yaml:"debug"` + SpeedUp bool `yaml:"speedup"` + Translate bool `yaml:"translate"` + DefaultPunctuation bool `yaml:"default_punctuation"` + // DefaultAsync: STT via API enqueues to cache/waiting and returns taskID (use ?async=0 for sync). + DefaultAsync bool `yaml:"default_async"` + // Garbage: artifact substrings removed from transcript text and words (default includes *выбая*). + Garbage []string `yaml:"garbage"` +} + +func (a API) WithDefaults() API { + a.VAD = a.VAD.WithDefaults() + a.Transcript = a.Transcript.WithDefaults() + if strings.TrimSpace(a.Language) == "" { + a.Language = "ru" + } + return a +} + +// GarbagePatterns returns garbage filter list (never empty unless explicitly set to [] in YAML). +func (a API) GarbagePatterns() []string { + return a.garbagePatterns() +} diff --git a/config/diarization.go b/config/diarization.go new file mode 100644 index 0000000..ae4858c --- /dev/null +++ b/config/diarization.go @@ -0,0 +1,72 @@ +package config + +import ( + "os" + "path/filepath" +) + +type Diarization struct { + Enabled bool `yaml:"enabled"` + ModelDir string `yaml:"model_dir"` + SegmentationModel string `yaml:"segmentation_model"` + EmbeddingModel string `yaml:"embedding_model"` + NumThreads int `yaml:"num_threads"` + NumClusters int `yaml:"num_clusters"` + ClusteringThreshold float32 `yaml:"clustering_threshold"` + MinDurationOn float32 `yaml:"min_duration_on"` + MinDurationOff float32 `yaml:"min_duration_off"` +} + +func (d Diarization) WithDefaults() Diarization { + if d.ModelDir == "" { + d.ModelDir = "./models/diarization" + } + if d.SegmentationModel == "" { + d.SegmentationModel = "pyannote-segmentation-3-0/model.onnx" + } + if d.EmbeddingModel == "" { + d.EmbeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + } + if d.NumThreads <= 0 { + d.NumThreads = 2 + } + if d.ClusteringThreshold <= 0 { + d.ClusteringThreshold = 0.5 + } + if d.MinDurationOn <= 0 { + d.MinDurationOn = 0.3 + } + if d.MinDurationOff <= 0 { + d.MinDurationOff = 0.5 + } + return d +} + +func (d Diarization) SegmentationPath() string { + return resolveModelPath(d.ModelDir, d.SegmentationModel) +} + +func (d Diarization) EmbeddingPath() string { + return resolveModelPath(d.ModelDir, d.EmbeddingModel) +} + +func (d Diarization) ModelsPresent() bool { + if _, err := os.Stat(d.SegmentationPath()); err != nil { + return false + } + if _, err := os.Stat(d.EmbeddingPath()); err != nil { + return false + } + return true +} + +func resolveModelPath(dir, name string) string { + if filepath.IsAbs(name) { + return name + } + return filepath.Join(dir, name) +} + +func (d Diarization) Active() bool { + return d.Enabled +} diff --git a/config/file.go b/config/file.go new file mode 100644 index 0000000..2cc513f --- /dev/null +++ b/config/file.go @@ -0,0 +1,53 @@ +package config + +import ( + "fmt" + "os" + "runtime" + + "gopkg.in/yaml.v3" +) + +type File struct { + API API `yaml:"api"` + Transcode Transcode `yaml:"transcode"` + Punctuation Punctuation `yaml:"punctuation"` + Diarization Diarization `yaml:"diarization"` +} + +func DefaultFile() File { + return File{ + API: API{ + Addr: ":8080", + ModelsDir: "./models", + CacheDir: "./cache", + Threads: uint(runtime.NumCPU()), + Language: "ru", + Transcript: Transcript{}.WithDefaults(), + MaxContext: 32, + BeamSize: 5, + EntropyThold: 2.4, + DefaultAsync: true, + Garbage: DefaultGarbage(), + }, + Diarization: Diarization{}.WithDefaults(), + Transcode: Transcode{}.WithDefaults(), + Punctuation: Punctuation{Enabled: false, Engine: "off"}.WithDefaults(), + } +} + +func LoadFile(path string) (File, error) { + data, err := os.ReadFile(path) + if err != nil { + return File{}, err + } + cfg := DefaultFile() + if err := yaml.Unmarshal(data, &cfg); err != nil { + return File{}, fmt.Errorf("parse config %s: %w", path, err) + } + return cfg, nil +} + +func (f File) APIConfig() API { + return f.API +} diff --git a/config/file_test.go b/config/file_test.go new file mode 100644 index 0000000..63b9f6f --- /dev/null +++ b/config/file_test.go @@ -0,0 +1,35 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + err := os.WriteFile(path, []byte(` +api: + addr: ":9090" + models_dir: "/data/models" + language: ru +`), 0o644) + if err != nil { + t.Fatal(err) + } + + cfg, err := LoadFile(path) + if err != nil { + t.Fatal(err) + } + if cfg.API.Addr != ":9090" { + t.Fatalf("addr: got %q", cfg.API.Addr) + } + if cfg.API.ModelsDir != "/data/models" { + t.Fatalf("models_dir: got %q", cfg.API.ModelsDir) + } + if cfg.API.Language != "ru" { + t.Fatalf("language: got %q", cfg.API.Language) + } +} diff --git a/config/garbage.go b/config/garbage.go new file mode 100644 index 0000000..4cfa7fe --- /dev/null +++ b/config/garbage.go @@ -0,0 +1,13 @@ +package config + +// DefaultGarbage returns STT artifact tokens filtered from API/CLI transcript output. +func DefaultGarbage() []string { + return []string{"*выбая*"} +} + +func (a API) garbagePatterns() []string { + if a.Garbage == nil { + return DefaultGarbage() + } + return a.Garbage +} diff --git a/config/garbage_test.go b/config/garbage_test.go new file mode 100644 index 0000000..2efa8d3 --- /dev/null +++ b/config/garbage_test.go @@ -0,0 +1,23 @@ +package config + +import "testing" + +func TestDefaultGarbage(t *testing.T) { + if len(DefaultGarbage()) == 0 || DefaultGarbage()[0] != "*выбая*" { + t.Fatalf("got %v", DefaultGarbage()) + } +} + +func TestAPI_GarbagePatterns_default(t *testing.T) { + p := (API{}).GarbagePatterns() + if len(p) != 1 || p[0] != "*выбая*" { + t.Fatalf("got %v", p) + } +} + +func TestAPI_GarbagePatterns_explicitEmpty(t *testing.T) { + p := (API{Garbage: []string{}}).GarbagePatterns() + if len(p) != 0 { + t.Fatalf("got %v", p) + } +} diff --git a/config/merge.go b/config/merge.go new file mode 100644 index 0000000..23407be --- /dev/null +++ b/config/merge.go @@ -0,0 +1,136 @@ +package config + +import ( + "os" + + "github.com/urfave/cli/v2" +) + +func LoadResolved(path string) (File, error) { + if path == "" { + if _, err := os.Stat("config.yaml"); err == nil { + path = "config.yaml" + } else { + return DefaultFile(), nil + } + } + return LoadFile(path) +} + +func mergeVAD(c *cli.Context, v VAD) VAD { + if c.IsSet("vad") { + v.Enabled = c.Bool("vad") + } + if c.IsSet("vad-model") { + v.Model = c.String("vad-model") + } + if c.IsSet("vad-threshold") { + v.Threshold = c.Float64("vad-threshold") + } + if c.IsSet("vad-min-speech-ms") { + v.MinSpeechMs = c.Int("vad-min-speech-ms") + } + if c.IsSet("vad-min-silence-ms") { + v.MinSilenceMs = c.Int("vad-min-silence-ms") + } + if c.IsSet("vad-max-speech-sec") { + v.MaxSpeechSec = c.Float64("vad-max-speech-sec") + } + if c.IsSet("vad-speech-pad-ms") { + v.SpeechPadMs = c.Int("vad-speech-pad-ms") + } + if c.IsSet("vad-samples-overlap") { + v.SamplesOverlap = c.Float64("vad-samples-overlap") + } + return v.WithDefaults() +} + +func mergeAPI(c *cli.Context, a API) API { + if c.IsSet("addr") { + a.Addr = c.String("addr") + } + if c.IsSet("models-dir") { + a.ModelsDir = c.String("models-dir") + } + if c.IsSet("cache-dir") { + a.CacheDir = c.String("cache-dir") + } + if c.IsSet("threads") { + a.Threads = c.Uint("threads") + } + if c.IsSet("language") { + a.Language = c.String("language") + } + if c.IsSet("debug") { + a.Debug = c.Bool("debug") + } + if c.IsSet("speedup") { + a.SpeedUp = c.Bool("speedup") + } + if c.IsSet("translate") { + a.Translate = c.Bool("translate") + } + if c.IsSet("prompt") { + a.Prompt = c.String("prompt") + } + if c.IsSet("max-context") { + a.MaxContext = c.Uint("max-context") + } + if c.IsSet("beam-size") { + a.BeamSize = c.Uint("beam-size") + } + if c.IsSet("entropy-thold") { + a.EntropyThold = c.Float64("entropy-thold") + } + a.VAD = mergeVAD(c, a.VAD) + if c.IsSet("default-punctuation") { + a.DefaultPunctuation = c.Bool("default-punctuation") + } + return a +} + +func APIFromCLI(c *cli.Context) (API, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return API{}, err + } + return mergeAPI(c, file.API), nil +} + +func TranscodeFromCLI(c *cli.Context) (Transcode, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return Transcode{}, err + } + return file.Transcode.WithDefaults(), nil +} + +func mergePunctuation(c *cli.Context, p Punctuation) Punctuation { + if c.IsSet("punctuation-enabled") { + p.Enabled = c.Bool("punctuation-enabled") + } + if c.IsSet("punctuation-engine") { + p.Engine = c.String("punctuation-engine") + } + if c.IsSet("punctuation-default-on") { + p.DefaultOn = c.Bool("punctuation-default-on") + } + return p.WithDefaults() +} + +func PunctuationFromCLI(c *cli.Context) (Punctuation, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return Punctuation{}, err + } + return mergePunctuation(c, file.Punctuation), nil +} + +func DiarizationFromCLI(c *cli.Context) (Diarization, error) { + file, err := LoadResolved(c.String("config")) + if err != nil { + return Diarization{}, err + } + return file.Diarization.WithDefaults(), nil +} + diff --git a/config/merge_test.go b/config/merge_test.go new file mode 100644 index 0000000..c78d6ea --- /dev/null +++ b/config/merge_test.go @@ -0,0 +1,28 @@ +package config + +import ( + "os" + "strings" + "testing" +) + +func TestDefaultFile_apiDefaults(t *testing.T) { + f := DefaultFile() + if f.API.Addr == "" { + t.Fatal("expected API listen addr") + } + if f.API.ModelsDir == "" { + t.Fatal("expected models_dir") + } +} + +func TestMkdirTemp_usesTmpRoot(t *testing.T) { + dir, err := MkdirTemp("go-whisper-api-test-*") + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.RemoveAll(dir) }() + if !strings.HasPrefix(dir, TempRoot+"/") { + t.Fatalf("temp dir should be under %s, got %s", TempRoot, dir) + } +} diff --git a/config/punctuation.go b/config/punctuation.go new file mode 100644 index 0000000..d0683fb --- /dev/null +++ b/config/punctuation.go @@ -0,0 +1,125 @@ +package config + +import ( + "net/http" + "path/filepath" + "strconv" + "strings" + "time" +) + +func (p Punctuation) XLMJoinSentences() bool { + return p.ApplySBD +} + +type Punctuation struct { + Command []string `yaml:"command"` + NumThreads int `yaml:"num_threads"` + TimeoutSec int `yaml:"timeout_sec"` + Engine string `yaml:"engine"` + ModelDir string `yaml:"model_dir"` + ModelFile string `yaml:"model_file"` + SPModel string `yaml:"sp_model"` + ConfigFile string `yaml:"config_file"` + BpeVocab string `yaml:"bpe_vocab"` + HTTPURL string `yaml:"http_url"` + Enabled bool `yaml:"enabled"` + DefaultOn bool `yaml:"default_on"` + ApplySBD bool `yaml:"apply_sbd"` +} + +func (p Punctuation) WithDefaults() Punctuation { + if p.Engine == "" { + p.Engine = "heuristic" + } + engine := strings.ToLower(strings.TrimSpace(p.Engine)) + if p.ModelDir == "" { + if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { + p.ModelDir = "./models/punctuation/xlm-roberta" + } else { + p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8" + } + } + if p.ModelFile == "" { + if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { + p.ModelFile = "model.onnx" + } else { + p.ModelFile = "model.int8.onnx" + } + } + if p.NumThreads <= 0 { + p.NumThreads = 2 + } + if p.TimeoutSec <= 0 { + p.TimeoutSec = 120 + } + return p +} + +func (p Punctuation) Timeout() time.Duration { + p = p.WithDefaults() + return time.Duration(p.TimeoutSec) * time.Second +} + +func (p Punctuation) Active() bool { + p = p.WithDefaults() + return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off") +} + +func (p Punctuation) ModelPath() string { + p = p.WithDefaults() + return filepath.Join(p.ModelDir, p.ModelFile) +} + +func (p Punctuation) SPModelPath() string { + p = p.WithDefaults() + if p.SPModel == "" { + return filepath.Join(p.ModelDir, "sp.model") + } + return filepath.Join(p.ModelDir, p.SPModel) +} + +func (p Punctuation) XLMConfigPath() string { + p = p.WithDefaults() + if p.ConfigFile == "" { + return filepath.Join(p.ModelDir, "config.yaml") + } + return filepath.Join(p.ModelDir, p.ConfigFile) +} + +func (p Punctuation) BpeVocabPath() string { + p = p.WithDefaults() + if p.BpeVocab == "" { + return filepath.Join(p.ModelDir, "bpe.vocab") + } + return filepath.Join(p.ModelDir, p.BpeVocab) +} + +func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool { + if !p.Active() { + return false + } + q := strings.TrimSpace(r.URL.Query().Get("punctuation")) + if q != "" { + return parsePunctuationQuery(q, true) + } + return apiDefault || p.Enabled +} + +func parsePunctuationQuery(raw string, def bool) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return def + } + switch strings.ToLower(raw) { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + } + b, err := strconv.ParseBool(raw) + if err != nil { + return def + } + return b +} diff --git a/config/punctuation_test.go b/config/punctuation_test.go new file mode 100644 index 0000000..1072217 --- /dev/null +++ b/config/punctuation_test.go @@ -0,0 +1,58 @@ +package config + +import ( + "net/http/httptest" + "testing" + +) + +func TestPunctuation_Active(t *testing.T) { + if (Punctuation{Enabled: false, Engine: "heuristic"}).Active() { + t.Fatal("disabled should be inactive") + } + if (Punctuation{Enabled: true, Engine: "off"}).Active() { + t.Fatal("engine off should be inactive") + } + if !(Punctuation{Enabled: true, Engine: "heuristic"}).Active() { + t.Fatal("enabled heuristic should be active") + } +} + +func TestPunctuation_ShouldApplyAPI(t *testing.T) { + p := Punctuation{Enabled: true, Engine: "heuristic"} + + req := httptest.NewRequest("GET", "/?punctuation=1", nil) + if !p.ShouldApplyAPI(req, false) { + t.Fatal("query 1 should enable") + } + + req = httptest.NewRequest("GET", "/?punctuation=0", nil) + if p.ShouldApplyAPI(req, true) { + t.Fatal("query 0 should disable") + } + + req = httptest.NewRequest("GET", "/", nil) + if !p.ShouldApplyAPI(req, false) { + t.Fatal("enabled in config => apply when query omitted") + } + if !p.ShouldApplyAPI(req, true) { + t.Fatal("api default true") + } + + off := Punctuation{Enabled: false, Engine: "heuristic"} + if off.ShouldApplyAPI(req, true) { + t.Fatal("master disabled must never apply") + } +} + +func TestParsePunctuationQuery(t *testing.T) { + if !parsePunctuationQuery("1", false) { + t.Fatal("1 => true") + } + if parsePunctuationQuery("0", true) { + t.Fatal("0 => false") + } + if !parsePunctuationQuery("", true) { + t.Fatal("empty => default") + } +} diff --git a/config/tmp.go b/config/tmp.go new file mode 100644 index 0000000..305c5e1 --- /dev/null +++ b/config/tmp.go @@ -0,0 +1,9 @@ +package config + +import "os" + +const TempRoot = "/tmp" + +func MkdirTemp(prefix string) (string, error) { + return os.MkdirTemp(TempRoot, prefix) +} diff --git a/config/transcode.go b/config/transcode.go new file mode 100644 index 0000000..cd0a175 --- /dev/null +++ b/config/transcode.go @@ -0,0 +1,11 @@ +package config + +// Transcode holds audio normalization settings (pure Go decoders; no external ffmpeg). +type Transcode struct { + // FFmpegPath is deprecated and ignored; kept for backward-compatible YAML. + FFmpegPath string `yaml:"ffmpeg_path,omitempty"` +} + +func (t Transcode) WithDefaults() Transcode { + return t +} diff --git a/config/transcript.go b/config/transcript.go new file mode 100644 index 0000000..1f5a942 --- /dev/null +++ b/config/transcript.go @@ -0,0 +1,26 @@ +package config + +import ( + "strings" + "time" +) + +// Transcript controls how STT segments are joined into one text field (with embedded newlines). +type Transcript struct { + PauseGapSec float64 `yaml:"pause_gap_sec"` + SpeakerLabel string `yaml:"speaker_label"` +} + +func (t Transcript) WithDefaults() Transcript { + if t.PauseGapSec <= 0 { + t.PauseGapSec = 1.5 + } + if strings.TrimSpace(t.SpeakerLabel) == "" { + t.SpeakerLabel = "Спикер" + } + return t +} + +func (t Transcript) PauseGapDuration() time.Duration { + return time.Duration(t.PauseGapSec * float64(time.Second)) +} diff --git a/config/vad.go b/config/vad.go new file mode 100644 index 0000000..faf40d2 --- /dev/null +++ b/config/vad.go @@ -0,0 +1,88 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +type VAD struct { + Enabled bool `yaml:"enabled"` + Model string `yaml:"model"` + Threshold float64 `yaml:"threshold"` + MinSpeechMs int `yaml:"min_speech_duration_ms"` + MinSilenceMs int `yaml:"min_silence_duration_ms"` + MaxSpeechSec float64 `yaml:"max_speech_duration_s"` + SpeechPadMs int `yaml:"speech_pad_ms"` + SamplesOverlap float64 `yaml:"samples_overlap"` +} + +func DefaultVAD() VAD { + return VAD{ + Threshold: 0.5, + MinSpeechMs: 250, + MinSilenceMs: 100, + MaxSpeechSec: 0, + SpeechPadMs: 30, + SamplesOverlap: 0.1, + } +} + +func (v VAD) WithDefaults() VAD { + d := DefaultVAD() + if v.Threshold <= 0 { + v.Threshold = d.Threshold + } + if v.MinSpeechMs <= 0 { + v.MinSpeechMs = d.MinSpeechMs + } + if v.MinSilenceMs <= 0 { + v.MinSilenceMs = d.MinSilenceMs + } + if v.SpeechPadMs <= 0 { + v.SpeechPadMs = d.SpeechPadMs + } + if v.SamplesOverlap <= 0 { + v.SamplesOverlap = d.SamplesOverlap + } + return v +} + +func (v VAD) ResolveModelPath(modelsDir string) string { + if v.Model == "" { + return "" + } + if filepath.IsAbs(v.Model) { + return v.Model + } + if modelsDir == "" { + return v.Model + } + direct := filepath.Join(modelsDir, v.Model) + if _, err := os.Stat(direct); err == nil { + return direct + } + // VAD weights often live under models_dir/vad/ (not listed in /spr/models). + base := filepath.Base(v.Model) + inVAD := filepath.Join(modelsDir, "vad", base) + if _, err := os.Stat(inVAD); err == nil { + return inVAD + } + return direct +} + +func (v VAD) Validate() error { + if !v.Enabled { + return nil + } + if v.Model == "" { + return fmt.Errorf("vad.model is required when vad.enabled is true") + } + if _, err := os.Stat(v.Model); err != nil { + return fmt.Errorf("vad model %q: %w", v.Model, err) + } + if v.Threshold < 0 || v.Threshold > 1 { + return fmt.Errorf("vad.threshold must be between 0 and 1") + } + return nil +} diff --git a/config/vad_test.go b/config/vad_test.go new file mode 100644 index 0000000..d52179f --- /dev/null +++ b/config/vad_test.go @@ -0,0 +1,65 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestVAD_Validate_disabled(t *testing.T) { + if err := (VAD{}).Validate(); err != nil { + t.Fatal(err) + } +} + +func TestVAD_Validate_requiresModel(t *testing.T) { + err := (VAD{Enabled: true}).Validate() + if err == nil { + t.Fatal("expected error") + } +} + +func TestVAD_Validate_modelExists(t *testing.T) { + dir := t.TempDir() + model := filepath.Join(dir, "ggml-silero.bin") + if err := os.WriteFile(model, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + v := VAD{Enabled: true, Model: model, Threshold: 0.5} + if err := v.Validate(); err != nil { + t.Fatal(err) + } +} + +func TestVAD_ResolveModelPath(t *testing.T) { + v := VAD{Model: "ggml-silero-v6.2.0.bin"} + got := v.ResolveModelPath("/data/models") + want := filepath.Join("/data/models", "ggml-silero-v6.2.0.bin") + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestVAD_ResolveModelPath_vadSubdir(t *testing.T) { + dir := t.TempDir() + vadDir := filepath.Join(dir, "vad") + if err := os.MkdirAll(vadDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(vadDir, "vad.bin"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + v := VAD{Model: "vad.bin"} + got := v.ResolveModelPath(dir) + want := filepath.Join(dir, "vad", "vad.bin") + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestVAD_WithDefaults(t *testing.T) { + v := VAD{Enabled: true}.WithDefaults() + if v.Threshold != 0.5 || v.MinSpeechMs != 250 || v.MinSilenceMs != 100 { + t.Fatalf("unexpected defaults: %+v", v) + } +} diff --git a/config/whisper.go b/config/whisper.go new file mode 100644 index 0000000..624f0ee --- /dev/null +++ b/config/whisper.go @@ -0,0 +1,43 @@ +package config + +import "fmt" + +type Whisper struct { + OutputFormat []string `yaml:"output_format"` + Model string `yaml:"model"` + AudioPath string `yaml:"audio_path"` + Language string `yaml:"language"` + Prompt string `yaml:"prompt"` + OutputFolder string `yaml:"output_folder"` + OutputFilename string `yaml:"output_filename"` + Threads uint `yaml:"threads"` + MaxContext uint `yaml:"max_context"` + BeamSize uint `yaml:"beam_size"` + EntropyThold float64 `yaml:"entropy_thold"` + VAD VAD `yaml:"vad"` + Debug bool `yaml:"debug"` + SpeedUp bool `yaml:"speedup"` + Translate bool `yaml:"translate"` + PrintProgress bool `yaml:"print_progress"` + PrintSegment bool `yaml:"print_segment"` +} + +func (c *Whisper) ValidateModel() error { + if c.Model == "" { + return fmt.Errorf("model is required") + } + return nil +} + +func (c *Whisper) Validate() error { + if err := c.ValidateModel(); err != nil { + return err + } + if c.AudioPath == "" { + return fmt.Errorf("audio path is required") + } + if err := c.VAD.WithDefaults().Validate(); err != nil { + return err + } + return nil +} diff --git a/config/xlm-roberta-model.yaml b/config/xlm-roberta-model.yaml new file mode 100644 index 0000000..f6852e6 --- /dev/null +++ b/config/xlm-roberta-model.yaml @@ -0,0 +1,81 @@ +# Metadata for Salama1429/xlm-roberta_punctuation_fullstop_truecase (ONNX punctuation). +# Install into the model directory: +# cp config/xlm-roberta-model.yaml models/punctuation/xlm-roberta/config.yaml +# or: make install-xlm-punctuation-config + +languages: + - af + - am + - ar + - bg + - bn + - de + - el + - en + - es + - et + - fa + - fi + - fr + - gu + - hi + - hr + - hu + - id + - is + - it + - ja + - kk + - kn + - ko + - ky + - lt + - lv + - mk + - ml + - mr + - nl + - or + - pa + - pl + - ps + - pt + - ro + - ru + - rw + - so + - sr + - sw + - ta + - te + - tr + - uk + - zh + +max_length: 256 + +pre_labels: + - "" + - "¿" + +post_labels: + - "" + - "" + - "." + - "," + - "?" + - "?" + - "," + - "。" + - "、" + - "・" + - "।" + - "؟" + - "،" + - ";" + - "።" + - "፣" + - "፧" + +null_token: "" +acronym_token: "" diff --git a/diarization/diarization.go b/diarization/diarization.go new file mode 100644 index 0000000..a884d4a --- /dev/null +++ b/diarization/diarization.go @@ -0,0 +1,19 @@ +package diarization + +import ( + "context" + + "go-whisper-api/config" + "go-whisper-api/whisper" +) + +// Engine runs offline speaker diarization when built with -tags sherpa. +type Engine interface { + Active() bool + Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) + Close() +} + +func New(cfg config.Diarization) (Engine, error) { + return newEngine(cfg) +} diff --git a/diarization/sherpa.go b/diarization/sherpa.go new file mode 100644 index 0000000..fd495c1 --- /dev/null +++ b/diarization/sherpa.go @@ -0,0 +1,107 @@ +//go:build sherpa + +package diarization + +import ( + "context" + "fmt" + "os" + "sync" + + "go-whisper-api/config" + "go-whisper-api/whisper" + + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" +) + +type sherpaEngine struct { + cfg config.Diarization + sd *sherpa.OfflineSpeakerDiarization + mu sync.Mutex +} + +func newEngine(cfg config.Diarization) (Engine, error) { + cfg = cfg.WithDefaults() + if !cfg.Active() { + return &noopEngine{}, nil + } + if !cfg.ModelsPresent() { + return nil, fmt.Errorf("diarization models missing (run: make download-diarization-models)") + } + conf := &sherpa.OfflineSpeakerDiarizationConfig{ + Segmentation: sherpa.OfflineSpeakerSegmentationModelConfig{ + Pyannote: sherpa.OfflineSpeakerSegmentationPyannoteModelConfig{ + Model: cfg.SegmentationPath(), + }, + NumThreads: cfg.NumThreads, + Debug: 0, + Provider: "cpu", + }, + Embedding: sherpa.SpeakerEmbeddingExtractorConfig{ + Model: cfg.EmbeddingPath(), + NumThreads: cfg.NumThreads, + Debug: 0, + Provider: "cpu", + }, + Clustering: sherpa.FastClusteringConfig{ + NumClusters: cfg.NumClusters, + Threshold: cfg.ClusteringThreshold, + }, + MinDurationOn: cfg.MinDurationOn, + MinDurationOff: cfg.MinDurationOff, + } + sd := sherpa.NewOfflineSpeakerDiarization(conf) + if sd == nil { + return nil, fmt.Errorf("failed to create sherpa speaker diarization") + } + return &sherpaEngine{cfg: cfg, sd: sd}, nil +} + +type noopEngine struct{} + +func (noopEngine) Active() bool { return false } +func (noopEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { return nil, nil } +func (noopEngine) Close() {} + +func (e *sherpaEngine) Active() bool { + return e.sd != nil +} + +func (e *sherpaEngine) Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) { + _ = ctx + if len(samples) == 0 { + return nil, fmt.Errorf("empty audio for diarization") + } + e.mu.Lock() + defer e.mu.Unlock() + if _, err := os.Stat(e.cfg.SegmentationPath()); err != nil { + return nil, err + } + clusters := numClusters + if clusters <= 0 { + clusters = e.cfg.NumClusters + } + e.sd.SetConfig(&sherpa.OfflineSpeakerDiarizationConfig{ + Clustering: sherpa.FastClusteringConfig{ + NumClusters: clusters, + Threshold: e.cfg.ClusteringThreshold, + }, + }) + segments := e.sd.Process(samples) + out := make([]whisper.Turn, len(segments)) + for i, s := range segments { + out[i] = whisper.Turn{ + Start: s.Start, + End: s.End, + Speaker: s.Speaker, + } + } + return out, nil +} + +func (e *sherpaEngine) Close() { + if e.sd != nil { + sherpa.DeleteOfflineSpeakerDiarization(e.sd) + e.sd = nil + } +} diff --git a/diarization/stub.go b/diarization/stub.go new file mode 100644 index 0000000..69383be --- /dev/null +++ b/diarization/stub.go @@ -0,0 +1,29 @@ +//go:build !sherpa + +package diarization + +import ( + "context" + "fmt" + + "go-whisper-api/config" + "go-whisper-api/whisper" +) + +type stubEngine struct { + cfg config.Diarization +} + +func newEngine(cfg config.Diarization) (Engine, error) { + return &stubEngine{cfg: cfg.WithDefaults()}, nil +} + +func (s *stubEngine) Active() bool { + return false +} + +func (s *stubEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { + return nil, fmt.Errorf("speaker diarization requires build with -tags sherpa and models (make build-sherpa, make download-diarization-models)") +} + +func (s *stubEngine) Close() {} diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..651f897 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,55 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG CUDA_VERSION=12.0.0 +# Target the CUDA build image +ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} +# Target the CUDA runtime image +ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} + +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS build +WORKDIR /app +# Unless otherwise specified, we make a fat build. +ARG CUDA_DOCKER_ARCH=all +# Set nvcc architecture +ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} +# Enable cuBLAS +ENV WHISPER_CUBLAS=1 + +#apt-get +RUN apt-get update && \ + apt-get install -y --no-install-recommends build-essential git gcc g++ wget \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +# install golang +RUN wget --progress=dot:giga https://go.dev/dl/go1.22.10.linux-amd64.tar.gz +RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz +ENV PATH ${PATH}:/usr/local/go/bin + +# Ref: https://stackoverflow.com/a/53464012 +ENV CUDA_MAIN_VERSION=12.0 +ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH + +COPY ./ . +RUN make dependency && env && make build && \ + mv bin/go-whisper-api /bin/ && \ + rm -rf bin + +FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} AS runtime +WORKDIR /app + +LABEL maintainer="Bo-Yi Wu " \ + org.label-schema.name="Speech-to-Text" \ + org.label-schema.vendor="Bo-Yi Wu" \ + org.label-schema.schema-version="1.0" + +LABEL org.opencontainers.image.source=https://github.com/appleboy/go-whisper-api +LABEL org.opencontainers.image.description="Speech-to-Text." +LABEL org.opencontainers.image.licenses=MIT + +RUN apt-get update && \ + apt-get install -y --no-install-recommends curl \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY --from=build /bin/go-whisper-api /bin/go-whisper-api +EXPOSE 8080 +ENTRYPOINT ["/bin/go-whisper-api"] \ No newline at end of file diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci new file mode 100644 index 0000000..1e1ff27 --- /dev/null +++ b/docker/Dockerfile.ci @@ -0,0 +1,27 @@ +# CPU image for Gitea CI and hosts without NVIDIA GPU. +# Production GPU image: docker/Dockerfile + +FROM golang:1.23-bookworm AS build +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git cmake pkg-config libsentencepiece-dev \ + && rm -rf /var/lib/apt/lists/* + +COPY . . +RUN make dependency && make build-xlm + +FROM debian:bookworm-slim AS runtime +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=build /app/bin/go-whisper-api /bin/go-whisper-api +COPY --from=build /app/third_party/whisper.cpp/build/src/libwhisper.so* /usr/local/lib/ +COPY --from=build /app/third_party/whisper.cpp/build/ggml/src/libggml*.so* /usr/local/lib/ +ENV LD_LIBRARY_PATH=/usr/local/lib + +EXPOSE 8080 +ENTRYPOINT ["/bin/go-whisper-api"] diff --git a/garbage/filter.go b/garbage/filter.go new file mode 100644 index 0000000..2e33098 --- /dev/null +++ b/garbage/filter.go @@ -0,0 +1,56 @@ +package garbage + +import ( + "regexp" + "strings" +) + +var spaceCollapse = regexp.MustCompile(`\s+`) + +// Word is a timed token for garbage filtering (mirrors whisper.Word JSON shape). +type Word struct { + Word string `json:"word"` + Start int `json:"start"` + Stop int `json:"stop"` +} + +// FilterText removes configured artifact substrings and normalizes whitespace. +func FilterText(text string, patterns []string) string { + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + text = strings.ReplaceAll(text, p, " ") + } + return strings.TrimSpace(spaceCollapse.ReplaceAllString(text, " ")) +} + +// FilterWords drops tokens that match any garbage pattern. +func FilterWords(words []Word, patterns []string) []Word { + if len(words) == 0 { + return words + } + out := make([]Word, 0, len(words)) + for _, w := range words { + if matchesGarbage(w.Word, patterns) { + continue + } + out = append(out, w) + } + return out +} + +func matchesGarbage(word string, patterns []string) bool { + word = strings.TrimSpace(word) + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if word == p || strings.Contains(word, p) { + return true + } + } + return false +} diff --git a/garbage/filter_test.go b/garbage/filter_test.go new file mode 100644 index 0000000..5996bae --- /dev/null +++ b/garbage/filter_test.go @@ -0,0 +1,30 @@ +package garbage + +import "testing" + +func TestFilterText(t *testing.T) { + in := "Привет *выбая* мир *выбая*" + got := FilterText(in, []string{"*выбая*"}) + want := "Привет мир" + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestFilterWords(t *testing.T) { + words := []Word{ + {Word: "Что", Start: 0, Stop: 100}, + {Word: "*выбая*", Start: 100, Stop: 200}, + {Word: "мир", Start: 200, Stop: 300}, + } + got := FilterWords(words, []string{"*выбая*"}) + if len(got) != 2 || got[1].Word != "мир" { + t.Fatalf("got %+v", got) + } +} + +func TestFilterText_emptyPatterns(t *testing.T) { + if got := FilterText("a b", nil); got != "a b" { + t.Fatalf("got %q", got) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b6ab674 --- /dev/null +++ b/go.mod @@ -0,0 +1,48 @@ +module go-whisper-api + +go 1.26 + +require ( + github.com/Eyevinn/mp4ff v0.51.0 + github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27 + github.com/go-audio/audio v1.0.0 + github.com/go-audio/wav v1.1.0 + github.com/google/uuid v1.6.0 + github.com/gopxl/beep v1.4.1 + github.com/joho/godotenv v1.5.1 + github.com/k2-fsa/sherpa-onnx-go v1.13.2 + github.com/mattn/go-isatty v0.0.20 + github.com/olivier-w/climp-aac-decoder v0.1.0 + github.com/pion/opus v0.0.0-20260601214817-71d58474cec8 + github.com/rs/zerolog v1.35.0 + github.com/skrashevich/go-aac v0.1.0 + github.com/urfave/cli/v2 v2.27.7 + github.com/yalue/onnxruntime_go v1.24.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/go-audio/riff v1.0.0 // indirect + github.com/hajimehoshi/go-mp3 v0.3.4 // indirect + github.com/icza/bitio v1.1.0 // indirect + github.com/jfreymuth/oggvorbis v1.0.5 // indirect + github.com/jfreymuth/vorbis v1.0.2 // indirect + github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 // indirect + github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 // indirect + github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mewkiz/flac v1.0.8 // indirect + github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect + golang.org/x/sys v0.45.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect +) + +replace github.com/ggerganov/whisper.cpp/bindings/go => ./third_party/whisper.cpp/bindings/go diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7ae66e5 --- /dev/null +++ b/go.sum @@ -0,0 +1,127 @@ +github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M= +github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg= +github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= +github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/ebitengine/oto/v3 v3.1.0 h1:9tChG6rizyeR2w3vsygTTTVVJ9QMMyu00m2yBOCch6U= +github.com/ebitengine/oto/v3 v3.1.0/go.mod h1:IK1QTnlfZK2GIB6ziyECm433hAdTaPpOsGMLhEyEGTg= +github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA= +github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= +github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= +github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= +github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= +github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= +github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= +github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= +github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= +github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopxl/beep v1.4.1 h1:WqNs9RsDAhG9M3khMyc1FaVY50dTdxG/6S6a3qsUHqE= +github.com/gopxl/beep v1.4.1/go.mod h1:A1dmiUkuY8kxsvcNJNUBIEcchmiP6eUyCHSxpXl0YO0= +github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= +github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= +github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo= +github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0= +github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A= +github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k= +github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA= +github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ= +github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII= +github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE= +github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/jszwec/csvutil v1.5.1/go.mod h1:Rpu7Uu9giO9subDyMCIQfHVDuLrcaC36UA4YcJjGBkg= +github.com/k2-fsa/sherpa-onnx-go v1.13.2 h1:1orlwwdJcVk37OoV1gi1L6fIdeZHKJcLQjm6S5h/WWs= +github.com/k2-fsa/sherpa-onnx-go v1.13.2/go.mod h1:NyA3OsF/hmcvk4FpbXpBO9Rl/I3dlnrOB9Ny+TnhvIc= +github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 h1:w/RHaU9liD/ovA9Q5iI2lCoVjVEd4M8+uTrf54rjQPY= +github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2/go.mod h1:NXEH2rsBgTdqY59YpPq6CtSBlBAXy/8a9FmpLERU97I= +github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 h1:oIIdSfU3NEMT7oq7yxoH7Rk37kekkbmr/b+7e4er1WE= +github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2/go.mod h1:ZOhUAXC62Unj0ZNfu6zxSFKcW96aXf7P3BsqiUyOBbE= +github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 h1:tkfwXnmJRsChv59ZzLsycphwpvJpSkyd29aMG9kUoEE= +github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2/go.mod h1:5AX7TU8+P/gInjglY1ijtWUM2b8iyR0QX4yEngzMe64= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mewkiz/flac v1.0.8 h1:cophRjvafteDGmqsfXRK28YAX6l8wy19QxTHruEEg1s= +github.com/mewkiz/flac v1.0.8/go.mod h1:l7dt5uFY724eKVkHQtAJAQSkhpC3helU3RDxN0ESAqo= +github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 h1:tnAPMExbRERsyEYkmR1YjhTgDM0iqyiBYf8ojRXxdbA= +github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14/go.mod h1:QYCFBiH5q6XTHEbWhR0uhR3M9qNPoD2CSQzr0g75kE4= +github.com/olivier-w/climp-aac-decoder v0.1.0 h1:jdwYeBlnfQlnm6KA/fG20s22HrlBlVVaaWQTj/tnb3A= +github.com/olivier-w/climp-aac-decoder v0.1.0/go.mod h1:Hpqi8cI4NeN4JEcgKxiAK0zEwx+hO+lWdg9Q3ruHouI= +github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw= +github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0= +github.com/pion/opus v0.0.0-20260601214817-71d58474cec8 h1:THWkLaX+FzbTbHIo1irDpbBMXbNrqV6q2Xs+00p4d+0= +github.com/pion/opus v0.0.0-20260601214817-71d58474cec8/go.mod h1:t5Xog2n682JnawoykACE6nKVmupFvmJvkpM7x6bTv6g= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI= +github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/skrashevich/go-aac v0.1.0 h1:7oHNj1ADmgfjAHvi3wAIFbmbCpQBrcjZEVTLlRtAS1A= +github.com/skrashevich/go-aac v0.1.0/go.mod h1:Mj7r//4LDL4FC0ezORj+MnmQ+nDEkJhTOy2aMC8dzww= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= +github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= +github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= +github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +github.com/yalue/onnxruntime_go v1.24.0 h1:IdgJLxxyotlsUTmL1UnHZgBzXJGgY51LZ4vQ5rZeOXU= +github.com/yalue/onnxruntime_go v1.24.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..14ff581 --- /dev/null +++ b/main.go @@ -0,0 +1,251 @@ +package main + +import ( + "os" + "runtime" + "strconv" + "time" + + "go-whisper-api/api" + "go-whisper-api/config" + + _ "github.com/joho/godotenv/autoload" + "github.com/mattn/go-isatty" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/urfave/cli/v2" +) + +var ( + Version string +) + +func main() { + isTerm := isatty.IsTerminal(os.Stdout.Fd()) + zerolog.SetGlobalLevel(zerolog.InfoLevel) + log.Logger = log.Output( + zerolog.ConsoleWriter{ + Out: os.Stderr, + NoColor: !isTerm, + }, + ) + zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string { + short := file + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + file = short + return file + ":" + strconv.Itoa(line) + } + app := cli.NewApp() + app.Name = "go-whisper-api" + app.Usage = "HTTP API for speech-to-text (SPR + OpenAI-compatible)." + app.Copyright = "Copyright (c) " + strconv.Itoa(time.Now().Year()) + " Bo-Yi Wu" + app.Authors = []*cli.Author{ + { + Name: "Bo-Yi Wu", + Email: "appleboy.tw@gmail.com", + }, + } + app.Version = Version + app.Commands = []*cli.Command{ + { + Name: "serve", + Aliases: []string{"s", "api"}, + Usage: "start HTTP API server", + Flags: serveFlags(), + Action: runServe, + }, + } + app.Flags = append(append([]cli.Flag{configFlag()}, punctuationFlags()...), serveFlags()...) + app.Action = runServe + if err := app.Run(os.Args); err != nil { + log.Fatal().Err(err).Msg("can't run app") + } +} + +func configFlag() cli.Flag { + return &cli.StringFlag{ + Name: "config", + Usage: "path to YAML config file (default: config.yaml if present)", + EnvVars: []string{"CONFIG_PATH", "GO_WHISPER_CONFIG"}, + } +} + +func punctuationFlags() []cli.Flag { + return []cli.Flag{ + &cli.BoolFlag{ + Name: "punctuation-enabled", + Usage: "master switch: enable or disable punctuation (YAML: punctuation.enabled)", + EnvVars: []string{"PUNCTUATION_ENABLED"}, + }, + &cli.BoolFlag{ + Name: "punctuation-default-on", + Usage: "apply punctuation by default when query/flags omitted (YAML: punctuation.default_on)", + }, + &cli.StringFlag{ + Name: "punctuation-engine", + Usage: "punctuation engine: off, heuristic, sherpa, sherpa-online, http, xlm", + EnvVars: []string{"PUNCTUATION_ENGINE"}, + }, + } +} + +func vadFlags() []cli.Flag { + return []cli.Flag{ + &cli.BoolFlag{ + Name: "vad", + Usage: "enable voice activity detection (Silero VAD)", + EnvVars: []string{"API_VAD"}, + }, + &cli.StringFlag{ + Name: "vad-model", + Usage: "path to ggml Silero VAD model (e.g. models/ggml-silero-v6.2.0.bin)", + EnvVars: []string{"API_VAD_MODEL"}, + }, + &cli.Float64Flag{ + Name: "vad-threshold", + Usage: "VAD speech threshold (0.0-1.0)", + EnvVars: []string{"API_VAD_THRESHOLD"}, + }, + &cli.IntFlag{ + Name: "vad-min-speech-ms", + Usage: "minimum speech duration in ms", + EnvVars: []string{"API_VAD_MIN_SPEECH_MS"}, + }, + &cli.IntFlag{ + Name: "vad-min-silence-ms", + Usage: "minimum silence between segments in ms", + EnvVars: []string{"API_VAD_MIN_SILENCE_MS"}, + }, + &cli.Float64Flag{ + Name: "vad-max-speech-sec", + Usage: "maximum speech segment length in seconds (0 = unlimited)", + EnvVars: []string{"API_VAD_MAX_SPEECH_SEC"}, + }, + &cli.IntFlag{ + Name: "vad-speech-pad-ms", + Usage: "padding around speech segments in ms", + EnvVars: []string{"API_VAD_SPEECH_PAD_MS"}, + }, + &cli.Float64Flag{ + Name: "vad-samples-overlap", + Usage: "overlap between VAD segments in seconds", + EnvVars: []string{"API_VAD_SAMPLES_OVERLAP"}, + }, + } +} + +func serveFlags() []cli.Flag { + flags := []cli.Flag{ + &cli.StringFlag{ + Name: "addr", + Usage: "HTTP listen address", + EnvVars: []string{"API_ADDR"}, + Value: ":8080", + }, + &cli.StringFlag{ + Name: "models-dir", + Usage: "directory with ggml *.bin whisper models", + EnvVars: []string{"API_MODELS_DIR"}, + Value: "./models", + }, + &cli.StringFlag{ + Name: "cache-dir", + Usage: "directory for async task cache (waiting/ready)", + EnvVars: []string{"API_CACHE_DIR"}, + Value: "./cache", + }, + &cli.StringFlag{ + Name: "language", + Usage: "default language for speech recognition", + EnvVars: []string{"API_LANGUAGE"}, + Value: "auto", + }, + &cli.UintFlag{ + Name: "threads", + Usage: "number of threads for whisper", + EnvVars: []string{"API_THREADS"}, + Value: uint(runtime.NumCPU()), + }, + &cli.BoolFlag{ + Name: "debug", + Usage: "enable debug mode", + EnvVars: []string{"API_DEBUG"}, + }, + &cli.BoolFlag{ + Name: "speedup", + Usage: "speed up audio by x2", + EnvVars: []string{"API_SPEEDUP"}, + }, + &cli.BoolFlag{ + Name: "translate", + Usage: "translate to english", + EnvVars: []string{"API_TRANSLATE"}, + }, + &cli.StringFlag{ + Name: "prompt", + Usage: "initial prompt", + EnvVars: []string{"API_PROMPT"}, + }, + &cli.UintFlag{ + Name: "max-context", + Usage: "maximum text context tokens", + EnvVars: []string{"API_MAX_CONTEXT"}, + Value: 32, + }, + &cli.UintFlag{ + Name: "beam-size", + Usage: "beam size for beam search", + EnvVars: []string{"API_BEAM_SIZE"}, + Value: 5, + }, + &cli.Float64Flag{ + Name: "entropy-thold", + Usage: "entropy threshold", + EnvVars: []string{"API_ENTROPY_THOLD"}, + Value: 2.4, + }, + &cli.BoolFlag{ + Name: "default-punctuation", + Usage: "enable punctuation on STT by default", + EnvVars: []string{"API_DEFAULT_PUNCTUATION"}, + }, + } + return append(flags, vadFlags()...) +} + +func runServe(c *cli.Context) error { + apiCfg, err := config.APIFromCLI(c) + if err != nil { + return err + } + if apiCfg.Debug { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + log.Logger = log.With().Caller().Logger() + } + vad := apiCfg.VAD.WithDefaults() + if vad.Enabled { + vad.Model = vad.ResolveModelPath(apiCfg.ModelsDir) + if err := vad.Validate(); err != nil { + return err + } + apiCfg.VAD = vad + } + tc, err := config.TranscodeFromCLI(c) + if err != nil { + return err + } + pc, err := config.PunctuationFromCLI(c) + if err != nil { + return err + } + dc, err := config.DiarizationFromCLI(c) + if err != nil { + return err + } + return api.Run(c.Context, apiCfg, tc, pc, dc) +} diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/diarization/.gitkeep b/models/diarization/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/diarization/pyannote-segmentation-3-0/.gitkeep b/models/diarization/pyannote-segmentation-3-0/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/punctuation/.gitkeep b/models/punctuation/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/punctuation/xlm-roberta/.gitkeep b/models/punctuation/xlm-roberta/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/models/vad/.gitkeep b/models/vad/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/punctuation/heuristic.go b/punctuation/heuristic.go new file mode 100644 index 0000000..441b931 --- /dev/null +++ b/punctuation/heuristic.go @@ -0,0 +1,86 @@ +package punctuation + +import ( + "context" + "regexp" + "strings" + "unicode" + "unicode/utf8" +) + +type Heuristic struct{} + +func (Heuristic) Active() bool { + return true +} + +func (Heuristic) Restore(ctx context.Context, text, language string) (string, error) { + _ = ctx + text = strings.TrimSpace(text) + if text == "" { + return text, nil + } + text = normalizeSpaces(text) + text = capitalizeFirst(text) + lang := strings.ToLower(strings.TrimSpace(language)) + if lang == "ru" || lang == "rus" || lang == "russian" || lang == "auto" { + text = heuristicRU(text) + } else { + text = heuristicEN(text) + } + return ensureTerminalPunct(text), nil +} + +func normalizeSpaces(s string) string { + return strings.Join(strings.Fields(s), " ") +} + +func capitalizeFirst(s string) string { + r, size := utf8.DecodeRuneInString(s) + if r == utf8.RuneError { + return s + } + return string(unicode.ToUpper(r)) + s[size:] +} + +var ( + reQuestionRU = regexp.MustCompile(`(?i)(^|.*\s)(как|что|где|когда|почему|зачем|кто|чей|какой|какая|какое|какие|сколько|зачем|откуда|куда|ли)(\s+[^.?!]+)$`) + reQuestionEN = regexp.MustCompile(`(?i)^(who|what|when|where|why|how|which|whose|whom|is|are|am|was|were|do|does|did|can|could|would|will|shall|should)\b`) +) + +func heuristicRU(s string) string { + if reQuestionRU.MatchString(s) && !strings.HasSuffix(s, "?") { + return s + "?" + } + if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 { + return s + "." + } + return s +} + +func heuristicEN(s string) string { + lower := strings.ToLower(s) + if reQuestionEN.MatchString(lower) && !strings.HasSuffix(s, "?") { + return s + "?" + } + if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 { + return s + "." + } + return s +} + +func hasTerminalPunct(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + r, _ := utf8.DecodeLastRuneInString(s) + return r == '.' || r == '?' || r == '!' || r == '…' +} + +func ensureTerminalPunct(s string) string { + if hasTerminalPunct(s) { + return s + } + return s + "." +} diff --git a/punctuation/heuristic_test.go b/punctuation/heuristic_test.go new file mode 100644 index 0000000..b97fd8c --- /dev/null +++ b/punctuation/heuristic_test.go @@ -0,0 +1,32 @@ +package punctuation + +import ( + "context" + "testing" +) + +func TestHeuristicRU_question(t *testing.T) { + h := Heuristic{} + out, err := h.Restore(context.Background(), "как дела", "ru") + if err != nil { + t.Fatal(err) + } + if !stringsHasSuffix(out, "?") { + t.Fatalf("expected question mark, got %q", out) + } +} + +func TestHeuristicEN_period(t *testing.T) { + h := Heuristic{} + out, err := h.Restore(context.Background(), "hello world", "en") + if err != nil { + t.Fatal(err) + } + if !stringsHasSuffix(out, ".") { + t.Fatalf("expected period, got %q", out) + } +} + +func stringsHasSuffix(s, suffix string) bool { + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} diff --git a/punctuation/internal/spwrap/sp.go b/punctuation/internal/spwrap/sp.go new file mode 100644 index 0000000..16a64b8 --- /dev/null +++ b/punctuation/internal/spwrap/sp.go @@ -0,0 +1,92 @@ +//go:build xlm + +package spwrap + +/* +#cgo CXXFLAGS: -std=c++17 +#cgo LDFLAGS: -lsentencepiece +#include +#include "sp_wrap.h" +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +type Processor struct { + p *C.SPProcessor +} + +func Load(path string) (*Processor, error) { + cpath := C.CString(path) + defer C.free(unsafe.Pointer(cpath)) + var errMsg *C.char + p := C.sp_load(cpath, &errMsg) + if p == nil { + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return nil, fmt.Errorf("sentencepiece: %s", C.GoString(errMsg)) + } + return nil, fmt.Errorf("sentencepiece: failed to load %s", path) + } + return &Processor{p: p}, nil +} + +func (proc *Processor) Close() { + if proc.p != nil { + C.sp_free(proc.p) + proc.p = nil + } +} + +func (proc *Processor) BOSID() int { + return int(C.sp_bos_id(proc.p)) +} + +func (proc *Processor) EOSID() int { + return int(C.sp_eos_id(proc.p)) +} + +func (proc *Processor) PadID() int { + return int(C.sp_pad_id(proc.p)) +} + +func (proc *Processor) EncodeAsIDs(text string) ([]int, error) { + ctext := C.CString(text) + defer C.free(unsafe.Pointer(ctext)) + var ids *C.int + var n C.int + var errMsg *C.char + if C.sp_encode(proc.p, ctext, &ids, &n, &errMsg) == 0 { + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return nil, fmt.Errorf("sentencepiece encode: %s", C.GoString(errMsg)) + } + return nil, fmt.Errorf("sentencepiece encode failed") + } + if ids == nil || n == 0 { + return nil, nil + } + defer C.free(unsafe.Pointer(ids)) + out := make([]int, int(n)) + slice := unsafe.Slice(ids, int(n)) + for i := range out { + out[i] = int(slice[i]) + } + return out, nil +} + +func (proc *Processor) IDToPiece(id int) (string, error) { + var errMsg *C.char + piece := C.sp_id_to_piece(proc.p, C.int(id), &errMsg) + if piece == nil { + if errMsg != nil { + defer C.free(unsafe.Pointer(errMsg)) + return "", fmt.Errorf("sentencepiece id to piece: %s", C.GoString(errMsg)) + } + return "", fmt.Errorf("sentencepiece id to piece failed") + } + defer C.free(unsafe.Pointer(piece)) + return C.GoString(piece), nil +} diff --git a/punctuation/internal/spwrap/sp_wrap.cc b/punctuation/internal/spwrap/sp_wrap.cc new file mode 100644 index 0000000..f63ecb5 --- /dev/null +++ b/punctuation/internal/spwrap/sp_wrap.cc @@ -0,0 +1,88 @@ +#include "sp_wrap.h" + +#include + +#include +#include +#include +#include + +struct SPProcessor { + sentencepiece::SentencePieceProcessor proc; +}; + +static char *copy_err(const std::string &msg) { + char *out = static_cast(std::malloc(msg.size() + 1)); + if (out != nullptr) { + std::memcpy(out, msg.c_str(), msg.size() + 1); + } + return out; +} + +SPProcessor *sp_load(const char *path, char **err) { + if (err != nullptr) { + *err = nullptr; + } + auto *p = new SPProcessor(); + const auto status = p->proc.Load(path); + if (!status.ok()) { + if (err != nullptr) { + *err = copy_err(status.ToString()); + } + delete p; + return nullptr; + } + return p; +} + +void sp_free(SPProcessor *p) { delete p; } + +int sp_bos_id(const SPProcessor *p) { return p->proc.bos_id(); } + +int sp_eos_id(const SPProcessor *p) { return p->proc.eos_id(); } + +int sp_pad_id(const SPProcessor *p) { return p->proc.pad_id(); } + +int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err) { + if (err != nullptr) { + *err = nullptr; + } + if (out_ids != nullptr) { + *out_ids = nullptr; + } + if (out_len != nullptr) { + *out_len = 0; + } + std::vector ids; + const auto status = p->proc.Encode(text, &ids); + if (!status.ok()) { + if (err != nullptr) { + *err = copy_err(status.ToString()); + } + return 0; + } + if (ids.empty()) { + return 1; + } + int *buf = static_cast(std::malloc(sizeof(int) * ids.size())); + if (buf == nullptr) { + if (err != nullptr) { + *err = copy_err("malloc failed"); + } + return 0; + } + for (size_t i = 0; i < ids.size(); i++) { + buf[i] = ids[i]; + } + *out_ids = buf; + *out_len = static_cast(ids.size()); + return 1; +} + +char *sp_id_to_piece(const SPProcessor *p, int id, char **err) { + if (err != nullptr) { + *err = nullptr; + } + const std::string piece = p->proc.IdToPiece(id); + return copy_err(piece); +} diff --git a/punctuation/internal/spwrap/sp_wrap.h b/punctuation/internal/spwrap/sp_wrap.h new file mode 100644 index 0000000..2695240 --- /dev/null +++ b/punctuation/internal/spwrap/sp_wrap.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct SPProcessor SPProcessor; + +SPProcessor *sp_load(const char *path, char **err); +void sp_free(SPProcessor *p); + +int sp_bos_id(const SPProcessor *p); +int sp_eos_id(const SPProcessor *p); +int sp_pad_id(const SPProcessor *p); + +int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err); +char *sp_id_to_piece(const SPProcessor *p, int id, char **err); + +#ifdef __cplusplus +} +#endif diff --git a/punctuation/ort_env.go b/punctuation/ort_env.go new file mode 100644 index 0000000..ab0f97e --- /dev/null +++ b/punctuation/ort_env.go @@ -0,0 +1,177 @@ +//go:build xlm + +package punctuation + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + ort "github.com/yalue/onnxruntime_go" +) + +var ( + ortOnce sync.Once + ortErr error +) + +func ensureORT() error { + ortOnce.Do(func() { + if p := resolveONNXRuntimeLib(); p != "" { + ort.SetSharedLibraryPath(p) + } + ortErr = ort.InitializeEnvironment() + }) + return ortErr +} + +func resolveONNXRuntimeLib() string { + if p := strings.TrimSpace(os.Getenv("ONNXRUNTIME_SHARED_LIBRARY_PATH")); p != "" { + return p + } + for _, p := range onnxRuntimeCandidates() { + if st, err := os.Stat(p); err == nil && !st.IsDir() { + return p + } + } + return "" +} + +func onnxRuntimeCandidates() []string { + arch := sherpaLibArch() + ver := sherpaLinuxModuleVersion() + var out []string + for _, root := range goModCacheRoots() { + if ver != "" { + out = append(out, filepath.Join(root, + "github.com/k2-fsa/sherpa-onnx-go-linux@"+ver, + "lib", arch, "libonnxruntime.so")) + } + } + if exe, err := os.Executable(); err == nil { + exeDir := filepath.Dir(exe) + out = append(out, + filepath.Join(exeDir, "libonnxruntime.so"), + filepath.Join(exeDir, "lib", "libonnxruntime.so"), + filepath.Join(exeDir, "..", "lib", "libonnxruntime.so"), + ) + if modRoot := findModuleRoot(exeDir); modRoot != "" { + out = append(out, filepath.Join(modRoot, "lib", "libonnxruntime.so")) + } + } + return out +} + +func goModCacheRoots() []string { + var roots []string + seen := map[string]struct{}{} + add := func(p string) { + p = strings.TrimSpace(p) + if p == "" { + return + } + if _, ok := seen[p]; ok { + return + } + seen[p] = struct{}{} + roots = append(roots, p) + } + add(os.Getenv("GOMODCACHE")) + if gopath := os.Getenv("GOPATH"); gopath != "" { + for _, gp := range filepath.SplitList(gopath) { + add(filepath.Join(gp, "pkg", "mod")) + } + } + if home, err := os.UserHomeDir(); err == nil { + add(filepath.Join(home, "go", "pkg", "mod")) + } + return roots +} + +func findModuleRoot(start string) string { + dir := start + for i := 0; i < 8; i++ { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + if cwd, err := os.Getwd(); err == nil { + return findModuleRootFrom(cwd) + } + return "" +} + +func findModuleRootFrom(dir string) string { + for i := 0; i < 8; i++ { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return "" +} + +func sherpaLinuxModuleVersion() string { + for _, dir := range []string{findModuleRootFrom(mustCwd()), ""} { + if dir == "" { + continue + } + if v := readSherpaVersion(filepath.Join(dir, "go.mod")); v != "" { + return v + } + } + if exe, err := os.Executable(); err == nil { + if root := findModuleRoot(filepath.Dir(exe)); root != "" { + return readSherpaVersion(filepath.Join(root, "go.mod")) + } + } + return "" +} + +func readSherpaVersion(goModPath string) string { + data, err := os.ReadFile(goModPath) + if err != nil { + return "" + } + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if !strings.Contains(line, "github.com/k2-fsa/sherpa-onnx-go-linux") { + continue + } + fields := strings.Fields(line) + if len(fields) >= 2 { + return fields[1] // e.g. v1.13.2 — must match pkg/mod path @v1.13.2 + } + } + return "" +} + +func sherpaLibArch() string { + switch runtime.GOARCH { + case "arm64": + return "aarch64-unknown-linux-gnu" + case "arm": + return "arm-unknown-linux-gnueabihf" + default: + return "x86_64-unknown-linux-gnu" + } +} + +func mustCwd() string { + cwd, err := os.Getwd() + if err != nil { + return "" + } + return cwd +} diff --git a/punctuation/punctuation.go b/punctuation/punctuation.go new file mode 100644 index 0000000..3a995fb --- /dev/null +++ b/punctuation/punctuation.go @@ -0,0 +1,154 @@ +package punctuation + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + + "go-whisper-api/config" +) + +type Restorer interface { + Active() bool + Restore(ctx context.Context, text, language string) (string, error) +} + +type Closer interface { + Close() +} + +func New(cfg config.Punctuation) (Restorer, error) { + cfg = cfg.WithDefaults() + if !cfg.Active() { + return Nop{}, nil + } + engine := strings.ToLower(strings.TrimSpace(cfg.Engine)) + switch engine { + case "heuristic": + return Heuristic{}, nil + case "sherpa", "sherpa-offline", "offline": + cfg.Engine = "sherpa" + return newSherpaRestorer(cfg) + case "sherpa-online", "online": + cfg.Engine = "sherpa-online" + return newSherpaRestorer(cfg) + case "xlm", "xlm-roberta", "roberta": + return newXLM(cfg) + case "http": + if cfg.HTTPURL == "" { + return nil, fmt.Errorf("punctuation.http_url is required when engine=http") + } + return HTTP{cfg: cfg}, nil + default: + return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine) + } +} + +type Nop struct{} + +func (Nop) Active() bool { + return false +} + +func (Nop) Restore(ctx context.Context, text, language string) (string, error) { + return text, nil +} + +type HTTP struct { + cfg config.Punctuation + client *http.Client +} + +func (h HTTP) Active() bool { return true } + +func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) { + body, err := json.Marshal(map[string]string{ + "text": text, + "language": language, + }) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + client := h.client + if client == nil { + client = &http.Client{Timeout: h.cfg.Timeout()} + } + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + if resp.StatusCode >= 300 { + return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw))) + } + var out struct { + Text string `json:"text"` + } + if err := json.Unmarshal(raw, &out); err != nil { + return strings.TrimSpace(string(raw)), nil + } + if out.Text == "" { + return text, nil + } + return out.Text, nil +} + +func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) { + if !enabled || r == nil || !r.Active() { + return text, nil + } + text = strings.TrimSpace(text) + if text == "" { + return text, nil + } + return r.Restore(ctx, text, language) +} + +func Close(r Restorer) { + if c, ok := r.(Closer); ok { + c.Close() + } +} + +func AutoSelect(cfg config.Punctuation) (Restorer, error) { + cfg = cfg.WithDefaults() + if !cfg.Active() { + return Nop{}, nil + } + engine := strings.ToLower(cfg.Engine) + if engine == "heuristic" { + return Heuristic{}, nil + } + if engine == "http" { + return New(cfg) + } + if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" { + return newXLM(cfg) + } + if _, err := os.Stat(cfg.ModelPath()); err == nil { + cfg.Engine = engine + if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" { + cfg.Engine = "sherpa" + } + return newSherpaRestorer(cfg) + } + if engine == "sherpa" || engine == "sherpa-online" || engine == "online" { + return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath()) + } + return Heuristic{}, nil +} diff --git a/punctuation/punctuation_test.go b/punctuation/punctuation_test.go new file mode 100644 index 0000000..6149fca --- /dev/null +++ b/punctuation/punctuation_test.go @@ -0,0 +1,40 @@ +package punctuation + +import ( + "context" + "testing" + + "go-whisper-api/config" +) + +func TestNop(t *testing.T) { + r, err := New(config.Punctuation{Engine: "off"}) + if err != nil { + t.Fatal(err) + } + out, err := Apply(context.Background(), r, true, "hello world", "en") + if err != nil { + t.Fatal(err) + } + if out != "hello world" { + t.Fatalf("got %q", out) + } +} + +func TestApply_disabled(t *testing.T) { + r := Heuristic{} + out, err := Apply(context.Background(), r, false, "hello", "en") + if err != nil || out != "hello" { + t.Fatalf("got %q err=%v", out, err) + } +} + +func TestNew_heuristic(t *testing.T) { + r, err := New(config.Punctuation{Enabled: true, Engine: "heuristic"}) + if err != nil { + t.Fatal(err) + } + if !r.Active() { + t.Fatal("expected active") + } +} diff --git a/punctuation/sherpa.go b/punctuation/sherpa.go new file mode 100644 index 0000000..ab5871b --- /dev/null +++ b/punctuation/sherpa.go @@ -0,0 +1,107 @@ +//go:build sherpa + +package punctuation + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + + "go-whisper-api/config" + + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" +) + +type Sherpa struct { + offline *sherpa.OfflinePunctuation + online *sherpa.OnlinePunctuation + mu sync.Mutex +} + +func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) { + return newSherpa(cfg) +} + +func newSherpa(cfg config.Punctuation) (*Sherpa, error) { + cfg = cfg.WithDefaults() + engine := strings.ToLower(cfg.Engine) + s := &Sherpa{} + switch engine { + case "sherpa-online", "online": + modelPath := cfg.ModelPath() + vocabPath := cfg.BpeVocabPath() + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err) + } + if _, err := os.Stat(vocabPath); err != nil { + return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err) + } + conf := sherpa.OnlinePunctuationConfig{} + conf.Model.CnnBilstm = modelPath + conf.Model.BpeVocab = vocabPath + conf.Model.NumThreads = cfg.NumThreads + conf.Model.Provider = "cpu" + s.online = sherpa.NewOnlinePunctuation(&conf) + if s.online == nil { + return nil, fmt.Errorf("failed to create sherpa online punctuation") + } + default: + modelPath := cfg.ModelPath() + if _, err := os.Stat(modelPath); err != nil { + return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err) + } + conf := sherpa.OfflinePunctuationConfig{} + conf.Model.CtTransformer = modelPath + conf.Model.NumThreads = cfg.NumThreads + conf.Model.Provider = "cpu" + s.offline = sherpa.NewOfflinePunctuation(&conf) + if s.offline == nil { + return nil, fmt.Errorf("failed to create sherpa offline punctuation") + } + } + return s, nil +} + +func (s *Sherpa) Active() bool { + return true +} + +func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) { + _ = ctx + _ = language + text = strings.TrimSpace(text) + if text == "" { + return text, nil + } + s.mu.Lock() + defer s.mu.Unlock() + var out string + switch { + case s.offline != nil: + out = s.offline.AddPunct(text) + case s.online != nil: + out = s.online.AddPunct(text) + default: + return text, fmt.Errorf("sherpa punctuation not initialized") + } + out = strings.TrimSpace(out) + if out == "" { + return text, nil + } + return out, nil +} + +func (s *Sherpa) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if s.offline != nil { + sherpa.DeleteOfflinePunc(s.offline) + s.offline = nil + } + if s.online != nil { + sherpa.DeleteOnlinePunctuation(s.online) + s.online = nil + } +} diff --git a/punctuation/sherpa_stub.go b/punctuation/sherpa_stub.go new file mode 100644 index 0000000..cc4f44a --- /dev/null +++ b/punctuation/sherpa_stub.go @@ -0,0 +1,13 @@ +//go:build !sherpa + +package punctuation + +import ( + "fmt" + + "go-whisper-api/config" +) + +func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) { + return nil, fmt.Errorf("punctuation engine %q requires build tag sherpa (use: make build)", cfg.Engine) +} diff --git a/punctuation/xlm.go b/punctuation/xlm.go new file mode 100644 index 0000000..0bb4104 --- /dev/null +++ b/punctuation/xlm.go @@ -0,0 +1,319 @@ +//go:build xlm + +package punctuation + +import ( + "context" + "fmt" + "os" + "strings" + "unicode" + + "go-whisper-api/config" + "go-whisper-api/punctuation/internal/spwrap" + + ort "github.com/yalue/onnxruntime_go" +) + +type XLM struct { + cfg config.Punctuation + modelCfg xlmModelConfig + sp *spwrap.Processor + session *ort.DynamicAdvancedSession + inputName string + outputNames []string + joinSBD bool +} + +func newXLM(cfg config.Punctuation) (*XLM, error) { + if err := ensureORT(); err != nil { + return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err) + } + onnxPath := cfg.ModelPath() + if _, err := os.Stat(onnxPath); err != nil { + return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath) + } + spPath := cfg.SPModelPath() + if _, err := os.Stat(spPath); err != nil { + return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath) + } + cfgPath := cfg.XLMConfigPath() + modelCfg, err := loadXLMConfig(cfgPath) + if err != nil { + return nil, err + } + sp, err := spwrap.Load(spPath) + if err != nil { + return nil, err + } + inputs, outputs, err := ort.GetInputOutputInfo(onnxPath) + if err != nil { + sp.Close() + return nil, fmt.Errorf("xlm model io: %w", err) + } + if len(inputs) == 0 || len(outputs) < 4 { + sp.Close() + return nil, fmt.Errorf("xlm model: unexpected inputs/outputs") + } + inNames := make([]string, len(inputs)) + for i, in := range inputs { + inNames[i] = in.Name + } + outNames := make([]string, len(outputs)) + for i, out := range outputs { + outNames[i] = out.Name + } + session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil) + if err != nil { + sp.Close() + return nil, fmt.Errorf("xlm onnx session: %w", err) + } + return &XLM{ + cfg: cfg, + modelCfg: modelCfg, + sp: sp, + session: session, + inputName: inNames[0], + outputNames: outNames, + joinSBD: cfg.XLMJoinSentences(), + }, nil +} + +func (x *XLM) Active() bool { + return true +} + +func (x *XLM) Close() { + if x.session != nil { + _ = x.session.Destroy() + x.session = nil + } + if x.sp != nil { + x.sp.Close() + x.sp = nil + } +} + +func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + text = strings.TrimSpace(normalizeXLMSpaces(text)) + if text == "" { + return text, nil + } + + ids, err := x.sp.EncodeAsIDs(text) + if err != nil { + return "", err + } + full := make([]int, 0, len(ids)+2) + full = append(full, x.sp.BOSID()) + full = append(full, ids...) + full = append(full, x.sp.EOSID()) + maxLen := x.modelCfg.MaxLength + if maxLen <= 2 { + maxLen = 256 + } + if len(full) <= maxLen { + return x.inferIDs(full) + } + var parts []string + content := full[1 : len(full)-1] + step := maxLen - 2 + for start := 0; start < len(content); { + end := start + step + if end > len(content) { + end = len(content) + } + chunk := make([]int, 0, end-start+2) + chunk = append(chunk, x.sp.BOSID()) + chunk = append(chunk, content[start:end]...) + chunk = append(chunk, x.sp.EOSID()) + out, err := x.inferIDs(chunk) + if err != nil { + return "", err + } + if out != "" { + parts = append(parts, out) + } + if end >= len(content) { + break + } + start = end + } + return strings.TrimSpace(strings.Join(parts, " ")), nil +} + +func (x *XLM) inferIDs(inputIDs []int) (string, error) { + data := make([]int64, len(inputIDs)) + for i, id := range inputIDs { + data[i] = int64(id) + } + inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data) + if err != nil { + return "", err + } + defer inputTensor.Destroy() + outputs := make([]ort.Value, len(x.outputNames)) + if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil { + return "", err + } + defer destroyValues(outputs) + pre, err := int64Row(outputs[0], len(inputIDs)) + if err != nil { + return "", err + } + post, err := int64Row(outputs[1], len(inputIDs)) + if err != nil { + return "", err + } + cap, err := boolMatrix(outputs[2], len(inputIDs)) + if err != nil { + return "", err + } + sbd, err := boolRow(outputs[3], len(inputIDs)) + if err != nil { + return "", err + } + return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD) +} + +func destroyValues(vals []ort.Value) { + for _, v := range vals { + if v != nil { + _ = v.Destroy() + } + } +} + +func int64Row(v ort.Value, wantLen int) ([]int64, error) { + switch t := v.(type) { + case *ort.Tensor[int64]: + d := t.GetData() + if len(d) == wantLen { + return d, nil + } + if len(d) > wantLen { + return d[len(d)-wantLen:], nil + } + return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen) + case *ort.Tensor[int32]: + d := t.GetData() + if len(d) > wantLen { + d = d[len(d)-wantLen:] + } + out := make([]int64, wantLen) + for i := 0; i < wantLen && i < len(d); i++ { + out[i] = int64(d[i]) + } + return out, nil + default: + return nil, fmt.Errorf("unexpected int output type %T", v) + } +} + +func boolRow(v ort.Value, wantLen int) ([]bool, error) { + switch t := v.(type) { + case *ort.Tensor[bool]: + d := t.GetData() + if len(d) == wantLen { + return d, nil + } + if len(d) > wantLen { + return d[len(d)-wantLen:], nil + } + return nil, fmt.Errorf("bool output short") + case *ort.Tensor[float32]: + d := t.GetData() + out := make([]bool, wantLen) + for i := 0; i < wantLen && i < len(d); i++ { + out[i] = d[i] > 0.5 + } + return out, nil + default: + return nil, fmt.Errorf("unexpected bool output type %T", v) + } +} + +func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) { + switch t := v.(type) { + case *ort.Tensor[bool]: + shape := t.GetShape() + d := t.GetData() + if len(shape) == 3 { + _, sl, width := shape[0], shape[1], shape[2] + out := make([][]bool, sl) + for i := 0; i < int(sl); i++ { + row := make([]bool, width) + base := int(i) * int(width) + copy(row, d[base:base+int(width)]) + out[i] = row + } + return out, nil + } + width := len(d) / seqLen + if width < 1 { + width = 1 + } + out := make([][]bool, seqLen) + for i := 0; i < seqLen; i++ { + row := make([]bool, width) + base := i * width + if base+width <= len(d) { + copy(row, d[base:base+width]) + } + out[i] = row + } + return out, nil + case *ort.Tensor[float32]: + shape := t.GetShape() + d := t.GetData() + if len(shape) == 3 { + _, sl, width := shape[0], shape[1], shape[2] + out := make([][]bool, sl) + for i := 0; i < int(sl); i++ { + row := make([]bool, width) + base := int(i) * int(width) + for j := 0; j < int(width); j++ { + row[j] = d[base+j] > 0.5 + } + out[i] = row + } + return out, nil + } + width := len(d) / seqLen + if width < 1 { + width = 1 + } + out := make([][]bool, seqLen) + for i := 0; i < seqLen; i++ { + row := make([]bool, width) + base := i * width + for j := 0; j < width && base+j < len(d); j++ { + row[j] = d[base+j] > 0.5 + } + out[i] = row + } + return out, nil + default: + return nil, fmt.Errorf("unexpected cap output type %T", v) + } +} + +func normalizeXLMSpaces(s string) string { + var b strings.Builder + prevSpace := false + for _, r := range s { + if unicode.IsSpace(r) { + if !prevSpace { + b.WriteRune(' ') + prevSpace = true + } + continue + } + prevSpace = false + b.WriteRune(r) + } + return b.String() +} diff --git a/punctuation/xlm_config.go b/punctuation/xlm_config.go new file mode 100644 index 0000000..59e1fe1 --- /dev/null +++ b/punctuation/xlm_config.go @@ -0,0 +1,43 @@ +package punctuation + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +type xlmModelConfig struct { + Languages []string `yaml:"languages"` + MaxLength int `yaml:"max_length"` + PreLabels []string `yaml:"pre_labels"` + PostLabels []string `yaml:"post_labels"` + NullToken string `yaml:"null_token"` + Acronym string `yaml:"acronym_token"` +} + +func loadXLMConfig(path string) (xlmModelConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return xlmModelConfig{}, err + } + var cfg xlmModelConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return xlmModelConfig{}, fmt.Errorf("parse xlm config %s: %w", path, err) + } + if cfg.MaxLength <= 0 { + cfg.MaxLength = 256 + } + if cfg.NullToken == "" { + cfg.NullToken = "" + } + if cfg.Acronym == "" { + cfg.Acronym = "" + } + return cfg, nil +} + +func defaultXLMConfigPath(modelDir string) string { + return filepath.Join(modelDir, "config.yaml") +} diff --git a/punctuation/xlm_decode.go b/punctuation/xlm_decode.go new file mode 100644 index 0000000..d0b17ea --- /dev/null +++ b/punctuation/xlm_decode.go @@ -0,0 +1,97 @@ +//go:build xlm + +package punctuation + +import ( + "strings" + + "go-whisper-api/punctuation/internal/spwrap" +) + +func decodeXLMSegment( + sp *spwrap.Processor, + cfg xlmModelConfig, + inputIDs []int, + prePred, postPred []int64, + capPred [][]bool, + sbdPred []bool, + joinSentences bool, +) (string, error) { + var outputTexts []string + current := make([]string, 0, len(inputIDs)*4) + for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ { + piece, err := sp.IDToPiece(inputIDs[tokenIdx]) + if err != nil { + return "", err + } + if strings.HasPrefix(piece, "▁") && len(current) > 0 { + current = append(current, " ") + } + preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx) + postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx) + if preLabel != cfg.NullToken { + current = append(current, preLabel) + } + charStart := 0 + if strings.HasPrefix(piece, "▁") { + charStart = 1 + } + runes := []rune(piece) + for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ { + ch := string(runes[tokenCharIdx]) + if capAt(capPred, tokenIdx, tokenCharIdx) { + ch = strings.ToUpper(ch) + } + current = append(current, ch) + if postLabel == cfg.Acronym { + current = append(current, ".") + } + } + if postLabel != cfg.NullToken && postLabel != cfg.Acronym { + current = append(current, postLabel) + } + if sbdAt(sbdPred, tokenIdx) { + outputTexts = append(outputTexts, strings.Join(current, "")) + current = current[:0] + } + } + if len(current) > 0 { + outputTexts = append(outputTexts, strings.Join(current, "")) + } + if len(outputTexts) == 0 { + return "", nil + } + if joinSentences { + return strings.Join(outputTexts, " "), nil + } + return outputTexts[0], nil +} + +func labelAt(labels []string, preds []int64, idx int) string { + if idx < 0 || idx >= len(preds) { + return labels[0] + } + pi := int(preds[idx]) + if pi < 0 || pi >= len(labels) { + return labels[0] + } + return labels[pi] +} + +func capAt(capPred [][]bool, tokenIdx, charIdx int) bool { + if tokenIdx < 0 || tokenIdx >= len(capPred) { + return false + } + row := capPred[tokenIdx] + if charIdx < 0 || charIdx >= len(row) { + return false + } + return row[charIdx] +} + +func sbdAt(sbd []bool, tokenIdx int) bool { + if tokenIdx < 0 || tokenIdx >= len(sbd) { + return false + } + return sbd[tokenIdx] +} diff --git a/punctuation/xlm_stub.go b/punctuation/xlm_stub.go new file mode 100644 index 0000000..d12a476 --- /dev/null +++ b/punctuation/xlm_stub.go @@ -0,0 +1,13 @@ +//go:build !xlm + +package punctuation + +import ( + "fmt" + + "go-whisper-api/config" +) + +func newXLM(cfg config.Punctuation) (Restorer, error) { + return nil, fmt.Errorf("punctuation engine %q requires build tag xlm (run: make build-xlm)", cfg.Engine) +} diff --git a/transcode/aac_decode.go b/transcode/aac_decode.go new file mode 100644 index 0000000..e06ce9d --- /dev/null +++ b/transcode/aac_decode.go @@ -0,0 +1,132 @@ +package transcode + +import ( + "fmt" + "io" + "math" + "os" + "path/filepath" + "strings" + + "github.com/olivier-w/climp-aac-decoder/aacfile" +) + +func decodeAACPath(path, ext string) ([]float64, int, int, error) { + f, err := os.Open(path) + if err != nil { + return nil, 0, 0, err + } + defer f.Close() + st, err := f.Stat() + if err != nil { + return nil, 0, 0, err + } + // aacfile picks the parser from the *name* extension, not file content (.mp4 → .m4a). + name := aacOpenName(path, ext) + size := st.Size() + r, err := aacfile.Open(f, size, name) + if err != nil && isMP4SampleDeltaError(err) { + return decodeMP4AACRelaxed(f, size) + } + if err != nil { + return nil, 0, 0, err + } + defer r.Close() + sr := r.SampleRate() + ch := r.ChannelCount() + pcm, err := io.ReadAll(r) + if err != nil { + return nil, 0, 0, fmt.Errorf("read aac pcm: %w", err) + } + samples := pcm16LEToFloat(pcm, ch) + if ch > 1 { + samples = interleavedToMono(samples, ch) + ch = 1 + } + return samples, sr, ch, nil +} + +// aacContainerExt maps file extensions to a container name understood by aacfile. +func aacContainerExt(ext string) string { + switch strings.ToLower(ext) { + case ".mp4", ".m4v", ".mov", ".3gp", ".3g2": + return ".m4a" + case ".aac", ".m4a", ".m4b": + return ext + default: + return ".m4a" + } +} + +func aacOpenName(path, ext string) string { + containerExt := aacContainerExt(ext) + if containerExt == "" { + containerExt = ".m4a" + } + base := filepath.Base(path) + if e := strings.ToLower(filepath.Ext(base)); e == containerExt { + return base + } + stem := strings.TrimSuffix(base, filepath.Ext(base)) + if stem == "" || stem == base { + stem = "audio" + } + return stem + containerExt +} + +func pcm16LEToFloat(pcm []byte, channels int) []float64 { + if channels <= 0 { + channels = 1 + } + frameBytes := 2 * channels + nFrames := len(pcm) / frameBytes + out := make([]float64, nFrames*channels) + for i := 0; i < nFrames*channels; i++ { + off := i * 2 + if off+1 >= len(pcm) { + break + } + v := int16(pcm[off]) | int16(pcm[off+1])<<8 + out[i] = float64(v) / 32768.0 + } + return out +} + +func interleavedToMono(samples []float64, channels int) []float64 { + if channels <= 1 { + return samples + } + nFrames := len(samples) / channels + out := make([]float64, nFrames) + for i := 0; i < nFrames; i++ { + var sum float64 + for c := 0; c < channels; c++ { + sum += samples[i*channels+c] + } + out[i] = sum / float64(channels) + } + return out +} + +func resampleLinear(samples []float64, fromRate, toRate int) []float64 { + if fromRate <= 0 || toRate <= 0 || fromRate == toRate || len(samples) == 0 { + return samples + } + ratio := float64(fromRate) / float64(toRate) + outLen := int(math.Round(float64(len(samples)) / ratio)) + if outLen < 1 { + outLen = 1 + } + out := make([]float64, outLen) + for i := 0; i < outLen; i++ { + src := float64(i) * ratio + j := int(src) + if j >= len(samples)-1 { + out[i] = samples[len(samples)-1] + continue + } + frac := src - float64(j) + out[i] = samples[j]*(1-frac) + samples[j+1]*frac + } + return out +} diff --git a/transcode/aac_decode_test.go b/transcode/aac_decode_test.go new file mode 100644 index 0000000..04f7167 --- /dev/null +++ b/transcode/aac_decode_test.go @@ -0,0 +1,31 @@ +package transcode + +import "testing" + +func TestAacOpenName_mp4(t *testing.T) { + got := aacOpenName("/tmp/cache/input.mp4", ".mp4") + if got != "input.m4a" { + t.Fatalf("got %q want input.m4a", got) + } +} + +func TestAacOpenName_noExt(t *testing.T) { + got := aacOpenName("/tmp/input", ".m4a") + if got != "audio.m4a" { + t.Fatalf("got %q", got) + } +} + +func TestAacContainerExt(t *testing.T) { + cases := map[string]string{ + ".mp4": ".m4a", + ".mov": ".m4a", + ".aac": ".aac", + ".m4a": ".m4a", + } + for in, want := range cases { + if got := aacContainerExt(in); got != want { + t.Fatalf("%s: got %q want %q", in, got, want) + } + } +} diff --git a/transcode/decode.go b/transcode/decode.go new file mode 100644 index 0000000..c6fc875 --- /dev/null +++ b/transcode/decode.go @@ -0,0 +1,127 @@ +package transcode + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/gopxl/beep" + "github.com/gopxl/beep/flac" + "github.com/gopxl/beep/mp3" + beepwav "github.com/gopxl/beep/wav" +) + +var probeFormats = []string{ + ".wav", ".wave", ".mp3", ".flac", ".ogg", ".opus", + ".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".3gp", ".3g2", ".aac", +} + +func supportedFormatsMessage() string { + return strings.Join(probeFormats, ", ") +} + +func resolveFormat(path string) (string, error) { + ext := strings.ToLower(filepath.Ext(path)) + if ext != "" { + return ext, nil + } + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + ext = sniffFormat(f) + if ext == "" { + return "", fmt.Errorf("could not detect audio format (supported: %s)", supportedFormatsMessage()) + } + return ext, nil +} + +func openDecoder(path string) (beep.Streamer, beep.Format, io.Closer, error) { + ext, err := resolveFormat(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + streamer, format, closer, err := decodeByExt(path, ext) + if err == nil { + return streamer, format, closer, nil + } + for _, try := range probeFormats { + if try == ext { + continue + } + streamer, format, closer, tryErr := decodeByExt(path, try) + if tryErr == nil { + return streamer, format, closer, nil + } + } + return nil, beep.Format{}, nil, fmt.Errorf("unsupported audio format %q (supported: %s): %w", ext, supportedFormatsMessage(), err) +} + +func decodeByExt(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) { + switch ext { + case ".wav", ".wave": + return decodeBeepFile(path, ext) + case ".mp3": + return decodeBeepFile(path, ext) + case ".flac": + return decodeBeepFile(path, ext) + case ".ogg", ".opus": + return decodeOggFile(path) + case ".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".aac": + return decodeAACAsStreamer(path, ext) + case ".webm": + return nil, beep.Format{}, nil, fmt.Errorf("webm is not supported yet") + default: + return nil, beep.Format{}, nil, fmt.Errorf("unsupported extension %q", ext) + } +} + +func decodeBeepFile(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) { + f, err := os.Open(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + var ( + streamer beep.StreamSeekCloser + format beep.Format + decErr error + ) + switch ext { + case ".wav", ".wave": + streamer, format, decErr = beepwav.Decode(f) + case ".mp3": + streamer, format, decErr = mp3.Decode(f) + case ".flac": + streamer, format, decErr = flac.Decode(f) + default: + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("internal: beep decode for %q", ext) + } + if decErr != nil { + f.Close() + return nil, beep.Format{}, nil, decErr + } + return streamer, format, f, nil +} + +func decodeAACAsStreamer(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) { + samples, sr, ch, err := decodeAACPath(path, ext) + if err != nil { + return nil, beep.Format{}, nil, err + } + if ch <= 0 { + ch = 1 + } + return newSamplesStreamer(samples, sr), beep.Format{ + SampleRate: beep.SampleRate(sr), + NumChannels: ch, + Precision: 2, + }, noopCloser{}, nil +} + +type noopCloser struct{} + +func (noopCloser) Close() error { return nil } diff --git a/transcode/engine.go b/transcode/engine.go new file mode 100644 index 0000000..df9adad --- /dev/null +++ b/transcode/engine.go @@ -0,0 +1,119 @@ +package transcode + +import ( + "context" + "fmt" + "os" +) + +// Engine converts input audio to PCM WAV using pure Go decoders (no ffmpeg). +type Engine struct{} + +// NewEngine creates a transcoder. The ffmpegPath argument is ignored (kept for config compatibility). +func NewEngine(_ string) *Engine { + return &Engine{} +} + +func (e *Engine) Available() error { + return nil +} + +func (e *Engine) Transcode(ctx context.Context, src, dst string, opt Options) error { + if err := opt.Validate(); err != nil { + return err + } + spec, err := ResolveFormat(opt.Format) + if err != nil { + return err + } + dst, err = OutputPath(dst, spec.ID) + if err != nil { + return err + } + streamer, format, closer, err := openDecoder(src) + if err != nil { + return err + } + defer closer.Close() + s, format := buildPipeline(streamer, format, opt) + samples, err := drainSamples(ctx, s) + if err != nil { + return err + } + ch := format.NumChannels + if opt.Channels > 0 { + ch = opt.Channels + } + sr := int(format.SampleRate) + if opt.SampleRate > 0 { + sr = opt.SampleRate + } + if err := writePCM16WAV(dst, sr, ch, samples); err != nil { + return fmt.Errorf("write wav: %w", err) + } + return nil +} + +func Transcode(ctx context.Context, src, dst string, opt Options) error { + return NewEngine("").Transcode(ctx, src, dst, opt) +} + +func ToWhisperWAV(ctx context.Context, src, dst string) error { + return Transcode(ctx, src, dst, WhisperOptions()) +} + +// SupportedInputFormats lists file extensions decoded without external tools. +func SupportedInputFormats() []string { + return append([]string(nil), probeFormats...) +} + +func (e *Engine) Probe(ctx context.Context, path string) (MediaInfo, error) { + _ = ctx + ext, err := resolveFormat(path) + if err != nil { + return MediaInfo{}, err + } + streamer, format, closer, err := openDecoder(path) + if err != nil { + return MediaInfo{}, err + } + defer closer.Close() + info := MediaInfo{ + Format: ext, + Streams: []StreamInfo{{ + Index: 0, + Codec: extTrim(ext), + Type: "audio", + SampleRate: int(format.SampleRate), + Channels: format.NumChannels, + }}, + } + if st, err := os.Stat(path); err == nil { + info.BitRate = st.Size() * 8 + } + _ = streamer + return info, nil +} + +func extTrim(ext string) string { + if len(ext) > 0 && ext[0] == '.' { + return ext[1:] + } + return ext +} + +// MediaInfo describes decoded input (for optional diagnostics). +type MediaInfo struct { + Format string `json:"format"` + Duration float64 `json:"duration_seconds"` + BitRate int64 `json:"bit_rate"` + Streams []StreamInfo `json:"streams"` +} + +type StreamInfo struct { + Index int `json:"index"` + Codec string `json:"codec"` + Type string `json:"type"` + SampleRate int `json:"sample_rate,omitempty"` + Channels int `json:"channels,omitempty"` +} diff --git a/transcode/engine_test.go b/transcode/engine_test.go new file mode 100644 index 0000000..641056c --- /dev/null +++ b/transcode/engine_test.go @@ -0,0 +1,85 @@ +package transcode + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/go-audio/wav" +) + +func samplePath(t *testing.T, name string) string { + t.Helper() + p := filepath.Join("..", "third_party", "whisper.cpp", "samples", name) + if _, err := os.Stat(p); err != nil { + t.Skip("sample not found:", p) + } + return p +} + +func TestToWhisperWAV_mp3(t *testing.T) { + src := samplePath(t, "jfk.mp3") + dst := filepath.Join(t.TempDir(), "out.wav") + if err := ToWhisperWAV(context.Background(), src, dst); err != nil { + t.Fatal(err) + } + assertWhisperWAV(t, dst) +} + +func TestToWhisperWAV_wav(t *testing.T) { + src := samplePath(t, "jfk.wav") + dst := filepath.Join(t.TempDir(), "out.wav") + if err := ToWhisperWAV(context.Background(), src, dst); err != nil { + t.Fatal(err) + } + assertWhisperWAV(t, dst) +} + +func TestResolveFormat_noExtension_mp3(t *testing.T) { + src := samplePath(t, "jfk.mp3") + data, err := os.ReadFile(src) + if err != nil { + t.Fatal(err) + } + path := filepath.Join(t.TempDir(), "upload") + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatal(err) + } + ext, err := resolveFormat(path) + if err != nil || ext != ".mp3" { + t.Fatalf("ext=%q err=%v", ext, err) + } +} + +func TestEngine_Available(t *testing.T) { + if err := NewEngine("").Available(); err != nil { + t.Fatal(err) + } +} + +func assertWhisperWAV(t *testing.T, path string) { + t.Helper() + f, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer f.Close() + dec := wav.NewDecoder(f) + if !dec.IsValidFile() { + t.Fatal("invalid wav") + } + buf, err := dec.FullPCMBuffer() + if err != nil { + t.Fatal(err) + } + if dec.SampleRate != 16000 { + t.Fatalf("sample rate %d", dec.SampleRate) + } + if dec.NumChans != 1 { + t.Fatalf("channels %d", dec.NumChans) + } + if len(buf.Data) == 0 { + t.Fatal("empty audio") + } +} diff --git a/transcode/format.go b/transcode/format.go new file mode 100644 index 0000000..fdc34bb --- /dev/null +++ b/transcode/format.go @@ -0,0 +1,47 @@ +package transcode + +import ( + "fmt" + "path/filepath" + "strings" +) + +const ( + FormatWAV = "wav" + FormatPCM = "pcm" +) + +type FormatSpec struct { + ID string + Extension string + Codec string +} + +var formats = map[string]FormatSpec{ + FormatWAV: {ID: FormatWAV, Extension: ".wav", Codec: "pcm_s16le"}, + FormatPCM: {ID: FormatPCM, Extension: ".wav", Codec: "pcm_s16le"}, +} + +func ResolveFormat(name string) (FormatSpec, error) { + name = strings.ToLower(strings.TrimSpace(name)) + if name == "" { + name = FormatWAV + } + spec, ok := formats[name] + if !ok { + return FormatSpec{}, fmt.Errorf("unsupported format %q (supported: wav)", name) + } + return spec, nil +} + +func OutputPath(dst, format string) (string, error) { + spec, err := ResolveFormat(format) + if err != nil { + return "", err + } + ext := filepath.Ext(dst) + if ext == "" { + return dst + spec.Extension, nil + } + return dst, nil +} diff --git a/transcode/mp4_aac_decode.go b/transcode/mp4_aac_decode.go new file mode 100644 index 0000000..ea39dfb --- /dev/null +++ b/transcode/mp4_aac_decode.go @@ -0,0 +1,288 @@ +package transcode + +import ( + "errors" + "fmt" + "io" + + "github.com/Eyevinn/mp4ff/mp4" + "github.com/olivier-w/climp-aac-decoder/aacfile" + aacdec "github.com/skrashevich/go-aac/pkg/decoder" +) + +type mp4AACSample struct { + offset int64 + size int +} + +func isMP4SampleDeltaError(err error) bool { + var uf *aacfile.UnsupportedFeatureError + if !errors.As(err, &uf) { + return false + } + return uf.Feature == "MP4 sample delta" || uf.Feature == "MP4 sample delta layout" +} + +// decodeMP4AACRelaxed demuxes MP4/M4A with mp4ff (ignoring stts sample deltas) and +// decodes raw AAC frames with go-aac. Used when climp-aac-decoder rejects stts +// entries whose delta is not exactly 1024 (common in ffmpeg/phone muxers). +func decodeMP4AACRelaxed(r io.ReaderAt, size int64) ([]float64, int, int, error) { + asc, samples, leading, err := demuxMP4AAC(r, size) + if err != nil { + return nil, 0, 0, err + } + dec := aacdec.New() + if err := dec.SetASC(asc); err != nil { + return nil, 0, 0, fmt.Errorf("aac config: %w", err) + } + ch := dec.Config.ChanConfig + if ch < 1 { + return nil, 0, 0, fmt.Errorf("aac config: invalid channel count %d", ch) + } + sr := dec.Config.SampleRate + if sr <= 0 { + return nil, 0, 0, fmt.Errorf("aac config: invalid sample rate %d", sr) + } + + maxSize := 0 + for _, s := range samples { + if s.size > maxSize { + maxSize = s.size + } + } + buf := make([]byte, maxSize) + var pcm []float32 + for i, s := range samples { + if cap(buf) < s.size { + buf = make([]byte, s.size) + } + frame := buf[:s.size] + if _, err := r.ReadAt(frame, s.offset); err != nil { + return nil, 0, 0, fmt.Errorf("read mp4 aac sample %d: %w", i, err) + } + out, err := dec.DecodeFrame(frame) + if err != nil { + return nil, 0, 0, fmt.Errorf("decode mp4 aac sample %d: %w", i, err) + } + pcm = append(pcm, out...) + } + + skipSamples := leading * ch + if skipSamples > len(pcm) { + skipSamples = len(pcm) + } + pcm = pcm[skipSamples:] + + samplesF64 := make([]float64, len(pcm)) + for i, v := range pcm { + samplesF64[i] = float64(v) + } + if ch > 1 { + samplesF64 = float32InterleavedToMono(samplesF64, ch) + ch = 1 + } + return samplesF64, sr, ch, nil +} + +func float32InterleavedToMono(samples []float64, channels int) []float64 { + if channels <= 1 { + return samples + } + nFrames := len(samples) / channels + out := make([]float64, nFrames) + for i := 0; i < nFrames; i++ { + var sum float64 + for c := 0; c < channels; c++ { + sum += samples[i*channels+c] + } + out[i] = sum / float64(channels) + } + return out +} + +func demuxMP4AAC(r io.ReaderAt, size int64) (asc []byte, samples []mp4AACSample, leading int, err error) { + file, err := mp4.DecodeFile(io.NewSectionReader(r, 0, size), mp4.WithDecodeMode(mp4.DecModeLazyMdat)) + if err != nil { + return nil, nil, 0, fmt.Errorf("mp4 decode: %w", err) + } + if file.IsFragmented() { + return nil, nil, 0, fmt.Errorf("fragmented mp4 is not supported") + } + if file.Moov == nil { + return nil, nil, 0, fmt.Errorf("mp4: missing moov") + } + + var audioTracks []*mp4.TrakBox + for _, trak := range file.Moov.Traks { + if trak != nil && trak.Mdia != nil && trak.Mdia.Hdlr != nil && trak.Mdia.Hdlr.HandlerType == "soun" { + audioTracks = append(audioTracks, trak) + } + } + if len(audioTracks) != 1 { + return nil, nil, 0, fmt.Errorf("mp4: expected one audio track, found %d", len(audioTracks)) + } + trak := audioTracks[0] + if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil || trak.Mdia.Minf.Stbl.Stsd == nil { + return nil, nil, 0, fmt.Errorf("mp4: incomplete audio track") + } + stsd := trak.Mdia.Minf.Stbl.Stsd + if len(stsd.Children) != 1 { + return nil, nil, 0, fmt.Errorf("mp4: multiple sample descriptions") + } + if stsd.Enca != nil { + return nil, nil, 0, fmt.Errorf("mp4: encrypted audio") + } + sampleEntry := stsd.Mp4a + if sampleEntry == nil { + return nil, nil, 0, fmt.Errorf("mp4: unsupported audio sample entry %s", stsd.Children[0].Type()) + } + if sampleEntry.Sinf != nil { + return nil, nil, 0, fmt.Errorf("mp4: encrypted audio") + } + if sampleEntry.Esds == nil || + sampleEntry.Esds.DecConfigDescriptor == nil || + sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo == nil || + len(sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig) == 0 { + return nil, nil, 0, fmt.Errorf("mp4: missing AudioSpecificConfig") + } + asc = append([]byte(nil), sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig...) + + leading, _ = mp4LeadingTrimRelaxed(trak) + + samples, err = buildMP4AACSamples(trak, size) + if err != nil { + return nil, nil, 0, err + } + if len(samples) == 0 { + return nil, nil, 0, fmt.Errorf("mp4: no audio samples") + } + return asc, samples, leading, nil +} + +func mp4LeadingTrimRelaxed(trak *mp4.TrakBox) (int, error) { + if trak.Edts == nil || len(trak.Edts.Elst) == 0 { + return 0, nil + } + if len(trak.Edts.Elst) != 1 || len(trak.Edts.Elst[0].Entries) != 1 { + return 0, nil + } + entry := trak.Edts.Elst[0].Entries[0] + if entry.MediaRateInteger != 1 || entry.MediaRateFraction != 0 { + return 0, nil + } + if entry.MediaTime < 0 { + return 0, nil + } + return int(entry.MediaTime), nil +} + +func buildMP4AACSamples(trak *mp4.TrakBox, size int64) ([]mp4AACSample, error) { + if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil { + return nil, fmt.Errorf("mp4: incomplete sample table") + } + stbl := trak.Mdia.Minf.Stbl + if stbl.Stsc == nil || stbl.Stsz == nil { + return nil, fmt.Errorf("mp4: incomplete sample table") + } + if stbl.Stco == nil && stbl.Co64 == nil { + return nil, fmt.Errorf("mp4: missing chunk offsets") + } + if len(stbl.Stsc.Entries) == 0 { + return nil, fmt.Errorf("mp4: empty chunk map") + } + + totalSamples := int(trak.GetNrSamples()) + if totalSamples <= 0 { + return nil, fmt.Errorf("mp4: empty sample table") + } + + sampleSizes, err := mp4AACSampleSizes(stbl.Stsz, totalSamples) + if err != nil { + return nil, err + } + chunkOffsets, err := mp4AACChunkOffsets(stbl) + if err != nil { + return nil, err + } + + out := make([]mp4AACSample, 0, totalSamples) + sampleIndex := 0 + entryIndex := 0 + entry := stbl.Stsc.Entries[entryIndex] + + for chunkIndex := 0; chunkIndex < len(chunkOffsets) && sampleIndex < totalSamples; chunkIndex++ { + chunkNr := uint32(chunkIndex + 1) + for entryIndex+1 < len(stbl.Stsc.Entries) && chunkNr >= stbl.Stsc.Entries[entryIndex+1].FirstChunk { + entryIndex++ + entry = stbl.Stsc.Entries[entryIndex] + } + if entry.SamplesPerChunk == 0 { + return nil, fmt.Errorf("mp4: zero samples per chunk") + } + + offset := chunkOffsets[chunkIndex] + samplesPerChunk := int(entry.SamplesPerChunk) + for i := 0; i < samplesPerChunk && sampleIndex < totalSamples; i++ { + sampleSize := sampleSizes[sampleIndex] + end := offset + int64(sampleSize) + if offset < 0 || end < offset || end > size { + return nil, fmt.Errorf("mp4: invalid sample bounds at sample %d", sampleIndex+1) + } + out = append(out, mp4AACSample{offset: offset, size: sampleSize}) + offset = end + sampleIndex++ + } + } + if sampleIndex != totalSamples { + return nil, fmt.Errorf("mp4: sample table mismatch") + } + return out, nil +} + +func mp4AACSampleSizes(stsz *mp4.StszBox, totalSamples int) ([]int, error) { + if stsz == nil { + return nil, fmt.Errorf("mp4: missing sample sizes") + } + if int(stsz.GetNrSamples()) != totalSamples { + return nil, fmt.Errorf("mp4: sample size count mismatch") + } + sizes := make([]int, totalSamples) + if stsz.SampleUniformSize != 0 { + sz := int(stsz.SampleUniformSize) + for i := range sizes { + sizes[i] = sz + } + return sizes, nil + } + if len(stsz.SampleSize) != totalSamples { + return nil, fmt.Errorf("mp4: sample size table mismatch") + } + for i, sz := range stsz.SampleSize { + sizes[i] = int(sz) + } + return sizes, nil +} + +func mp4AACChunkOffsets(stbl *mp4.StblBox) ([]int64, error) { + switch { + case stbl == nil: + return nil, fmt.Errorf("mp4: incomplete sample table") + case stbl.Stco != nil: + offsets := make([]int64, len(stbl.Stco.ChunkOffset)) + for i, off := range stbl.Stco.ChunkOffset { + offsets[i] = int64(off) + } + return offsets, nil + case stbl.Co64 != nil: + offsets := make([]int64, len(stbl.Co64.ChunkOffset)) + for i, off := range stbl.Co64.ChunkOffset { + if off > uint64(^uint64(0)>>1) { + return nil, fmt.Errorf("mp4: invalid chunk offset") + } + offsets[i] = int64(off) + } + return offsets, nil + default: + return nil, fmt.Errorf("mp4: missing chunk offsets") + } +} diff --git a/transcode/ogg_decode.go b/transcode/ogg_decode.go new file mode 100644 index 0000000..a60cbe1 --- /dev/null +++ b/transcode/ogg_decode.go @@ -0,0 +1,154 @@ +package transcode + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/gopxl/beep" + "github.com/gopxl/beep/vorbis" + "github.com/pion/opus" + "github.com/pion/opus/pkg/oggreader" +) + +func decodeOggFile(path string) (beep.Streamer, beep.Format, io.Closer, error) { + switch sniffOggCodec(path) { + case "opus": + return decodeOggOpus(path) + case "vorbis": + return decodeOggVorbis(path) + default: + streamer, format, closer, err := decodeOggVorbis(path) + if err == nil { + return streamer, format, closer, nil + } + if isVorbisInvalidHeader(err) { + return decodeOggOpus(path) + } + return nil, beep.Format{}, nil, err + } +} + +func sniffOggCodec(path string) string { + f, err := os.Open(path) + if err != nil { + return "" + } + defer f.Close() + buf := make([]byte, 8192) + n, _ := io.ReadFull(f, buf) + buf = buf[:n] + if len(buf) < 4 || !bytes.HasPrefix(buf, []byte("OggS")) { + return "" + } + if bytes.Contains(buf, []byte("OpusHead")) { + return "opus" + } + // Vorbis ID packet: 0x01 + "vorbis" + if bytes.Contains(buf, []byte{0x01, 'v', 'o', 'r', 'b', 'i', 's'}) { + return "vorbis" + } + return "" +} + +func isVorbisInvalidHeader(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "invalid header") || strings.Contains(msg, "vorbis:") +} + +func decodeOggVorbis(path string) (beep.Streamer, beep.Format, io.Closer, error) { + f, err := os.Open(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + streamer, format, decErr := vorbis.Decode(f) + if decErr != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/vorbis: %w", decErr) + } + return streamer, format, f, nil +} + +func decodeOggOpus(path string) (beep.Streamer, beep.Format, io.Closer, error) { + f, err := os.Open(path) + if err != nil { + return nil, beep.Format{}, nil, err + } + ogg, header, err := oggreader.NewWith(f) + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: %w", err) + } + + sr := int(header.SampleRate) + if sr <= 0 { + sr = 48000 + } + ch := int(header.Channels) + if ch <= 0 { + ch = 1 + } + + dec, err := opus.NewDecoderWithOutput(sr, ch) + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decoder: %w", err) + } + + const maxFrameSamples = 5760 + pcmBuf := make([]float32, maxFrameSamples*ch) + var samples []float64 + + for { + pkt, _, err := ogg.ParseNextPacket() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus read: %w", err) + } + if len(pkt) == 0 || bytes.HasPrefix(pkt, []byte("OpusHead")) || bytes.HasPrefix(pkt, []byte("OpusTags")) { + continue + } + + n, err := dec.DecodeToFloat32(pkt, pcmBuf) + if err != nil { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decode: %w", err) + } + if n <= 0 { + continue + } + total := n * ch + if total > len(pcmBuf) { + total = len(pcmBuf) + } + for i := 0; i < total; i++ { + samples = append(samples, float64(pcmBuf[i])) + } + } + + if len(samples) == 0 { + f.Close() + return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: no audio samples") + } + + outCh := ch + if outCh > 1 { + samples = interleavedToMono(samples, outCh) + outCh = 1 + } + + return newSamplesStreamer(samples, sr), beep.Format{ + SampleRate: beep.SampleRate(sr), + NumChannels: outCh, + Precision: 2, + }, f, nil +} diff --git a/transcode/ogg_decode_test.go b/transcode/ogg_decode_test.go new file mode 100644 index 0000000..a0c2b1b --- /dev/null +++ b/transcode/ogg_decode_test.go @@ -0,0 +1,52 @@ +package transcode + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestSniffOggCodec_opus(t *testing.T) { + path := pionTinyOggPath(t) + if got := sniffOggCodec(path); got != "opus" { + t.Fatalf("sniffOggCodec() = %q want opus", got) + } +} + +func TestDecodeOggOpus_pionTiny(t *testing.T) { + path := pionTinyOggPath(t) + streamer, format, closer, err := decodeOggOpus(path) + if err != nil { + t.Fatal(err) + } + defer closer.Close() + if format.SampleRate == 0 { + t.Fatal("zero sample rate") + } + buf := make([][2]float64, 4096) + n, ok := streamer.Stream(buf) + if !ok || n == 0 { + t.Fatal("expected pcm samples") + } +} + +func pionTinyOggPath(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("runtime.Caller failed") + } + modRoot := filepath.Clean(filepath.Join(filepath.Dir(file), "..")) + cache := os.Getenv("GOMODCACHE") + if cache == "" { + home, _ := os.UserHomeDir() + cache = filepath.Join(home, "go", "pkg", "mod") + } + path := filepath.Join(cache, "github.com/pion/opus@v0.0.0-20260601214817-71d58474cec8/testdata/tiny.ogg") + if _, err := os.Stat(path); err != nil { + _ = modRoot + t.Skipf("pion testdata not in module cache: %v", err) + } + return path +} diff --git a/transcode/options.go b/transcode/options.go new file mode 100644 index 0000000..a1025ce --- /dev/null +++ b/transcode/options.go @@ -0,0 +1,43 @@ +package transcode + +import "fmt" + +type Options struct { + Format string + SampleRate int + Channels int + Codec string +} + +func WhisperOptions() Options { + return Options{ + Format: FormatWAV, + SampleRate: 16000, + Channels: 1, + Codec: "pcm_s16le", + } +} + +func (o *Options) ApplyDefaults() error { + spec, err := ResolveFormat(o.Format) + if err != nil { + return err + } + if o.Codec == "" { + o.Codec = spec.Codec + } + return nil +} + +func (o *Options) Validate() error { + if err := o.ApplyDefaults(); err != nil { + return err + } + if o.SampleRate < 0 || o.Channels < 0 { + return fmt.Errorf("sample_rate and channels must be >= 0") + } + if o.Channels > 8 { + return fmt.Errorf("channels must be <= 8") + } + return nil +} diff --git a/transcode/options_test.go b/transcode/options_test.go new file mode 100644 index 0000000..637c962 --- /dev/null +++ b/transcode/options_test.go @@ -0,0 +1,20 @@ +package transcode + +import "testing" + +func TestWhisperOptions(t *testing.T) { + o := WhisperOptions() + if err := o.Validate(); err != nil { + t.Fatal(err) + } + if o.SampleRate != 16000 || o.Channels != 1 { + t.Fatalf("unexpected whisper opts: %+v", o) + } +} + +func TestResolveFormat_unknown(t *testing.T) { + _, err := ResolveFormat("xyz") + if err == nil { + t.Fatal("expected error") + } +} diff --git a/transcode/samples_stream.go b/transcode/samples_stream.go new file mode 100644 index 0000000..9a7a2c8 --- /dev/null +++ b/transcode/samples_stream.go @@ -0,0 +1,38 @@ +package transcode + +import "github.com/gopxl/beep" + +type samplesStreamer struct { + samples []float64 + pos int + sampleRate beep.SampleRate +} + +func newSamplesStreamer(samples []float64, sampleRate int) *samplesStreamer { + return &samplesStreamer{ + samples: samples, + sampleRate: beep.SampleRate(sampleRate), + } +} + +func (s *samplesStreamer) Stream(buf [][2]float64) (int, bool) { + if s.pos >= len(s.samples) { + return 0, false + } + n := 0 + for i := range buf { + if s.pos >= len(s.samples) { + return n, n > 0 + } + v := s.samples[s.pos] + buf[i][0] = v + buf[i][1] = v + s.pos++ + n++ + } + return n, true +} + +func (s *samplesStreamer) Err() error { + return nil +} diff --git a/transcode/sniff.go b/transcode/sniff.go new file mode 100644 index 0000000..d16e2c4 --- /dev/null +++ b/transcode/sniff.go @@ -0,0 +1,44 @@ +package transcode + +import ( + "bytes" + "io" +) + +// sniffFormat detects container/codec from file header when the path has no extension. +func sniffFormat(r io.Reader) string { + head := make([]byte, 32) + n, _ := io.ReadFull(r, head) + head = head[:n] + if len(head) < 4 { + return "" + } + if bytes.HasPrefix(head, []byte("RIFF")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("WAVE")) { + return ".wav" + } + if bytes.HasPrefix(head, []byte("ID3")) { + return ".mp3" + } + if len(head) >= 2 && head[0] == 0xFF && (head[1]&0xE0) == 0xE0 { + return ".mp3" + } + if bytes.HasPrefix(head, []byte("fLaC")) { + return ".flac" + } + if bytes.HasPrefix(head, []byte("OggS")) { + return ".ogg" + } + if len(head) >= 8 && bytes.Equal(head[4:8], []byte("ftyp")) { + return ".m4a" + } + if bytes.HasPrefix(head, []byte{0xFF, 0xF1}) || bytes.HasPrefix(head, []byte{0xFF, 0xF9}) { + return ".aac" + } + if bytes.HasPrefix(head, []byte("FORM")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("AIFF")) { + return ".aiff" + } + if bytes.HasPrefix(head, []byte{0x1A, 0x45, 0xDF, 0xA3}) { + return ".webm" + } + return "" +} diff --git a/transcode/stream.go b/transcode/stream.go new file mode 100644 index 0000000..870e95e --- /dev/null +++ b/transcode/stream.go @@ -0,0 +1,42 @@ +package transcode + +import ( + "context" + + "github.com/gopxl/beep" + "github.com/gopxl/beep/effects" +) + +func buildPipeline(streamer beep.Streamer, format beep.Format, opt Options) (beep.Streamer, beep.Format) { + out := streamer + if opt.SampleRate > 0 && format.SampleRate != beep.SampleRate(opt.SampleRate) { + out = beep.Resample(4, format.SampleRate, beep.SampleRate(opt.SampleRate), out) + format.SampleRate = beep.SampleRate(opt.SampleRate) + } + if opt.Channels == 1 { + out = effects.Mono(out) + format.NumChannels = 1 + } + return out, format +} + +func drainSamples(ctx context.Context, s beep.Streamer) ([]float64, error) { + buf := make([][2]float64, 4096) + var samples []float64 + for { + if err := ctx.Err(); err != nil { + return nil, err + } + n, ok := s.Stream(buf) + if !ok { + if err := s.Err(); err != nil { + return nil, err + } + break + } + for i := 0; i < n; i++ { + samples = append(samples, buf[i][0]) + } + } + return samples, nil +} diff --git a/transcode/wav_out.go b/transcode/wav_out.go new file mode 100644 index 0000000..ffe0021 --- /dev/null +++ b/transcode/wav_out.go @@ -0,0 +1,53 @@ +package transcode + +import ( + "os" + + "github.com/go-audio/audio" + "github.com/go-audio/wav" +) + +func writePCM16WAV(path string, sampleRate int, channels int, samples []float64) error { + if channels <= 0 { + channels = 1 + } + if err := os.MkdirAll(dirOf(path), 0o755); err != nil && dirOf(path) != "." { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + enc := wav.NewEncoder(f, sampleRate, 16, channels, 1) + data := make([]int, len(samples)) + for i, s := range samples { + data[i] = floatToInt16(s) + } + if err := enc.Write(&audio.IntBuffer{ + Format: &audio.Format{SampleRate: sampleRate, NumChannels: channels}, + Data: data, + }); err != nil { + return err + } + return enc.Close() +} + +func floatToInt16(f float64) int { + if f > 1 { + f = 1 + } + if f < -1 { + f = -1 + } + return int(f * 32767) +} + +func dirOf(path string) string { + for i := len(path) - 1; i >= 0; i-- { + if path[i] == '/' || path[i] == '\\' { + return path[:i] + } + } + return "." +} diff --git a/whisper/audio.go b/whisper/audio.go new file mode 100644 index 0000000..1cc9e50 --- /dev/null +++ b/whisper/audio.go @@ -0,0 +1,11 @@ +package whisper + +import ( + "context" + + "go-whisper-api/transcode" +) + +func AudioToWav(src, dst string) error { + return transcode.ToWhisperWAV(context.Background(), src, dst) +} diff --git a/whisper/audio_load.go b/whisper/audio_load.go new file mode 100644 index 0000000..be2aa4f --- /dev/null +++ b/whisper/audio_load.go @@ -0,0 +1,60 @@ +package whisper + +import ( + "fmt" + "os" + "path/filepath" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-audio/wav" +) + +// LoadPCM16Mono reads 16 kHz mono WAV into float32 samples (for diarization). +func LoadPCM16Mono(path string) ([]float32, error) { + return loadPCM16Mono(path) +} + +func loadPCM16Mono(path string) ([]float32, error) { + fh, err := os.Open(path) + if err != nil { + return nil, err + } + defer fh.Close() + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + if err != nil { + return nil, err + } + if dec.SampleRate != whisper.SampleRate { + return nil, fmt.Errorf("unsupported sample rate: %d", dec.SampleRate) + } + if dec.NumChans != 1 { + return nil, fmt.Errorf("unsupported number of channels: %d", dec.NumChans) + } + return buf.AsFloat32Buffer().Data, nil +} + +func prepareAudioPCM(sourcePath string) (data []float32, cleanup func(), err error) { + cleanup = func() {} + if data, err = loadPCM16Mono(sourcePath); err == nil { + return data, cleanup, nil + } + dir, err := config.MkdirTemp("go-whisper-api-whisper-*") + if err != nil { + return nil, nil, err + } + cleanup = func() { os.RemoveAll(dir) } + converted := filepath.Join(dir, "converted.wav") + if err := AudioToWav(sourcePath, converted); err != nil { + cleanup() + return nil, nil, err + } + data, err = loadPCM16Mono(converted) + if err != nil { + cleanup() + return nil, nil, err + } + return data, cleanup, nil +} diff --git a/whisper/format.go b/whisper/format.go new file mode 100644 index 0000000..4dfd7f7 --- /dev/null +++ b/whisper/format.go @@ -0,0 +1,154 @@ +package whisper + +import ( + "fmt" + "strings" + "time" + + wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +// Turn is a speaker-active time range from diarization (seconds). +type Turn struct { + Start float32 + End float32 + Speaker int +} + +// FormatOptions controls joining Whisper segments into one string with newlines. +type FormatOptions struct { + PauseGap time.Duration + SpeakerLabel string + UseSpeakers bool +} + +// FormatSegments joins segment texts with \n on long pauses and optional speaker labels. +func FormatSegments(segments []wpkg.Segment, turns []Turn, opts FormatOptions) string { + if len(segments) == 0 { + return "" + } + if opts.PauseGap <= 0 { + opts.PauseGap = 1500 * time.Millisecond + } + label := strings.TrimSpace(opts.SpeakerLabel) + if label == "" { + label = "Спикер" + } + + lines := make([]segmentLine, len(segments)) + for i, seg := range segments { + lines[i] = segmentLine{ + Text: strings.TrimSpace(seg.Text), + Start: seg.Start, + End: seg.End, + Speaker: -1, + } + } + if opts.UseSpeakers && len(turns) > 0 { + assignSpeakers(lines, turns) + } + + var b strings.Builder + prevSpeaker := -2 + prevIdx := -1 + for i, line := range lines { + if line.Text == "" { + continue + } + if b.Len() > 0 && prevIdx >= 0 { + speakerBreak := opts.UseSpeakers && line.Speaker >= 0 && line.Speaker != prevSpeaker + pauseBreak := line.Start-lines[prevIdx].End >= opts.PauseGap + switch { + case speakerBreak: + b.WriteString("\n\n") + fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1) + case pauseBreak: + b.WriteString("\n") + default: + if !strings.HasSuffix(b.String(), " ") && !strings.HasSuffix(b.String(), "\n") { + b.WriteByte(' ') + } + } + } else if opts.UseSpeakers && line.Speaker >= 0 { + fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1) + } + b.WriteString(line.Text) + if line.Speaker >= 0 { + prevSpeaker = line.Speaker + } + prevIdx = i + } + return strings.TrimSpace(b.String()) +} + +type segmentLine struct { + Text string + Start time.Duration + End time.Duration + Speaker int +} + +func assignSpeakers(lines []segmentLine, turns []Turn) { + for i := range lines { + mid := lines[i].Start + (lines[i].End-lines[i].Start)/2 + lines[i].Speaker = speakerAt(mid, turns) + } +} + +func speakerAt(t time.Duration, turns []Turn) int { + sec := float32(t.Seconds()) + bestSpeaker := -1 + bestOverlap := float32(0) + for _, tr := range turns { + if sec >= tr.Start && sec < tr.End { + return tr.Speaker + } + overlap := intervalOverlap(sec, sec, tr.Start, tr.End) + if overlap > bestOverlap { + bestOverlap = overlap + bestSpeaker = tr.Speaker + } + } + return bestSpeaker +} + +func intervalOverlap(a0, a1, b0, b1 float32) float32 { + start := max32(a0, b0) + end := min32(a1, b1) + if end <= start { + return 0 + } + return end - start +} + +func max32(a, b float32) float32 { + if a > b { + return a + } + return b +} + +func min32(a, b float32) float32 { + if a < b { + return a + } + return b +} + +// PunctuateSegments runs punctuation on each segment separately (preserves line breaks). +func PunctuateSegments(segments []wpkg.Segment, restore func(text string) (string, error)) ([]wpkg.Segment, error) { + out := make([]wpkg.Segment, len(segments)) + copy(out, segments) + for i := range out { + t := strings.TrimSpace(out[i].Text) + if t == "" { + continue + } + p, err := restore(t) + if err != nil { + return nil, err + } + out[i].Text = " " + strings.TrimSpace(p) + } + return out, nil +} diff --git a/whisper/format_test.go b/whisper/format_test.go new file mode 100644 index 0000000..7490313 --- /dev/null +++ b/whisper/format_test.go @@ -0,0 +1,40 @@ +package whisper + +import ( + "testing" + "time" + + wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func TestFormatSegments_pauseAndSpeaker(t *testing.T) { + segments := []wpkg.Segment{ + {Text: " привет", Start: 0, End: 2 * time.Second}, + {Text: " мир", Start: 4 * time.Second, End: 5 * time.Second}, + {Text: " ответ", Start: 6 * time.Second, End: 8 * time.Second}, + } + turns := []Turn{ + {Start: 0, End: 5.5, Speaker: 0}, + {Start: 5.5, End: 10, Speaker: 1}, + } + got := FormatSegments(segments, turns, FormatOptions{ + PauseGap: 1500 * time.Millisecond, + SpeakerLabel: "Спикер", + UseSpeakers: true, + }) + want := "Спикер 1: привет\nмир\n\nСпикер 2: ответ" + if got != want { + t.Fatalf("got %q\nwant %q", got, want) + } +} + +func TestFormatSegments_noSpeakers(t *testing.T) { + segments := []wpkg.Segment{ + {Text: "a", Start: 0, End: time.Second}, + {Text: "b", Start: 3 * time.Second, End: 4 * time.Second}, + } + got := FormatSegments(segments, nil, FormatOptions{PauseGap: time.Second, UseSpeakers: false}) + if got != "a\nb" { + t.Fatalf("got %q", got) + } +} diff --git a/whisper/helper.go b/whisper/helper.go new file mode 100644 index 0000000..8f98e6a --- /dev/null +++ b/whisper/helper.go @@ -0,0 +1,15 @@ +package whisper + +import ( + "fmt" + "time" +) + +func srtTimestamp(t time.Duration) string { + return fmt.Sprintf("%02d:%02d:%02d,%03d", + t/time.Hour, + (t%time.Hour)/time.Minute, + (t%time.Minute)/time.Second, + (t%time.Second)/time.Millisecond, + ) +} diff --git a/whisper/helper_test.go b/whisper/helper_test.go new file mode 100644 index 0000000..3f6b29d --- /dev/null +++ b/whisper/helper_test.go @@ -0,0 +1,39 @@ +package whisper + +import ( + "testing" + "time" +) + +func TestSrtTimestamp(t *testing.T) { + type args struct { + t time.Duration + } + tests := []struct { + name string + args args + want string + }{ + { + name: "test 1", + args: args{ + t: time.Duration(1*time.Hour + 2*time.Minute + 3*time.Second + 4*time.Millisecond), + }, + want: "01:02:03,004", + }, + { + name: "test 2", + args: args{ + t: time.Duration(10*time.Hour + 20*time.Minute + 30*time.Second + 40*time.Millisecond), + }, + want: "10:20:30,040", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := srtTimestamp(tt.args.t); got != tt.want { + t.Errorf("srtTimestamp() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/whisper/model_pool.go b/whisper/model_pool.go new file mode 100644 index 0000000..9a96da4 --- /dev/null +++ b/whisper/model_pool.go @@ -0,0 +1,89 @@ +package whisper + +import ( + "sync" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +// Model is a loaded whisper.cpp weights file (re-export for API callers). +type Model = whisper.Model + +// ModelPool keeps whisper models loaded in memory and serializes inference per model path. +type ModelPool struct { + mu sync.Mutex + entries map[string]*pooledModel +} + +type pooledModel struct { + model whisper.Model + mu sync.Mutex +} + +func NewModelPool() *ModelPool { + return &ModelPool{entries: make(map[string]*pooledModel)} +} + +func (p *ModelPool) WithModel(path string, fn func(whisper.Model) error) error { + e, err := p.entry(path) + if err != nil { + return err + } + e.mu.Lock() + defer e.mu.Unlock() + return fn(e.model) +} + +func (p *ModelPool) entry(path string) (*pooledModel, error) { + p.mu.Lock() + defer p.mu.Unlock() + if e, ok := p.entries[path]; ok { + return e, nil + } + m, err := whisper.New(path) + if err != nil { + return nil, err + } + e := &pooledModel{model: m} + p.entries[path] = e + return e, nil +} + +func (p *ModelPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + for _, e := range p.entries { + _ = e.model.Close() + } + p.entries = make(map[string]*pooledModel) +} + +var defaultPool = NewModelPool() + +func DefaultPool() *ModelPool { + return defaultPool +} + +// Transcribe runs speech recognition using a cached model. +func Transcribe(cfg *config.Whisper) (TranscriptResult, error) { + return TranscribeWithPool(defaultPool, cfg, RunOptions{}) +} + +func TranscribeWithPool(pool *ModelPool, cfg *config.Whisper, opts RunOptions) (TranscriptResult, error) { + if pool == nil { + pool = defaultPool + } + if err := cfg.Validate(); err != nil { + return TranscriptResult{}, err + } + eng := &Engine{cfg: cfg, runOpts: opts} + err := pool.WithModel(cfg.Model, func(m whisper.Model) error { + return eng.transcribeWithModel(m) + }) + if err != nil { + return TranscriptResult{}, err + } + return eng.Result(), nil +} diff --git a/whisper/options.go b/whisper/options.go new file mode 100644 index 0000000..3da6037 --- /dev/null +++ b/whisper/options.go @@ -0,0 +1,8 @@ +package whisper + +// RunOptions affects transcription output formatting and optional diarization hints. +type RunOptions struct { + Format FormatOptions + Turns []Turn + PunctuateRestore func(text string) (string, error) +} diff --git a/whisper/vad.go b/whisper/vad.go new file mode 100644 index 0000000..39bff45 --- /dev/null +++ b/whisper/vad.go @@ -0,0 +1,43 @@ +package whisper + +import ( + "fmt" + "math" + + "go-whisper-api/config" + + pkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func ApplyVAD(ctx pkg.Context, vad config.VAD) { + if !vad.Enabled { + return + } + vad = vad.WithDefaults() + ctx.SetVAD(true) + ctx.SetVADModelPath(vad.Model) + ctx.SetVADThreshold(float32(vad.Threshold)) + ctx.SetVADMinSpeechMs(vad.MinSpeechMs) + ctx.SetVADMinSilenceMs(vad.MinSilenceMs) + if vad.MaxSpeechSec > 0 { + ctx.SetVADMaxSpeechSec(float32(vad.MaxSpeechSec)) + } else { + ctx.SetVADMaxSpeechSec(float32(math.MaxFloat32)) + } + ctx.SetVADSpeechPadMs(vad.SpeechPadMs) + ctx.SetVADSamplesOverlap(float32(vad.SamplesOverlap)) +} + +func prepareVAD(vad *config.VAD, modelsDir string) error { + if vad == nil || !vad.Enabled { + return nil + } + if modelsDir != "" { + vad.Model = vad.ResolveModelPath(modelsDir) + } + *vad = vad.WithDefaults() + if err := vad.Validate(); err != nil { + return fmt.Errorf("vad: %w", err) + } + return nil +} diff --git a/whisper/whisper.go b/whisper/whisper.go new file mode 100644 index 0000000..24c7b9f --- /dev/null +++ b/whisper/whisper.go @@ -0,0 +1,245 @@ +package whisper + +import ( + "fmt" + "os" + "path" + "path/filepath" + "strings" + "time" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/rs/zerolog/log" +) + +type OutputFormat string + +func (f OutputFormat) String() string { + return string(f) +} + +var ( + FormatTxt OutputFormat = "txt" + FormatSrt OutputFormat = "srt" + FormatCSV OutputFormat = "csv" +) + +type Engine struct { + cfg *config.Whisper + ctx whisper.Context + model whisper.Model + segments []whisper.Segment + progress int + runOpts RunOptions +} + +func (e *Engine) Transcript() error { + return defaultPool.WithModel(e.cfg.Model, func(m whisper.Model) error { + return e.transcribeWithModel(m) + }) +} + +func (e *Engine) transcribeWithModel(model whisper.Model) error { + data, cleanup, err := prepareAudioPCM(e.cfg.AudioPath) + if err != nil { + return err + } + defer cleanup() + e.model = model + e.ctx, err = e.model.NewContext() + if err != nil { + return err + } + e.ctx.SetThreads(e.cfg.Threads) + if e.cfg.SpeedUp { + e.ctx.SetAudioCtx(750) + } + e.ctx.SetTranslate(e.cfg.Translate) + if e.cfg.Prompt != "" { + e.ctx.SetInitialPrompt(e.cfg.Prompt) + } + e.ctx.SetMaxContext(int(e.cfg.MaxContext)) + if e.cfg.Debug { + log.Info().Msgf("%s", e.ctx.SystemInfo()) + } + if e.cfg.Language != "" { + _ = e.ctx.SetLanguage(e.cfg.Language) + } + if e.cfg.BeamSize > 0 { + e.ctx.SetBeamSize(int(e.cfg.BeamSize)) + } + if e.cfg.EntropyThold > 0 { + e.ctx.SetEntropyThold(float32(e.cfg.EntropyThold)) + } + if err := prepareVAD(&e.cfg.VAD, ""); err != nil { + return err + } + ApplyVAD(e.ctx, e.cfg.VAD) + log.Debug().Msg("start transcribe process") + e.ctx.ResetTimings() + if err := e.ctx.Process(data, e.cbEncoderBegin(), e.cbSegment(), e.cbProgress()); err != nil { + return err + } + if e.cfg.Debug { + e.ctx.PrintTimings() + } + return nil +} + +func (e *Engine) cbEncoderBegin() func() bool { + return func() bool { return true } +} + +func (e *Engine) cbSegment() func(segment whisper.Segment) { + return func(segment whisper.Segment) { + e.segments = append(e.segments, segment) + if !e.cfg.PrintSegment { + return + } + log.Info().Msgf( + "[%6s -> %6s] %s", + segment.Start.Truncate(time.Millisecond), + segment.End.Truncate(time.Millisecond), + segment.Text, + ) + } +} + +func (e *Engine) cbProgress() func(progress int) { + return func(progress int) { + if progress > 100 { + progress = 100 + } + if e.progress == progress { + return + } + e.progress = progress + if e.cfg.PrintProgress { + log.Info().Msgf("current progress: %d%%", progress) + } + } +} + +func (e *Engine) getOutputPath(format string) string { + ext := filepath.Ext(e.cfg.AudioPath) + filename := filepath.Base(e.cfg.AudioPath) + if e.cfg.OutputFilename != "" { + filename = e.cfg.OutputFilename + } + folder := filepath.Dir(e.cfg.AudioPath) + if e.cfg.OutputFolder != "" { + folder = e.cfg.OutputFolder + } + return path.Join(folder, strings.TrimSuffix(filename, ext)+"."+format) +} + +func (e *Engine) Save(format string) error { + outputPath := e.getOutputPath(format) + log.Info().Str("output-path", outputPath).Str("output-format", format).Msg("save text to file") + text := "" + switch OutputFormat(format) { + case FormatSrt: + for i, segment := range e.segments { + text += fmt.Sprintf("%d\n", i+1) + text += fmt.Sprintf("%s --> %s\n", srtTimestamp(segment.Start), srtTimestamp(segment.End)) + text += segment.Text + "\n\n" + + } + case FormatTxt: + for _, segment := range e.segments { + text += segment.Text + } + case FormatCSV: + text = "start,end,text\n" + for _, segment := range e.segments { + text += fmt.Sprintf("%s,%s,\"%s\"\n", segment.Start, segment.End, segment.Text) + } + } + if err := os.WriteFile(outputPath, []byte(text), 0o644); err != nil { + return err + } + return nil +} + +type Word struct { + Word string `json:"word"` + Start int `json:"start"` + Stop int `json:"stop"` +} + +type TranscriptResult struct { + Text string `json:"text"` + Words []Word `json:"words,omitempty"` +} + +func (e *Engine) SetTranscriptText(text string) { + if len(e.segments) == 0 { + e.segments = []whisper.Segment{{Text: text}} + return + } + start := e.segments[0].Start + end := e.segments[len(e.segments)-1].End + e.segments = []whisper.Segment{{Text: text, Start: start, End: end}} +} + +func (e *Engine) Result() TranscriptResult { + segments := e.segments + if e.runOpts.PunctuateRestore != nil { + updated, err := PunctuateSegments(segments, e.runOpts.PunctuateRestore) + if err == nil { + segments = updated + } + } + text := FormatSegments(segments, e.runOpts.Turns, e.runOpts.Format) + var words []Word + for _, segment := range segments { + words = append(words, segmentWords(segment)...) + } + return TranscriptResult{ + Text: text, + Words: words, + } +} + +func segmentWords(segment whisper.Segment) []Word { + parts := strings.Fields(strings.TrimSpace(segment.Text)) + if len(parts) == 0 { + return nil + } + startMs := int(segment.Start / time.Millisecond) + endMs := int(segment.End / time.Millisecond) + if endMs < startMs { + endMs = startMs + } + span := endMs - startMs + if span <= 0 { + span = 1 + } + step := span / len(parts) + if step < 1 { + step = 1 + } + out := make([]Word, 0, len(parts)) + for i, part := range parts { + wStart := startMs + i*step + wStop := wStart + step + if i == len(parts)-1 { + wStop = endMs + } + out = append(out, Word{ + Word: part, + Start: wStart, + Stop: wStop, + }) + } + return out +} + +func (e *Engine) Close() error { + // Models are owned by ModelPool; do not close shared weights here. + e.ctx = nil + e.model = nil + return nil +} diff --git a/whisper/whisper_test.go b/whisper/whisper_test.go new file mode 100644 index 0000000..a1078fb --- /dev/null +++ b/whisper/whisper_test.go @@ -0,0 +1,66 @@ +package whisper + +import ( + "testing" + + "go-whisper-api/config" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" +) + +func TestEngine_getOutputPath(t *testing.T) { + type fields struct { + cfg *config.Whisper + ctx whisper.Context + model whisper.Model + segments []whisper.Segment + } + type args struct { + format string + } + tests := []struct { + name string + fields fields + args args + want string + }{ + { + name: "change wav to txt", + fields: fields{ + cfg: &config.Whisper{ + AudioPath: "/test/1234/foo.wav", + }, + }, + args: args{ + format: "txt", + }, + want: "/test/1234/foo.txt", + }, + { + name: "change output folder", + fields: fields{ + cfg: &config.Whisper{ + AudioPath: "/test/1234/foo.wav", + OutputFolder: "/foo/bar", + }, + }, + args: args{ + format: "txt", + }, + want: "/foo/bar/foo.txt", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Engine{ + cfg: tt.fields.cfg, + ctx: tt.fields.ctx, + model: tt.fields.model, + segments: tt.fields.segments, + } + if got := e.getOutputPath(tt.args.format); got != tt.want { + t.Errorf("Engine.getOutputPath() = %v, want %v", got, tt.want) + } + }) + } +}