first commit
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s

This commit is contained in:
admin 2026-06-04 18:10:52 +07:00
commit b5c083e06f
105 changed files with 8172 additions and 0 deletions

13
.gitea/FUNDING.yml Normal file
View File

@ -0,0 +1,13 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: ['https://www.paypal.me/appleboy46']

10
.gitea/dependabot.yml Normal file
View File

@ -0,0 +1,10 @@
version: 2
updates:
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
- package-ecosystem: gomod
directory: /
schedule:
interval: weekly

View File

@ -0,0 +1,27 @@
# Gitea Actions
Workflows use [Gitea Actions](https://docs.gitea.com/usage/actions/overview) (compatible with GitHub Actions syntax).
## Secrets
| Secret | Used by | Purpose |
|--------|---------|---------|
| `GITEA_TOKEN` | `docker.yml` | Push images to the instance container registry |
Create a token with **write:package** (and **read:repository** for checkout if needed).
## Images
- **CI / CPU**: `docker/Dockerfile.ci` — built on every push to `main` and tags `v*`
- **GPU**: `docker/Dockerfile` — build manually on NVIDIA hosts
Registry URL: `${{ gitea.server_url }}/${{ gitea.repository }}`
## Local parity
```sh
sudo apt-get install -y build-essential cmake git
go test ./config/... ./punctuation/... ./transcode/...
make dependency && make test
docker build -f docker/Dockerfile.ci -t go-whisper-api:ci .
```

View File

@ -0,0 +1,32 @@
# CodeQL requires GitHub.com or a Gitea instance with the codeql-action available.
# Disable this workflow on self-hosted Gitea if the action cannot be resolved.
name: CodeQL
on:
push:
branches: [main]
pull_request:
branches: [main]
schedule:
- cron: "41 23 * * 6"
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: ["go"]
steps:
- uses: actions/checkout@v4
- uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
continue-on-error: true
- uses: github/codeql-action/analyze@v3
continue-on-error: true

View File

@ -0,0 +1,60 @@
name: Docker Image
on:
push:
branches:
- main
tags:
- "v*"
pull_request:
branches:
- main
env:
# Gitea container registry: <server>/<owner>/<repo>
IMAGE: ${{ gitea.server_url }}/${{ gitea.repository }}
jobs:
build-docker:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Gitea Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ${{ gitea.server_url }}
username: ${{ gitea.actor }}
password: ${{ secrets.GITEA_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: |
${{ env.IMAGE }}
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
- name: Build and push (CPU)
uses: docker/build-push-action@v6
with:
context: .
file: docker/Dockerfile.ci
platforms: linux/amd64
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@ -0,0 +1,43 @@
name: Release
on:
push:
tags:
- "v*"
permissions:
contents: write
packages: write
jobs:
goreleaser:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: Install build dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends build-essential cmake git
- name: Build whisper dependency
run: make dependency
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser
version: latest
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}

62
.gitea/workflows/lint.yml Normal file
View File

@ -0,0 +1,62 @@
name: Lint and Testing
on:
push:
pull_request:
jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Hadolint Dockerfile (CI)
uses: hadolint/hadolint-action@v3.1.0
with:
dockerfile: docker/Dockerfile.ci
- name: Hadolint Dockerfile (GPU)
uses: hadolint/hadolint-action@v3.1.0
with:
dockerfile: docker/Dockerfile
continue-on-error: true
test:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: Install build dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
build-essential cmake git libsentencepiece-dev pkg-config
- name: Unit tests (no CGO whisper)
run: go test ./config/... ./punctuation/... ./transcode/... -count=1
- name: Build with xlm punctuation and run full tests
run: make dependency && make test TAGS=xlm
golangci:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: latest
args: --timeout=10m
continue-on-error: true

44
.gitignore vendored Normal file
View File

@ -0,0 +1,44 @@
# Go build artifacts
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
*.out
go.work
# Binaries and native libs (local build)
/bin/
bin/
/lib/
lib/
# Runtime data (never commit)
config.yaml
.env
.env.*
cache/
# Packaged releases
dist/
*.tar.gz
# Downloaded models and ONNX weights
models/**/*.bin
models/**/*.onnx
models/**/*.model
models/*.bin
models/*.onnx
models/*.model
# whisper.cpp submodule and its build tree
third_party/
# IDE and OS
.idea/
.vscode/
*.swp
*~
.DS_Store
Thumbs.db

3
.golangci.yml Normal file
View File

@ -0,0 +1,3 @@
run:
skip-dirs:
- third_party/whisper.cpp

3
.goreleaser.yaml Normal file
View File

@ -0,0 +1,3 @@
builds:
- skip: true

3
.hadolint.yaml Normal file
View File

@ -0,0 +1,3 @@
ignored:
- DL3018
- DL3008

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Bo-Yi Wu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

280
Makefile Normal file
View File

@ -0,0 +1,280 @@
EXECUTABLE := go-whisper-api
GO ?= go
GOFILES := $(shell find . -name "*.go" -type f)
HAS_GO = $(shell hash $(GO) > /dev/null 2>&1 && echo "GO" || echo "NOGO" )
WHISPER_CPP := $(abspath third_party/whisper.cpp)
WHISPER_BUILD := $(WHISPER_CPP)/build
WHISPER_VENDOR := $(WHISPER_CPP)/bindings/go
WHISPER_LIBDIR := $(WHISPER_BUILD)/src:$(WHISPER_BUILD)/ggml/src
RUNTIME_LIB_DIR := $(abspath lib)
# $ORIGIN/lib — binary next to lib/ (e.g. ./go-whisper-api + ./lib/)
# $ORIGIN/../lib — binary in bin/ (e.g. bin/go-whisper-api + lib/)
RUNTIME_RPATH := -Wl,-rpath,$$ORIGIN/lib:$$ORIGIN/../lib
ifneq ($(shell uname), Darwin)
EXTLDFLAGS = -extldflags "$(RUNTIME_RPATH)"
else
EXTLDFLAGS =
endif
ifeq ($(HAS_GO), GO)
GOPATH ?= $(shell $(GO) env GOPATH)
export PATH := $(GOPATH)/bin:$(PATH)
CGO_EXTRA_CFLAGS := -DSQLITE_MAX_VARIABLE_NUMBER=32766
CGO_CFLAGS ?= $(shell $(GO) env CGO_CFLAGS) $(CGO_EXTRA_CFLAGS)
endif
ifeq ($(OS), Windows_NT)
GOFLAGS := -v -buildmode=exe
EXECUTABLE ?= $(EXECUTABLE).exe
else ifeq ($(OS), Windows)
GOFLAGS := -v -buildmode=exe
EXECUTABLE ?= $(EXECUTABLE).exe
else
GOFLAGS := -v
EXECUTABLE ?= $(EXECUTABLE)
endif
ifneq ($(DRONE_TAG),)
VERSION ?= $(DRONE_TAG)
else
VERSION ?= $(shell git describe --tags --always || git rev-parse --short HEAD)
endif
TAGS ?=
UNAME_M := $(shell uname -m)
ifeq ($(UNAME_M),x86_64)
SHERPA_LIBARCH := x86_64-unknown-linux-gnu
endif
ifeq ($(UNAME_M),aarch64)
SHERPA_LIBARCH := aarch64-unknown-linux-gnu
endif
SHERPA_LINUX_VER := $(shell awk '/sherpa-onnx-go-linux/ {print $$2; exit}' go.mod)
SHERPA_LIBDIR := $(GOPATH)/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH)
ifneq ($(shell uname), Darwin)
EXTLDFLAGS_SHERPA = -extldflags "$(RUNTIME_RPATH)"
EXTLDFLAGS_XLM = -extldflags "$(RUNTIME_RPATH)"
else
EXTLDFLAGS_SHERPA =
EXTLDFLAGS_XLM =
endif
GOLDFLAGS ?= -X 'main.Version=$(VERSION)'
INCLUDE_PATH := $(WHISPER_CPP)/include:$(WHISPER_CPP)/ggml/include:$(WHISPER_VENDOR):$(INCLUDE_PATH)
LIBRARY_PATH := $(WHISPER_LIBDIR):$(LIBRARY_PATH)
export LD_LIBRARY_PATH := $(WHISPER_LIBDIR):$(LD_LIBRARY_PATH)
ifdef WHISPER_CUBLAS
CGO_CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CGO_CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
EXTLDFLAGS = -extldflags "-lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib"
build: $(EXECUTABLE)
$(EXECUTABLE): $(GOFILES)
CGO_CXXFLAGS=${CGO_CXXFLAGS} CGO_CFLAGS=${CGO_CFLAGS} C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@
endif
MODEL_URL ?= https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin
MODEL_PATH ?= models/ggml-tiny.en.bin
VAD_MODEL ?= silero-v6.2.0
VAD_MODEL_PATH ?= models/ggml-silero-v6.2.0.bin
all: build
PUNCT_MODEL_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8.tar.bz2
PUNCT_MODEL_DIR ?= models/punctuation/ct-transformer-zh-en-int8
XLM_PUNCT_DIR ?= models/punctuation/xlm-roberta
XLM_HF_REPO ?= Salama1429/xlm-roberta_punctuation_fullstop_truecase
XLM_MODEL_CONFIG_SRC ?= config/xlm-roberta-model.yaml
ORT_LIB_SRC ?= $(shell $(GO) env GOMODCACHE 2>/dev/null)/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH)/libonnxruntime.so
# Copy runtime .so into ./lib/. Binary rpath: $ORIGIN/lib or $ORIGIN/../lib (see RUNTIME_RPATH).
# Use cp -n where possible: existing root-owned libs in lib/ must not break the build.
install-runtime-libs: dependency
@mkdir -p "$(RUNTIME_LIB_DIR)"
@cp -an "$(WHISPER_BUILD)/src"/libwhisper.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true
@cp -an "$(WHISPER_BUILD)/ggml/src"/libggml*.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true
@echo "Whisper/ggml libs ready in $(RUNTIME_LIB_DIR)/"
install-ort-lib:
@mkdir -p "$(RUNTIME_LIB_DIR)"
@if [ ! -f "$(ORT_LIB_SRC)" ]; then echo "missing $(ORT_LIB_SRC); run: go mod download"; exit 1; fi
@dest="$(RUNTIME_LIB_DIR)/libonnxruntime.so"; \
if [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \
echo "libonnxruntime.so already up to date in $(RUNTIME_LIB_DIR)/"; \
elif cp -f "$(ORT_LIB_SRC)" "$$dest" 2>/dev/null; then \
echo "Installed $$dest"; \
elif [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \
echo "libonnxruntime.so present in $(RUNTIME_LIB_DIR)/ (unchanged, not writable)"; \
else \
echo "cannot install libonnxruntime.so to $$dest"; \
echo "fix: sudo chown -R $$USER:$$(id -gn) $(RUNTIME_LIB_DIR)"; \
exit 1; \
fi
# XLM punctuation links -lsentencepiece; bundle .so for hosts without libsentencepiece0 package.
SP_LIB_DIRS := /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu /usr/lib64 /usr/lib
install-sp-lib:
@mkdir -p "$(RUNTIME_LIB_DIR)"
@found=0; \
for d in $(SP_LIB_DIRS); do \
if [ -e "$$d/libsentencepiece.so.0" ] || [ -L "$$d/libsentencepiece.so.0" ]; then \
cp -an "$$d"/libsentencepiece.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true; \
found=1; \
break; \
fi; \
done; \
if [ "$$found" = "0" ]; then \
echo "libsentencepiece.so.0 not found; install: sudo apt-get install libsentencepiece0"; \
exit 1; \
fi
@test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0 after install-sp-lib"; exit 1)
@echo "Sentencepiece libs ready in $(RUNTIME_LIB_DIR)/"
# If lib/*.so were created as root (e.g. manual cp with sudo), reclaim ownership for builds.
fix-lib-perms:
@if [ -d "$(RUNTIME_LIB_DIR)" ]; then \
chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)" 2>/dev/null || \
sudo chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)"; \
echo "Ownership of $(RUNTIME_LIB_DIR)/ updated"; \
fi
install-runtime-libs-xlm: install-runtime-libs install-ort-lib install-sp-lib
# Fail fast before deploy if ./lib is incomplete (lib/ is not in git: *.so is gitignored).
verify-runtime-libs-xlm:
@test -f bin/$(EXECUTABLE) || (echo "missing bin/$(EXECUTABLE); run: make build-xlm"; exit 1)
@test -f "$(RUNTIME_LIB_DIR)/libonnxruntime.so" || (echo "missing $(RUNTIME_LIB_DIR)/libonnxruntime.so; run: make install-runtime-libs-xlm"; exit 1)
@test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0; run: make install-sp-lib"; exit 1)
@test -e "$(RUNTIME_LIB_DIR)/libwhisper.so.1" || (echo "missing $(RUNTIME_LIB_DIR)/libwhisper.so.1; run: make install-runtime-libs"; exit 1)
@echo "Runtime libs OK in $(RUNTIME_LIB_DIR)/"
RUNTIME_TARBALL := dist/go-whisper-api-runtime-$(shell uname -m).tar.gz
package-runtime-xlm: verify-runtime-libs-xlm
@mkdir -p dist
tar -czf "$(RUNTIME_TARBALL)" bin/$(EXECUTABLE) lib
@echo "Created $(RUNTIME_TARBALL) — on prod: tar -xzf ... -C /opt/go-whisper-api (keeps bin/ and lib/)"
# Copy bundled label config (needs write access to $(XLM_PUNCT_DIR); fix with: sudo chown -R $$USER models/punctuation)
install-xlm-punctuation-config:
@mkdir -p $(XLM_PUNCT_DIR)
@cp "$(XLM_MODEL_CONFIG_SRC)" "$(XLM_PUNCT_DIR)/config.yaml"
@echo "Installed $(XLM_PUNCT_DIR)/config.yaml"
download-xlm-punctuation-model: install-xlm-punctuation-config
@mkdir -p $(XLM_PUNCT_DIR)
@for f in model.onnx sp.model; do \
if [ ! -f "$(XLM_PUNCT_DIR)/$$f" ]; then \
echo "Downloading $$f from $(XLM_HF_REPO)..."; \
curl -fL "https://huggingface.co/$(XLM_HF_REPO)/resolve/main/$$f" -o "$(XLM_PUNCT_DIR)/$$f"; \
else \
echo "Already have $(XLM_PUNCT_DIR)/$$f"; \
fi; \
done
DIAR_SEG_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
DIAR_EMB_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
DIAR_DIR ?= models/diarization
download-diarization-models:
@mkdir -p $(DIAR_DIR)
@if [ ! -f "$(DIAR_DIR)/pyannote-segmentation-3-0/model.onnx" ]; then \
echo "Downloading speaker segmentation model..."; \
curl -fL "$(DIAR_SEG_URL)" -o /tmp/diar-seg.tar.bz2; \
tar -xjf /tmp/diar-seg.tar.bz2 -C $(DIAR_DIR); \
rm -f /tmp/diar-seg.tar.bz2; \
else \
echo "Segmentation model present"; \
fi
@if [ ! -f "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" ]; then \
echo "Downloading speaker embedding model..."; \
curl -fL "$(DIAR_EMB_URL)" -o "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; \
else \
echo "Embedding model present"; \
fi
download-punctuation-model:
@mkdir -p models/punctuation
@if [ ! -f "$(PUNCT_MODEL_DIR)/model.int8.onnx" ]; then \
echo "Downloading punctuation model..."; \
curl -fL "$(PUNCT_MODEL_URL)" -o /tmp/punct-model.tar.bz2; \
tar -xjf /tmp/punct-model.tar.bz2 -C models/punctuation; \
rm -f /tmp/punct-model.tar.bz2; \
if [ -d models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 ]; then \
mv models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 "$(PUNCT_MODEL_DIR)"; \
fi; \
else \
echo "Punctuation model already exists: $(PUNCT_MODEL_DIR)/model.int8.onnx"; \
fi
download-model:
@mkdir -p models
@if [ ! -f "$(MODEL_PATH)" ]; then \
echo "Downloading $(MODEL_PATH)..."; \
curl -fL "$(MODEL_URL)" -o "$(MODEL_PATH)"; \
else \
echo "Model already exists: $(MODEL_PATH)"; \
fi
download-vad-model:
@mkdir -p models
@if [ ! -f "$(VAD_MODEL_PATH)" ]; then \
echo "Downloading VAD model $(VAD_MODEL) to models/..."; \
./third_party/whisper.cpp/models/download-vad-model.sh $(VAD_MODEL) models; \
else \
echo "VAD model already exists: $(VAD_MODEL_PATH)"; \
fi
clone:
@[ -d third_party/whisper.cpp ] || git clone https://github.com/appleboy/whisper.cpp.git third_party/whisper.cpp
dependency: clone
@echo Build whisper
@if [ ! -f "$(WHISPER_BUILD)/src/libwhisper.so" ] && [ ! -f "$(WHISPER_BUILD)/src/libwhisper.a" ]; then \
cmake -S "$(WHISPER_CPP)" -B "$(WHISPER_BUILD)" -DCMAKE_BUILD_TYPE=Release && \
cmake --build "$(WHISPER_BUILD)" --config Release -j$$(nproc 2>/dev/null || echo 4); \
else \
echo "whisper library already built in $(WHISPER_BUILD)"; \
fi
test:
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) test -v -cover -coverprofile coverage.txt ./... && echo "\n==>\033[32m Ok\033[m\n" || exit 1
install: $(GOFILES)
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) install -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)'
build: install-runtime-libs $(EXECUTABLE)
# Build with sherpa-onnx (punctuation + speaker diarization)
build-sherpa:
@$(MAKE) build TAGS=sherpa
# XLM-RoBERTa punctuation (47 languages); requires libsentencepiece-dev
build-xlm:
@$(MAKE) install-runtime-libs-xlm
@$(MAKE) build TAGS=xlm
$(EXECUTABLE): $(GOFILES)
ifneq (,$(findstring xlm,$(TAGS)))
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_XLM) -s -w $(GOLDFLAGS)' -o bin/$@
else ifneq (,$(findstring sherpa,$(TAGS)))
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_SHERPA) -s -w $(GOLDFLAGS)' -o bin/$@
else
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@
endif
clean:
$(GO) clean -x -i ./...
rm -rf coverage.txt $(EXECUTABLE) $(DIST) bin/$(EXECUTABLE)
clean-whisper:
rm -rf "$(WHISPER_BUILD)"
version:
@echo $(VERSION)

129
README.md Normal file
View File

@ -0,0 +1,129 @@
# go-whisper-api
HTTP-сервер распознавания речи на базе [whisper.cpp](https://github.com/ggerganov/whisper.cpp): совместимость с **SPR** (`/spr/*`) и **OpenAI Whisper API** (`/v1/*`) для Open WebUI и других клиентов.
## Модели Whisper (ggml)
См. [каталог моделей на Hugging Face](https://huggingface.co/ggerganov/whisper.cpp/tree/main).
```sh
make download-model
# или вручную:
curl -LJ https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin \
--output models/ggml-small.bin
```
Веса STT: `*.bin` в корне `api.models_dir` (список — `GET /spr/models`). Подкаталоги `models/vad/`, `models/punctuation/` и т.п. в список моделей не входят.
## Конфигурация
Шаблон в репозитории — `config.yaml.example`. Рабочий файл создайте локально (он в `.gitignore`):
```sh
cp config.yaml.example config.yaml
```
Если `--config` не указан и в текущем каталоге есть `config.yaml`, он подгружается автоматически.
Промежуточные файлы (загрузки, конвертация) — только в `/tmp`, удаляются после обработки. Форматы **wav, mp3, flac, ogg, m4a, mp4, aac** декодируются на чистом Go (без ffmpeg).
## Запуск
```sh
go-whisper-api serve --config config.yaml
# или с флагами (перекрывают YAML):
go-whisper-api --models-dir ./models --addr :8080
```
Docker:
```sh
docker run -p 8080:8080 \
-v $PWD/models:/app/models \
-v $PWD/config.yaml:/app/config.yaml \
ghcr.io/appleboy/go-whisper-api:latest \
serve --config /app/config.yaml
```
Swagger UI: [http://localhost:8080/](http://localhost:8080/)
## VAD (детекция речи)
```sh
make download-vad-model
```
```yaml
api:
models_dir: ./models
vad:
enabled: true
model: ggml-silero-v6.2.0.bin
```
Полезно на длинных записях с тишиной или музыкой в начале/конце.
## Пунктуация
| Уровень | Параметр | Эффект |
|---------|----------|--------|
| Мастер | `punctuation.enabled: false` | Пунктуация отключена |
| API по умолчанию | `api.default_punctuation: true` | Если нет `?punctuation=` |
| На запрос | `?punctuation=1` / `?punctuation=0` | Только при `enabled: true` |
Движки: `heuristic`, `xlm`, `sherpa`, `sherpa-online`, `http`, `off`.
```sh
curl -X POST "http://localhost:8080/spr/stt/ggml-small?punctuation=1" -F "wav=@audio.wav"
```
Сборка XLM: `make download-xlm-punctuation-model && make build-xlm`. Продакшен: `make package-runtime-xlm` — в git не коммитятся `lib/*.so` (см. `.gitignore`).
Фильтр артефактов: `api.garbage` (по умолчанию `*выбая*`), отключить: `garbage: []`.
## HTTP API
- **SPR** (`/spr/*`) — очередь, waveform, импорт/экспорт моделей
- **OpenAI** (`/v1/*`) — синхронный STT
Пример SPR:
```sh
curl -X POST "http://localhost:8080/spr/stt/ggml-small?async=0" -F "wav=@audio.wav"
```
По умолчанию STT **асинхронный**: `POST /spr/stt/{id}``taskID`, опрос `GET /spr/queue/{taskID}`, результат `GET /spr/result/{taskID}`. Кэш: `./cache/waiting/` и `./cache/ready/`.
Язык по умолчанию — **`ru`** (`?language=en`, `?language=auto`).
| Endpoint | Метод | Описание |
|----------|--------|-------------|
| `/spr/models` | GET | Список моделей |
| `/spr/stt/{id}` | POST | Транскрипция |
| `/spr/result/{taskID}` | GET | Результат |
| `/spr/queue` | GET | Все задачи |
| `/spr/queue/{taskID}` | GET / DELETE | Статус / удаление |
| `/v1/audio/transcriptions` | POST | OpenAI-совместимая транскрипция |
| `/v1/models` | GET | Список моделей |
## Open WebUI
| Параметр | Значение |
|----------|----------|
| Engine | `OpenAI` |
| API Base URL | `http://<host>:6183/v1` |
| API Key | любая непустая строка |
| STT Model | `whisper-1` (см. `api.default_model` в config) |
```yaml
api:
default_model: ggml-large-v3-turbo
```
## CI/CD
Workflows в `.gitea/workflows/` (`lint.yml`, `docker.yml`). Секрет `GITEA_TOKEN` с `write:package` для push образов.
```sh
docker build -f docker/Dockerfile.ci -t go-whisper-api .
```

362
api/cache.go Normal file
View File

@ -0,0 +1,362 @@
package api
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"time"
"go-whisper-api/whisper"
)
const (
cacheWaiting = "waiting"
cacheReady = "ready"
fileParams = "params.conf"
fileAudio = "audio.wav"
fileAudioJSON = "audio.json"
)
type TaskParams struct {
ID string `json:"id"`
Created string `json:"created"`
Processed string `json:"processed,omitempty"`
Status string `json:"status"`
Model string `json:"model"`
Language string `json:"language,omitempty"`
Punctuation bool `json:"punctuation,omitempty"`
Speakers bool `json:"speakers,omitempty"`
NumClusters int `json:"num_clusters,omitempty"`
Text string `json:"text,omitempty"`
Words []whisper.Word `json:"words,omitempty"`
Error string `json:"error,omitempty"`
}
type AudioJSON struct {
Waveform []float64 `json:"waveform"`
Buckets int `json:"buckets"`
}
type DiskCache struct {
root string
}
func (c *DiskCache) Root() string {
return c.root
}
func resolveCacheRoot(root string) (string, error) {
if root == "" {
root = "./cache"
}
abs, err := filepath.Abs(root)
if err != nil {
return "", fmt.Errorf("cache dir: %w", err)
}
return filepath.Clean(abs), nil
}
func NewDiskCache(root string) (*DiskCache, error) {
abs, err := resolveCacheRoot(root)
if err != nil {
return nil, err
}
c := &DiskCache{root: abs}
for _, sub := range []string{cacheWaiting, cacheReady} {
if err := os.MkdirAll(filepath.Join(abs, sub), 0o755); err != nil {
return nil, err
}
}
return c, nil
}
func (c *DiskCache) waitingDir(id string) string {
return filepath.Join(c.root, cacheWaiting, id)
}
func (c *DiskCache) readyDir(id string) string {
return filepath.Join(c.root, cacheReady, id)
}
func (c *DiskCache) locate(id string) (dir, phase string, ok bool) {
if id == "" {
return "", "", false
}
ready := c.readyDir(id)
if st, err := os.Stat(ready); err == nil && st.IsDir() {
return ready, cacheReady, true
}
waiting := c.waitingDir(id)
if st, err := os.Stat(waiting); err == nil && st.IsDir() {
return waiting, cacheWaiting, true
}
return "", "", false
}
func (c *DiskCache) Enqueue(id string, params TaskParams, audioWavPath string) error {
dir := c.waitingDir(id)
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
params.ID = id
if err := c.writeParams(dir, params); err != nil {
_ = os.RemoveAll(dir)
return err
}
dst := filepath.Join(dir, fileAudio)
if err := os.Rename(audioWavPath, dst); err != nil {
if err2 := copyFile(audioWavPath, dst); err2 != nil {
_ = os.RemoveAll(dir)
return fmt.Errorf("move audio to %s: %w", dst, err2)
}
_ = os.Remove(audioWavPath)
}
return nil
}
func (c *DiskCache) writeParams(dir string, params TaskParams) error {
data, err := json.Marshal(params)
if err != nil {
return err
}
return os.WriteFile(filepath.Join(dir, fileParams), data, 0o644)
}
func (c *DiskCache) LoadParams(id string) (TaskParams, string, error) {
dir, phase, ok := c.locate(id)
if !ok {
return TaskParams{}, "", fmt.Errorf("task not found")
}
params, err := c.readParams(dir)
if err != nil {
return TaskParams{}, "", err
}
return params, phase, nil
}
func (c *DiskCache) readParams(dir string) (TaskParams, error) {
path := filepath.Join(dir, fileParams)
data, err := os.ReadFile(path)
if err != nil {
return TaskParams{}, fmt.Errorf("read %s: %w", path, err)
}
var params TaskParams
if err := json.Unmarshal(data, &params); err != nil {
return TaskParams{}, err
}
return params, nil
}
func (c *DiskCache) List() (map[string]map[string]string, error) {
out := make(map[string]map[string]string)
for _, phase := range []string{cacheWaiting, cacheReady} {
base := filepath.Join(c.root, phase)
entries, err := os.ReadDir(base)
if err != nil {
if os.IsNotExist(err) {
continue
}
return nil, err
}
for _, e := range entries {
if !e.IsDir() {
continue
}
id := e.Name()
params, err := c.readParams(filepath.Join(base, id))
if err != nil {
continue
}
out[id] = map[string]string{
"created": params.Created,
"status": params.Status,
}
}
}
return out, nil
}
func (c *DiskCache) NextWaiting() (string, bool, error) {
base := filepath.Join(c.root, cacheWaiting)
entries, err := os.ReadDir(base)
if err != nil {
if os.IsNotExist(err) {
return "", false, nil
}
return "", false, err
}
type item struct {
id string
created time.Time
}
var pending []item
for _, e := range entries {
if !e.IsDir() {
continue
}
params, err := c.readParams(filepath.Join(base, e.Name()))
if err != nil || params.Status != string(statusWaiting) {
continue
}
t, _ := time.ParseInLocation("2006-01-02 15:04:05", params.Created, time.Local)
pending = append(pending, item{id: e.Name(), created: t})
}
if len(pending) == 0 {
return "", false, nil
}
sort.Slice(pending, func(i, j int) bool {
return pending[i].created.Before(pending[j].created)
})
return pending[0].id, true, nil
}
func (c *DiskCache) SetStatus(id string, status taskStatus, mutate func(*TaskParams)) error {
dir := c.waitingDir(id)
if _, err := os.Stat(dir); err != nil {
return fmt.Errorf("task %s not in waiting", id)
}
params, err := c.readParams(dir)
if err != nil {
return err
}
params.Status = string(status)
if mutate != nil {
mutate(&params)
}
return c.writeParams(dir, params)
}
func (c *DiskCache) FinishWaiting(id string, result whisper.TranscriptResult, errMsg string, waveform []float64) error {
dir := c.waitingDir(id)
params, err := c.readParams(dir)
if err != nil {
return err
}
params.Processed = time.Now().Format("2006-01-02 15:04:05")
if errMsg != "" {
params.Status = string(statusError)
params.Error = errMsg
} else {
params.Status = string(statusReady)
params.Text = result.Text
params.Words = result.Words
}
if err := c.writeParams(dir, params); err != nil {
return fmt.Errorf("update %s: %w", filepath.Join(dir, fileParams), err)
}
if len(waveform) > 0 {
aj := AudioJSON{Waveform: waveform, Buckets: len(waveform)}
data, err := json.Marshal(aj)
if err != nil {
return err
}
if err := os.WriteFile(filepath.Join(dir, fileAudioJSON), data, 0o644); err != nil {
return err
}
}
return nil
}
func (c *DiskCache) PromoteToReady(id string) error {
if _, phase, ok := c.locate(id); ok && phase == cacheReady {
return nil
}
src := c.waitingDir(id)
dst := c.readyDir(id)
if _, err := os.Stat(src); err != nil {
return fmt.Errorf("task %s not in waiting", id)
}
if _, err := os.Stat(dst); err == nil {
return os.RemoveAll(src)
}
return os.Rename(src, dst)
}
func (c *DiskCache) Delete(id string) bool {
dir, _, ok := c.locate(id)
if !ok {
return false
}
_ = os.RemoveAll(dir)
return true
}
func (c *DiskCache) AudioPath(id string) (string, bool) {
dir, _, ok := c.locate(id)
if !ok {
return "", false
}
p := filepath.Join(dir, fileAudio)
if _, err := os.Stat(p); err != nil {
return "", false
}
return p, true
}
func (c *DiskCache) Waveform(id string) ([]float64, error) {
dir, _, ok := c.locate(id)
if !ok {
return nil, fmt.Errorf("task not found")
}
data, err := os.ReadFile(filepath.Join(dir, fileAudioJSON))
if err == nil {
var aj AudioJSON
if json.Unmarshal(data, &aj) == nil && len(aj.Waveform) > 0 {
return aj.Waveform, nil
}
}
return waveformFromWav(filepath.Join(dir, fileAudio), 512)
}
func (c *DiskCache) RecoverInterrupted() error {
base := filepath.Join(c.root, cacheWaiting)
entries, err := os.ReadDir(base)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
for _, e := range entries {
if !e.IsDir() {
continue
}
id := e.Name()
dir := filepath.Join(base, id)
params, err := c.readParams(dir)
if err != nil {
continue
}
switch params.Status {
case string(statusProcessing):
params.Status = string(statusWaiting)
_ = c.writeParams(dir, params)
case string(statusReady), string(statusError):
_ = c.PromoteToReady(id)
}
}
return nil
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, in)
return err
}
func isValidTaskID(id string) bool {
return id != "" && !strings.Contains(id, "..") && !strings.ContainsAny(id, `/\`)
}

127
api/cache_test.go Normal file
View File

@ -0,0 +1,127 @@
package api
import (
"os"
"path/filepath"
"testing"
)
func TestDiskCache_params_promote_list(t *testing.T) {
root := t.TempDir()
c, err := NewDiskCache(root)
if err != nil {
t.Fatal(err)
}
id := "319d72c7-301d-44fd-935f-3526dfb70f9f"
dir := c.waitingDir(id)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatal(err)
}
params := TaskParams{
ID: id,
Created: "2026-03-31 21:37:46",
Status: string(statusReady),
Model: "ggml-small",
Text: "test",
}
if err := c.writeParams(dir, params); err != nil {
t.Fatal(err)
}
list, err := c.List()
if err != nil {
t.Fatal(err)
}
if list[id]["status"] != string(statusReady) {
t.Fatalf("list: %v", list[id])
}
if err := c.PromoteToReady(id); err != nil {
t.Fatal(err)
}
_, phase, err := c.LoadParams(id)
if err != nil || phase != cacheReady {
t.Fatalf("phase=%s err=%v", phase, err)
}
if _, err := os.Stat(filepath.Join(c.readyDir(id), fileParams)); err != nil {
t.Fatal(err)
}
}
func TestIsValidTaskID(t *testing.T) {
if !isValidTaskID("319d72c7-301d-44fd-935f-3526dfb70f9f") {
t.Fatal("uuid should be valid")
}
if isValidTaskID("../etc") {
t.Fatal("path traversal must be rejected")
}
}
func TestResolveCacheRoot_absolute(t *testing.T) {
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
abs, err := resolveCacheRoot("./cache")
if err != nil {
t.Fatal(err)
}
want := filepath.Join(cwd, "cache")
if abs != want {
t.Fatalf("got %q want %q", abs, want)
}
}
func TestDiskCache_RecoverInterrupted(t *testing.T) {
root := t.TempDir()
c, err := NewDiskCache(root)
if err != nil {
t.Fatal(err)
}
id := "task-1"
dir := c.waitingDir(id)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatal(err)
}
if err := c.writeParams(dir, TaskParams{
ID: id, Created: "2026-01-01 00:00:00", Status: string(statusProcessing), Model: "m",
}); err != nil {
t.Fatal(err)
}
if err := c.RecoverInterrupted(); err != nil {
t.Fatal(err)
}
p, _, err := c.LoadParams(id)
if err != nil {
t.Fatal(err)
}
if p.Status != string(statusWaiting) {
t.Fatalf("got %q", p.Status)
}
}
func TestDiskCache_RecoverInterrupted_promotesReady(t *testing.T) {
root := t.TempDir()
c, err := NewDiskCache(root)
if err != nil {
t.Fatal(err)
}
id := "done-task"
dir := c.waitingDir(id)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatal(err)
}
if err := c.writeParams(dir, TaskParams{
ID: id, Created: "2026-01-01 00:00:00", Status: string(statusReady), Model: "m", Text: "hi",
}); err != nil {
t.Fatal(err)
}
if err := c.RecoverInterrupted(); err != nil {
t.Fatal(err)
}
_, phase, err := c.LoadParams(id)
if err != nil || phase != cacheReady {
t.Fatalf("phase=%s err=%v", phase, err)
}
}

26
api/garbage.go Normal file
View File

@ -0,0 +1,26 @@
package api
import (
"go-whisper-api/garbage"
"go-whisper-api/whisper"
)
func applyGarbage(r whisper.TranscriptResult, patterns []string) whisper.TranscriptResult {
if len(patterns) == 0 {
return r
}
r.Text = garbage.FilterText(r.Text, patterns)
if len(r.Words) == 0 {
return r
}
gw := make([]garbage.Word, len(r.Words))
for i, w := range r.Words {
gw[i] = garbage.Word{Word: w.Word, Start: w.Start, Stop: w.Stop}
}
gw = garbage.FilterWords(gw, patterns)
r.Words = make([]whisper.Word, len(gw))
for i, w := range gw {
r.Words[i] = whisper.Word{Word: w.Word, Start: w.Start, Stop: w.Stop}
}
return r
}

181
api/models.go Normal file
View File

@ -0,0 +1,181 @@
package api
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
)
// Subdirectories under models_dir that hold non-Whisper assets (VAD, punctuation, etc.).
var reservedModelSubdirs = map[string]struct{}{
"vad": {},
"punctuation": {},
}
// Top-level .bin files that are not Whisper STT models.
var excludedWhisperModelFiles = map[string]struct{}{
"vad.bin": {},
}
type Registry struct {
dir string
mu sync.RWMutex
}
func NewRegistry(dir string) *Registry {
return &Registry{dir: dir}
}
func isReservedModelSubdir(name string) bool {
_, ok := reservedModelSubdirs[strings.ToLower(name)]
return ok
}
func isWhisperModelFile(name string) bool {
if !strings.HasSuffix(strings.ToLower(name), ".bin") {
return false
}
_, excluded := excludedWhisperModelFiles[strings.ToLower(name)]
return !excluded
}
func (r *Registry) List() ([]string, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entries, err := os.ReadDir(r.dir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var models []string
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !isWhisperModelFile(name) {
continue
}
models = append(models, strings.TrimSuffix(name, filepath.Ext(name)))
}
return models, nil
}
func (r *Registry) Path(id string) (string, error) {
if id == "" {
return "", fmt.Errorf("model id is required")
}
if strings.Contains(id, "/") || strings.Contains(id, "..") {
return "", fmt.Errorf("invalid model id")
}
if isReservedModelSubdir(id) {
return "", fmt.Errorf("model %q not found", id)
}
r.mu.RLock()
defer r.mu.RUnlock()
candidates := []string{
filepath.Join(r.dir, id+".bin"),
filepath.Join(r.dir, id),
filepath.Join(r.dir, "ggml-"+id+".bin"),
}
for _, p := range candidates {
if st, err := os.Stat(p); err == nil && !st.IsDir() && isWhisperModelFile(filepath.Base(p)) {
return p, nil
}
}
return "", fmt.Errorf("model %q not found", id)
}
func (r *Registry) Delete(id string) error {
p, err := r.Path(id)
if err != nil {
return err
}
r.mu.Lock()
defer r.mu.Unlock()
return os.Remove(p)
}
func (r *Registry) Import(id string, src io.Reader) error {
if id == "" {
return fmt.Errorf("model id is required")
}
if strings.Contains(id, "/") || strings.Contains(id, "..") {
return fmt.Errorf("invalid model id")
}
if isReservedModelSubdir(id) {
return fmt.Errorf("invalid model id")
}
r.mu.Lock()
defer r.mu.Unlock()
if err := os.MkdirAll(r.dir, 0o755); err != nil {
return err
}
dst := filepath.Join(r.dir, id+".bin")
f, err := os.Create(dst)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, src); err != nil {
os.Remove(dst)
return err
}
return nil
}
func (r *Registry) Open(id string) (*os.File, error) {
p, err := r.Path(id)
if err != nil {
return nil, err
}
return os.Open(p)
}
// Resolve maps an OpenAI-style model name to a local whisper model id.
func (r *Registry) Resolve(id, defaultModel string) (string, error) {
id = strings.TrimSpace(id)
if id == "" {
id = strings.TrimSpace(defaultModel)
}
if id == "" {
models, err := r.List()
if err != nil {
return "", err
}
if len(models) == 0 {
return "", fmt.Errorf("model is required")
}
return models[0], nil
}
if id == "whisper-1" {
if dm := strings.TrimSpace(defaultModel); dm != "" {
if _, err := r.Path(dm); err == nil {
return dm, nil
}
}
models, err := r.List()
if err != nil {
return "", err
}
if len(models) == 0 {
return "", fmt.Errorf("no whisper models installed")
}
return models[0], nil
}
if _, err := r.Path(id); err == nil {
return id, nil
}
if alt := strings.TrimPrefix(id, "whisper-"); alt != id {
if _, err := r.Path(alt); err == nil {
return alt, nil
}
}
return "", fmt.Errorf("model %q not found", id)
}

55
api/models_test.go Normal file
View File

@ -0,0 +1,55 @@
package api
import (
"os"
"path/filepath"
"testing"
)
func TestRegistry_List_excludesAuxiliaryDirs(t *testing.T) {
root := t.TempDir()
for _, name := range []string{"common.bin", "vad/vad.bin", "punctuation/model.onnx"} {
p := filepath.Join(root, name)
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(p, []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
}
r := NewRegistry(root)
list, err := r.List()
if err != nil {
t.Fatal(err)
}
if len(list) != 1 || list[0] != "common" {
t.Fatalf("list=%v want [common]", list)
}
}
func TestRegistry_List_excludesVADAtRoot(t *testing.T) {
root := t.TempDir()
if err := os.WriteFile(filepath.Join(root, "vad.bin"), []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(root, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
list, err := NewRegistry(root).List()
if err != nil {
t.Fatal(err)
}
if len(list) != 1 || list[0] != "ggml-small" {
t.Fatalf("list=%v", list)
}
}
func TestRegistry_Path_rejectsReservedIDs(t *testing.T) {
root := t.TempDir()
r := NewRegistry(root)
for _, id := range []string{"vad", "punctuation"} {
if _, err := r.Path(id); err == nil {
t.Fatalf("expected error for %q", id)
}
}
}

121
api/openai.go Normal file
View File

@ -0,0 +1,121 @@
package api
import (
"encoding/json"
"net/http"
"strings"
"time"
)
func (s *Server) handleOpenAITranscriptions(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if err := r.ParseMultipartForm(128 << 20); err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
modelID, err := s.models.Resolve(r.FormValue("model"), s.cfg.DefaultModel)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
modelPath, err := s.models.Path(modelID)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
audioPath, cleanup, err := s.saveUploadedOpenAI(r)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
defer cleanup()
stt := s.parseOpenAISTTOptions(r)
result, err := s.transcribe(r.Context(), modelPath, audioPath, stt)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
return
}
switch strings.ToLower(strings.TrimSpace(r.FormValue("response_format"))) {
case "text":
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(result.Text))
default:
writeJSON(w, http.StatusOK, map[string]string{"text": result.Text})
}
}
func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ids, err := s.models.List()
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
return
}
now := time.Now().Unix()
data := make([]map[string]any, 0, len(ids))
for _, id := range ids {
data = append(data, map[string]any{
"id": id,
"object": "model",
"created": now,
"owned_by": "go-whisper-api",
})
}
writeJSON(w, http.StatusOK, map[string]any{
"object": "list",
"data": data,
})
}
func (s *Server) parseOpenAISTTOptions(r *http.Request) sttOptions {
lang := strings.TrimSpace(r.FormValue("language"))
if lang == "" {
lang = s.cfg.Language
}
return sttOptions{
language: lang,
punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation),
}
}
func (s *Server) saveUploadedOpenAI(r *http.Request) (path string, cleanup func(), err error) {
if r.MultipartForm == nil {
if err := r.ParseMultipartForm(128 << 20); err != nil {
return "", nil, err
}
}
return saveUploadedRawFields(r, []string{"file", "audio", "wav"})
}
func writeOpenAIError(w http.ResponseWriter, code int, msg string) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{
"message": msg,
"type": openAIErrorType(code),
},
})
}
func openAIErrorType(code int) string {
switch code {
case http.StatusBadRequest:
return "invalid_request_error"
case http.StatusUnauthorized:
return "authentication_error"
case http.StatusNotFound:
return "not_found_error"
default:
return "server_error"
}
}

90
api/openai_test.go Normal file
View File

@ -0,0 +1,90 @@
package api
import (
"bytes"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"go-whisper-api/config"
)
func TestRegistryResolve(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "ggml-large-v3-turbo.bin"), []byte("y"), 0o644); err != nil {
t.Fatal(err)
}
reg := NewRegistry(dir)
id, err := reg.Resolve("whisper-1", "ggml-large-v3-turbo")
if err != nil || id != "ggml-large-v3-turbo" {
t.Fatalf("whisper-1: id=%q err=%v", id, err)
}
id, err = reg.Resolve("whisper-large-v3-turbo", "")
if err != nil || id != "large-v3-turbo" {
t.Fatalf("whisper-large-v3-turbo: id=%q err=%v", id, err)
}
id, err = reg.Resolve("ggml-small", "")
if err != nil || id != "ggml-small" {
t.Fatalf("ggml-small: id=%q err=%v", id, err)
}
id, err = reg.Resolve("", "ggml-small")
if err != nil || id != "ggml-small" {
t.Fatalf("empty with default: id=%q err=%v", id, err)
}
}
func TestHandleOpenAITranscriptionsMissingFile(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
srv := &Server{
cfg: config.API{ModelsDir: dir, Language: "ru"},
models: NewRegistry(dir),
}
body := &bytes.Buffer{}
w := multipart.NewWriter(body)
_ = w.WriteField("model", "ggml-small")
w.Close()
req := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", body)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
srv.handleOpenAITranscriptions(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "file") {
t.Fatalf("expected file error, got %s", rec.Body.String())
}
}
func TestHandleOpenAIModels(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
srv := &Server{
cfg: config.API{ModelsDir: dir, Language: "ru"},
models: NewRegistry(dir),
}
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
rec := httptest.NewRecorder()
srv.handleOpenAIModels(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "ggml-small") {
t.Fatalf("expected model list, got %s", rec.Body.String())
}
}

98
api/queue_worker.go Normal file
View File

@ -0,0 +1,98 @@
package api
import (
"context"
"time"
"go-whisper-api/whisper"
"github.com/rs/zerolog/log"
)
func (s *Server) StartWorker(ctx context.Context) {
go s.queueWorker(ctx)
}
func (s *Server) queueWorker(ctx context.Context) {
const idlePoll = 2 * time.Second
for {
if ctx.Err() != nil {
return
}
id, ok, err := s.cache.NextWaiting()
if err != nil {
log.Error().Err(err).Msg("cache queue scan")
if !sleepOrWake(ctx, s.queueWake, time.Second) {
return
}
continue
}
if ok {
s.processCacheTask(ctx, id)
continue
}
if !sleepOrWake(ctx, s.queueWake, idlePoll) {
return
}
}
}
func sleepOrWake(ctx context.Context, wake <-chan struct{}, d time.Duration) bool {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-wake:
return true
case <-timer.C:
return true
}
}
func (s *Server) processCacheTask(ctx context.Context, id string) {
params, _, err := s.cache.LoadParams(id)
if err != nil {
return
}
modelPath, err := s.models.Path(params.Model)
if err != nil {
s.completeCacheTask(id, whisper.TranscriptResult{}, err.Error())
return
}
if err := s.cache.SetStatus(id, statusProcessing, nil); err != nil {
log.Error().Err(err).Str("task", id).Msg("set processing")
return
}
audioPath, ok := s.cache.AudioPath(id)
if !ok {
s.completeCacheTask(id, whisper.TranscriptResult{}, "audio file missing")
return
}
stt := sttOptions{
language: params.Language,
punctuate: params.Punctuation,
speakers: params.Speakers,
numClusters: params.NumClusters,
}
if stt.language == "" {
stt.language = s.cfg.Language
}
result, err := s.transcribe(ctx, modelPath, audioPath, stt)
if err != nil {
s.completeCacheTask(id, result, err.Error())
log.Error().Err(err).Str("task", id).Msg("async transcribe")
return
}
s.completeCacheTask(id, result, "")
}
func (s *Server) completeCacheTask(id string, result whisper.TranscriptResult, errMsg string) {
if err := s.cache.FinishWaiting(id, result, errMsg, nil); err != nil {
log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("finish task")
return
}
if err := s.cache.PromoteToReady(id); err != nil {
log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("promote to ready")
}
}

28
api/result.go Normal file
View File

@ -0,0 +1,28 @@
package api
import "go-whisper-api/whisper"
// sprResultReady builds GET /spr/result/{taskID} body for completed tasks (SPR-compatible).
func sprResultReady(params TaskParams) map[string]any {
words := params.Words
if words == nil {
words = []whisper.Word{}
}
return map[string]any{
"model": params.Model,
"text": params.Text,
"words": words,
"toxicity": map[string]float64{
"insult": 0,
"obscenity": 0,
"threat": 0,
"politeness": 0,
},
"emotion": map[string]any{},
"voice_analysis": map[string]any{},
"status": "ready",
"taskID": params.ID,
"created": params.Created,
"processed": params.Processed,
}
}

26
api/result_test.go Normal file
View File

@ -0,0 +1,26 @@
package api
import "testing"
func TestSprResultReady(t *testing.T) {
body := sprResultReady(TaskParams{
ID: "701b3e84-5815-4baf-97a2-933ff820f16d",
Created: "2026-06-03 10:06:35",
Processed: "2026-06-03 10:07:48",
Status: "ready",
Model: "common",
Text: "hello",
})
for _, key := range []string{"model", "text", "words", "toxicity", "emotion", "voice_analysis", "status", "taskID", "created", "processed"} {
if body[key] == nil {
t.Fatalf("missing %q", key)
}
}
if body["status"] != "ready" {
t.Fatalf("status=%v", body["status"])
}
tox, ok := body["toxicity"].(map[string]float64)
if !ok || len(tox) != 4 {
t.Fatalf("toxicity=%v", body["toxicity"])
}
}

640
api/server.go Normal file
View File

@ -0,0 +1,640 @@
package api
import (
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"go-whisper-api/config"
"go-whisper-api/diarization"
"go-whisper-api/punctuation"
"go-whisper-api/transcode"
"go-whisper-api/whisper"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
)
type Server struct {
cfg config.API
punctCfg config.Punctuation
diarCfg config.Diarization
transcode *transcode.Engine
modelPool *whisper.ModelPool
punct punctuation.Restorer
diarizer diarization.Engine
models *Registry
cache *DiskCache
mux *http.ServeMux
queueWake chan struct{}
}
func NewServer(cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) (*Server, error) {
cfg = cfg.WithDefaults()
tc = tc.WithDefaults()
pc = pc.WithDefaults()
dc = dc.WithDefaults()
restorer, err := punctuation.New(pc)
if err != nil {
return nil, err
}
if pc.Active() && !restorer.Active() {
return nil, fmt.Errorf("punctuation is enabled but engine %q is not available", pc.Engine)
}
if cfg.ModelsDir == "" {
cfg.ModelsDir = "./models"
}
if cfg.Addr == "" {
cfg.Addr = ":8080"
}
if cfg.Threads == 0 {
cfg.Threads = uint(runtime.NumCPU())
}
cfg = cfg.WithDefaults()
if cfg.MaxContext == 0 {
cfg.MaxContext = 32
}
if cfg.BeamSize == 0 {
cfg.BeamSize = 5
}
if cfg.EntropyThold == 0 {
cfg.EntropyThold = 2.4
}
if err := os.MkdirAll(cfg.ModelsDir, 0o755); err != nil {
return nil, err
}
cacheDir := cfg.CacheDir
if cacheDir == "" {
cacheDir = "./cache"
}
cache, err := NewDiskCache(cacheDir)
if err != nil {
return nil, err
}
cfg.CacheDir = cache.Root()
if err := cache.RecoverInterrupted(); err != nil {
return nil, err
}
diar, err := diarization.New(dc)
if err != nil {
return nil, err
}
s := &Server{
cfg: cfg,
punctCfg: pc,
diarCfg: dc,
transcode: transcode.NewEngine(tc.FFmpegPath),
modelPool: whisper.NewModelPool(),
punct: restorer,
diarizer: diar,
models: NewRegistry(cfg.ModelsDir),
cache: cache,
mux: http.NewServeMux(),
queueWake: make(chan struct{}, 1),
}
s.routes()
go s.warmModels()
return s, nil
}
func (s *Server) warmModels() {
ids, err := s.models.List()
if err != nil {
return
}
for _, id := range ids {
path, err := s.models.Path(id)
if err != nil {
continue
}
if err := s.modelPool.WithModel(path, func(whisper.Model) error { return nil }); err != nil {
log.Warn().Err(err).Str("model", id).Msg("preload whisper model")
} else {
log.Info().Str("model", id).Msg("whisper model loaded")
}
}
}
func (s *Server) routes() {
s.mux.HandleFunc("/", s.handleSwaggerUI)
s.mux.HandleFunc("/swagger.json", s.handleSwaggerJSON)
s.mux.HandleFunc("/spr/models", s.handleModels)
s.mux.HandleFunc("/spr/hostname", s.handleHostname)
s.mux.HandleFunc("/spr/queue", s.handleQueue)
s.mux.HandleFunc("/spr/stt/", s.handleSTT)
s.mux.HandleFunc("/spr/result/", s.handleResult)
s.mux.HandleFunc("/spr/queue/", s.handleQueueItem)
s.mux.HandleFunc("/spr/audio/", s.handleAudio)
s.mux.HandleFunc("/spr/waveform/", s.handleWaveform)
s.mux.HandleFunc("/spr/delete/", s.handleDeleteModel)
s.mux.HandleFunc("/spr/export/", s.handleExportModel)
s.mux.HandleFunc("/spr/import/", s.handleImportModel)
s.mux.HandleFunc("/v1/audio/transcriptions", s.handleOpenAITranscriptions)
s.mux.HandleFunc("/v1/audio/transcriptions/", s.handleOpenAITranscriptions)
s.mux.HandleFunc("/v1/models", s.handleOpenAIModels)
}
func (s *Server) ListenAndServe() error {
log.Info().
Str("addr", s.cfg.Addr).
Str("models", s.cfg.ModelsDir).
Str("cache", s.cache.Root()).
Msg("starting API server")
return http.ListenAndServe(s.cfg.Addr, s.mux)
}
func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
models, err := s.models.List()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"models": models})
}
func (s *Server) handleHostname(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
host, _ := os.Hostname()
cwd, _ := os.Getwd()
writeJSON(w, http.StatusOK, map[string]any{
"error": 0,
"message": "Success",
"hostname": host,
"version": "go-whisper-api",
"cwd": cwd,
"models": s.cfg.ModelsDir,
"cache": s.cache.Root(),
})
}
func (s *Server) handleQueue(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
list, err := s.cache.List()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, list)
}
func (s *Server) handleQueueItem(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/spr/queue/")
if id == "" || !isValidTaskID(id) {
writeError(w, http.StatusBadRequest, "task id required")
return
}
switch r.Method {
case http.MethodGet:
s.handleQueueGet(w, id)
case http.MethodDelete:
if !s.cache.Delete(id) {
writeAPIError(w, http.StatusNotFound, "TaskNotFound")
return
}
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"})
default:
methodNotAllowed(w)
}
}
func (s *Server) handleQueueGet(w http.ResponseWriter, id string) {
params, phase, err := s.cache.LoadParams(id)
if err != nil {
writeAPIError(w, http.StatusNotFound, "task not found")
return
}
switch params.Status {
case string(statusReady):
if phase == cacheWaiting {
if err := s.cache.PromoteToReady(id); err != nil {
writeAPIError(w, http.StatusInternalServerError, err.Error())
return
}
}
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"})
case string(statusError):
msg := params.Error
if msg == "" {
msg = "transcription failed"
}
writeAPIError(w, http.StatusNotFound, msg)
default:
writeJSON(w, http.StatusOK, map[string]any{
"error": 0,
"message": params.Status,
"status": params.Status,
})
}
}
func (s *Server) handleSTT(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
methodNotAllowed(w)
return
}
modelID := strings.TrimPrefix(r.URL.Path, "/spr/stt/")
if modelID == "" {
writeError(w, http.StatusBadRequest, "model id required")
return
}
modelPath, err := s.models.Path(modelID)
if err != nil {
writeAPIError(w, http.StatusNotFound, err.Error())
return
}
audioPath, cleanup, err := s.saveUploadedWav(r)
if err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
stt, err := s.parseSTTOptions(r)
if err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
if queryAsync(r, s.cfg.DefaultAsync) {
taskID, err := s.enqueueAsync(r, modelID, audioPath, stt)
cleanup()
if err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]string{"taskID": taskID})
return
}
defer cleanup()
result, err := s.transcribe(r.Context(), modelPath, audioPath, stt)
if err != nil {
writeAPIError(w, http.StatusMethodNotAllowed, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"model": modelID,
"text": result.Text,
"words": result.Words,
})
}
func (s *Server) handleResult(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/result/")
if !isValidTaskID(id) {
writeAPIError(w, http.StatusBadRequest, "task id required")
return
}
params, _, err := s.cache.LoadParams(id)
if err != nil {
writeAPIError(w, http.StatusNotFound, "TaskNotFound")
return
}
switch params.Status {
case string(statusWaiting), string(statusProcessing):
writeJSON(w, http.StatusOK, map[string]string{"status": params.Status})
case string(statusError):
msg := params.Error
if msg == "" {
msg = "TaskNotFound"
}
writeAPIError(w, http.StatusNotFound, msg)
case string(statusReady):
writeJSON(w, http.StatusOK, sprResultReady(params))
default:
writeJSON(w, http.StatusOK, map[string]string{"status": params.Status})
}
}
func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/audio/")
path, ok := s.cache.AudioPath(id)
if !ok {
writeError(w, http.StatusNotFound, "task not found")
return
}
http.ServeFile(w, r, path)
}
func (s *Server) handleWaveform(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/waveform/")
if !isValidTaskID(id) {
writeAPIError(w, http.StatusBadRequest, "task id required")
return
}
wf, err := s.cache.Waveform(id)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "waveform": wf})
}
func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/delete/")
if err := s.models.Delete(id); err != nil {
writeAPIError(w, http.StatusNotFound, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (s *Server) handleExportModel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/export/")
f, err := s.models.Open(id)
if err != nil {
writeAPIError(w, http.StatusNotFound, err.Error())
return
}
defer f.Close()
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q.bin", id))
w.Header().Set("Content-Type", "application/octet-stream")
io.Copy(w, f)
}
func (s *Server) handleImportModel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
methodNotAllowed(w)
return
}
id := strings.TrimPrefix(r.URL.Path, "/spr/import/")
if err := r.ParseMultipartForm(32 << 20); err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
file, header, err := r.FormFile("zip-model")
if err != nil {
file, header, err = r.FormFile("model")
}
if err != nil {
writeAPIError(w, http.StatusBadRequest, "model file required")
return
}
defer file.Close()
src := io.Reader(file)
if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") {
writeAPIError(w, http.StatusBadRequest, "zip import is not supported; upload .bin model file as zip-model field")
return
}
if err := s.models.Import(id, src); err != nil {
writeAPIError(w, http.StatusBadRequest, err.Error())
return
}
w.WriteHeader(http.StatusOK)
}
func (s *Server) enqueueAsync(r *http.Request, modelID, audioWavPath string, stt sttOptions) (string, error) {
id := uuid.New().String()
params := TaskParams{
ID: id,
Created: time.Now().Format("2006-01-02 15:04:05"),
Status: string(statusWaiting),
Model: modelID,
Language: stt.language,
Punctuation: stt.punctuate,
Speakers: stt.speakers,
NumClusters: stt.numClusters,
}
if err := s.cache.Enqueue(id, params, audioWavPath); err != nil {
return "", err
}
s.notifyQueue()
log.Info().
Str("task", id).
Str("model", modelID).
Str("cache", s.cache.Root()).
Msg("enqueued async task")
return id, nil
}
func (s *Server) transcribe(ctx context.Context, modelPath, audioPath string, stt sttOptions) (whisper.TranscriptResult, error) {
turns, err := s.runDiarization(ctx, audioPath, stt)
if err != nil {
return whisper.TranscriptResult{}, err
}
vad := s.cfg.VAD
if vad.Enabled {
vad.Model = vad.ResolveModelPath(s.cfg.ModelsDir)
}
cfg := &config.Whisper{
Model: modelPath,
AudioPath: audioPath,
Threads: s.cfg.Threads,
Language: stt.language,
Debug: s.cfg.Debug,
SpeedUp: s.cfg.SpeedUp,
Translate: s.cfg.Translate,
Prompt: s.cfg.Prompt,
MaxContext: s.cfg.MaxContext,
BeamSize: s.cfg.BeamSize,
EntropyThold: s.cfg.EntropyThold,
VAD: vad,
PrintProgress: false,
PrintSegment: false,
}
runOpts := s.whisperRunOpts(stt, turns)
if stt.punctuate && s.punct.Active() {
runOpts.PunctuateRestore = func(text string) (string, error) {
return punctuation.Apply(ctx, s.punct, true, text, stt.language)
}
}
result, err := whisper.TranscribeWithPool(s.modelPool, cfg, runOpts)
if err != nil {
return whisper.TranscriptResult{}, err
}
return applyGarbage(result, s.cfg.GarbagePatterns()), nil
}
func (s *Server) runDiarization(ctx context.Context, audioPath string, stt sttOptions) ([]whisper.Turn, error) {
if !stt.speakers {
return nil, nil
}
samples, err := whisper.LoadPCM16Mono(audioPath)
if err != nil {
return nil, fmt.Errorf("diarization audio: %w", err)
}
return s.diarizer.Process(ctx, samples, stt.numClusters)
}
func saveUploadedRaw(r *http.Request) (path string, cleanup func(), err error) {
return saveUploadedRawFields(r, []string{"audio", "wav", "file"})
}
func saveUploadedRawFields(r *http.Request, fieldNames []string) (path string, cleanup func(), err error) {
if r.MultipartForm == nil {
if err := r.ParseMultipartForm(128 << 20); err != nil {
return "", nil, err
}
}
var (
file multipart.File
header *multipart.FileHeader
found bool
)
for _, name := range fieldNames {
file, header, err = r.FormFile(name)
if err == nil {
found = true
break
}
}
if !found {
return "", nil, fmt.Errorf("audio file required (form field: %s)", strings.Join(fieldNames, ", "))
}
defer file.Close()
dir, err := config.MkdirTemp("go-whisper-api-upload-*")
if err != nil {
return "", nil, err
}
cleanup = func() { os.RemoveAll(dir) }
base := "input"
if header != nil {
if ext := filepath.Ext(header.Filename); ext != "" {
base += ext
}
}
raw := filepath.Join(dir, base)
out, err := os.Create(raw)
if err != nil {
cleanup()
return "", nil, err
}
if _, err := io.Copy(out, file); err != nil {
out.Close()
cleanup()
return "", nil, err
}
out.Close()
return raw, cleanup, nil
}
func (s *Server) saveUploadedWav(r *http.Request) (path string, cleanup func(), err error) {
raw, cleanup, err := saveUploadedRaw(r)
if err != nil {
return "", nil, err
}
dst := filepath.Join(filepath.Dir(raw), "audio.wav")
if err := s.transcode.Transcode(r.Context(), raw, dst, transcode.WhisperOptions()); err != nil {
cleanup()
return "", nil, err
}
return dst, cleanup, nil
}
func queryBoolDefault(r *http.Request, key string, def bool) bool {
v := r.URL.Query().Get(key)
if v == "" {
return def
}
b, err := strconv.ParseBool(v)
if err != nil {
return def
}
return b
}
func queryAsync(r *http.Request, defaultAsync bool) bool {
v := r.URL.Query().Get("async")
if v == "" {
return defaultAsync
}
n, err := strconv.Atoi(v)
if err != nil {
return defaultAsync
}
return n == 1
}
func queryInt(r *http.Request, key string, def int) int {
v := r.URL.Query().Get(key)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
if err != nil {
return def
}
return n
}
func writeJSON(w http.ResponseWriter, code int, v any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, code int, msg string) {
http.Error(w, msg, code)
}
func writeAPIError(w http.ResponseWriter, code int, msg string) {
writeJSON(w, code, map[string]any{"error": 1, "message": msg})
}
func methodNotAllowed(w http.ResponseWriter) {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
func (s *Server) notifyQueue() {
select {
case s.queueWake <- struct{}{}:
default:
}
}
func Run(ctx context.Context, cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) error {
srv, err := NewServer(cfg, tc, pc, dc)
if err != nil {
return err
}
defer srv.modelPool.Close()
srv.StartWorker(ctx)
hs := &http.Server{Addr: cfg.Addr, Handler: srv.mux}
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = hs.Shutdown(shutdownCtx)
}()
if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return err
}
punctuation.Close(srv.punct)
srv.diarizer.Close()
return nil
}

23
api/swagger-ui.html Normal file
View File

@ -0,0 +1,23 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>go-whisper-api</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
<script>
window.onload = function () {
window.ui = SwaggerUIBundle({
url: '/swagger.json',
dom_id: '#swagger-ui',
presets: [SwaggerUIBundle.presets.apis],
docExpansion: 'list',
displayRequestDuration: true,
});
};
</script>
</body>
</html>

34
api/swagger.go Normal file
View File

@ -0,0 +1,34 @@
package api
import (
_ "embed"
"net/http"
)
//go:embed swagger.json
var swaggerSpec []byte
//go:embed swagger-ui.html
var swaggerUI []byte
func (s *Server) handleSwaggerJSON(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(swaggerSpec)
}
func (s *Server) handleSwaggerUI(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
if r.Method != http.MethodGet {
methodNotAllowed(w)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write(swaggerUI)
}

457
api/swagger.json Normal file
View File

@ -0,0 +1,457 @@
{
"swagger": "2.0",
"basePath": "/",
"paths": {
"/spr/audio/{taskID}": {
"parameters": [
{
"name": "taskID",
"in": "path",
"required": true,
"type": "string"
}
],
"get": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "get_audio_stt",
"tags": [
"spr"
]
}
},
"/spr/delete/{id}": {
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"type": "string"
}
],
"delete": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "delete_model_delete",
"tags": [
"spr"
]
}
},
"/spr/export/{id}": {
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"type": "string"
}
],
"get": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "get_model_export",
"tags": [
"spr"
]
}
},
"/spr/hostname": {
"get": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "get_hostname_class",
"tags": [
"spr"
]
}
},
"/spr/import/{id}": {
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"type": "string"
}
],
"post": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "post_model_import",
"parameters": [
{
"name": "zip-model",
"in": "formData",
"type": "file",
"required": true,
"description": "prepared model zip file"
}
],
"consumes": [
"multipart/form-data"
],
"tags": [
"spr"
]
}
},
"/spr/models": {
"get": {
"responses": {
"200": {
"description": "Success",
"schema": {
"$ref": "#/definitions/modelList"
}
}
},
"operationId": "get_model_list",
"tags": [
"spr"
]
}
},
"/spr/queue": {
"get": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "get_queue_stt",
"tags": [
"spr"
]
}
},
"/spr/queue/{taskID}": {
"parameters": [
{
"name": "taskID",
"in": "path",
"required": true,
"type": "string"
}
],
"delete": {
"responses": {
"200": {
"description": "Success",
"schema": {
"type": "object",
"properties": {
"error": { "type": "integer" },
"message": { "type": "string" }
}
}
},
"404": {
"description": "TaskNotFound"
}
},
"operationId": "delete_queue_del_stt",
"tags": [
"spr"
]
}
},
"/spr/result/{taskID}": {
"parameters": [
{
"name": "taskID",
"in": "path",
"required": true,
"type": "string"
}
],
"get": {
"responses": {
"404": {
"description": "Not found",
"schema": {
"$ref": "#/definitions/error"
}
},
"200": {
"description": "Success",
"schema": {
"$ref": "#/definitions/resultTTS"
}
}
},
"operationId": "get_result_stt",
"tags": [
"spr"
]
}
},
"/spr/stt/{id}": {
"post": {
"responses": {
"405": {
"description": "Error",
"schema": {
"$ref": "#/definitions/error"
}
},
"404": {
"description": "Not found",
"schema": {
"$ref": "#/definitions/error"
}
},
"200": {
"description": "Success",
"schema": {
"$ref": "#/definitions/modelTTS"
}
}
},
"operationId": "post_model_test",
"parameters": [
{
"name": "id",
"in": "path",
"required": true,
"type": "string",
"description": "NN Model ID"
},
{
"name": "wav",
"in": "formData",
"type": "file",
"description": "file to recognize"
},
{
"name": "async",
"in": "query",
"type": "integer",
"description": "async mode (default 1: enqueue to cache/waiting; 0: sync text response)",
"default": 1,
"enum": [
0,
1
]
},
{
"name": "speakers",
"in": "query",
"type": "integer",
"description": "find speakers",
"default": 0,
"enum": [
0,
1
]
},
{
"name": "speaker_counter",
"in": "query",
"type": "integer",
"description": "number of speakers. 0 for autodetect. -1 disable speaker detection.",
"default": 0
},
{
"name": "normalization",
"in": "query",
"type": "integer",
"description": "normalize text",
"default": 1,
"enum": [
0,
1
]
},
{
"name": "punctuation",
"in": "query",
"type": "integer",
"description": "punctuate text",
"default": 1,
"enum": [
0,
1
]
},
{
"name": "toxicity",
"in": "query",
"type": "integer",
"description": "toxicity analyzer",
"default": 1,
"enum": [
0,
1
]
},
{
"name": "emotion",
"in": "query",
"type": "integer",
"description": "emotion analyzer",
"default": 0,
"enum": [
0,
1
]
},
{
"name": "voice_analyzer",
"in": "query",
"type": "integer",
"description": "voice analyzer",
"default": 1,
"enum": [
0,
1
]
},
{
"name": "vad",
"in": "query",
"type": "string",
"description": "VAD type",
"default": "webrtc",
"enum": [
"webrtc"
]
},
{
"name": "classifiers",
"in": "query",
"type": "string",
"description": "JSON with classification models to process each sentence"
},
{
"name": "webhook",
"in": "query",
"type": "string",
"description": "webhook url to send stt async result"
}
],
"consumes": [
"multipart/form-data"
],
"tags": [
"spr"
]
}
},
"/spr/waveform/{taskID}": {
"parameters": [
{
"name": "taskID",
"in": "path",
"required": true,
"type": "string"
}
],
"get": {
"responses": {
"200": {
"description": "Success"
}
},
"operationId": "get_audioarray_stt",
"tags": [
"spr"
]
}
}
},
"info": {
"title": "go-whisper-api",
"version": "5.008 release",
"description": "Whisper speech-to-text API (SPR-compatible)"
},
"produces": [
"application/json"
],
"consumes": [
"application/json"
],
"tags": [
{
"name": "spr",
"description": "NN Model operations"
}
],
"definitions": {
"modelList": {
"properties": {
"models": {
"type": "array",
"items": {
"type": "string",
"description": "NN Model ID"
}
}
},
"type": "object"
},
"error": {
"properties": {
"error": {
"type": "integer",
"description": "Error flag"
},
"message": {
"type": "string",
"description": "Error description"
}
},
"type": "object"
},
"modelTTS": {
"required": [
"text"
],
"properties": {
"text": {
"type": "string",
"description": "Recognized text"
}
},
"type": "object"
},
"resultTTS": {
"properties": {
"model": { "type": "string" },
"text": { "type": "string" },
"words": { "type": "array" },
"toxicity": { "type": "object" },
"emotion": { "type": "object" },
"voice_analysis": { "type": "object" },
"status": { "type": "string" },
"taskID": { "type": "string" },
"created": { "type": "string" },
"processed": { "type": "string" }
},
"type": "object"
}
},
"responses": {
"ParseError": {
"description": "When a mask can't be parsed"
},
"MaskError": {
"description": "When any error occurs on mask"
}
}
}

10
api/tasks.go Normal file
View File

@ -0,0 +1,10 @@
package api
type taskStatus string
const (
statusWaiting taskStatus = "waiting"
statusProcessing taskStatus = "processing"
statusReady taskStatus = "ready"
statusError taskStatus = "error"
)

77
api/transcribe_opts.go Normal file
View File

@ -0,0 +1,77 @@
package api
import (
"fmt"
"net/http"
"strings"
"go-whisper-api/config"
"go-whisper-api/whisper"
)
type sttOptions struct {
language string
punctuate bool
speakers bool
numClusters int
}
func (s *Server) parseSTTOptions(r *http.Request) (sttOptions, error) {
opts := sttOptions{
language: resolveLanguage(r, s.cfg.Language),
punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation),
}
sp, clusters, err := querySpeakers(r, s.cfg.DefaultSpeakers, s.diarCfg, s.diarizer.Active())
if err != nil {
return opts, err
}
opts.speakers = sp
opts.numClusters = clusters
return opts, nil
}
func resolveLanguage(r *http.Request, defaultLang string) string {
if v := strings.TrimSpace(r.URL.Query().Get("language")); v != "" {
return v
}
return strings.TrimSpace(defaultLang)
}
func querySpeakers(r *http.Request, defaultOn bool, dc config.Diarization, diarizerActive bool) (enabled bool, numClusters int, err error) {
counter := queryInt(r, "speaker_counter", -999)
if counter == -1 {
return false, 0, nil
}
speakersQ := r.URL.Query().Get("speakers")
enabled = defaultOn
if speakersQ != "" {
enabled = queryInt(r, "speakers", 0) == 1
}
if !enabled {
return false, 0, nil
}
if !dc.Active() {
return false, 0, fmt.Errorf("speaker diarization is disabled in config (diarization.enabled: true)")
}
if !diarizerActive {
return false, 0, fmt.Errorf("speaker diarization requires server built with -tags sherpa (make build-sherpa) and models (make download-diarization-models)")
}
if counter > 0 {
numClusters = counter
} else if dc.NumClusters > 0 {
numClusters = dc.NumClusters
}
return true, numClusters, nil
}
func (s *Server) whisperRunOpts(stt sttOptions, turns []whisper.Turn) whisper.RunOptions {
t := s.cfg.Transcript.WithDefaults()
return whisper.RunOptions{
Turns: turns,
Format: whisper.FormatOptions{
PauseGap: t.PauseGapDuration(),
SpeakerLabel: t.SpeakerLabel,
UseSpeakers: stt.speakers && len(turns) > 0,
},
}
}

48
api/waveform.go Normal file
View File

@ -0,0 +1,48 @@
package api
import (
"math"
"os"
"github.com/go-audio/wav"
)
func waveformFromWav(path string, buckets int) ([]float64, error) {
if buckets <= 0 {
buckets = 256
}
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
dec := wav.NewDecoder(f)
buf, err := dec.FullPCMBuffer()
if err != nil {
return nil, err
}
samples := buf.AsFloat32Buffer().Data
if len(samples) == 0 {
return make([]float64, buckets), nil
}
chunk := len(samples) / buckets
if chunk < 1 {
chunk = 1
}
out := make([]float64, 0, buckets)
for i := 0; i < len(samples) && len(out) < buckets; i += chunk {
end := i + chunk
if end > len(samples) {
end = len(samples)
}
peak := 0.0
for _, s := range samples[i:end] {
v := math.Abs(float64(s))
if v > peak {
peak = v
}
}
out = append(out, math.Round(peak*1000)/1000)
}
return out, nil
}

55
config.yaml.example Normal file
View File

@ -0,0 +1,55 @@
api:
addr: "0.0.0.0:6183"
models_dir: "./models"
cache_dir: "./cache"
# Open WebUI / OpenAI STT: model id when client sends whisper-1 (e.g. ggml-large-v3-turbo)
default_model: ggml-large-v3-turbo
threads: 16
language: ru
transcript:
pause_gap_sec: 1.5
speaker_label: "Спикер"
default_speakers: false
debug: false
speedup: false
translate: false
prompt: ""
max_context: 32
beam_size: 5
entropy_thold: 2.4
vad:
enabled: false
model: vad/vad.bin
threshold: 0.5
min_speech_duration_ms: 250
min_silence_duration_ms: 100
speech_pad_ms: 30
samples_overlap: 0.1
default_punctuation: true
default_async: true
garbage:
- "*выбая*"
# transcode: pure Go (wav, mp3, flac, ogg, m4a, mp4, aac) — no ffmpeg required
transcode: {}
diarization:
enabled: false
model_dir: ./models/diarization
segmentation_model: pyannote-segmentation-3-0/model.onnx
embedding_model: 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
num_threads: 2
num_clusters: 0
clustering_threshold: 0.5
punctuation:
enabled: true
default_on: true
# off | heuristic | xlm | sherpa | sherpa-online | http
engine: xlm
model_dir: ./models/punctuation/xlm-roberta
model_file: model.onnx
sp_model: sp.model
config_file: config.yaml
apply_sbd: true
num_threads: 2

42
config/api.go Normal file
View File

@ -0,0 +1,42 @@
package config
import "strings"
type API struct {
Addr string `yaml:"addr"`
ModelsDir string `yaml:"models_dir"`
CacheDir string `yaml:"cache_dir"`
// DefaultModel: whisper model id for OpenAI /v1/audio/transcriptions (maps whisper-1).
DefaultModel string `yaml:"default_model"`
Language string `yaml:"language"`
Transcript Transcript `yaml:"transcript"`
DefaultSpeakers bool `yaml:"default_speakers"`
Prompt string `yaml:"prompt"`
Threads uint `yaml:"threads"`
MaxContext uint `yaml:"max_context"`
BeamSize uint `yaml:"beam_size"`
EntropyThold float64 `yaml:"entropy_thold"`
VAD VAD `yaml:"vad"`
Debug bool `yaml:"debug"`
SpeedUp bool `yaml:"speedup"`
Translate bool `yaml:"translate"`
DefaultPunctuation bool `yaml:"default_punctuation"`
// DefaultAsync: STT via API enqueues to cache/waiting and returns taskID (use ?async=0 for sync).
DefaultAsync bool `yaml:"default_async"`
// Garbage: artifact substrings removed from transcript text and words (default includes *выбая*).
Garbage []string `yaml:"garbage"`
}
func (a API) WithDefaults() API {
a.VAD = a.VAD.WithDefaults()
a.Transcript = a.Transcript.WithDefaults()
if strings.TrimSpace(a.Language) == "" {
a.Language = "ru"
}
return a
}
// GarbagePatterns returns garbage filter list (never empty unless explicitly set to [] in YAML).
func (a API) GarbagePatterns() []string {
return a.garbagePatterns()
}

72
config/diarization.go Normal file
View File

@ -0,0 +1,72 @@
package config
import (
"os"
"path/filepath"
)
type Diarization struct {
Enabled bool `yaml:"enabled"`
ModelDir string `yaml:"model_dir"`
SegmentationModel string `yaml:"segmentation_model"`
EmbeddingModel string `yaml:"embedding_model"`
NumThreads int `yaml:"num_threads"`
NumClusters int `yaml:"num_clusters"`
ClusteringThreshold float32 `yaml:"clustering_threshold"`
MinDurationOn float32 `yaml:"min_duration_on"`
MinDurationOff float32 `yaml:"min_duration_off"`
}
func (d Diarization) WithDefaults() Diarization {
if d.ModelDir == "" {
d.ModelDir = "./models/diarization"
}
if d.SegmentationModel == "" {
d.SegmentationModel = "pyannote-segmentation-3-0/model.onnx"
}
if d.EmbeddingModel == "" {
d.EmbeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
}
if d.NumThreads <= 0 {
d.NumThreads = 2
}
if d.ClusteringThreshold <= 0 {
d.ClusteringThreshold = 0.5
}
if d.MinDurationOn <= 0 {
d.MinDurationOn = 0.3
}
if d.MinDurationOff <= 0 {
d.MinDurationOff = 0.5
}
return d
}
func (d Diarization) SegmentationPath() string {
return resolveModelPath(d.ModelDir, d.SegmentationModel)
}
func (d Diarization) EmbeddingPath() string {
return resolveModelPath(d.ModelDir, d.EmbeddingModel)
}
func (d Diarization) ModelsPresent() bool {
if _, err := os.Stat(d.SegmentationPath()); err != nil {
return false
}
if _, err := os.Stat(d.EmbeddingPath()); err != nil {
return false
}
return true
}
func resolveModelPath(dir, name string) string {
if filepath.IsAbs(name) {
return name
}
return filepath.Join(dir, name)
}
func (d Diarization) Active() bool {
return d.Enabled
}

53
config/file.go Normal file
View File

@ -0,0 +1,53 @@
package config
import (
"fmt"
"os"
"runtime"
"gopkg.in/yaml.v3"
)
type File struct {
API API `yaml:"api"`
Transcode Transcode `yaml:"transcode"`
Punctuation Punctuation `yaml:"punctuation"`
Diarization Diarization `yaml:"diarization"`
}
func DefaultFile() File {
return File{
API: API{
Addr: ":8080",
ModelsDir: "./models",
CacheDir: "./cache",
Threads: uint(runtime.NumCPU()),
Language: "ru",
Transcript: Transcript{}.WithDefaults(),
MaxContext: 32,
BeamSize: 5,
EntropyThold: 2.4,
DefaultAsync: true,
Garbage: DefaultGarbage(),
},
Diarization: Diarization{}.WithDefaults(),
Transcode: Transcode{}.WithDefaults(),
Punctuation: Punctuation{Enabled: false, Engine: "off"}.WithDefaults(),
}
}
func LoadFile(path string) (File, error) {
data, err := os.ReadFile(path)
if err != nil {
return File{}, err
}
cfg := DefaultFile()
if err := yaml.Unmarshal(data, &cfg); err != nil {
return File{}, fmt.Errorf("parse config %s: %w", path, err)
}
return cfg, nil
}
func (f File) APIConfig() API {
return f.API
}

35
config/file_test.go Normal file
View File

@ -0,0 +1,35 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
err := os.WriteFile(path, []byte(`
api:
addr: ":9090"
models_dir: "/data/models"
language: ru
`), 0o644)
if err != nil {
t.Fatal(err)
}
cfg, err := LoadFile(path)
if err != nil {
t.Fatal(err)
}
if cfg.API.Addr != ":9090" {
t.Fatalf("addr: got %q", cfg.API.Addr)
}
if cfg.API.ModelsDir != "/data/models" {
t.Fatalf("models_dir: got %q", cfg.API.ModelsDir)
}
if cfg.API.Language != "ru" {
t.Fatalf("language: got %q", cfg.API.Language)
}
}

13
config/garbage.go Normal file
View File

@ -0,0 +1,13 @@
package config
// DefaultGarbage returns STT artifact tokens filtered from API/CLI transcript output.
func DefaultGarbage() []string {
return []string{"*выбая*"}
}
func (a API) garbagePatterns() []string {
if a.Garbage == nil {
return DefaultGarbage()
}
return a.Garbage
}

23
config/garbage_test.go Normal file
View File

@ -0,0 +1,23 @@
package config
import "testing"
func TestDefaultGarbage(t *testing.T) {
if len(DefaultGarbage()) == 0 || DefaultGarbage()[0] != "*выбая*" {
t.Fatalf("got %v", DefaultGarbage())
}
}
func TestAPI_GarbagePatterns_default(t *testing.T) {
p := (API{}).GarbagePatterns()
if len(p) != 1 || p[0] != "*выбая*" {
t.Fatalf("got %v", p)
}
}
func TestAPI_GarbagePatterns_explicitEmpty(t *testing.T) {
p := (API{Garbage: []string{}}).GarbagePatterns()
if len(p) != 0 {
t.Fatalf("got %v", p)
}
}

136
config/merge.go Normal file
View File

@ -0,0 +1,136 @@
package config
import (
"os"
"github.com/urfave/cli/v2"
)
func LoadResolved(path string) (File, error) {
if path == "" {
if _, err := os.Stat("config.yaml"); err == nil {
path = "config.yaml"
} else {
return DefaultFile(), nil
}
}
return LoadFile(path)
}
func mergeVAD(c *cli.Context, v VAD) VAD {
if c.IsSet("vad") {
v.Enabled = c.Bool("vad")
}
if c.IsSet("vad-model") {
v.Model = c.String("vad-model")
}
if c.IsSet("vad-threshold") {
v.Threshold = c.Float64("vad-threshold")
}
if c.IsSet("vad-min-speech-ms") {
v.MinSpeechMs = c.Int("vad-min-speech-ms")
}
if c.IsSet("vad-min-silence-ms") {
v.MinSilenceMs = c.Int("vad-min-silence-ms")
}
if c.IsSet("vad-max-speech-sec") {
v.MaxSpeechSec = c.Float64("vad-max-speech-sec")
}
if c.IsSet("vad-speech-pad-ms") {
v.SpeechPadMs = c.Int("vad-speech-pad-ms")
}
if c.IsSet("vad-samples-overlap") {
v.SamplesOverlap = c.Float64("vad-samples-overlap")
}
return v.WithDefaults()
}
func mergeAPI(c *cli.Context, a API) API {
if c.IsSet("addr") {
a.Addr = c.String("addr")
}
if c.IsSet("models-dir") {
a.ModelsDir = c.String("models-dir")
}
if c.IsSet("cache-dir") {
a.CacheDir = c.String("cache-dir")
}
if c.IsSet("threads") {
a.Threads = c.Uint("threads")
}
if c.IsSet("language") {
a.Language = c.String("language")
}
if c.IsSet("debug") {
a.Debug = c.Bool("debug")
}
if c.IsSet("speedup") {
a.SpeedUp = c.Bool("speedup")
}
if c.IsSet("translate") {
a.Translate = c.Bool("translate")
}
if c.IsSet("prompt") {
a.Prompt = c.String("prompt")
}
if c.IsSet("max-context") {
a.MaxContext = c.Uint("max-context")
}
if c.IsSet("beam-size") {
a.BeamSize = c.Uint("beam-size")
}
if c.IsSet("entropy-thold") {
a.EntropyThold = c.Float64("entropy-thold")
}
a.VAD = mergeVAD(c, a.VAD)
if c.IsSet("default-punctuation") {
a.DefaultPunctuation = c.Bool("default-punctuation")
}
return a
}
func APIFromCLI(c *cli.Context) (API, error) {
file, err := LoadResolved(c.String("config"))
if err != nil {
return API{}, err
}
return mergeAPI(c, file.API), nil
}
func TranscodeFromCLI(c *cli.Context) (Transcode, error) {
file, err := LoadResolved(c.String("config"))
if err != nil {
return Transcode{}, err
}
return file.Transcode.WithDefaults(), nil
}
func mergePunctuation(c *cli.Context, p Punctuation) Punctuation {
if c.IsSet("punctuation-enabled") {
p.Enabled = c.Bool("punctuation-enabled")
}
if c.IsSet("punctuation-engine") {
p.Engine = c.String("punctuation-engine")
}
if c.IsSet("punctuation-default-on") {
p.DefaultOn = c.Bool("punctuation-default-on")
}
return p.WithDefaults()
}
func PunctuationFromCLI(c *cli.Context) (Punctuation, error) {
file, err := LoadResolved(c.String("config"))
if err != nil {
return Punctuation{}, err
}
return mergePunctuation(c, file.Punctuation), nil
}
func DiarizationFromCLI(c *cli.Context) (Diarization, error) {
file, err := LoadResolved(c.String("config"))
if err != nil {
return Diarization{}, err
}
return file.Diarization.WithDefaults(), nil
}

28
config/merge_test.go Normal file
View File

@ -0,0 +1,28 @@
package config
import (
"os"
"strings"
"testing"
)
func TestDefaultFile_apiDefaults(t *testing.T) {
f := DefaultFile()
if f.API.Addr == "" {
t.Fatal("expected API listen addr")
}
if f.API.ModelsDir == "" {
t.Fatal("expected models_dir")
}
}
func TestMkdirTemp_usesTmpRoot(t *testing.T) {
dir, err := MkdirTemp("go-whisper-api-test-*")
if err != nil {
t.Fatal(err)
}
defer func() { _ = os.RemoveAll(dir) }()
if !strings.HasPrefix(dir, TempRoot+"/") {
t.Fatalf("temp dir should be under %s, got %s", TempRoot, dir)
}
}

125
config/punctuation.go Normal file
View File

@ -0,0 +1,125 @@
package config
import (
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
)
func (p Punctuation) XLMJoinSentences() bool {
return p.ApplySBD
}
type Punctuation struct {
Command []string `yaml:"command"`
NumThreads int `yaml:"num_threads"`
TimeoutSec int `yaml:"timeout_sec"`
Engine string `yaml:"engine"`
ModelDir string `yaml:"model_dir"`
ModelFile string `yaml:"model_file"`
SPModel string `yaml:"sp_model"`
ConfigFile string `yaml:"config_file"`
BpeVocab string `yaml:"bpe_vocab"`
HTTPURL string `yaml:"http_url"`
Enabled bool `yaml:"enabled"`
DefaultOn bool `yaml:"default_on"`
ApplySBD bool `yaml:"apply_sbd"`
}
func (p Punctuation) WithDefaults() Punctuation {
if p.Engine == "" {
p.Engine = "heuristic"
}
engine := strings.ToLower(strings.TrimSpace(p.Engine))
if p.ModelDir == "" {
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
p.ModelDir = "./models/punctuation/xlm-roberta"
} else {
p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8"
}
}
if p.ModelFile == "" {
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
p.ModelFile = "model.onnx"
} else {
p.ModelFile = "model.int8.onnx"
}
}
if p.NumThreads <= 0 {
p.NumThreads = 2
}
if p.TimeoutSec <= 0 {
p.TimeoutSec = 120
}
return p
}
func (p Punctuation) Timeout() time.Duration {
p = p.WithDefaults()
return time.Duration(p.TimeoutSec) * time.Second
}
func (p Punctuation) Active() bool {
p = p.WithDefaults()
return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off")
}
func (p Punctuation) ModelPath() string {
p = p.WithDefaults()
return filepath.Join(p.ModelDir, p.ModelFile)
}
func (p Punctuation) SPModelPath() string {
p = p.WithDefaults()
if p.SPModel == "" {
return filepath.Join(p.ModelDir, "sp.model")
}
return filepath.Join(p.ModelDir, p.SPModel)
}
func (p Punctuation) XLMConfigPath() string {
p = p.WithDefaults()
if p.ConfigFile == "" {
return filepath.Join(p.ModelDir, "config.yaml")
}
return filepath.Join(p.ModelDir, p.ConfigFile)
}
func (p Punctuation) BpeVocabPath() string {
p = p.WithDefaults()
if p.BpeVocab == "" {
return filepath.Join(p.ModelDir, "bpe.vocab")
}
return filepath.Join(p.ModelDir, p.BpeVocab)
}
func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool {
if !p.Active() {
return false
}
q := strings.TrimSpace(r.URL.Query().Get("punctuation"))
if q != "" {
return parsePunctuationQuery(q, true)
}
return apiDefault || p.Enabled
}
func parsePunctuationQuery(raw string, def bool) bool {
raw = strings.TrimSpace(raw)
if raw == "" {
return def
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
}
b, err := strconv.ParseBool(raw)
if err != nil {
return def
}
return b
}

View File

@ -0,0 +1,58 @@
package config
import (
"net/http/httptest"
"testing"
)
func TestPunctuation_Active(t *testing.T) {
if (Punctuation{Enabled: false, Engine: "heuristic"}).Active() {
t.Fatal("disabled should be inactive")
}
if (Punctuation{Enabled: true, Engine: "off"}).Active() {
t.Fatal("engine off should be inactive")
}
if !(Punctuation{Enabled: true, Engine: "heuristic"}).Active() {
t.Fatal("enabled heuristic should be active")
}
}
func TestPunctuation_ShouldApplyAPI(t *testing.T) {
p := Punctuation{Enabled: true, Engine: "heuristic"}
req := httptest.NewRequest("GET", "/?punctuation=1", nil)
if !p.ShouldApplyAPI(req, false) {
t.Fatal("query 1 should enable")
}
req = httptest.NewRequest("GET", "/?punctuation=0", nil)
if p.ShouldApplyAPI(req, true) {
t.Fatal("query 0 should disable")
}
req = httptest.NewRequest("GET", "/", nil)
if !p.ShouldApplyAPI(req, false) {
t.Fatal("enabled in config => apply when query omitted")
}
if !p.ShouldApplyAPI(req, true) {
t.Fatal("api default true")
}
off := Punctuation{Enabled: false, Engine: "heuristic"}
if off.ShouldApplyAPI(req, true) {
t.Fatal("master disabled must never apply")
}
}
func TestParsePunctuationQuery(t *testing.T) {
if !parsePunctuationQuery("1", false) {
t.Fatal("1 => true")
}
if parsePunctuationQuery("0", true) {
t.Fatal("0 => false")
}
if !parsePunctuationQuery("", true) {
t.Fatal("empty => default")
}
}

9
config/tmp.go Normal file
View File

@ -0,0 +1,9 @@
package config
import "os"
const TempRoot = "/tmp"
func MkdirTemp(prefix string) (string, error) {
return os.MkdirTemp(TempRoot, prefix)
}

11
config/transcode.go Normal file
View File

@ -0,0 +1,11 @@
package config
// Transcode holds audio normalization settings (pure Go decoders; no external ffmpeg).
type Transcode struct {
// FFmpegPath is deprecated and ignored; kept for backward-compatible YAML.
FFmpegPath string `yaml:"ffmpeg_path,omitempty"`
}
func (t Transcode) WithDefaults() Transcode {
return t
}

26
config/transcript.go Normal file
View File

@ -0,0 +1,26 @@
package config
import (
"strings"
"time"
)
// Transcript controls how STT segments are joined into one text field (with embedded newlines).
type Transcript struct {
PauseGapSec float64 `yaml:"pause_gap_sec"`
SpeakerLabel string `yaml:"speaker_label"`
}
func (t Transcript) WithDefaults() Transcript {
if t.PauseGapSec <= 0 {
t.PauseGapSec = 1.5
}
if strings.TrimSpace(t.SpeakerLabel) == "" {
t.SpeakerLabel = "Спикер"
}
return t
}
func (t Transcript) PauseGapDuration() time.Duration {
return time.Duration(t.PauseGapSec * float64(time.Second))
}

88
config/vad.go Normal file
View File

@ -0,0 +1,88 @@
package config
import (
"fmt"
"os"
"path/filepath"
)
type VAD struct {
Enabled bool `yaml:"enabled"`
Model string `yaml:"model"`
Threshold float64 `yaml:"threshold"`
MinSpeechMs int `yaml:"min_speech_duration_ms"`
MinSilenceMs int `yaml:"min_silence_duration_ms"`
MaxSpeechSec float64 `yaml:"max_speech_duration_s"`
SpeechPadMs int `yaml:"speech_pad_ms"`
SamplesOverlap float64 `yaml:"samples_overlap"`
}
func DefaultVAD() VAD {
return VAD{
Threshold: 0.5,
MinSpeechMs: 250,
MinSilenceMs: 100,
MaxSpeechSec: 0,
SpeechPadMs: 30,
SamplesOverlap: 0.1,
}
}
func (v VAD) WithDefaults() VAD {
d := DefaultVAD()
if v.Threshold <= 0 {
v.Threshold = d.Threshold
}
if v.MinSpeechMs <= 0 {
v.MinSpeechMs = d.MinSpeechMs
}
if v.MinSilenceMs <= 0 {
v.MinSilenceMs = d.MinSilenceMs
}
if v.SpeechPadMs <= 0 {
v.SpeechPadMs = d.SpeechPadMs
}
if v.SamplesOverlap <= 0 {
v.SamplesOverlap = d.SamplesOverlap
}
return v
}
func (v VAD) ResolveModelPath(modelsDir string) string {
if v.Model == "" {
return ""
}
if filepath.IsAbs(v.Model) {
return v.Model
}
if modelsDir == "" {
return v.Model
}
direct := filepath.Join(modelsDir, v.Model)
if _, err := os.Stat(direct); err == nil {
return direct
}
// VAD weights often live under models_dir/vad/ (not listed in /spr/models).
base := filepath.Base(v.Model)
inVAD := filepath.Join(modelsDir, "vad", base)
if _, err := os.Stat(inVAD); err == nil {
return inVAD
}
return direct
}
func (v VAD) Validate() error {
if !v.Enabled {
return nil
}
if v.Model == "" {
return fmt.Errorf("vad.model is required when vad.enabled is true")
}
if _, err := os.Stat(v.Model); err != nil {
return fmt.Errorf("vad model %q: %w", v.Model, err)
}
if v.Threshold < 0 || v.Threshold > 1 {
return fmt.Errorf("vad.threshold must be between 0 and 1")
}
return nil
}

65
config/vad_test.go Normal file
View File

@ -0,0 +1,65 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestVAD_Validate_disabled(t *testing.T) {
if err := (VAD{}).Validate(); err != nil {
t.Fatal(err)
}
}
func TestVAD_Validate_requiresModel(t *testing.T) {
err := (VAD{Enabled: true}).Validate()
if err == nil {
t.Fatal("expected error")
}
}
func TestVAD_Validate_modelExists(t *testing.T) {
dir := t.TempDir()
model := filepath.Join(dir, "ggml-silero.bin")
if err := os.WriteFile(model, []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
v := VAD{Enabled: true, Model: model, Threshold: 0.5}
if err := v.Validate(); err != nil {
t.Fatal(err)
}
}
func TestVAD_ResolveModelPath(t *testing.T) {
v := VAD{Model: "ggml-silero-v6.2.0.bin"}
got := v.ResolveModelPath("/data/models")
want := filepath.Join("/data/models", "ggml-silero-v6.2.0.bin")
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestVAD_ResolveModelPath_vadSubdir(t *testing.T) {
dir := t.TempDir()
vadDir := filepath.Join(dir, "vad")
if err := os.MkdirAll(vadDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(vadDir, "vad.bin"), []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
v := VAD{Model: "vad.bin"}
got := v.ResolveModelPath(dir)
want := filepath.Join(dir, "vad", "vad.bin")
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestVAD_WithDefaults(t *testing.T) {
v := VAD{Enabled: true}.WithDefaults()
if v.Threshold != 0.5 || v.MinSpeechMs != 250 || v.MinSilenceMs != 100 {
t.Fatalf("unexpected defaults: %+v", v)
}
}

43
config/whisper.go Normal file
View File

@ -0,0 +1,43 @@
package config
import "fmt"
type Whisper struct {
OutputFormat []string `yaml:"output_format"`
Model string `yaml:"model"`
AudioPath string `yaml:"audio_path"`
Language string `yaml:"language"`
Prompt string `yaml:"prompt"`
OutputFolder string `yaml:"output_folder"`
OutputFilename string `yaml:"output_filename"`
Threads uint `yaml:"threads"`
MaxContext uint `yaml:"max_context"`
BeamSize uint `yaml:"beam_size"`
EntropyThold float64 `yaml:"entropy_thold"`
VAD VAD `yaml:"vad"`
Debug bool `yaml:"debug"`
SpeedUp bool `yaml:"speedup"`
Translate bool `yaml:"translate"`
PrintProgress bool `yaml:"print_progress"`
PrintSegment bool `yaml:"print_segment"`
}
func (c *Whisper) ValidateModel() error {
if c.Model == "" {
return fmt.Errorf("model is required")
}
return nil
}
func (c *Whisper) Validate() error {
if err := c.ValidateModel(); err != nil {
return err
}
if c.AudioPath == "" {
return fmt.Errorf("audio path is required")
}
if err := c.VAD.WithDefaults().Validate(); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,81 @@
# Metadata for Salama1429/xlm-roberta_punctuation_fullstop_truecase (ONNX punctuation).
# Install into the model directory:
# cp config/xlm-roberta-model.yaml models/punctuation/xlm-roberta/config.yaml
# or: make install-xlm-punctuation-config
languages:
- af
- am
- ar
- bg
- bn
- de
- el
- en
- es
- et
- fa
- fi
- fr
- gu
- hi
- hr
- hu
- id
- is
- it
- ja
- kk
- kn
- ko
- ky
- lt
- lv
- mk
- ml
- mr
- nl
- or
- pa
- pl
- ps
- pt
- ro
- ru
- rw
- so
- sr
- sw
- ta
- te
- tr
- uk
- zh
max_length: 256
pre_labels:
- "<NULL>"
- "¿"
post_labels:
- "<NULL>"
- "<ACRONYM>"
- "."
- ","
- "?"
- ""
- ""
- "。"
- "、"
- "・"
- "।"
- "؟"
- "،"
- ";"
- "።"
- "፣"
- "፧"
null_token: "<NULL>"
acronym_token: "<ACRONYM>"

View File

@ -0,0 +1,19 @@
package diarization
import (
"context"
"go-whisper-api/config"
"go-whisper-api/whisper"
)
// Engine runs offline speaker diarization when built with -tags sherpa.
type Engine interface {
Active() bool
Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error)
Close()
}
func New(cfg config.Diarization) (Engine, error) {
return newEngine(cfg)
}

107
diarization/sherpa.go Normal file
View File

@ -0,0 +1,107 @@
//go:build sherpa
package diarization
import (
"context"
"fmt"
"os"
"sync"
"go-whisper-api/config"
"go-whisper-api/whisper"
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
)
type sherpaEngine struct {
cfg config.Diarization
sd *sherpa.OfflineSpeakerDiarization
mu sync.Mutex
}
func newEngine(cfg config.Diarization) (Engine, error) {
cfg = cfg.WithDefaults()
if !cfg.Active() {
return &noopEngine{}, nil
}
if !cfg.ModelsPresent() {
return nil, fmt.Errorf("diarization models missing (run: make download-diarization-models)")
}
conf := &sherpa.OfflineSpeakerDiarizationConfig{
Segmentation: sherpa.OfflineSpeakerSegmentationModelConfig{
Pyannote: sherpa.OfflineSpeakerSegmentationPyannoteModelConfig{
Model: cfg.SegmentationPath(),
},
NumThreads: cfg.NumThreads,
Debug: 0,
Provider: "cpu",
},
Embedding: sherpa.SpeakerEmbeddingExtractorConfig{
Model: cfg.EmbeddingPath(),
NumThreads: cfg.NumThreads,
Debug: 0,
Provider: "cpu",
},
Clustering: sherpa.FastClusteringConfig{
NumClusters: cfg.NumClusters,
Threshold: cfg.ClusteringThreshold,
},
MinDurationOn: cfg.MinDurationOn,
MinDurationOff: cfg.MinDurationOff,
}
sd := sherpa.NewOfflineSpeakerDiarization(conf)
if sd == nil {
return nil, fmt.Errorf("failed to create sherpa speaker diarization")
}
return &sherpaEngine{cfg: cfg, sd: sd}, nil
}
type noopEngine struct{}
func (noopEngine) Active() bool { return false }
func (noopEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { return nil, nil }
func (noopEngine) Close() {}
func (e *sherpaEngine) Active() bool {
return e.sd != nil
}
func (e *sherpaEngine) Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) {
_ = ctx
if len(samples) == 0 {
return nil, fmt.Errorf("empty audio for diarization")
}
e.mu.Lock()
defer e.mu.Unlock()
if _, err := os.Stat(e.cfg.SegmentationPath()); err != nil {
return nil, err
}
clusters := numClusters
if clusters <= 0 {
clusters = e.cfg.NumClusters
}
e.sd.SetConfig(&sherpa.OfflineSpeakerDiarizationConfig{
Clustering: sherpa.FastClusteringConfig{
NumClusters: clusters,
Threshold: e.cfg.ClusteringThreshold,
},
})
segments := e.sd.Process(samples)
out := make([]whisper.Turn, len(segments))
for i, s := range segments {
out[i] = whisper.Turn{
Start: s.Start,
End: s.End,
Speaker: s.Speaker,
}
}
return out, nil
}
func (e *sherpaEngine) Close() {
if e.sd != nil {
sherpa.DeleteOfflineSpeakerDiarization(e.sd)
e.sd = nil
}
}

29
diarization/stub.go Normal file
View File

@ -0,0 +1,29 @@
//go:build !sherpa
package diarization
import (
"context"
"fmt"
"go-whisper-api/config"
"go-whisper-api/whisper"
)
type stubEngine struct {
cfg config.Diarization
}
func newEngine(cfg config.Diarization) (Engine, error) {
return &stubEngine{cfg: cfg.WithDefaults()}, nil
}
func (s *stubEngine) Active() bool {
return false
}
func (s *stubEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) {
return nil, fmt.Errorf("speaker diarization requires build with -tags sherpa and models (make build-sherpa, make download-diarization-models)")
}
func (s *stubEngine) Close() {}

55
docker/Dockerfile Normal file
View File

@ -0,0 +1,55 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=12.0.0
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the CUDA runtime image
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS build
WORKDIR /app
# Unless otherwise specified, we make a fat build.
ARG CUDA_DOCKER_ARCH=all
# Set nvcc architecture
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV WHISPER_CUBLAS=1
#apt-get
RUN apt-get update && \
apt-get install -y --no-install-recommends build-essential git gcc g++ wget \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# install golang
RUN wget --progress=dot:giga https://go.dev/dl/go1.22.10.linux-amd64.tar.gz
RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz
ENV PATH ${PATH}:/usr/local/go/bin
# Ref: https://stackoverflow.com/a/53464012
ENV CUDA_MAIN_VERSION=12.0
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY ./ .
RUN make dependency && env && make build && \
mv bin/go-whisper-api /bin/ && \
rm -rf bin
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} AS runtime
WORKDIR /app
LABEL maintainer="Bo-Yi Wu <appleboy.tw@gmail.com>" \
org.label-schema.name="Speech-to-Text" \
org.label-schema.vendor="Bo-Yi Wu" \
org.label-schema.schema-version="1.0"
LABEL org.opencontainers.image.source=https://github.com/appleboy/go-whisper-api
LABEL org.opencontainers.image.description="Speech-to-Text."
LABEL org.opencontainers.image.licenses=MIT
RUN apt-get update && \
apt-get install -y --no-install-recommends curl \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /bin/go-whisper-api /bin/go-whisper-api
EXPOSE 8080
ENTRYPOINT ["/bin/go-whisper-api"]

27
docker/Dockerfile.ci Normal file
View File

@ -0,0 +1,27 @@
# CPU image for Gitea CI and hosts without NVIDIA GPU.
# Production GPU image: docker/Dockerfile
FROM golang:1.23-bookworm AS build
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git cmake pkg-config libsentencepiece-dev \
&& rm -rf /var/lib/apt/lists/*
COPY . .
RUN make dependency && make build-xlm
FROM debian:bookworm-slim AS runtime
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
COPY --from=build /app/bin/go-whisper-api /bin/go-whisper-api
COPY --from=build /app/third_party/whisper.cpp/build/src/libwhisper.so* /usr/local/lib/
COPY --from=build /app/third_party/whisper.cpp/build/ggml/src/libggml*.so* /usr/local/lib/
ENV LD_LIBRARY_PATH=/usr/local/lib
EXPOSE 8080
ENTRYPOINT ["/bin/go-whisper-api"]

56
garbage/filter.go Normal file
View File

@ -0,0 +1,56 @@
package garbage
import (
"regexp"
"strings"
)
var spaceCollapse = regexp.MustCompile(`\s+`)
// Word is a timed token for garbage filtering (mirrors whisper.Word JSON shape).
type Word struct {
Word string `json:"word"`
Start int `json:"start"`
Stop int `json:"stop"`
}
// FilterText removes configured artifact substrings and normalizes whitespace.
func FilterText(text string, patterns []string) string {
for _, p := range patterns {
p = strings.TrimSpace(p)
if p == "" {
continue
}
text = strings.ReplaceAll(text, p, " ")
}
return strings.TrimSpace(spaceCollapse.ReplaceAllString(text, " "))
}
// FilterWords drops tokens that match any garbage pattern.
func FilterWords(words []Word, patterns []string) []Word {
if len(words) == 0 {
return words
}
out := make([]Word, 0, len(words))
for _, w := range words {
if matchesGarbage(w.Word, patterns) {
continue
}
out = append(out, w)
}
return out
}
func matchesGarbage(word string, patterns []string) bool {
word = strings.TrimSpace(word)
for _, p := range patterns {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if word == p || strings.Contains(word, p) {
return true
}
}
return false
}

30
garbage/filter_test.go Normal file
View File

@ -0,0 +1,30 @@
package garbage
import "testing"
func TestFilterText(t *testing.T) {
in := "Привет *выбая* мир *выбая*"
got := FilterText(in, []string{"*выбая*"})
want := "Привет мир"
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestFilterWords(t *testing.T) {
words := []Word{
{Word: "Что", Start: 0, Stop: 100},
{Word: "*выбая*", Start: 100, Stop: 200},
{Word: "мир", Start: 200, Stop: 300},
}
got := FilterWords(words, []string{"*выбая*"})
if len(got) != 2 || got[1].Word != "мир" {
t.Fatalf("got %+v", got)
}
}
func TestFilterText_emptyPatterns(t *testing.T) {
if got := FilterText("a b", nil); got != "a b" {
t.Fatalf("got %q", got)
}
}

48
go.mod Normal file
View File

@ -0,0 +1,48 @@
module go-whisper-api
go 1.26
require (
github.com/Eyevinn/mp4ff v0.51.0
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27
github.com/go-audio/audio v1.0.0
github.com/go-audio/wav v1.1.0
github.com/google/uuid v1.6.0
github.com/gopxl/beep v1.4.1
github.com/joho/godotenv v1.5.1
github.com/k2-fsa/sherpa-onnx-go v1.13.2
github.com/mattn/go-isatty v0.0.20
github.com/olivier-w/climp-aac-decoder v0.1.0
github.com/pion/opus v0.0.0-20260601214817-71d58474cec8
github.com/rs/zerolog v1.35.0
github.com/skrashevich/go-aac v0.1.0
github.com/urfave/cli/v2 v2.27.7
github.com/yalue/onnxruntime_go v1.24.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/go-audio/riff v1.0.0 // indirect
github.com/hajimehoshi/go-mp3 v0.3.4 // indirect
github.com/icza/bitio v1.1.0 // indirect
github.com/jfreymuth/oggvorbis v1.0.5 // indirect
github.com/jfreymuth/vorbis v1.0.2 // indirect
github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 // indirect
github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 // indirect
github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mewkiz/flac v1.0.8 // indirect
github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect
golang.org/x/sys v0.45.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)
replace github.com/ggerganov/whisper.cpp/bindings/go => ./third_party/whisper.cpp/bindings/go

127
go.sum Normal file
View File

@ -0,0 +1,127 @@
github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M=
github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg=
github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo=
github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ebitengine/oto/v3 v3.1.0 h1:9tChG6rizyeR2w3vsygTTTVVJ9QMMyu00m2yBOCch6U=
github.com/ebitengine/oto/v3 v3.1.0/go.mod h1:IK1QTnlfZK2GIB6ziyECm433hAdTaPpOsGMLhEyEGTg=
github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA=
github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ=
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg=
github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopxl/beep v1.4.1 h1:WqNs9RsDAhG9M3khMyc1FaVY50dTdxG/6S6a3qsUHqE=
github.com/gopxl/beep v1.4.1/go.mod h1:A1dmiUkuY8kxsvcNJNUBIEcchmiP6eUyCHSxpXl0YO0=
github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68=
github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo=
github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo=
github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA=
github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ=
github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII=
github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE=
github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/jszwec/csvutil v1.5.1/go.mod h1:Rpu7Uu9giO9subDyMCIQfHVDuLrcaC36UA4YcJjGBkg=
github.com/k2-fsa/sherpa-onnx-go v1.13.2 h1:1orlwwdJcVk37OoV1gi1L6fIdeZHKJcLQjm6S5h/WWs=
github.com/k2-fsa/sherpa-onnx-go v1.13.2/go.mod h1:NyA3OsF/hmcvk4FpbXpBO9Rl/I3dlnrOB9Ny+TnhvIc=
github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 h1:w/RHaU9liD/ovA9Q5iI2lCoVjVEd4M8+uTrf54rjQPY=
github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2/go.mod h1:NXEH2rsBgTdqY59YpPq6CtSBlBAXy/8a9FmpLERU97I=
github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 h1:oIIdSfU3NEMT7oq7yxoH7Rk37kekkbmr/b+7e4er1WE=
github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2/go.mod h1:ZOhUAXC62Unj0ZNfu6zxSFKcW96aXf7P3BsqiUyOBbE=
github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 h1:tkfwXnmJRsChv59ZzLsycphwpvJpSkyd29aMG9kUoEE=
github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2/go.mod h1:5AX7TU8+P/gInjglY1ijtWUM2b8iyR0QX4yEngzMe64=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mewkiz/flac v1.0.8 h1:cophRjvafteDGmqsfXRK28YAX6l8wy19QxTHruEEg1s=
github.com/mewkiz/flac v1.0.8/go.mod h1:l7dt5uFY724eKVkHQtAJAQSkhpC3helU3RDxN0ESAqo=
github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 h1:tnAPMExbRERsyEYkmR1YjhTgDM0iqyiBYf8ojRXxdbA=
github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14/go.mod h1:QYCFBiH5q6XTHEbWhR0uhR3M9qNPoD2CSQzr0g75kE4=
github.com/olivier-w/climp-aac-decoder v0.1.0 h1:jdwYeBlnfQlnm6KA/fG20s22HrlBlVVaaWQTj/tnb3A=
github.com/olivier-w/climp-aac-decoder v0.1.0/go.mod h1:Hpqi8cI4NeN4JEcgKxiAK0zEwx+hO+lWdg9Q3ruHouI=
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw=
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0=
github.com/pion/opus v0.0.0-20260601214817-71d58474cec8 h1:THWkLaX+FzbTbHIo1irDpbBMXbNrqV6q2Xs+00p4d+0=
github.com/pion/opus v0.0.0-20260601214817-71d58474cec8/go.mod h1:t5Xog2n682JnawoykACE6nKVmupFvmJvkpM7x6bTv6g=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/skrashevich/go-aac v0.1.0 h1:7oHNj1ADmgfjAHvi3wAIFbmbCpQBrcjZEVTLlRtAS1A=
github.com/skrashevich/go-aac v0.1.0/go.mod h1:Mj7r//4LDL4FC0ezORj+MnmQ+nDEkJhTOy2aMC8dzww=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU=
github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4=
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg=
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
github.com/yalue/onnxruntime_go v1.24.0 h1:IdgJLxxyotlsUTmL1UnHZgBzXJGgY51LZ4vQ5rZeOXU=
github.com/yalue/onnxruntime_go v1.24.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

251
main.go Normal file
View File

@ -0,0 +1,251 @@
package main
import (
"os"
"runtime"
"strconv"
"time"
"go-whisper-api/api"
"go-whisper-api/config"
_ "github.com/joho/godotenv/autoload"
"github.com/mattn/go-isatty"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v2"
)
var (
Version string
)
func main() {
isTerm := isatty.IsTerminal(os.Stdout.Fd())
zerolog.SetGlobalLevel(zerolog.InfoLevel)
log.Logger = log.Output(
zerolog.ConsoleWriter{
Out: os.Stderr,
NoColor: !isTerm,
},
)
zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string {
short := file
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
short = file[i+1:]
break
}
}
file = short
return file + ":" + strconv.Itoa(line)
}
app := cli.NewApp()
app.Name = "go-whisper-api"
app.Usage = "HTTP API for speech-to-text (SPR + OpenAI-compatible)."
app.Copyright = "Copyright (c) " + strconv.Itoa(time.Now().Year()) + " Bo-Yi Wu"
app.Authors = []*cli.Author{
{
Name: "Bo-Yi Wu",
Email: "appleboy.tw@gmail.com",
},
}
app.Version = Version
app.Commands = []*cli.Command{
{
Name: "serve",
Aliases: []string{"s", "api"},
Usage: "start HTTP API server",
Flags: serveFlags(),
Action: runServe,
},
}
app.Flags = append(append([]cli.Flag{configFlag()}, punctuationFlags()...), serveFlags()...)
app.Action = runServe
if err := app.Run(os.Args); err != nil {
log.Fatal().Err(err).Msg("can't run app")
}
}
func configFlag() cli.Flag {
return &cli.StringFlag{
Name: "config",
Usage: "path to YAML config file (default: config.yaml if present)",
EnvVars: []string{"CONFIG_PATH", "GO_WHISPER_CONFIG"},
}
}
func punctuationFlags() []cli.Flag {
return []cli.Flag{
&cli.BoolFlag{
Name: "punctuation-enabled",
Usage: "master switch: enable or disable punctuation (YAML: punctuation.enabled)",
EnvVars: []string{"PUNCTUATION_ENABLED"},
},
&cli.BoolFlag{
Name: "punctuation-default-on",
Usage: "apply punctuation by default when query/flags omitted (YAML: punctuation.default_on)",
},
&cli.StringFlag{
Name: "punctuation-engine",
Usage: "punctuation engine: off, heuristic, sherpa, sherpa-online, http, xlm",
EnvVars: []string{"PUNCTUATION_ENGINE"},
},
}
}
func vadFlags() []cli.Flag {
return []cli.Flag{
&cli.BoolFlag{
Name: "vad",
Usage: "enable voice activity detection (Silero VAD)",
EnvVars: []string{"API_VAD"},
},
&cli.StringFlag{
Name: "vad-model",
Usage: "path to ggml Silero VAD model (e.g. models/ggml-silero-v6.2.0.bin)",
EnvVars: []string{"API_VAD_MODEL"},
},
&cli.Float64Flag{
Name: "vad-threshold",
Usage: "VAD speech threshold (0.0-1.0)",
EnvVars: []string{"API_VAD_THRESHOLD"},
},
&cli.IntFlag{
Name: "vad-min-speech-ms",
Usage: "minimum speech duration in ms",
EnvVars: []string{"API_VAD_MIN_SPEECH_MS"},
},
&cli.IntFlag{
Name: "vad-min-silence-ms",
Usage: "minimum silence between segments in ms",
EnvVars: []string{"API_VAD_MIN_SILENCE_MS"},
},
&cli.Float64Flag{
Name: "vad-max-speech-sec",
Usage: "maximum speech segment length in seconds (0 = unlimited)",
EnvVars: []string{"API_VAD_MAX_SPEECH_SEC"},
},
&cli.IntFlag{
Name: "vad-speech-pad-ms",
Usage: "padding around speech segments in ms",
EnvVars: []string{"API_VAD_SPEECH_PAD_MS"},
},
&cli.Float64Flag{
Name: "vad-samples-overlap",
Usage: "overlap between VAD segments in seconds",
EnvVars: []string{"API_VAD_SAMPLES_OVERLAP"},
},
}
}
func serveFlags() []cli.Flag {
flags := []cli.Flag{
&cli.StringFlag{
Name: "addr",
Usage: "HTTP listen address",
EnvVars: []string{"API_ADDR"},
Value: ":8080",
},
&cli.StringFlag{
Name: "models-dir",
Usage: "directory with ggml *.bin whisper models",
EnvVars: []string{"API_MODELS_DIR"},
Value: "./models",
},
&cli.StringFlag{
Name: "cache-dir",
Usage: "directory for async task cache (waiting/ready)",
EnvVars: []string{"API_CACHE_DIR"},
Value: "./cache",
},
&cli.StringFlag{
Name: "language",
Usage: "default language for speech recognition",
EnvVars: []string{"API_LANGUAGE"},
Value: "auto",
},
&cli.UintFlag{
Name: "threads",
Usage: "number of threads for whisper",
EnvVars: []string{"API_THREADS"},
Value: uint(runtime.NumCPU()),
},
&cli.BoolFlag{
Name: "debug",
Usage: "enable debug mode",
EnvVars: []string{"API_DEBUG"},
},
&cli.BoolFlag{
Name: "speedup",
Usage: "speed up audio by x2",
EnvVars: []string{"API_SPEEDUP"},
},
&cli.BoolFlag{
Name: "translate",
Usage: "translate to english",
EnvVars: []string{"API_TRANSLATE"},
},
&cli.StringFlag{
Name: "prompt",
Usage: "initial prompt",
EnvVars: []string{"API_PROMPT"},
},
&cli.UintFlag{
Name: "max-context",
Usage: "maximum text context tokens",
EnvVars: []string{"API_MAX_CONTEXT"},
Value: 32,
},
&cli.UintFlag{
Name: "beam-size",
Usage: "beam size for beam search",
EnvVars: []string{"API_BEAM_SIZE"},
Value: 5,
},
&cli.Float64Flag{
Name: "entropy-thold",
Usage: "entropy threshold",
EnvVars: []string{"API_ENTROPY_THOLD"},
Value: 2.4,
},
&cli.BoolFlag{
Name: "default-punctuation",
Usage: "enable punctuation on STT by default",
EnvVars: []string{"API_DEFAULT_PUNCTUATION"},
},
}
return append(flags, vadFlags()...)
}
func runServe(c *cli.Context) error {
apiCfg, err := config.APIFromCLI(c)
if err != nil {
return err
}
if apiCfg.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
log.Logger = log.With().Caller().Logger()
}
vad := apiCfg.VAD.WithDefaults()
if vad.Enabled {
vad.Model = vad.ResolveModelPath(apiCfg.ModelsDir)
if err := vad.Validate(); err != nil {
return err
}
apiCfg.VAD = vad
}
tc, err := config.TranscodeFromCLI(c)
if err != nil {
return err
}
pc, err := config.PunctuationFromCLI(c)
if err != nil {
return err
}
dc, err := config.DiarizationFromCLI(c)
if err != nil {
return err
}
return api.Run(c.Context, apiCfg, tc, pc, dc)
}

0
models/.gitkeep Normal file
View File

View File

View File

View File

0
models/vad/.gitkeep Normal file
View File

86
punctuation/heuristic.go Normal file
View File

@ -0,0 +1,86 @@
package punctuation
import (
"context"
"regexp"
"strings"
"unicode"
"unicode/utf8"
)
type Heuristic struct{}
func (Heuristic) Active() bool {
return true
}
func (Heuristic) Restore(ctx context.Context, text, language string) (string, error) {
_ = ctx
text = strings.TrimSpace(text)
if text == "" {
return text, nil
}
text = normalizeSpaces(text)
text = capitalizeFirst(text)
lang := strings.ToLower(strings.TrimSpace(language))
if lang == "ru" || lang == "rus" || lang == "russian" || lang == "auto" {
text = heuristicRU(text)
} else {
text = heuristicEN(text)
}
return ensureTerminalPunct(text), nil
}
func normalizeSpaces(s string) string {
return strings.Join(strings.Fields(s), " ")
}
func capitalizeFirst(s string) string {
r, size := utf8.DecodeRuneInString(s)
if r == utf8.RuneError {
return s
}
return string(unicode.ToUpper(r)) + s[size:]
}
var (
reQuestionRU = regexp.MustCompile(`(?i)(^|.*\s)(как|что|где|когда|почему|зачем|кто|чей|какой|какая|какое|какие|сколько|зачем|откуда|куда|ли)(\s+[^.?!]+)$`)
reQuestionEN = regexp.MustCompile(`(?i)^(who|what|when|where|why|how|which|whose|whom|is|are|am|was|were|do|does|did|can|could|would|will|shall|should)\b`)
)
func heuristicRU(s string) string {
if reQuestionRU.MatchString(s) && !strings.HasSuffix(s, "?") {
return s + "?"
}
if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 {
return s + "."
}
return s
}
func heuristicEN(s string) string {
lower := strings.ToLower(s)
if reQuestionEN.MatchString(lower) && !strings.HasSuffix(s, "?") {
return s + "?"
}
if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 {
return s + "."
}
return s
}
func hasTerminalPunct(s string) bool {
s = strings.TrimSpace(s)
if s == "" {
return false
}
r, _ := utf8.DecodeLastRuneInString(s)
return r == '.' || r == '?' || r == '!' || r == '…'
}
func ensureTerminalPunct(s string) string {
if hasTerminalPunct(s) {
return s
}
return s + "."
}

View File

@ -0,0 +1,32 @@
package punctuation
import (
"context"
"testing"
)
func TestHeuristicRU_question(t *testing.T) {
h := Heuristic{}
out, err := h.Restore(context.Background(), "как дела", "ru")
if err != nil {
t.Fatal(err)
}
if !stringsHasSuffix(out, "?") {
t.Fatalf("expected question mark, got %q", out)
}
}
func TestHeuristicEN_period(t *testing.T) {
h := Heuristic{}
out, err := h.Restore(context.Background(), "hello world", "en")
if err != nil {
t.Fatal(err)
}
if !stringsHasSuffix(out, ".") {
t.Fatalf("expected period, got %q", out)
}
}
func stringsHasSuffix(s, suffix string) bool {
return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
}

View File

@ -0,0 +1,92 @@
//go:build xlm
package spwrap
/*
#cgo CXXFLAGS: -std=c++17
#cgo LDFLAGS: -lsentencepiece
#include <stdlib.h>
#include "sp_wrap.h"
*/
import "C"
import (
"fmt"
"unsafe"
)
type Processor struct {
p *C.SPProcessor
}
func Load(path string) (*Processor, error) {
cpath := C.CString(path)
defer C.free(unsafe.Pointer(cpath))
var errMsg *C.char
p := C.sp_load(cpath, &errMsg)
if p == nil {
if errMsg != nil {
defer C.free(unsafe.Pointer(errMsg))
return nil, fmt.Errorf("sentencepiece: %s", C.GoString(errMsg))
}
return nil, fmt.Errorf("sentencepiece: failed to load %s", path)
}
return &Processor{p: p}, nil
}
func (proc *Processor) Close() {
if proc.p != nil {
C.sp_free(proc.p)
proc.p = nil
}
}
func (proc *Processor) BOSID() int {
return int(C.sp_bos_id(proc.p))
}
func (proc *Processor) EOSID() int {
return int(C.sp_eos_id(proc.p))
}
func (proc *Processor) PadID() int {
return int(C.sp_pad_id(proc.p))
}
func (proc *Processor) EncodeAsIDs(text string) ([]int, error) {
ctext := C.CString(text)
defer C.free(unsafe.Pointer(ctext))
var ids *C.int
var n C.int
var errMsg *C.char
if C.sp_encode(proc.p, ctext, &ids, &n, &errMsg) == 0 {
if errMsg != nil {
defer C.free(unsafe.Pointer(errMsg))
return nil, fmt.Errorf("sentencepiece encode: %s", C.GoString(errMsg))
}
return nil, fmt.Errorf("sentencepiece encode failed")
}
if ids == nil || n == 0 {
return nil, nil
}
defer C.free(unsafe.Pointer(ids))
out := make([]int, int(n))
slice := unsafe.Slice(ids, int(n))
for i := range out {
out[i] = int(slice[i])
}
return out, nil
}
func (proc *Processor) IDToPiece(id int) (string, error) {
var errMsg *C.char
piece := C.sp_id_to_piece(proc.p, C.int(id), &errMsg)
if piece == nil {
if errMsg != nil {
defer C.free(unsafe.Pointer(errMsg))
return "", fmt.Errorf("sentencepiece id to piece: %s", C.GoString(errMsg))
}
return "", fmt.Errorf("sentencepiece id to piece failed")
}
defer C.free(unsafe.Pointer(piece))
return C.GoString(piece), nil
}

View File

@ -0,0 +1,88 @@
#include "sp_wrap.h"
#include <sentencepiece_processor.h>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>
struct SPProcessor {
sentencepiece::SentencePieceProcessor proc;
};
static char *copy_err(const std::string &msg) {
char *out = static_cast<char *>(std::malloc(msg.size() + 1));
if (out != nullptr) {
std::memcpy(out, msg.c_str(), msg.size() + 1);
}
return out;
}
SPProcessor *sp_load(const char *path, char **err) {
if (err != nullptr) {
*err = nullptr;
}
auto *p = new SPProcessor();
const auto status = p->proc.Load(path);
if (!status.ok()) {
if (err != nullptr) {
*err = copy_err(status.ToString());
}
delete p;
return nullptr;
}
return p;
}
void sp_free(SPProcessor *p) { delete p; }
int sp_bos_id(const SPProcessor *p) { return p->proc.bos_id(); }
int sp_eos_id(const SPProcessor *p) { return p->proc.eos_id(); }
int sp_pad_id(const SPProcessor *p) { return p->proc.pad_id(); }
int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err) {
if (err != nullptr) {
*err = nullptr;
}
if (out_ids != nullptr) {
*out_ids = nullptr;
}
if (out_len != nullptr) {
*out_len = 0;
}
std::vector<int> ids;
const auto status = p->proc.Encode(text, &ids);
if (!status.ok()) {
if (err != nullptr) {
*err = copy_err(status.ToString());
}
return 0;
}
if (ids.empty()) {
return 1;
}
int *buf = static_cast<int *>(std::malloc(sizeof(int) * ids.size()));
if (buf == nullptr) {
if (err != nullptr) {
*err = copy_err("malloc failed");
}
return 0;
}
for (size_t i = 0; i < ids.size(); i++) {
buf[i] = ids[i];
}
*out_ids = buf;
*out_len = static_cast<int>(ids.size());
return 1;
}
char *sp_id_to_piece(const SPProcessor *p, int id, char **err) {
if (err != nullptr) {
*err = nullptr;
}
const std::string piece = p->proc.IdToPiece(id);
return copy_err(piece);
}

View File

@ -0,0 +1,21 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
typedef struct SPProcessor SPProcessor;
SPProcessor *sp_load(const char *path, char **err);
void sp_free(SPProcessor *p);
int sp_bos_id(const SPProcessor *p);
int sp_eos_id(const SPProcessor *p);
int sp_pad_id(const SPProcessor *p);
int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err);
char *sp_id_to_piece(const SPProcessor *p, int id, char **err);
#ifdef __cplusplus
}
#endif

177
punctuation/ort_env.go Normal file
View File

@ -0,0 +1,177 @@
//go:build xlm
package punctuation
import (
"os"
"path/filepath"
"runtime"
"strings"
"sync"
ort "github.com/yalue/onnxruntime_go"
)
var (
ortOnce sync.Once
ortErr error
)
func ensureORT() error {
ortOnce.Do(func() {
if p := resolveONNXRuntimeLib(); p != "" {
ort.SetSharedLibraryPath(p)
}
ortErr = ort.InitializeEnvironment()
})
return ortErr
}
func resolveONNXRuntimeLib() string {
if p := strings.TrimSpace(os.Getenv("ONNXRUNTIME_SHARED_LIBRARY_PATH")); p != "" {
return p
}
for _, p := range onnxRuntimeCandidates() {
if st, err := os.Stat(p); err == nil && !st.IsDir() {
return p
}
}
return ""
}
func onnxRuntimeCandidates() []string {
arch := sherpaLibArch()
ver := sherpaLinuxModuleVersion()
var out []string
for _, root := range goModCacheRoots() {
if ver != "" {
out = append(out, filepath.Join(root,
"github.com/k2-fsa/sherpa-onnx-go-linux@"+ver,
"lib", arch, "libonnxruntime.so"))
}
}
if exe, err := os.Executable(); err == nil {
exeDir := filepath.Dir(exe)
out = append(out,
filepath.Join(exeDir, "libonnxruntime.so"),
filepath.Join(exeDir, "lib", "libonnxruntime.so"),
filepath.Join(exeDir, "..", "lib", "libonnxruntime.so"),
)
if modRoot := findModuleRoot(exeDir); modRoot != "" {
out = append(out, filepath.Join(modRoot, "lib", "libonnxruntime.so"))
}
}
return out
}
func goModCacheRoots() []string {
var roots []string
seen := map[string]struct{}{}
add := func(p string) {
p = strings.TrimSpace(p)
if p == "" {
return
}
if _, ok := seen[p]; ok {
return
}
seen[p] = struct{}{}
roots = append(roots, p)
}
add(os.Getenv("GOMODCACHE"))
if gopath := os.Getenv("GOPATH"); gopath != "" {
for _, gp := range filepath.SplitList(gopath) {
add(filepath.Join(gp, "pkg", "mod"))
}
}
if home, err := os.UserHomeDir(); err == nil {
add(filepath.Join(home, "go", "pkg", "mod"))
}
return roots
}
func findModuleRoot(start string) string {
dir := start
for i := 0; i < 8; i++ {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
if cwd, err := os.Getwd(); err == nil {
return findModuleRootFrom(cwd)
}
return ""
}
func findModuleRootFrom(dir string) string {
for i := 0; i < 8; i++ {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
return ""
}
func sherpaLinuxModuleVersion() string {
for _, dir := range []string{findModuleRootFrom(mustCwd()), ""} {
if dir == "" {
continue
}
if v := readSherpaVersion(filepath.Join(dir, "go.mod")); v != "" {
return v
}
}
if exe, err := os.Executable(); err == nil {
if root := findModuleRoot(filepath.Dir(exe)); root != "" {
return readSherpaVersion(filepath.Join(root, "go.mod"))
}
}
return ""
}
func readSherpaVersion(goModPath string) string {
data, err := os.ReadFile(goModPath)
if err != nil {
return ""
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if !strings.Contains(line, "github.com/k2-fsa/sherpa-onnx-go-linux") {
continue
}
fields := strings.Fields(line)
if len(fields) >= 2 {
return fields[1] // e.g. v1.13.2 — must match pkg/mod path @v1.13.2
}
}
return ""
}
func sherpaLibArch() string {
switch runtime.GOARCH {
case "arm64":
return "aarch64-unknown-linux-gnu"
case "arm":
return "arm-unknown-linux-gnueabihf"
default:
return "x86_64-unknown-linux-gnu"
}
}
func mustCwd() string {
cwd, err := os.Getwd()
if err != nil {
return ""
}
return cwd
}

154
punctuation/punctuation.go Normal file
View File

@ -0,0 +1,154 @@
package punctuation
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"go-whisper-api/config"
)
type Restorer interface {
Active() bool
Restore(ctx context.Context, text, language string) (string, error)
}
type Closer interface {
Close()
}
func New(cfg config.Punctuation) (Restorer, error) {
cfg = cfg.WithDefaults()
if !cfg.Active() {
return Nop{}, nil
}
engine := strings.ToLower(strings.TrimSpace(cfg.Engine))
switch engine {
case "heuristic":
return Heuristic{}, nil
case "sherpa", "sherpa-offline", "offline":
cfg.Engine = "sherpa"
return newSherpaRestorer(cfg)
case "sherpa-online", "online":
cfg.Engine = "sherpa-online"
return newSherpaRestorer(cfg)
case "xlm", "xlm-roberta", "roberta":
return newXLM(cfg)
case "http":
if cfg.HTTPURL == "" {
return nil, fmt.Errorf("punctuation.http_url is required when engine=http")
}
return HTTP{cfg: cfg}, nil
default:
return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine)
}
}
type Nop struct{}
func (Nop) Active() bool {
return false
}
func (Nop) Restore(ctx context.Context, text, language string) (string, error) {
return text, nil
}
type HTTP struct {
cfg config.Punctuation
client *http.Client
}
func (h HTTP) Active() bool { return true }
func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) {
body, err := json.Marshal(map[string]string{
"text": text,
"language": language,
})
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := h.client
if client == nil {
client = &http.Client{Timeout: h.cfg.Timeout()}
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode >= 300 {
return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw)))
}
var out struct {
Text string `json:"text"`
}
if err := json.Unmarshal(raw, &out); err != nil {
return strings.TrimSpace(string(raw)), nil
}
if out.Text == "" {
return text, nil
}
return out.Text, nil
}
func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) {
if !enabled || r == nil || !r.Active() {
return text, nil
}
text = strings.TrimSpace(text)
if text == "" {
return text, nil
}
return r.Restore(ctx, text, language)
}
func Close(r Restorer) {
if c, ok := r.(Closer); ok {
c.Close()
}
}
func AutoSelect(cfg config.Punctuation) (Restorer, error) {
cfg = cfg.WithDefaults()
if !cfg.Active() {
return Nop{}, nil
}
engine := strings.ToLower(cfg.Engine)
if engine == "heuristic" {
return Heuristic{}, nil
}
if engine == "http" {
return New(cfg)
}
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
return newXLM(cfg)
}
if _, err := os.Stat(cfg.ModelPath()); err == nil {
cfg.Engine = engine
if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" {
cfg.Engine = "sherpa"
}
return newSherpaRestorer(cfg)
}
if engine == "sherpa" || engine == "sherpa-online" || engine == "online" {
return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath())
}
return Heuristic{}, nil
}

View File

@ -0,0 +1,40 @@
package punctuation
import (
"context"
"testing"
"go-whisper-api/config"
)
func TestNop(t *testing.T) {
r, err := New(config.Punctuation{Engine: "off"})
if err != nil {
t.Fatal(err)
}
out, err := Apply(context.Background(), r, true, "hello world", "en")
if err != nil {
t.Fatal(err)
}
if out != "hello world" {
t.Fatalf("got %q", out)
}
}
func TestApply_disabled(t *testing.T) {
r := Heuristic{}
out, err := Apply(context.Background(), r, false, "hello", "en")
if err != nil || out != "hello" {
t.Fatalf("got %q err=%v", out, err)
}
}
func TestNew_heuristic(t *testing.T) {
r, err := New(config.Punctuation{Enabled: true, Engine: "heuristic"})
if err != nil {
t.Fatal(err)
}
if !r.Active() {
t.Fatal("expected active")
}
}

107
punctuation/sherpa.go Normal file
View File

@ -0,0 +1,107 @@
//go:build sherpa
package punctuation
import (
"context"
"fmt"
"os"
"strings"
"sync"
"go-whisper-api/config"
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
)
type Sherpa struct {
offline *sherpa.OfflinePunctuation
online *sherpa.OnlinePunctuation
mu sync.Mutex
}
func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) {
return newSherpa(cfg)
}
func newSherpa(cfg config.Punctuation) (*Sherpa, error) {
cfg = cfg.WithDefaults()
engine := strings.ToLower(cfg.Engine)
s := &Sherpa{}
switch engine {
case "sherpa-online", "online":
modelPath := cfg.ModelPath()
vocabPath := cfg.BpeVocabPath()
if _, err := os.Stat(modelPath); err != nil {
return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err)
}
if _, err := os.Stat(vocabPath); err != nil {
return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err)
}
conf := sherpa.OnlinePunctuationConfig{}
conf.Model.CnnBilstm = modelPath
conf.Model.BpeVocab = vocabPath
conf.Model.NumThreads = cfg.NumThreads
conf.Model.Provider = "cpu"
s.online = sherpa.NewOnlinePunctuation(&conf)
if s.online == nil {
return nil, fmt.Errorf("failed to create sherpa online punctuation")
}
default:
modelPath := cfg.ModelPath()
if _, err := os.Stat(modelPath); err != nil {
return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err)
}
conf := sherpa.OfflinePunctuationConfig{}
conf.Model.CtTransformer = modelPath
conf.Model.NumThreads = cfg.NumThreads
conf.Model.Provider = "cpu"
s.offline = sherpa.NewOfflinePunctuation(&conf)
if s.offline == nil {
return nil, fmt.Errorf("failed to create sherpa offline punctuation")
}
}
return s, nil
}
func (s *Sherpa) Active() bool {
return true
}
func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) {
_ = ctx
_ = language
text = strings.TrimSpace(text)
if text == "" {
return text, nil
}
s.mu.Lock()
defer s.mu.Unlock()
var out string
switch {
case s.offline != nil:
out = s.offline.AddPunct(text)
case s.online != nil:
out = s.online.AddPunct(text)
default:
return text, fmt.Errorf("sherpa punctuation not initialized")
}
out = strings.TrimSpace(out)
if out == "" {
return text, nil
}
return out, nil
}
func (s *Sherpa) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.offline != nil {
sherpa.DeleteOfflinePunc(s.offline)
s.offline = nil
}
if s.online != nil {
sherpa.DeleteOnlinePunctuation(s.online)
s.online = nil
}
}

View File

@ -0,0 +1,13 @@
//go:build !sherpa
package punctuation
import (
"fmt"
"go-whisper-api/config"
)
func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) {
return nil, fmt.Errorf("punctuation engine %q requires build tag sherpa (use: make build)", cfg.Engine)
}

319
punctuation/xlm.go Normal file
View File

@ -0,0 +1,319 @@
//go:build xlm
package punctuation
import (
"context"
"fmt"
"os"
"strings"
"unicode"
"go-whisper-api/config"
"go-whisper-api/punctuation/internal/spwrap"
ort "github.com/yalue/onnxruntime_go"
)
type XLM struct {
cfg config.Punctuation
modelCfg xlmModelConfig
sp *spwrap.Processor
session *ort.DynamicAdvancedSession
inputName string
outputNames []string
joinSBD bool
}
func newXLM(cfg config.Punctuation) (*XLM, error) {
if err := ensureORT(); err != nil {
return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err)
}
onnxPath := cfg.ModelPath()
if _, err := os.Stat(onnxPath); err != nil {
return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath)
}
spPath := cfg.SPModelPath()
if _, err := os.Stat(spPath); err != nil {
return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath)
}
cfgPath := cfg.XLMConfigPath()
modelCfg, err := loadXLMConfig(cfgPath)
if err != nil {
return nil, err
}
sp, err := spwrap.Load(spPath)
if err != nil {
return nil, err
}
inputs, outputs, err := ort.GetInputOutputInfo(onnxPath)
if err != nil {
sp.Close()
return nil, fmt.Errorf("xlm model io: %w", err)
}
if len(inputs) == 0 || len(outputs) < 4 {
sp.Close()
return nil, fmt.Errorf("xlm model: unexpected inputs/outputs")
}
inNames := make([]string, len(inputs))
for i, in := range inputs {
inNames[i] = in.Name
}
outNames := make([]string, len(outputs))
for i, out := range outputs {
outNames[i] = out.Name
}
session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil)
if err != nil {
sp.Close()
return nil, fmt.Errorf("xlm onnx session: %w", err)
}
return &XLM{
cfg: cfg,
modelCfg: modelCfg,
sp: sp,
session: session,
inputName: inNames[0],
outputNames: outNames,
joinSBD: cfg.XLMJoinSentences(),
}, nil
}
func (x *XLM) Active() bool {
return true
}
func (x *XLM) Close() {
if x.session != nil {
_ = x.session.Destroy()
x.session = nil
}
if x.sp != nil {
x.sp.Close()
x.sp = nil
}
}
func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) {
if err := ctx.Err(); err != nil {
return "", err
}
text = strings.TrimSpace(normalizeXLMSpaces(text))
if text == "" {
return text, nil
}
ids, err := x.sp.EncodeAsIDs(text)
if err != nil {
return "", err
}
full := make([]int, 0, len(ids)+2)
full = append(full, x.sp.BOSID())
full = append(full, ids...)
full = append(full, x.sp.EOSID())
maxLen := x.modelCfg.MaxLength
if maxLen <= 2 {
maxLen = 256
}
if len(full) <= maxLen {
return x.inferIDs(full)
}
var parts []string
content := full[1 : len(full)-1]
step := maxLen - 2
for start := 0; start < len(content); {
end := start + step
if end > len(content) {
end = len(content)
}
chunk := make([]int, 0, end-start+2)
chunk = append(chunk, x.sp.BOSID())
chunk = append(chunk, content[start:end]...)
chunk = append(chunk, x.sp.EOSID())
out, err := x.inferIDs(chunk)
if err != nil {
return "", err
}
if out != "" {
parts = append(parts, out)
}
if end >= len(content) {
break
}
start = end
}
return strings.TrimSpace(strings.Join(parts, " ")), nil
}
func (x *XLM) inferIDs(inputIDs []int) (string, error) {
data := make([]int64, len(inputIDs))
for i, id := range inputIDs {
data[i] = int64(id)
}
inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data)
if err != nil {
return "", err
}
defer inputTensor.Destroy()
outputs := make([]ort.Value, len(x.outputNames))
if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil {
return "", err
}
defer destroyValues(outputs)
pre, err := int64Row(outputs[0], len(inputIDs))
if err != nil {
return "", err
}
post, err := int64Row(outputs[1], len(inputIDs))
if err != nil {
return "", err
}
cap, err := boolMatrix(outputs[2], len(inputIDs))
if err != nil {
return "", err
}
sbd, err := boolRow(outputs[3], len(inputIDs))
if err != nil {
return "", err
}
return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD)
}
func destroyValues(vals []ort.Value) {
for _, v := range vals {
if v != nil {
_ = v.Destroy()
}
}
}
func int64Row(v ort.Value, wantLen int) ([]int64, error) {
switch t := v.(type) {
case *ort.Tensor[int64]:
d := t.GetData()
if len(d) == wantLen {
return d, nil
}
if len(d) > wantLen {
return d[len(d)-wantLen:], nil
}
return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen)
case *ort.Tensor[int32]:
d := t.GetData()
if len(d) > wantLen {
d = d[len(d)-wantLen:]
}
out := make([]int64, wantLen)
for i := 0; i < wantLen && i < len(d); i++ {
out[i] = int64(d[i])
}
return out, nil
default:
return nil, fmt.Errorf("unexpected int output type %T", v)
}
}
func boolRow(v ort.Value, wantLen int) ([]bool, error) {
switch t := v.(type) {
case *ort.Tensor[bool]:
d := t.GetData()
if len(d) == wantLen {
return d, nil
}
if len(d) > wantLen {
return d[len(d)-wantLen:], nil
}
return nil, fmt.Errorf("bool output short")
case *ort.Tensor[float32]:
d := t.GetData()
out := make([]bool, wantLen)
for i := 0; i < wantLen && i < len(d); i++ {
out[i] = d[i] > 0.5
}
return out, nil
default:
return nil, fmt.Errorf("unexpected bool output type %T", v)
}
}
func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) {
switch t := v.(type) {
case *ort.Tensor[bool]:
shape := t.GetShape()
d := t.GetData()
if len(shape) == 3 {
_, sl, width := shape[0], shape[1], shape[2]
out := make([][]bool, sl)
for i := 0; i < int(sl); i++ {
row := make([]bool, width)
base := int(i) * int(width)
copy(row, d[base:base+int(width)])
out[i] = row
}
return out, nil
}
width := len(d) / seqLen
if width < 1 {
width = 1
}
out := make([][]bool, seqLen)
for i := 0; i < seqLen; i++ {
row := make([]bool, width)
base := i * width
if base+width <= len(d) {
copy(row, d[base:base+width])
}
out[i] = row
}
return out, nil
case *ort.Tensor[float32]:
shape := t.GetShape()
d := t.GetData()
if len(shape) == 3 {
_, sl, width := shape[0], shape[1], shape[2]
out := make([][]bool, sl)
for i := 0; i < int(sl); i++ {
row := make([]bool, width)
base := int(i) * int(width)
for j := 0; j < int(width); j++ {
row[j] = d[base+j] > 0.5
}
out[i] = row
}
return out, nil
}
width := len(d) / seqLen
if width < 1 {
width = 1
}
out := make([][]bool, seqLen)
for i := 0; i < seqLen; i++ {
row := make([]bool, width)
base := i * width
for j := 0; j < width && base+j < len(d); j++ {
row[j] = d[base+j] > 0.5
}
out[i] = row
}
return out, nil
default:
return nil, fmt.Errorf("unexpected cap output type %T", v)
}
}
func normalizeXLMSpaces(s string) string {
var b strings.Builder
prevSpace := false
for _, r := range s {
if unicode.IsSpace(r) {
if !prevSpace {
b.WriteRune(' ')
prevSpace = true
}
continue
}
prevSpace = false
b.WriteRune(r)
}
return b.String()
}

43
punctuation/xlm_config.go Normal file
View File

@ -0,0 +1,43 @@
package punctuation
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
type xlmModelConfig struct {
Languages []string `yaml:"languages"`
MaxLength int `yaml:"max_length"`
PreLabels []string `yaml:"pre_labels"`
PostLabels []string `yaml:"post_labels"`
NullToken string `yaml:"null_token"`
Acronym string `yaml:"acronym_token"`
}
func loadXLMConfig(path string) (xlmModelConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return xlmModelConfig{}, err
}
var cfg xlmModelConfig
if err := yaml.Unmarshal(data, &cfg); err != nil {
return xlmModelConfig{}, fmt.Errorf("parse xlm config %s: %w", path, err)
}
if cfg.MaxLength <= 0 {
cfg.MaxLength = 256
}
if cfg.NullToken == "" {
cfg.NullToken = "<NULL>"
}
if cfg.Acronym == "" {
cfg.Acronym = "<ACRONYM>"
}
return cfg, nil
}
func defaultXLMConfigPath(modelDir string) string {
return filepath.Join(modelDir, "config.yaml")
}

97
punctuation/xlm_decode.go Normal file
View File

@ -0,0 +1,97 @@
//go:build xlm
package punctuation
import (
"strings"
"go-whisper-api/punctuation/internal/spwrap"
)
func decodeXLMSegment(
sp *spwrap.Processor,
cfg xlmModelConfig,
inputIDs []int,
prePred, postPred []int64,
capPred [][]bool,
sbdPred []bool,
joinSentences bool,
) (string, error) {
var outputTexts []string
current := make([]string, 0, len(inputIDs)*4)
for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ {
piece, err := sp.IDToPiece(inputIDs[tokenIdx])
if err != nil {
return "", err
}
if strings.HasPrefix(piece, "▁") && len(current) > 0 {
current = append(current, " ")
}
preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx)
postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx)
if preLabel != cfg.NullToken {
current = append(current, preLabel)
}
charStart := 0
if strings.HasPrefix(piece, "▁") {
charStart = 1
}
runes := []rune(piece)
for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ {
ch := string(runes[tokenCharIdx])
if capAt(capPred, tokenIdx, tokenCharIdx) {
ch = strings.ToUpper(ch)
}
current = append(current, ch)
if postLabel == cfg.Acronym {
current = append(current, ".")
}
}
if postLabel != cfg.NullToken && postLabel != cfg.Acronym {
current = append(current, postLabel)
}
if sbdAt(sbdPred, tokenIdx) {
outputTexts = append(outputTexts, strings.Join(current, ""))
current = current[:0]
}
}
if len(current) > 0 {
outputTexts = append(outputTexts, strings.Join(current, ""))
}
if len(outputTexts) == 0 {
return "", nil
}
if joinSentences {
return strings.Join(outputTexts, " "), nil
}
return outputTexts[0], nil
}
func labelAt(labels []string, preds []int64, idx int) string {
if idx < 0 || idx >= len(preds) {
return labels[0]
}
pi := int(preds[idx])
if pi < 0 || pi >= len(labels) {
return labels[0]
}
return labels[pi]
}
func capAt(capPred [][]bool, tokenIdx, charIdx int) bool {
if tokenIdx < 0 || tokenIdx >= len(capPred) {
return false
}
row := capPred[tokenIdx]
if charIdx < 0 || charIdx >= len(row) {
return false
}
return row[charIdx]
}
func sbdAt(sbd []bool, tokenIdx int) bool {
if tokenIdx < 0 || tokenIdx >= len(sbd) {
return false
}
return sbd[tokenIdx]
}

13
punctuation/xlm_stub.go Normal file
View File

@ -0,0 +1,13 @@
//go:build !xlm
package punctuation
import (
"fmt"
"go-whisper-api/config"
)
func newXLM(cfg config.Punctuation) (Restorer, error) {
return nil, fmt.Errorf("punctuation engine %q requires build tag xlm (run: make build-xlm)", cfg.Engine)
}

132
transcode/aac_decode.go Normal file
View File

@ -0,0 +1,132 @@
package transcode
import (
"fmt"
"io"
"math"
"os"
"path/filepath"
"strings"
"github.com/olivier-w/climp-aac-decoder/aacfile"
)
func decodeAACPath(path, ext string) ([]float64, int, int, error) {
f, err := os.Open(path)
if err != nil {
return nil, 0, 0, err
}
defer f.Close()
st, err := f.Stat()
if err != nil {
return nil, 0, 0, err
}
// aacfile picks the parser from the *name* extension, not file content (.mp4 → .m4a).
name := aacOpenName(path, ext)
size := st.Size()
r, err := aacfile.Open(f, size, name)
if err != nil && isMP4SampleDeltaError(err) {
return decodeMP4AACRelaxed(f, size)
}
if err != nil {
return nil, 0, 0, err
}
defer r.Close()
sr := r.SampleRate()
ch := r.ChannelCount()
pcm, err := io.ReadAll(r)
if err != nil {
return nil, 0, 0, fmt.Errorf("read aac pcm: %w", err)
}
samples := pcm16LEToFloat(pcm, ch)
if ch > 1 {
samples = interleavedToMono(samples, ch)
ch = 1
}
return samples, sr, ch, nil
}
// aacContainerExt maps file extensions to a container name understood by aacfile.
func aacContainerExt(ext string) string {
switch strings.ToLower(ext) {
case ".mp4", ".m4v", ".mov", ".3gp", ".3g2":
return ".m4a"
case ".aac", ".m4a", ".m4b":
return ext
default:
return ".m4a"
}
}
func aacOpenName(path, ext string) string {
containerExt := aacContainerExt(ext)
if containerExt == "" {
containerExt = ".m4a"
}
base := filepath.Base(path)
if e := strings.ToLower(filepath.Ext(base)); e == containerExt {
return base
}
stem := strings.TrimSuffix(base, filepath.Ext(base))
if stem == "" || stem == base {
stem = "audio"
}
return stem + containerExt
}
func pcm16LEToFloat(pcm []byte, channels int) []float64 {
if channels <= 0 {
channels = 1
}
frameBytes := 2 * channels
nFrames := len(pcm) / frameBytes
out := make([]float64, nFrames*channels)
for i := 0; i < nFrames*channels; i++ {
off := i * 2
if off+1 >= len(pcm) {
break
}
v := int16(pcm[off]) | int16(pcm[off+1])<<8
out[i] = float64(v) / 32768.0
}
return out
}
func interleavedToMono(samples []float64, channels int) []float64 {
if channels <= 1 {
return samples
}
nFrames := len(samples) / channels
out := make([]float64, nFrames)
for i := 0; i < nFrames; i++ {
var sum float64
for c := 0; c < channels; c++ {
sum += samples[i*channels+c]
}
out[i] = sum / float64(channels)
}
return out
}
func resampleLinear(samples []float64, fromRate, toRate int) []float64 {
if fromRate <= 0 || toRate <= 0 || fromRate == toRate || len(samples) == 0 {
return samples
}
ratio := float64(fromRate) / float64(toRate)
outLen := int(math.Round(float64(len(samples)) / ratio))
if outLen < 1 {
outLen = 1
}
out := make([]float64, outLen)
for i := 0; i < outLen; i++ {
src := float64(i) * ratio
j := int(src)
if j >= len(samples)-1 {
out[i] = samples[len(samples)-1]
continue
}
frac := src - float64(j)
out[i] = samples[j]*(1-frac) + samples[j+1]*frac
}
return out
}

View File

@ -0,0 +1,31 @@
package transcode
import "testing"
func TestAacOpenName_mp4(t *testing.T) {
got := aacOpenName("/tmp/cache/input.mp4", ".mp4")
if got != "input.m4a" {
t.Fatalf("got %q want input.m4a", got)
}
}
func TestAacOpenName_noExt(t *testing.T) {
got := aacOpenName("/tmp/input", ".m4a")
if got != "audio.m4a" {
t.Fatalf("got %q", got)
}
}
func TestAacContainerExt(t *testing.T) {
cases := map[string]string{
".mp4": ".m4a",
".mov": ".m4a",
".aac": ".aac",
".m4a": ".m4a",
}
for in, want := range cases {
if got := aacContainerExt(in); got != want {
t.Fatalf("%s: got %q want %q", in, got, want)
}
}
}

127
transcode/decode.go Normal file
View File

@ -0,0 +1,127 @@
package transcode
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/gopxl/beep"
"github.com/gopxl/beep/flac"
"github.com/gopxl/beep/mp3"
beepwav "github.com/gopxl/beep/wav"
)
var probeFormats = []string{
".wav", ".wave", ".mp3", ".flac", ".ogg", ".opus",
".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".3gp", ".3g2", ".aac",
}
func supportedFormatsMessage() string {
return strings.Join(probeFormats, ", ")
}
func resolveFormat(path string) (string, error) {
ext := strings.ToLower(filepath.Ext(path))
if ext != "" {
return ext, nil
}
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
ext = sniffFormat(f)
if ext == "" {
return "", fmt.Errorf("could not detect audio format (supported: %s)", supportedFormatsMessage())
}
return ext, nil
}
func openDecoder(path string) (beep.Streamer, beep.Format, io.Closer, error) {
ext, err := resolveFormat(path)
if err != nil {
return nil, beep.Format{}, nil, err
}
streamer, format, closer, err := decodeByExt(path, ext)
if err == nil {
return streamer, format, closer, nil
}
for _, try := range probeFormats {
if try == ext {
continue
}
streamer, format, closer, tryErr := decodeByExt(path, try)
if tryErr == nil {
return streamer, format, closer, nil
}
}
return nil, beep.Format{}, nil, fmt.Errorf("unsupported audio format %q (supported: %s): %w", ext, supportedFormatsMessage(), err)
}
func decodeByExt(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) {
switch ext {
case ".wav", ".wave":
return decodeBeepFile(path, ext)
case ".mp3":
return decodeBeepFile(path, ext)
case ".flac":
return decodeBeepFile(path, ext)
case ".ogg", ".opus":
return decodeOggFile(path)
case ".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".aac":
return decodeAACAsStreamer(path, ext)
case ".webm":
return nil, beep.Format{}, nil, fmt.Errorf("webm is not supported yet")
default:
return nil, beep.Format{}, nil, fmt.Errorf("unsupported extension %q", ext)
}
}
func decodeBeepFile(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) {
f, err := os.Open(path)
if err != nil {
return nil, beep.Format{}, nil, err
}
var (
streamer beep.StreamSeekCloser
format beep.Format
decErr error
)
switch ext {
case ".wav", ".wave":
streamer, format, decErr = beepwav.Decode(f)
case ".mp3":
streamer, format, decErr = mp3.Decode(f)
case ".flac":
streamer, format, decErr = flac.Decode(f)
default:
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("internal: beep decode for %q", ext)
}
if decErr != nil {
f.Close()
return nil, beep.Format{}, nil, decErr
}
return streamer, format, f, nil
}
func decodeAACAsStreamer(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) {
samples, sr, ch, err := decodeAACPath(path, ext)
if err != nil {
return nil, beep.Format{}, nil, err
}
if ch <= 0 {
ch = 1
}
return newSamplesStreamer(samples, sr), beep.Format{
SampleRate: beep.SampleRate(sr),
NumChannels: ch,
Precision: 2,
}, noopCloser{}, nil
}
type noopCloser struct{}
func (noopCloser) Close() error { return nil }

119
transcode/engine.go Normal file
View File

@ -0,0 +1,119 @@
package transcode
import (
"context"
"fmt"
"os"
)
// Engine converts input audio to PCM WAV using pure Go decoders (no ffmpeg).
type Engine struct{}
// NewEngine creates a transcoder. The ffmpegPath argument is ignored (kept for config compatibility).
func NewEngine(_ string) *Engine {
return &Engine{}
}
func (e *Engine) Available() error {
return nil
}
func (e *Engine) Transcode(ctx context.Context, src, dst string, opt Options) error {
if err := opt.Validate(); err != nil {
return err
}
spec, err := ResolveFormat(opt.Format)
if err != nil {
return err
}
dst, err = OutputPath(dst, spec.ID)
if err != nil {
return err
}
streamer, format, closer, err := openDecoder(src)
if err != nil {
return err
}
defer closer.Close()
s, format := buildPipeline(streamer, format, opt)
samples, err := drainSamples(ctx, s)
if err != nil {
return err
}
ch := format.NumChannels
if opt.Channels > 0 {
ch = opt.Channels
}
sr := int(format.SampleRate)
if opt.SampleRate > 0 {
sr = opt.SampleRate
}
if err := writePCM16WAV(dst, sr, ch, samples); err != nil {
return fmt.Errorf("write wav: %w", err)
}
return nil
}
func Transcode(ctx context.Context, src, dst string, opt Options) error {
return NewEngine("").Transcode(ctx, src, dst, opt)
}
func ToWhisperWAV(ctx context.Context, src, dst string) error {
return Transcode(ctx, src, dst, WhisperOptions())
}
// SupportedInputFormats lists file extensions decoded without external tools.
func SupportedInputFormats() []string {
return append([]string(nil), probeFormats...)
}
func (e *Engine) Probe(ctx context.Context, path string) (MediaInfo, error) {
_ = ctx
ext, err := resolveFormat(path)
if err != nil {
return MediaInfo{}, err
}
streamer, format, closer, err := openDecoder(path)
if err != nil {
return MediaInfo{}, err
}
defer closer.Close()
info := MediaInfo{
Format: ext,
Streams: []StreamInfo{{
Index: 0,
Codec: extTrim(ext),
Type: "audio",
SampleRate: int(format.SampleRate),
Channels: format.NumChannels,
}},
}
if st, err := os.Stat(path); err == nil {
info.BitRate = st.Size() * 8
}
_ = streamer
return info, nil
}
func extTrim(ext string) string {
if len(ext) > 0 && ext[0] == '.' {
return ext[1:]
}
return ext
}
// MediaInfo describes decoded input (for optional diagnostics).
type MediaInfo struct {
Format string `json:"format"`
Duration float64 `json:"duration_seconds"`
BitRate int64 `json:"bit_rate"`
Streams []StreamInfo `json:"streams"`
}
type StreamInfo struct {
Index int `json:"index"`
Codec string `json:"codec"`
Type string `json:"type"`
SampleRate int `json:"sample_rate,omitempty"`
Channels int `json:"channels,omitempty"`
}

85
transcode/engine_test.go Normal file
View File

@ -0,0 +1,85 @@
package transcode
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/go-audio/wav"
)
func samplePath(t *testing.T, name string) string {
t.Helper()
p := filepath.Join("..", "third_party", "whisper.cpp", "samples", name)
if _, err := os.Stat(p); err != nil {
t.Skip("sample not found:", p)
}
return p
}
func TestToWhisperWAV_mp3(t *testing.T) {
src := samplePath(t, "jfk.mp3")
dst := filepath.Join(t.TempDir(), "out.wav")
if err := ToWhisperWAV(context.Background(), src, dst); err != nil {
t.Fatal(err)
}
assertWhisperWAV(t, dst)
}
func TestToWhisperWAV_wav(t *testing.T) {
src := samplePath(t, "jfk.wav")
dst := filepath.Join(t.TempDir(), "out.wav")
if err := ToWhisperWAV(context.Background(), src, dst); err != nil {
t.Fatal(err)
}
assertWhisperWAV(t, dst)
}
func TestResolveFormat_noExtension_mp3(t *testing.T) {
src := samplePath(t, "jfk.mp3")
data, err := os.ReadFile(src)
if err != nil {
t.Fatal(err)
}
path := filepath.Join(t.TempDir(), "upload")
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatal(err)
}
ext, err := resolveFormat(path)
if err != nil || ext != ".mp3" {
t.Fatalf("ext=%q err=%v", ext, err)
}
}
func TestEngine_Available(t *testing.T) {
if err := NewEngine("").Available(); err != nil {
t.Fatal(err)
}
}
func assertWhisperWAV(t *testing.T, path string) {
t.Helper()
f, err := os.Open(path)
if err != nil {
t.Fatal(err)
}
defer f.Close()
dec := wav.NewDecoder(f)
if !dec.IsValidFile() {
t.Fatal("invalid wav")
}
buf, err := dec.FullPCMBuffer()
if err != nil {
t.Fatal(err)
}
if dec.SampleRate != 16000 {
t.Fatalf("sample rate %d", dec.SampleRate)
}
if dec.NumChans != 1 {
t.Fatalf("channels %d", dec.NumChans)
}
if len(buf.Data) == 0 {
t.Fatal("empty audio")
}
}

47
transcode/format.go Normal file
View File

@ -0,0 +1,47 @@
package transcode
import (
"fmt"
"path/filepath"
"strings"
)
const (
FormatWAV = "wav"
FormatPCM = "pcm"
)
type FormatSpec struct {
ID string
Extension string
Codec string
}
var formats = map[string]FormatSpec{
FormatWAV: {ID: FormatWAV, Extension: ".wav", Codec: "pcm_s16le"},
FormatPCM: {ID: FormatPCM, Extension: ".wav", Codec: "pcm_s16le"},
}
func ResolveFormat(name string) (FormatSpec, error) {
name = strings.ToLower(strings.TrimSpace(name))
if name == "" {
name = FormatWAV
}
spec, ok := formats[name]
if !ok {
return FormatSpec{}, fmt.Errorf("unsupported format %q (supported: wav)", name)
}
return spec, nil
}
func OutputPath(dst, format string) (string, error) {
spec, err := ResolveFormat(format)
if err != nil {
return "", err
}
ext := filepath.Ext(dst)
if ext == "" {
return dst + spec.Extension, nil
}
return dst, nil
}

288
transcode/mp4_aac_decode.go Normal file
View File

@ -0,0 +1,288 @@
package transcode
import (
"errors"
"fmt"
"io"
"github.com/Eyevinn/mp4ff/mp4"
"github.com/olivier-w/climp-aac-decoder/aacfile"
aacdec "github.com/skrashevich/go-aac/pkg/decoder"
)
type mp4AACSample struct {
offset int64
size int
}
func isMP4SampleDeltaError(err error) bool {
var uf *aacfile.UnsupportedFeatureError
if !errors.As(err, &uf) {
return false
}
return uf.Feature == "MP4 sample delta" || uf.Feature == "MP4 sample delta layout"
}
// decodeMP4AACRelaxed demuxes MP4/M4A with mp4ff (ignoring stts sample deltas) and
// decodes raw AAC frames with go-aac. Used when climp-aac-decoder rejects stts
// entries whose delta is not exactly 1024 (common in ffmpeg/phone muxers).
func decodeMP4AACRelaxed(r io.ReaderAt, size int64) ([]float64, int, int, error) {
asc, samples, leading, err := demuxMP4AAC(r, size)
if err != nil {
return nil, 0, 0, err
}
dec := aacdec.New()
if err := dec.SetASC(asc); err != nil {
return nil, 0, 0, fmt.Errorf("aac config: %w", err)
}
ch := dec.Config.ChanConfig
if ch < 1 {
return nil, 0, 0, fmt.Errorf("aac config: invalid channel count %d", ch)
}
sr := dec.Config.SampleRate
if sr <= 0 {
return nil, 0, 0, fmt.Errorf("aac config: invalid sample rate %d", sr)
}
maxSize := 0
for _, s := range samples {
if s.size > maxSize {
maxSize = s.size
}
}
buf := make([]byte, maxSize)
var pcm []float32
for i, s := range samples {
if cap(buf) < s.size {
buf = make([]byte, s.size)
}
frame := buf[:s.size]
if _, err := r.ReadAt(frame, s.offset); err != nil {
return nil, 0, 0, fmt.Errorf("read mp4 aac sample %d: %w", i, err)
}
out, err := dec.DecodeFrame(frame)
if err != nil {
return nil, 0, 0, fmt.Errorf("decode mp4 aac sample %d: %w", i, err)
}
pcm = append(pcm, out...)
}
skipSamples := leading * ch
if skipSamples > len(pcm) {
skipSamples = len(pcm)
}
pcm = pcm[skipSamples:]
samplesF64 := make([]float64, len(pcm))
for i, v := range pcm {
samplesF64[i] = float64(v)
}
if ch > 1 {
samplesF64 = float32InterleavedToMono(samplesF64, ch)
ch = 1
}
return samplesF64, sr, ch, nil
}
func float32InterleavedToMono(samples []float64, channels int) []float64 {
if channels <= 1 {
return samples
}
nFrames := len(samples) / channels
out := make([]float64, nFrames)
for i := 0; i < nFrames; i++ {
var sum float64
for c := 0; c < channels; c++ {
sum += samples[i*channels+c]
}
out[i] = sum / float64(channels)
}
return out
}
func demuxMP4AAC(r io.ReaderAt, size int64) (asc []byte, samples []mp4AACSample, leading int, err error) {
file, err := mp4.DecodeFile(io.NewSectionReader(r, 0, size), mp4.WithDecodeMode(mp4.DecModeLazyMdat))
if err != nil {
return nil, nil, 0, fmt.Errorf("mp4 decode: %w", err)
}
if file.IsFragmented() {
return nil, nil, 0, fmt.Errorf("fragmented mp4 is not supported")
}
if file.Moov == nil {
return nil, nil, 0, fmt.Errorf("mp4: missing moov")
}
var audioTracks []*mp4.TrakBox
for _, trak := range file.Moov.Traks {
if trak != nil && trak.Mdia != nil && trak.Mdia.Hdlr != nil && trak.Mdia.Hdlr.HandlerType == "soun" {
audioTracks = append(audioTracks, trak)
}
}
if len(audioTracks) != 1 {
return nil, nil, 0, fmt.Errorf("mp4: expected one audio track, found %d", len(audioTracks))
}
trak := audioTracks[0]
if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil || trak.Mdia.Minf.Stbl.Stsd == nil {
return nil, nil, 0, fmt.Errorf("mp4: incomplete audio track")
}
stsd := trak.Mdia.Minf.Stbl.Stsd
if len(stsd.Children) != 1 {
return nil, nil, 0, fmt.Errorf("mp4: multiple sample descriptions")
}
if stsd.Enca != nil {
return nil, nil, 0, fmt.Errorf("mp4: encrypted audio")
}
sampleEntry := stsd.Mp4a
if sampleEntry == nil {
return nil, nil, 0, fmt.Errorf("mp4: unsupported audio sample entry %s", stsd.Children[0].Type())
}
if sampleEntry.Sinf != nil {
return nil, nil, 0, fmt.Errorf("mp4: encrypted audio")
}
if sampleEntry.Esds == nil ||
sampleEntry.Esds.DecConfigDescriptor == nil ||
sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo == nil ||
len(sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig) == 0 {
return nil, nil, 0, fmt.Errorf("mp4: missing AudioSpecificConfig")
}
asc = append([]byte(nil), sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig...)
leading, _ = mp4LeadingTrimRelaxed(trak)
samples, err = buildMP4AACSamples(trak, size)
if err != nil {
return nil, nil, 0, err
}
if len(samples) == 0 {
return nil, nil, 0, fmt.Errorf("mp4: no audio samples")
}
return asc, samples, leading, nil
}
func mp4LeadingTrimRelaxed(trak *mp4.TrakBox) (int, error) {
if trak.Edts == nil || len(trak.Edts.Elst) == 0 {
return 0, nil
}
if len(trak.Edts.Elst) != 1 || len(trak.Edts.Elst[0].Entries) != 1 {
return 0, nil
}
entry := trak.Edts.Elst[0].Entries[0]
if entry.MediaRateInteger != 1 || entry.MediaRateFraction != 0 {
return 0, nil
}
if entry.MediaTime < 0 {
return 0, nil
}
return int(entry.MediaTime), nil
}
func buildMP4AACSamples(trak *mp4.TrakBox, size int64) ([]mp4AACSample, error) {
if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil {
return nil, fmt.Errorf("mp4: incomplete sample table")
}
stbl := trak.Mdia.Minf.Stbl
if stbl.Stsc == nil || stbl.Stsz == nil {
return nil, fmt.Errorf("mp4: incomplete sample table")
}
if stbl.Stco == nil && stbl.Co64 == nil {
return nil, fmt.Errorf("mp4: missing chunk offsets")
}
if len(stbl.Stsc.Entries) == 0 {
return nil, fmt.Errorf("mp4: empty chunk map")
}
totalSamples := int(trak.GetNrSamples())
if totalSamples <= 0 {
return nil, fmt.Errorf("mp4: empty sample table")
}
sampleSizes, err := mp4AACSampleSizes(stbl.Stsz, totalSamples)
if err != nil {
return nil, err
}
chunkOffsets, err := mp4AACChunkOffsets(stbl)
if err != nil {
return nil, err
}
out := make([]mp4AACSample, 0, totalSamples)
sampleIndex := 0
entryIndex := 0
entry := stbl.Stsc.Entries[entryIndex]
for chunkIndex := 0; chunkIndex < len(chunkOffsets) && sampleIndex < totalSamples; chunkIndex++ {
chunkNr := uint32(chunkIndex + 1)
for entryIndex+1 < len(stbl.Stsc.Entries) && chunkNr >= stbl.Stsc.Entries[entryIndex+1].FirstChunk {
entryIndex++
entry = stbl.Stsc.Entries[entryIndex]
}
if entry.SamplesPerChunk == 0 {
return nil, fmt.Errorf("mp4: zero samples per chunk")
}
offset := chunkOffsets[chunkIndex]
samplesPerChunk := int(entry.SamplesPerChunk)
for i := 0; i < samplesPerChunk && sampleIndex < totalSamples; i++ {
sampleSize := sampleSizes[sampleIndex]
end := offset + int64(sampleSize)
if offset < 0 || end < offset || end > size {
return nil, fmt.Errorf("mp4: invalid sample bounds at sample %d", sampleIndex+1)
}
out = append(out, mp4AACSample{offset: offset, size: sampleSize})
offset = end
sampleIndex++
}
}
if sampleIndex != totalSamples {
return nil, fmt.Errorf("mp4: sample table mismatch")
}
return out, nil
}
func mp4AACSampleSizes(stsz *mp4.StszBox, totalSamples int) ([]int, error) {
if stsz == nil {
return nil, fmt.Errorf("mp4: missing sample sizes")
}
if int(stsz.GetNrSamples()) != totalSamples {
return nil, fmt.Errorf("mp4: sample size count mismatch")
}
sizes := make([]int, totalSamples)
if stsz.SampleUniformSize != 0 {
sz := int(stsz.SampleUniformSize)
for i := range sizes {
sizes[i] = sz
}
return sizes, nil
}
if len(stsz.SampleSize) != totalSamples {
return nil, fmt.Errorf("mp4: sample size table mismatch")
}
for i, sz := range stsz.SampleSize {
sizes[i] = int(sz)
}
return sizes, nil
}
func mp4AACChunkOffsets(stbl *mp4.StblBox) ([]int64, error) {
switch {
case stbl == nil:
return nil, fmt.Errorf("mp4: incomplete sample table")
case stbl.Stco != nil:
offsets := make([]int64, len(stbl.Stco.ChunkOffset))
for i, off := range stbl.Stco.ChunkOffset {
offsets[i] = int64(off)
}
return offsets, nil
case stbl.Co64 != nil:
offsets := make([]int64, len(stbl.Co64.ChunkOffset))
for i, off := range stbl.Co64.ChunkOffset {
if off > uint64(^uint64(0)>>1) {
return nil, fmt.Errorf("mp4: invalid chunk offset")
}
offsets[i] = int64(off)
}
return offsets, nil
default:
return nil, fmt.Errorf("mp4: missing chunk offsets")
}
}

154
transcode/ogg_decode.go Normal file
View File

@ -0,0 +1,154 @@
package transcode
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/gopxl/beep"
"github.com/gopxl/beep/vorbis"
"github.com/pion/opus"
"github.com/pion/opus/pkg/oggreader"
)
func decodeOggFile(path string) (beep.Streamer, beep.Format, io.Closer, error) {
switch sniffOggCodec(path) {
case "opus":
return decodeOggOpus(path)
case "vorbis":
return decodeOggVorbis(path)
default:
streamer, format, closer, err := decodeOggVorbis(path)
if err == nil {
return streamer, format, closer, nil
}
if isVorbisInvalidHeader(err) {
return decodeOggOpus(path)
}
return nil, beep.Format{}, nil, err
}
}
func sniffOggCodec(path string) string {
f, err := os.Open(path)
if err != nil {
return ""
}
defer f.Close()
buf := make([]byte, 8192)
n, _ := io.ReadFull(f, buf)
buf = buf[:n]
if len(buf) < 4 || !bytes.HasPrefix(buf, []byte("OggS")) {
return ""
}
if bytes.Contains(buf, []byte("OpusHead")) {
return "opus"
}
// Vorbis ID packet: 0x01 + "vorbis"
if bytes.Contains(buf, []byte{0x01, 'v', 'o', 'r', 'b', 'i', 's'}) {
return "vorbis"
}
return ""
}
func isVorbisInvalidHeader(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "invalid header") || strings.Contains(msg, "vorbis:")
}
func decodeOggVorbis(path string) (beep.Streamer, beep.Format, io.Closer, error) {
f, err := os.Open(path)
if err != nil {
return nil, beep.Format{}, nil, err
}
streamer, format, decErr := vorbis.Decode(f)
if decErr != nil {
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("ogg/vorbis: %w", decErr)
}
return streamer, format, f, nil
}
func decodeOggOpus(path string) (beep.Streamer, beep.Format, io.Closer, error) {
f, err := os.Open(path)
if err != nil {
return nil, beep.Format{}, nil, err
}
ogg, header, err := oggreader.NewWith(f)
if err != nil {
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: %w", err)
}
sr := int(header.SampleRate)
if sr <= 0 {
sr = 48000
}
ch := int(header.Channels)
if ch <= 0 {
ch = 1
}
dec, err := opus.NewDecoderWithOutput(sr, ch)
if err != nil {
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decoder: %w", err)
}
const maxFrameSamples = 5760
pcmBuf := make([]float32, maxFrameSamples*ch)
var samples []float64
for {
pkt, _, err := ogg.ParseNextPacket()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus read: %w", err)
}
if len(pkt) == 0 || bytes.HasPrefix(pkt, []byte("OpusHead")) || bytes.HasPrefix(pkt, []byte("OpusTags")) {
continue
}
n, err := dec.DecodeToFloat32(pkt, pcmBuf)
if err != nil {
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decode: %w", err)
}
if n <= 0 {
continue
}
total := n * ch
if total > len(pcmBuf) {
total = len(pcmBuf)
}
for i := 0; i < total; i++ {
samples = append(samples, float64(pcmBuf[i]))
}
}
if len(samples) == 0 {
f.Close()
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: no audio samples")
}
outCh := ch
if outCh > 1 {
samples = interleavedToMono(samples, outCh)
outCh = 1
}
return newSamplesStreamer(samples, sr), beep.Format{
SampleRate: beep.SampleRate(sr),
NumChannels: outCh,
Precision: 2,
}, f, nil
}

View File

@ -0,0 +1,52 @@
package transcode
import (
"os"
"path/filepath"
"runtime"
"testing"
)
func TestSniffOggCodec_opus(t *testing.T) {
path := pionTinyOggPath(t)
if got := sniffOggCodec(path); got != "opus" {
t.Fatalf("sniffOggCodec() = %q want opus", got)
}
}
func TestDecodeOggOpus_pionTiny(t *testing.T) {
path := pionTinyOggPath(t)
streamer, format, closer, err := decodeOggOpus(path)
if err != nil {
t.Fatal(err)
}
defer closer.Close()
if format.SampleRate == 0 {
t.Fatal("zero sample rate")
}
buf := make([][2]float64, 4096)
n, ok := streamer.Stream(buf)
if !ok || n == 0 {
t.Fatal("expected pcm samples")
}
}
func pionTinyOggPath(t *testing.T) string {
t.Helper()
_, file, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("runtime.Caller failed")
}
modRoot := filepath.Clean(filepath.Join(filepath.Dir(file), ".."))
cache := os.Getenv("GOMODCACHE")
if cache == "" {
home, _ := os.UserHomeDir()
cache = filepath.Join(home, "go", "pkg", "mod")
}
path := filepath.Join(cache, "github.com/pion/opus@v0.0.0-20260601214817-71d58474cec8/testdata/tiny.ogg")
if _, err := os.Stat(path); err != nil {
_ = modRoot
t.Skipf("pion testdata not in module cache: %v", err)
}
return path
}

43
transcode/options.go Normal file
View File

@ -0,0 +1,43 @@
package transcode
import "fmt"
type Options struct {
Format string
SampleRate int
Channels int
Codec string
}
func WhisperOptions() Options {
return Options{
Format: FormatWAV,
SampleRate: 16000,
Channels: 1,
Codec: "pcm_s16le",
}
}
func (o *Options) ApplyDefaults() error {
spec, err := ResolveFormat(o.Format)
if err != nil {
return err
}
if o.Codec == "" {
o.Codec = spec.Codec
}
return nil
}
func (o *Options) Validate() error {
if err := o.ApplyDefaults(); err != nil {
return err
}
if o.SampleRate < 0 || o.Channels < 0 {
return fmt.Errorf("sample_rate and channels must be >= 0")
}
if o.Channels > 8 {
return fmt.Errorf("channels must be <= 8")
}
return nil
}

20
transcode/options_test.go Normal file
View File

@ -0,0 +1,20 @@
package transcode
import "testing"
func TestWhisperOptions(t *testing.T) {
o := WhisperOptions()
if err := o.Validate(); err != nil {
t.Fatal(err)
}
if o.SampleRate != 16000 || o.Channels != 1 {
t.Fatalf("unexpected whisper opts: %+v", o)
}
}
func TestResolveFormat_unknown(t *testing.T) {
_, err := ResolveFormat("xyz")
if err == nil {
t.Fatal("expected error")
}
}

View File

@ -0,0 +1,38 @@
package transcode
import "github.com/gopxl/beep"
type samplesStreamer struct {
samples []float64
pos int
sampleRate beep.SampleRate
}
func newSamplesStreamer(samples []float64, sampleRate int) *samplesStreamer {
return &samplesStreamer{
samples: samples,
sampleRate: beep.SampleRate(sampleRate),
}
}
func (s *samplesStreamer) Stream(buf [][2]float64) (int, bool) {
if s.pos >= len(s.samples) {
return 0, false
}
n := 0
for i := range buf {
if s.pos >= len(s.samples) {
return n, n > 0
}
v := s.samples[s.pos]
buf[i][0] = v
buf[i][1] = v
s.pos++
n++
}
return n, true
}
func (s *samplesStreamer) Err() error {
return nil
}

44
transcode/sniff.go Normal file
View File

@ -0,0 +1,44 @@
package transcode
import (
"bytes"
"io"
)
// sniffFormat detects container/codec from file header when the path has no extension.
func sniffFormat(r io.Reader) string {
head := make([]byte, 32)
n, _ := io.ReadFull(r, head)
head = head[:n]
if len(head) < 4 {
return ""
}
if bytes.HasPrefix(head, []byte("RIFF")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("WAVE")) {
return ".wav"
}
if bytes.HasPrefix(head, []byte("ID3")) {
return ".mp3"
}
if len(head) >= 2 && head[0] == 0xFF && (head[1]&0xE0) == 0xE0 {
return ".mp3"
}
if bytes.HasPrefix(head, []byte("fLaC")) {
return ".flac"
}
if bytes.HasPrefix(head, []byte("OggS")) {
return ".ogg"
}
if len(head) >= 8 && bytes.Equal(head[4:8], []byte("ftyp")) {
return ".m4a"
}
if bytes.HasPrefix(head, []byte{0xFF, 0xF1}) || bytes.HasPrefix(head, []byte{0xFF, 0xF9}) {
return ".aac"
}
if bytes.HasPrefix(head, []byte("FORM")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("AIFF")) {
return ".aiff"
}
if bytes.HasPrefix(head, []byte{0x1A, 0x45, 0xDF, 0xA3}) {
return ".webm"
}
return ""
}

42
transcode/stream.go Normal file
View File

@ -0,0 +1,42 @@
package transcode
import (
"context"
"github.com/gopxl/beep"
"github.com/gopxl/beep/effects"
)
func buildPipeline(streamer beep.Streamer, format beep.Format, opt Options) (beep.Streamer, beep.Format) {
out := streamer
if opt.SampleRate > 0 && format.SampleRate != beep.SampleRate(opt.SampleRate) {
out = beep.Resample(4, format.SampleRate, beep.SampleRate(opt.SampleRate), out)
format.SampleRate = beep.SampleRate(opt.SampleRate)
}
if opt.Channels == 1 {
out = effects.Mono(out)
format.NumChannels = 1
}
return out, format
}
func drainSamples(ctx context.Context, s beep.Streamer) ([]float64, error) {
buf := make([][2]float64, 4096)
var samples []float64
for {
if err := ctx.Err(); err != nil {
return nil, err
}
n, ok := s.Stream(buf)
if !ok {
if err := s.Err(); err != nil {
return nil, err
}
break
}
for i := 0; i < n; i++ {
samples = append(samples, buf[i][0])
}
}
return samples, nil
}

53
transcode/wav_out.go Normal file
View File

@ -0,0 +1,53 @@
package transcode
import (
"os"
"github.com/go-audio/audio"
"github.com/go-audio/wav"
)
func writePCM16WAV(path string, sampleRate int, channels int, samples []float64) error {
if channels <= 0 {
channels = 1
}
if err := os.MkdirAll(dirOf(path), 0o755); err != nil && dirOf(path) != "." {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
enc := wav.NewEncoder(f, sampleRate, 16, channels, 1)
data := make([]int, len(samples))
for i, s := range samples {
data[i] = floatToInt16(s)
}
if err := enc.Write(&audio.IntBuffer{
Format: &audio.Format{SampleRate: sampleRate, NumChannels: channels},
Data: data,
}); err != nil {
return err
}
return enc.Close()
}
func floatToInt16(f float64) int {
if f > 1 {
f = 1
}
if f < -1 {
f = -1
}
return int(f * 32767)
}
func dirOf(path string) string {
for i := len(path) - 1; i >= 0; i-- {
if path[i] == '/' || path[i] == '\\' {
return path[:i]
}
}
return "."
}

11
whisper/audio.go Normal file
View File

@ -0,0 +1,11 @@
package whisper
import (
"context"
"go-whisper-api/transcode"
)
func AudioToWav(src, dst string) error {
return transcode.ToWhisperWAV(context.Background(), src, dst)
}

60
whisper/audio_load.go Normal file
View File

@ -0,0 +1,60 @@
package whisper
import (
"fmt"
"os"
"path/filepath"
"go-whisper-api/config"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
)
// LoadPCM16Mono reads 16 kHz mono WAV into float32 samples (for diarization).
func LoadPCM16Mono(path string) ([]float32, error) {
return loadPCM16Mono(path)
}
func loadPCM16Mono(path string) ([]float32, error) {
fh, err := os.Open(path)
if err != nil {
return nil, err
}
defer fh.Close()
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
if err != nil {
return nil, err
}
if dec.SampleRate != whisper.SampleRate {
return nil, fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
}
if dec.NumChans != 1 {
return nil, fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
}
return buf.AsFloat32Buffer().Data, nil
}
func prepareAudioPCM(sourcePath string) (data []float32, cleanup func(), err error) {
cleanup = func() {}
if data, err = loadPCM16Mono(sourcePath); err == nil {
return data, cleanup, nil
}
dir, err := config.MkdirTemp("go-whisper-api-whisper-*")
if err != nil {
return nil, nil, err
}
cleanup = func() { os.RemoveAll(dir) }
converted := filepath.Join(dir, "converted.wav")
if err := AudioToWav(sourcePath, converted); err != nil {
cleanup()
return nil, nil, err
}
data, err = loadPCM16Mono(converted)
if err != nil {
cleanup()
return nil, nil, err
}
return data, cleanup, nil
}

154
whisper/format.go Normal file
View File

@ -0,0 +1,154 @@
package whisper
import (
"fmt"
"strings"
"time"
wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
// Turn is a speaker-active time range from diarization (seconds).
type Turn struct {
Start float32
End float32
Speaker int
}
// FormatOptions controls joining Whisper segments into one string with newlines.
type FormatOptions struct {
PauseGap time.Duration
SpeakerLabel string
UseSpeakers bool
}
// FormatSegments joins segment texts with \n on long pauses and optional speaker labels.
func FormatSegments(segments []wpkg.Segment, turns []Turn, opts FormatOptions) string {
if len(segments) == 0 {
return ""
}
if opts.PauseGap <= 0 {
opts.PauseGap = 1500 * time.Millisecond
}
label := strings.TrimSpace(opts.SpeakerLabel)
if label == "" {
label = "Спикер"
}
lines := make([]segmentLine, len(segments))
for i, seg := range segments {
lines[i] = segmentLine{
Text: strings.TrimSpace(seg.Text),
Start: seg.Start,
End: seg.End,
Speaker: -1,
}
}
if opts.UseSpeakers && len(turns) > 0 {
assignSpeakers(lines, turns)
}
var b strings.Builder
prevSpeaker := -2
prevIdx := -1
for i, line := range lines {
if line.Text == "" {
continue
}
if b.Len() > 0 && prevIdx >= 0 {
speakerBreak := opts.UseSpeakers && line.Speaker >= 0 && line.Speaker != prevSpeaker
pauseBreak := line.Start-lines[prevIdx].End >= opts.PauseGap
switch {
case speakerBreak:
b.WriteString("\n\n")
fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1)
case pauseBreak:
b.WriteString("\n")
default:
if !strings.HasSuffix(b.String(), " ") && !strings.HasSuffix(b.String(), "\n") {
b.WriteByte(' ')
}
}
} else if opts.UseSpeakers && line.Speaker >= 0 {
fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1)
}
b.WriteString(line.Text)
if line.Speaker >= 0 {
prevSpeaker = line.Speaker
}
prevIdx = i
}
return strings.TrimSpace(b.String())
}
type segmentLine struct {
Text string
Start time.Duration
End time.Duration
Speaker int
}
func assignSpeakers(lines []segmentLine, turns []Turn) {
for i := range lines {
mid := lines[i].Start + (lines[i].End-lines[i].Start)/2
lines[i].Speaker = speakerAt(mid, turns)
}
}
func speakerAt(t time.Duration, turns []Turn) int {
sec := float32(t.Seconds())
bestSpeaker := -1
bestOverlap := float32(0)
for _, tr := range turns {
if sec >= tr.Start && sec < tr.End {
return tr.Speaker
}
overlap := intervalOverlap(sec, sec, tr.Start, tr.End)
if overlap > bestOverlap {
bestOverlap = overlap
bestSpeaker = tr.Speaker
}
}
return bestSpeaker
}
func intervalOverlap(a0, a1, b0, b1 float32) float32 {
start := max32(a0, b0)
end := min32(a1, b1)
if end <= start {
return 0
}
return end - start
}
func max32(a, b float32) float32 {
if a > b {
return a
}
return b
}
func min32(a, b float32) float32 {
if a < b {
return a
}
return b
}
// PunctuateSegments runs punctuation on each segment separately (preserves line breaks).
func PunctuateSegments(segments []wpkg.Segment, restore func(text string) (string, error)) ([]wpkg.Segment, error) {
out := make([]wpkg.Segment, len(segments))
copy(out, segments)
for i := range out {
t := strings.TrimSpace(out[i].Text)
if t == "" {
continue
}
p, err := restore(t)
if err != nil {
return nil, err
}
out[i].Text = " " + strings.TrimSpace(p)
}
return out, nil
}

40
whisper/format_test.go Normal file
View File

@ -0,0 +1,40 @@
package whisper
import (
"testing"
"time"
wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)
func TestFormatSegments_pauseAndSpeaker(t *testing.T) {
segments := []wpkg.Segment{
{Text: " привет", Start: 0, End: 2 * time.Second},
{Text: " мир", Start: 4 * time.Second, End: 5 * time.Second},
{Text: " ответ", Start: 6 * time.Second, End: 8 * time.Second},
}
turns := []Turn{
{Start: 0, End: 5.5, Speaker: 0},
{Start: 5.5, End: 10, Speaker: 1},
}
got := FormatSegments(segments, turns, FormatOptions{
PauseGap: 1500 * time.Millisecond,
SpeakerLabel: "Спикер",
UseSpeakers: true,
})
want := "Спикер 1: привет\nмир\n\nСпикер 2: ответ"
if got != want {
t.Fatalf("got %q\nwant %q", got, want)
}
}
func TestFormatSegments_noSpeakers(t *testing.T) {
segments := []wpkg.Segment{
{Text: "a", Start: 0, End: time.Second},
{Text: "b", Start: 3 * time.Second, End: 4 * time.Second},
}
got := FormatSegments(segments, nil, FormatOptions{PauseGap: time.Second, UseSpeakers: false})
if got != "a\nb" {
t.Fatalf("got %q", got)
}
}

15
whisper/helper.go Normal file
View File

@ -0,0 +1,15 @@
package whisper
import (
"fmt"
"time"
)
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d",
t/time.Hour,
(t%time.Hour)/time.Minute,
(t%time.Minute)/time.Second,
(t%time.Second)/time.Millisecond,
)
}

39
whisper/helper_test.go Normal file
View File

@ -0,0 +1,39 @@
package whisper
import (
"testing"
"time"
)
func TestSrtTimestamp(t *testing.T) {
type args struct {
t time.Duration
}
tests := []struct {
name string
args args
want string
}{
{
name: "test 1",
args: args{
t: time.Duration(1*time.Hour + 2*time.Minute + 3*time.Second + 4*time.Millisecond),
},
want: "01:02:03,004",
},
{
name: "test 2",
args: args{
t: time.Duration(10*time.Hour + 20*time.Minute + 30*time.Second + 40*time.Millisecond),
},
want: "10:20:30,040",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := srtTimestamp(tt.args.t); got != tt.want {
t.Errorf("srtTimestamp() = %v, want %v", got, tt.want)
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More