first commit
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
This commit is contained in:
commit
b5c083e06f
13
.gitea/FUNDING.yml
Normal file
13
.gitea/FUNDING.yml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# These are supported funding model platforms
|
||||||
|
|
||||||
|
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||||
|
patreon: # Replace with a single Patreon username
|
||||||
|
open_collective: # Replace with a single Open Collective username
|
||||||
|
ko_fi: # Replace with a single Ko-fi username
|
||||||
|
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||||
|
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||||
|
liberapay: # Replace with a single Liberapay username
|
||||||
|
issuehunt: # Replace with a single IssueHunt username
|
||||||
|
otechie: # Replace with a single Otechie username
|
||||||
|
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||||
|
custom: ['https://www.paypal.me/appleboy46']
|
||||||
10
.gitea/dependabot.yml
Normal file
10
.gitea/dependabot.yml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: github-actions
|
||||||
|
directory: /
|
||||||
|
schedule:
|
||||||
|
interval: weekly
|
||||||
|
- package-ecosystem: gomod
|
||||||
|
directory: /
|
||||||
|
schedule:
|
||||||
|
interval: weekly
|
||||||
27
.gitea/workflows/README.md
Normal file
27
.gitea/workflows/README.md
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Gitea Actions
|
||||||
|
|
||||||
|
Workflows use [Gitea Actions](https://docs.gitea.com/usage/actions/overview) (compatible with GitHub Actions syntax).
|
||||||
|
|
||||||
|
## Secrets
|
||||||
|
|
||||||
|
| Secret | Used by | Purpose |
|
||||||
|
|--------|---------|---------|
|
||||||
|
| `GITEA_TOKEN` | `docker.yml` | Push images to the instance container registry |
|
||||||
|
|
||||||
|
Create a token with **write:package** (and **read:repository** for checkout if needed).
|
||||||
|
|
||||||
|
## Images
|
||||||
|
|
||||||
|
- **CI / CPU**: `docker/Dockerfile.ci` — built on every push to `main` and tags `v*`
|
||||||
|
- **GPU**: `docker/Dockerfile` — build manually on NVIDIA hosts
|
||||||
|
|
||||||
|
Registry URL: `${{ gitea.server_url }}/${{ gitea.repository }}`
|
||||||
|
|
||||||
|
## Local parity
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo apt-get install -y build-essential cmake git
|
||||||
|
go test ./config/... ./punctuation/... ./transcode/...
|
||||||
|
make dependency && make test
|
||||||
|
docker build -f docker/Dockerfile.ci -t go-whisper-api:ci .
|
||||||
|
```
|
||||||
32
.gitea/workflows/codeql.yml
Normal file
32
.gitea/workflows/codeql.yml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# CodeQL requires GitHub.com or a Gitea instance with the codeql-action available.
|
||||||
|
# Disable this workflow on self-hosted Gitea if the action cannot be resolved.
|
||||||
|
|
||||||
|
name: CodeQL
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
schedule:
|
||||||
|
- cron: "41 23 * * 6"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
analyze:
|
||||||
|
name: Analyze
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
security-events: write
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
language: ["go"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: github/codeql-action/init@v3
|
||||||
|
with:
|
||||||
|
languages: ${{ matrix.language }}
|
||||||
|
continue-on-error: true
|
||||||
|
- uses: github/codeql-action/analyze@v3
|
||||||
|
continue-on-error: true
|
||||||
60
.gitea/workflows/docker.yml
Normal file
60
.gitea/workflows/docker.yml
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
name: Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
env:
|
||||||
|
# Gitea container registry: <server>/<owner>/<repo>
|
||||||
|
IMAGE: ${{ gitea.server_url }}/${{ gitea.repository }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Login to Gitea Container Registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ gitea.server_url }}
|
||||||
|
username: ${{ gitea.actor }}
|
||||||
|
password: ${{ secrets.GITEA_TOKEN }}
|
||||||
|
|
||||||
|
- name: Docker meta
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: |
|
||||||
|
${{ env.IMAGE }}
|
||||||
|
tags: |
|
||||||
|
type=raw,value=latest,enable={{is_default_branch}}
|
||||||
|
type=ref,event=branch
|
||||||
|
type=ref,event=pr
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
|
|
||||||
|
- name: Build and push (CPU)
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: docker/Dockerfile.ci
|
||||||
|
platforms: linux/amd64
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
cache-from: type=gha
|
||||||
|
cache-to: type=gha,mode=max
|
||||||
43
.gitea/workflows/goreleaser.yml
Normal file
43
.gitea/workflows/goreleaser.yml
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
goreleaser:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
cache: true
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y --no-install-recommends build-essential cmake git
|
||||||
|
|
||||||
|
- name: Build whisper dependency
|
||||||
|
run: make dependency
|
||||||
|
|
||||||
|
- name: Run GoReleaser
|
||||||
|
uses: goreleaser/goreleaser-action@v6
|
||||||
|
with:
|
||||||
|
distribution: goreleaser
|
||||||
|
version: latest
|
||||||
|
args: release --clean
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
GITEA_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||||
62
.gitea/workflows/lint.yml
Normal file
62
.gitea/workflows/lint.yml
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
name: Lint and Testing
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Hadolint Dockerfile (CI)
|
||||||
|
uses: hadolint/hadolint-action@v3.1.0
|
||||||
|
with:
|
||||||
|
dockerfile: docker/Dockerfile.ci
|
||||||
|
|
||||||
|
- name: Hadolint Dockerfile (GPU)
|
||||||
|
uses: hadolint/hadolint-action@v3.1.0
|
||||||
|
with:
|
||||||
|
dockerfile: docker/Dockerfile
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
cache: true
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
build-essential cmake git libsentencepiece-dev pkg-config
|
||||||
|
|
||||||
|
- name: Unit tests (no CGO whisper)
|
||||||
|
run: go test ./config/... ./punctuation/... ./transcode/... -count=1
|
||||||
|
|
||||||
|
- name: Build with xlm punctuation and run full tests
|
||||||
|
run: make dependency && make test TAGS=xlm
|
||||||
|
|
||||||
|
golangci:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
cache: true
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v6
|
||||||
|
with:
|
||||||
|
version: latest
|
||||||
|
args: --timeout=10m
|
||||||
|
continue-on-error: true
|
||||||
44
.gitignore
vendored
Normal file
44
.gitignore
vendored
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Go build artifacts
|
||||||
|
*.exe
|
||||||
|
*.exe~
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
*.test
|
||||||
|
*.out
|
||||||
|
go.work
|
||||||
|
|
||||||
|
# Binaries and native libs (local build)
|
||||||
|
/bin/
|
||||||
|
bin/
|
||||||
|
/lib/
|
||||||
|
lib/
|
||||||
|
|
||||||
|
# Runtime data (never commit)
|
||||||
|
config.yaml
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
cache/
|
||||||
|
|
||||||
|
# Packaged releases
|
||||||
|
dist/
|
||||||
|
*.tar.gz
|
||||||
|
|
||||||
|
# Downloaded models and ONNX weights
|
||||||
|
models/**/*.bin
|
||||||
|
models/**/*.onnx
|
||||||
|
models/**/*.model
|
||||||
|
models/*.bin
|
||||||
|
models/*.onnx
|
||||||
|
models/*.model
|
||||||
|
|
||||||
|
# whisper.cpp submodule and its build tree
|
||||||
|
third_party/
|
||||||
|
|
||||||
|
# IDE and OS
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
3
.golangci.yml
Normal file
3
.golangci.yml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
run:
|
||||||
|
skip-dirs:
|
||||||
|
- third_party/whisper.cpp
|
||||||
3
.goreleaser.yaml
Normal file
3
.goreleaser.yaml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
builds:
|
||||||
|
- skip: true
|
||||||
|
|
||||||
3
.hadolint.yaml
Normal file
3
.hadolint.yaml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
ignored:
|
||||||
|
- DL3018
|
||||||
|
- DL3008
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Bo-Yi Wu
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
280
Makefile
Normal file
280
Makefile
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
EXECUTABLE := go-whisper-api
|
||||||
|
GO ?= go
|
||||||
|
GOFILES := $(shell find . -name "*.go" -type f)
|
||||||
|
HAS_GO = $(shell hash $(GO) > /dev/null 2>&1 && echo "GO" || echo "NOGO" )
|
||||||
|
|
||||||
|
WHISPER_CPP := $(abspath third_party/whisper.cpp)
|
||||||
|
WHISPER_BUILD := $(WHISPER_CPP)/build
|
||||||
|
WHISPER_VENDOR := $(WHISPER_CPP)/bindings/go
|
||||||
|
WHISPER_LIBDIR := $(WHISPER_BUILD)/src:$(WHISPER_BUILD)/ggml/src
|
||||||
|
|
||||||
|
RUNTIME_LIB_DIR := $(abspath lib)
|
||||||
|
# $ORIGIN/lib — binary next to lib/ (e.g. ./go-whisper-api + ./lib/)
|
||||||
|
# $ORIGIN/../lib — binary in bin/ (e.g. bin/go-whisper-api + lib/)
|
||||||
|
RUNTIME_RPATH := -Wl,-rpath,$$ORIGIN/lib:$$ORIGIN/../lib
|
||||||
|
|
||||||
|
ifneq ($(shell uname), Darwin)
|
||||||
|
EXTLDFLAGS = -extldflags "$(RUNTIME_RPATH)"
|
||||||
|
else
|
||||||
|
EXTLDFLAGS =
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(HAS_GO), GO)
|
||||||
|
GOPATH ?= $(shell $(GO) env GOPATH)
|
||||||
|
export PATH := $(GOPATH)/bin:$(PATH)
|
||||||
|
|
||||||
|
CGO_EXTRA_CFLAGS := -DSQLITE_MAX_VARIABLE_NUMBER=32766
|
||||||
|
CGO_CFLAGS ?= $(shell $(GO) env CGO_CFLAGS) $(CGO_EXTRA_CFLAGS)
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(OS), Windows_NT)
|
||||||
|
GOFLAGS := -v -buildmode=exe
|
||||||
|
EXECUTABLE ?= $(EXECUTABLE).exe
|
||||||
|
else ifeq ($(OS), Windows)
|
||||||
|
GOFLAGS := -v -buildmode=exe
|
||||||
|
EXECUTABLE ?= $(EXECUTABLE).exe
|
||||||
|
else
|
||||||
|
GOFLAGS := -v
|
||||||
|
EXECUTABLE ?= $(EXECUTABLE)
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifneq ($(DRONE_TAG),)
|
||||||
|
VERSION ?= $(DRONE_TAG)
|
||||||
|
else
|
||||||
|
VERSION ?= $(shell git describe --tags --always || git rev-parse --short HEAD)
|
||||||
|
endif
|
||||||
|
|
||||||
|
TAGS ?=
|
||||||
|
UNAME_M := $(shell uname -m)
|
||||||
|
ifeq ($(UNAME_M),x86_64)
|
||||||
|
SHERPA_LIBARCH := x86_64-unknown-linux-gnu
|
||||||
|
endif
|
||||||
|
ifeq ($(UNAME_M),aarch64)
|
||||||
|
SHERPA_LIBARCH := aarch64-unknown-linux-gnu
|
||||||
|
endif
|
||||||
|
SHERPA_LINUX_VER := $(shell awk '/sherpa-onnx-go-linux/ {print $$2; exit}' go.mod)
|
||||||
|
SHERPA_LIBDIR := $(GOPATH)/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH)
|
||||||
|
ifneq ($(shell uname), Darwin)
|
||||||
|
EXTLDFLAGS_SHERPA = -extldflags "$(RUNTIME_RPATH)"
|
||||||
|
EXTLDFLAGS_XLM = -extldflags "$(RUNTIME_RPATH)"
|
||||||
|
else
|
||||||
|
EXTLDFLAGS_SHERPA =
|
||||||
|
EXTLDFLAGS_XLM =
|
||||||
|
endif
|
||||||
|
GOLDFLAGS ?= -X 'main.Version=$(VERSION)'
|
||||||
|
|
||||||
|
INCLUDE_PATH := $(WHISPER_CPP)/include:$(WHISPER_CPP)/ggml/include:$(WHISPER_VENDOR):$(INCLUDE_PATH)
|
||||||
|
LIBRARY_PATH := $(WHISPER_LIBDIR):$(LIBRARY_PATH)
|
||||||
|
export LD_LIBRARY_PATH := $(WHISPER_LIBDIR):$(LD_LIBRARY_PATH)
|
||||||
|
|
||||||
|
ifdef WHISPER_CUBLAS
|
||||||
|
CGO_CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||||
|
CGO_CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||||
|
EXTLDFLAGS = -extldflags "-lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib"
|
||||||
|
|
||||||
|
build: $(EXECUTABLE)
|
||||||
|
|
||||||
|
$(EXECUTABLE): $(GOFILES)
|
||||||
|
CGO_CXXFLAGS=${CGO_CXXFLAGS} CGO_CFLAGS=${CGO_CFLAGS} C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@
|
||||||
|
endif
|
||||||
|
|
||||||
|
MODEL_URL ?= https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin
|
||||||
|
MODEL_PATH ?= models/ggml-tiny.en.bin
|
||||||
|
VAD_MODEL ?= silero-v6.2.0
|
||||||
|
VAD_MODEL_PATH ?= models/ggml-silero-v6.2.0.bin
|
||||||
|
|
||||||
|
all: build
|
||||||
|
|
||||||
|
PUNCT_MODEL_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8.tar.bz2
|
||||||
|
PUNCT_MODEL_DIR ?= models/punctuation/ct-transformer-zh-en-int8
|
||||||
|
|
||||||
|
XLM_PUNCT_DIR ?= models/punctuation/xlm-roberta
|
||||||
|
XLM_HF_REPO ?= Salama1429/xlm-roberta_punctuation_fullstop_truecase
|
||||||
|
|
||||||
|
XLM_MODEL_CONFIG_SRC ?= config/xlm-roberta-model.yaml
|
||||||
|
ORT_LIB_SRC ?= $(shell $(GO) env GOMODCACHE 2>/dev/null)/github.com/k2-fsa/sherpa-onnx-go-linux@$(SHERPA_LINUX_VER)/lib/$(SHERPA_LIBARCH)/libonnxruntime.so
|
||||||
|
|
||||||
|
# Copy runtime .so into ./lib/. Binary rpath: $ORIGIN/lib or $ORIGIN/../lib (see RUNTIME_RPATH).
|
||||||
|
# Use cp -n where possible: existing root-owned libs in lib/ must not break the build.
|
||||||
|
install-runtime-libs: dependency
|
||||||
|
@mkdir -p "$(RUNTIME_LIB_DIR)"
|
||||||
|
@cp -an "$(WHISPER_BUILD)/src"/libwhisper.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true
|
||||||
|
@cp -an "$(WHISPER_BUILD)/ggml/src"/libggml*.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true
|
||||||
|
@echo "Whisper/ggml libs ready in $(RUNTIME_LIB_DIR)/"
|
||||||
|
|
||||||
|
install-ort-lib:
|
||||||
|
@mkdir -p "$(RUNTIME_LIB_DIR)"
|
||||||
|
@if [ ! -f "$(ORT_LIB_SRC)" ]; then echo "missing $(ORT_LIB_SRC); run: go mod download"; exit 1; fi
|
||||||
|
@dest="$(RUNTIME_LIB_DIR)/libonnxruntime.so"; \
|
||||||
|
if [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \
|
||||||
|
echo "libonnxruntime.so already up to date in $(RUNTIME_LIB_DIR)/"; \
|
||||||
|
elif cp -f "$(ORT_LIB_SRC)" "$$dest" 2>/dev/null; then \
|
||||||
|
echo "Installed $$dest"; \
|
||||||
|
elif [ -f "$$dest" ] && cmp -s "$(ORT_LIB_SRC)" "$$dest"; then \
|
||||||
|
echo "libonnxruntime.so present in $(RUNTIME_LIB_DIR)/ (unchanged, not writable)"; \
|
||||||
|
else \
|
||||||
|
echo "cannot install libonnxruntime.so to $$dest"; \
|
||||||
|
echo "fix: sudo chown -R $$USER:$$(id -gn) $(RUNTIME_LIB_DIR)"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# XLM punctuation links -lsentencepiece; bundle .so for hosts without libsentencepiece0 package.
|
||||||
|
SP_LIB_DIRS := /usr/lib/x86_64-linux-gnu /usr/lib/aarch64-linux-gnu /usr/lib64 /usr/lib
|
||||||
|
install-sp-lib:
|
||||||
|
@mkdir -p "$(RUNTIME_LIB_DIR)"
|
||||||
|
@found=0; \
|
||||||
|
for d in $(SP_LIB_DIRS); do \
|
||||||
|
if [ -e "$$d/libsentencepiece.so.0" ] || [ -L "$$d/libsentencepiece.so.0" ]; then \
|
||||||
|
cp -an "$$d"/libsentencepiece.so* "$(RUNTIME_LIB_DIR)/" 2>/dev/null || true; \
|
||||||
|
found=1; \
|
||||||
|
break; \
|
||||||
|
fi; \
|
||||||
|
done; \
|
||||||
|
if [ "$$found" = "0" ]; then \
|
||||||
|
echo "libsentencepiece.so.0 not found; install: sudo apt-get install libsentencepiece0"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
@test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0 after install-sp-lib"; exit 1)
|
||||||
|
@echo "Sentencepiece libs ready in $(RUNTIME_LIB_DIR)/"
|
||||||
|
|
||||||
|
# If lib/*.so were created as root (e.g. manual cp with sudo), reclaim ownership for builds.
|
||||||
|
fix-lib-perms:
|
||||||
|
@if [ -d "$(RUNTIME_LIB_DIR)" ]; then \
|
||||||
|
chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)" 2>/dev/null || \
|
||||||
|
sudo chown -R "$$USER:$$(id -gn)" "$(RUNTIME_LIB_DIR)"; \
|
||||||
|
echo "Ownership of $(RUNTIME_LIB_DIR)/ updated"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
install-runtime-libs-xlm: install-runtime-libs install-ort-lib install-sp-lib
|
||||||
|
|
||||||
|
# Fail fast before deploy if ./lib is incomplete (lib/ is not in git: *.so is gitignored).
|
||||||
|
verify-runtime-libs-xlm:
|
||||||
|
@test -f bin/$(EXECUTABLE) || (echo "missing bin/$(EXECUTABLE); run: make build-xlm"; exit 1)
|
||||||
|
@test -f "$(RUNTIME_LIB_DIR)/libonnxruntime.so" || (echo "missing $(RUNTIME_LIB_DIR)/libonnxruntime.so; run: make install-runtime-libs-xlm"; exit 1)
|
||||||
|
@test -e "$(RUNTIME_LIB_DIR)/libsentencepiece.so.0" || (echo "missing $(RUNTIME_LIB_DIR)/libsentencepiece.so.0; run: make install-sp-lib"; exit 1)
|
||||||
|
@test -e "$(RUNTIME_LIB_DIR)/libwhisper.so.1" || (echo "missing $(RUNTIME_LIB_DIR)/libwhisper.so.1; run: make install-runtime-libs"; exit 1)
|
||||||
|
@echo "Runtime libs OK in $(RUNTIME_LIB_DIR)/"
|
||||||
|
|
||||||
|
RUNTIME_TARBALL := dist/go-whisper-api-runtime-$(shell uname -m).tar.gz
|
||||||
|
package-runtime-xlm: verify-runtime-libs-xlm
|
||||||
|
@mkdir -p dist
|
||||||
|
tar -czf "$(RUNTIME_TARBALL)" bin/$(EXECUTABLE) lib
|
||||||
|
@echo "Created $(RUNTIME_TARBALL) — on prod: tar -xzf ... -C /opt/go-whisper-api (keeps bin/ and lib/)"
|
||||||
|
|
||||||
|
# Copy bundled label config (needs write access to $(XLM_PUNCT_DIR); fix with: sudo chown -R $$USER models/punctuation)
|
||||||
|
install-xlm-punctuation-config:
|
||||||
|
@mkdir -p $(XLM_PUNCT_DIR)
|
||||||
|
@cp "$(XLM_MODEL_CONFIG_SRC)" "$(XLM_PUNCT_DIR)/config.yaml"
|
||||||
|
@echo "Installed $(XLM_PUNCT_DIR)/config.yaml"
|
||||||
|
|
||||||
|
download-xlm-punctuation-model: install-xlm-punctuation-config
|
||||||
|
@mkdir -p $(XLM_PUNCT_DIR)
|
||||||
|
@for f in model.onnx sp.model; do \
|
||||||
|
if [ ! -f "$(XLM_PUNCT_DIR)/$$f" ]; then \
|
||||||
|
echo "Downloading $$f from $(XLM_HF_REPO)..."; \
|
||||||
|
curl -fL "https://huggingface.co/$(XLM_HF_REPO)/resolve/main/$$f" -o "$(XLM_PUNCT_DIR)/$$f"; \
|
||||||
|
else \
|
||||||
|
echo "Already have $(XLM_PUNCT_DIR)/$$f"; \
|
||||||
|
fi; \
|
||||||
|
done
|
||||||
|
|
||||||
|
DIAR_SEG_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||||
|
DIAR_EMB_URL ?= https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||||
|
DIAR_DIR ?= models/diarization
|
||||||
|
|
||||||
|
download-diarization-models:
|
||||||
|
@mkdir -p $(DIAR_DIR)
|
||||||
|
@if [ ! -f "$(DIAR_DIR)/pyannote-segmentation-3-0/model.onnx" ]; then \
|
||||||
|
echo "Downloading speaker segmentation model..."; \
|
||||||
|
curl -fL "$(DIAR_SEG_URL)" -o /tmp/diar-seg.tar.bz2; \
|
||||||
|
tar -xjf /tmp/diar-seg.tar.bz2 -C $(DIAR_DIR); \
|
||||||
|
rm -f /tmp/diar-seg.tar.bz2; \
|
||||||
|
else \
|
||||||
|
echo "Segmentation model present"; \
|
||||||
|
fi
|
||||||
|
@if [ ! -f "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" ]; then \
|
||||||
|
echo "Downloading speaker embedding model..."; \
|
||||||
|
curl -fL "$(DIAR_EMB_URL)" -o "$(DIAR_DIR)/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; \
|
||||||
|
else \
|
||||||
|
echo "Embedding model present"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
download-punctuation-model:
|
||||||
|
@mkdir -p models/punctuation
|
||||||
|
@if [ ! -f "$(PUNCT_MODEL_DIR)/model.int8.onnx" ]; then \
|
||||||
|
echo "Downloading punctuation model..."; \
|
||||||
|
curl -fL "$(PUNCT_MODEL_URL)" -o /tmp/punct-model.tar.bz2; \
|
||||||
|
tar -xjf /tmp/punct-model.tar.bz2 -C models/punctuation; \
|
||||||
|
rm -f /tmp/punct-model.tar.bz2; \
|
||||||
|
if [ -d models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 ]; then \
|
||||||
|
mv models/punctuation/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8 "$(PUNCT_MODEL_DIR)"; \
|
||||||
|
fi; \
|
||||||
|
else \
|
||||||
|
echo "Punctuation model already exists: $(PUNCT_MODEL_DIR)/model.int8.onnx"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
download-model:
|
||||||
|
@mkdir -p models
|
||||||
|
@if [ ! -f "$(MODEL_PATH)" ]; then \
|
||||||
|
echo "Downloading $(MODEL_PATH)..."; \
|
||||||
|
curl -fL "$(MODEL_URL)" -o "$(MODEL_PATH)"; \
|
||||||
|
else \
|
||||||
|
echo "Model already exists: $(MODEL_PATH)"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
download-vad-model:
|
||||||
|
@mkdir -p models
|
||||||
|
@if [ ! -f "$(VAD_MODEL_PATH)" ]; then \
|
||||||
|
echo "Downloading VAD model $(VAD_MODEL) to models/..."; \
|
||||||
|
./third_party/whisper.cpp/models/download-vad-model.sh $(VAD_MODEL) models; \
|
||||||
|
else \
|
||||||
|
echo "VAD model already exists: $(VAD_MODEL_PATH)"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
clone:
|
||||||
|
@[ -d third_party/whisper.cpp ] || git clone https://github.com/appleboy/whisper.cpp.git third_party/whisper.cpp
|
||||||
|
|
||||||
|
dependency: clone
|
||||||
|
@echo Build whisper
|
||||||
|
@if [ ! -f "$(WHISPER_BUILD)/src/libwhisper.so" ] && [ ! -f "$(WHISPER_BUILD)/src/libwhisper.a" ]; then \
|
||||||
|
cmake -S "$(WHISPER_CPP)" -B "$(WHISPER_BUILD)" -DCMAKE_BUILD_TYPE=Release && \
|
||||||
|
cmake --build "$(WHISPER_BUILD)" --config Release -j$$(nproc 2>/dev/null || echo 4); \
|
||||||
|
else \
|
||||||
|
echo "whisper library already built in $(WHISPER_BUILD)"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
test:
|
||||||
|
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) test -v -cover -coverprofile coverage.txt ./... && echo "\n==>\033[32m Ok\033[m\n" || exit 1
|
||||||
|
|
||||||
|
install: $(GOFILES)
|
||||||
|
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) install -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)'
|
||||||
|
|
||||||
|
build: install-runtime-libs $(EXECUTABLE)
|
||||||
|
|
||||||
|
# Build with sherpa-onnx (punctuation + speaker diarization)
|
||||||
|
build-sherpa:
|
||||||
|
@$(MAKE) build TAGS=sherpa
|
||||||
|
|
||||||
|
# XLM-RoBERTa punctuation (47 languages); requires libsentencepiece-dev
|
||||||
|
build-xlm:
|
||||||
|
@$(MAKE) install-runtime-libs-xlm
|
||||||
|
@$(MAKE) build TAGS=xlm
|
||||||
|
|
||||||
|
$(EXECUTABLE): $(GOFILES)
|
||||||
|
ifneq (,$(findstring xlm,$(TAGS)))
|
||||||
|
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_XLM) -s -w $(GOLDFLAGS)' -o bin/$@
|
||||||
|
else ifneq (,$(findstring sherpa,$(TAGS)))
|
||||||
|
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=$(SHERPA_LIBDIR):${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS_SHERPA) -s -w $(GOLDFLAGS)' -o bin/$@
|
||||||
|
else
|
||||||
|
C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GO) build -v -tags '$(TAGS)' -ldflags '$(EXTLDFLAGS)-s -w $(GOLDFLAGS)' -o bin/$@
|
||||||
|
endif
|
||||||
|
|
||||||
|
clean:
|
||||||
|
$(GO) clean -x -i ./...
|
||||||
|
rm -rf coverage.txt $(EXECUTABLE) $(DIST) bin/$(EXECUTABLE)
|
||||||
|
|
||||||
|
clean-whisper:
|
||||||
|
rm -rf "$(WHISPER_BUILD)"
|
||||||
|
|
||||||
|
version:
|
||||||
|
@echo $(VERSION)
|
||||||
129
README.md
Normal file
129
README.md
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# go-whisper-api
|
||||||
|
|
||||||
|
HTTP-сервер распознавания речи на базе [whisper.cpp](https://github.com/ggerganov/whisper.cpp): совместимость с **SPR** (`/spr/*`) и **OpenAI Whisper API** (`/v1/*`) для Open WebUI и других клиентов.
|
||||||
|
|
||||||
|
## Модели Whisper (ggml)
|
||||||
|
|
||||||
|
См. [каталог моделей на Hugging Face](https://huggingface.co/ggerganov/whisper.cpp/tree/main).
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make download-model
|
||||||
|
# или вручную:
|
||||||
|
curl -LJ https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin \
|
||||||
|
--output models/ggml-small.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
Веса STT: `*.bin` в корне `api.models_dir` (список — `GET /spr/models`). Подкаталоги `models/vad/`, `models/punctuation/` и т.п. в список моделей не входят.
|
||||||
|
|
||||||
|
## Конфигурация
|
||||||
|
|
||||||
|
Шаблон в репозитории — `config.yaml.example`. Рабочий файл создайте локально (он в `.gitignore`):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cp config.yaml.example config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Если `--config` не указан и в текущем каталоге есть `config.yaml`, он подгружается автоматически.
|
||||||
|
|
||||||
|
Промежуточные файлы (загрузки, конвертация) — только в `/tmp`, удаляются после обработки. Форматы **wav, mp3, flac, ogg, m4a, mp4, aac** декодируются на чистом Go (без ffmpeg).
|
||||||
|
|
||||||
|
## Запуск
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go-whisper-api serve --config config.yaml
|
||||||
|
# или с флагами (перекрывают YAML):
|
||||||
|
go-whisper-api --models-dir ./models --addr :8080
|
||||||
|
```
|
||||||
|
|
||||||
|
Docker:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker run -p 8080:8080 \
|
||||||
|
-v $PWD/models:/app/models \
|
||||||
|
-v $PWD/config.yaml:/app/config.yaml \
|
||||||
|
ghcr.io/appleboy/go-whisper-api:latest \
|
||||||
|
serve --config /app/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Swagger UI: [http://localhost:8080/](http://localhost:8080/)
|
||||||
|
|
||||||
|
## VAD (детекция речи)
|
||||||
|
|
||||||
|
```sh
|
||||||
|
make download-vad-model
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api:
|
||||||
|
models_dir: ./models
|
||||||
|
vad:
|
||||||
|
enabled: true
|
||||||
|
model: ggml-silero-v6.2.0.bin
|
||||||
|
```
|
||||||
|
|
||||||
|
Полезно на длинных записях с тишиной или музыкой в начале/конце.
|
||||||
|
|
||||||
|
## Пунктуация
|
||||||
|
|
||||||
|
| Уровень | Параметр | Эффект |
|
||||||
|
|---------|----------|--------|
|
||||||
|
| Мастер | `punctuation.enabled: false` | Пунктуация отключена |
|
||||||
|
| API по умолчанию | `api.default_punctuation: true` | Если нет `?punctuation=` |
|
||||||
|
| На запрос | `?punctuation=1` / `?punctuation=0` | Только при `enabled: true` |
|
||||||
|
|
||||||
|
Движки: `heuristic`, `xlm`, `sherpa`, `sherpa-online`, `http`, `off`.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
curl -X POST "http://localhost:8080/spr/stt/ggml-small?punctuation=1" -F "wav=@audio.wav"
|
||||||
|
```
|
||||||
|
|
||||||
|
Сборка XLM: `make download-xlm-punctuation-model && make build-xlm`. Продакшен: `make package-runtime-xlm` — в git не коммитятся `lib/*.so` (см. `.gitignore`).
|
||||||
|
|
||||||
|
Фильтр артефактов: `api.garbage` (по умолчанию `*выбая*`), отключить: `garbage: []`.
|
||||||
|
|
||||||
|
## HTTP API
|
||||||
|
|
||||||
|
- **SPR** (`/spr/*`) — очередь, waveform, импорт/экспорт моделей
|
||||||
|
- **OpenAI** (`/v1/*`) — синхронный STT
|
||||||
|
|
||||||
|
Пример SPR:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
curl -X POST "http://localhost:8080/spr/stt/ggml-small?async=0" -F "wav=@audio.wav"
|
||||||
|
```
|
||||||
|
|
||||||
|
По умолчанию STT **асинхронный**: `POST /spr/stt/{id}` → `taskID`, опрос `GET /spr/queue/{taskID}`, результат `GET /spr/result/{taskID}`. Кэш: `./cache/waiting/` и `./cache/ready/`.
|
||||||
|
|
||||||
|
Язык по умолчанию — **`ru`** (`?language=en`, `?language=auto`).
|
||||||
|
|
||||||
|
| Endpoint | Метод | Описание |
|
||||||
|
|----------|--------|-------------|
|
||||||
|
| `/spr/models` | GET | Список моделей |
|
||||||
|
| `/spr/stt/{id}` | POST | Транскрипция |
|
||||||
|
| `/spr/result/{taskID}` | GET | Результат |
|
||||||
|
| `/spr/queue` | GET | Все задачи |
|
||||||
|
| `/spr/queue/{taskID}` | GET / DELETE | Статус / удаление |
|
||||||
|
| `/v1/audio/transcriptions` | POST | OpenAI-совместимая транскрипция |
|
||||||
|
| `/v1/models` | GET | Список моделей |
|
||||||
|
|
||||||
|
## Open WebUI
|
||||||
|
|
||||||
|
| Параметр | Значение |
|
||||||
|
|----------|----------|
|
||||||
|
| Engine | `OpenAI` |
|
||||||
|
| API Base URL | `http://<host>:6183/v1` |
|
||||||
|
| API Key | любая непустая строка |
|
||||||
|
| STT Model | `whisper-1` (см. `api.default_model` в config) |
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api:
|
||||||
|
default_model: ggml-large-v3-turbo
|
||||||
|
```
|
||||||
|
|
||||||
|
## CI/CD
|
||||||
|
|
||||||
|
Workflows в `.gitea/workflows/` (`lint.yml`, `docker.yml`). Секрет `GITEA_TOKEN` с `write:package` для push образов.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker build -f docker/Dockerfile.ci -t go-whisper-api .
|
||||||
|
```
|
||||||
362
api/cache.go
Normal file
362
api/cache.go
Normal file
@ -0,0 +1,362 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cacheWaiting = "waiting"
|
||||||
|
cacheReady = "ready"
|
||||||
|
fileParams = "params.conf"
|
||||||
|
fileAudio = "audio.wav"
|
||||||
|
fileAudioJSON = "audio.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TaskParams struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Created string `json:"created"`
|
||||||
|
Processed string `json:"processed,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
Punctuation bool `json:"punctuation,omitempty"`
|
||||||
|
Speakers bool `json:"speakers,omitempty"`
|
||||||
|
NumClusters int `json:"num_clusters,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Words []whisper.Word `json:"words,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioJSON struct {
|
||||||
|
Waveform []float64 `json:"waveform"`
|
||||||
|
Buckets int `json:"buckets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiskCache struct {
|
||||||
|
root string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) Root() string {
|
||||||
|
return c.root
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCacheRoot(root string) (string, error) {
|
||||||
|
if root == "" {
|
||||||
|
root = "./cache"
|
||||||
|
}
|
||||||
|
abs, err := filepath.Abs(root)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("cache dir: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Clean(abs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDiskCache(root string) (*DiskCache, error) {
|
||||||
|
abs, err := resolveCacheRoot(root)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c := &DiskCache{root: abs}
|
||||||
|
for _, sub := range []string{cacheWaiting, cacheReady} {
|
||||||
|
if err := os.MkdirAll(filepath.Join(abs, sub), 0o755); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) waitingDir(id string) string {
|
||||||
|
return filepath.Join(c.root, cacheWaiting, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) readyDir(id string) string {
|
||||||
|
return filepath.Join(c.root, cacheReady, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) locate(id string) (dir, phase string, ok bool) {
|
||||||
|
if id == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
ready := c.readyDir(id)
|
||||||
|
if st, err := os.Stat(ready); err == nil && st.IsDir() {
|
||||||
|
return ready, cacheReady, true
|
||||||
|
}
|
||||||
|
waiting := c.waitingDir(id)
|
||||||
|
if st, err := os.Stat(waiting); err == nil && st.IsDir() {
|
||||||
|
return waiting, cacheWaiting, true
|
||||||
|
}
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) Enqueue(id string, params TaskParams, audioWavPath string) error {
|
||||||
|
dir := c.waitingDir(id)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
params.ID = id
|
||||||
|
if err := c.writeParams(dir, params); err != nil {
|
||||||
|
_ = os.RemoveAll(dir)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst := filepath.Join(dir, fileAudio)
|
||||||
|
if err := os.Rename(audioWavPath, dst); err != nil {
|
||||||
|
if err2 := copyFile(audioWavPath, dst); err2 != nil {
|
||||||
|
_ = os.RemoveAll(dir)
|
||||||
|
return fmt.Errorf("move audio to %s: %w", dst, err2)
|
||||||
|
}
|
||||||
|
_ = os.Remove(audioWavPath)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) writeParams(dir string, params TaskParams) error {
|
||||||
|
data, err := json.Marshal(params)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(filepath.Join(dir, fileParams), data, 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) LoadParams(id string) (TaskParams, string, error) {
|
||||||
|
dir, phase, ok := c.locate(id)
|
||||||
|
if !ok {
|
||||||
|
return TaskParams{}, "", fmt.Errorf("task not found")
|
||||||
|
}
|
||||||
|
params, err := c.readParams(dir)
|
||||||
|
if err != nil {
|
||||||
|
return TaskParams{}, "", err
|
||||||
|
}
|
||||||
|
return params, phase, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) readParams(dir string) (TaskParams, error) {
|
||||||
|
path := filepath.Join(dir, fileParams)
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return TaskParams{}, fmt.Errorf("read %s: %w", path, err)
|
||||||
|
}
|
||||||
|
var params TaskParams
|
||||||
|
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||||
|
return TaskParams{}, err
|
||||||
|
}
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) List() (map[string]map[string]string, error) {
|
||||||
|
out := make(map[string]map[string]string)
|
||||||
|
for _, phase := range []string{cacheWaiting, cacheReady} {
|
||||||
|
base := filepath.Join(c.root, phase)
|
||||||
|
entries, err := os.ReadDir(base)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id := e.Name()
|
||||||
|
params, err := c.readParams(filepath.Join(base, id))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[id] = map[string]string{
|
||||||
|
"created": params.Created,
|
||||||
|
"status": params.Status,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) NextWaiting() (string, bool, error) {
|
||||||
|
base := filepath.Join(c.root, cacheWaiting)
|
||||||
|
entries, err := os.ReadDir(base)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
type item struct {
|
||||||
|
id string
|
||||||
|
created time.Time
|
||||||
|
}
|
||||||
|
var pending []item
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
params, err := c.readParams(filepath.Join(base, e.Name()))
|
||||||
|
if err != nil || params.Status != string(statusWaiting) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t, _ := time.ParseInLocation("2006-01-02 15:04:05", params.Created, time.Local)
|
||||||
|
pending = append(pending, item{id: e.Name(), created: t})
|
||||||
|
}
|
||||||
|
if len(pending) == 0 {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
sort.Slice(pending, func(i, j int) bool {
|
||||||
|
return pending[i].created.Before(pending[j].created)
|
||||||
|
})
|
||||||
|
return pending[0].id, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) SetStatus(id string, status taskStatus, mutate func(*TaskParams)) error {
|
||||||
|
dir := c.waitingDir(id)
|
||||||
|
if _, err := os.Stat(dir); err != nil {
|
||||||
|
return fmt.Errorf("task %s not in waiting", id)
|
||||||
|
}
|
||||||
|
params, err := c.readParams(dir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
params.Status = string(status)
|
||||||
|
if mutate != nil {
|
||||||
|
mutate(¶ms)
|
||||||
|
}
|
||||||
|
return c.writeParams(dir, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) FinishWaiting(id string, result whisper.TranscriptResult, errMsg string, waveform []float64) error {
|
||||||
|
dir := c.waitingDir(id)
|
||||||
|
params, err := c.readParams(dir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
params.Processed = time.Now().Format("2006-01-02 15:04:05")
|
||||||
|
if errMsg != "" {
|
||||||
|
params.Status = string(statusError)
|
||||||
|
params.Error = errMsg
|
||||||
|
} else {
|
||||||
|
params.Status = string(statusReady)
|
||||||
|
params.Text = result.Text
|
||||||
|
params.Words = result.Words
|
||||||
|
}
|
||||||
|
if err := c.writeParams(dir, params); err != nil {
|
||||||
|
return fmt.Errorf("update %s: %w", filepath.Join(dir, fileParams), err)
|
||||||
|
}
|
||||||
|
if len(waveform) > 0 {
|
||||||
|
aj := AudioJSON{Waveform: waveform, Buckets: len(waveform)}
|
||||||
|
data, err := json.Marshal(aj)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, fileAudioJSON), data, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) PromoteToReady(id string) error {
|
||||||
|
if _, phase, ok := c.locate(id); ok && phase == cacheReady {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
src := c.waitingDir(id)
|
||||||
|
dst := c.readyDir(id)
|
||||||
|
if _, err := os.Stat(src); err != nil {
|
||||||
|
return fmt.Errorf("task %s not in waiting", id)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(dst); err == nil {
|
||||||
|
return os.RemoveAll(src)
|
||||||
|
}
|
||||||
|
return os.Rename(src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) Delete(id string) bool {
|
||||||
|
dir, _, ok := c.locate(id)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_ = os.RemoveAll(dir)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) AudioPath(id string) (string, bool) {
|
||||||
|
dir, _, ok := c.locate(id)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
p := filepath.Join(dir, fileAudio)
|
||||||
|
if _, err := os.Stat(p); err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) Waveform(id string) ([]float64, error) {
|
||||||
|
dir, _, ok := c.locate(id)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("task not found")
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(filepath.Join(dir, fileAudioJSON))
|
||||||
|
if err == nil {
|
||||||
|
var aj AudioJSON
|
||||||
|
if json.Unmarshal(data, &aj) == nil && len(aj.Waveform) > 0 {
|
||||||
|
return aj.Waveform, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return waveformFromWav(filepath.Join(dir, fileAudio), 512)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DiskCache) RecoverInterrupted() error {
|
||||||
|
base := filepath.Join(c.root, cacheWaiting)
|
||||||
|
entries, err := os.ReadDir(base)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id := e.Name()
|
||||||
|
dir := filepath.Join(base, id)
|
||||||
|
params, err := c.readParams(dir)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch params.Status {
|
||||||
|
case string(statusProcessing):
|
||||||
|
params.Status = string(statusWaiting)
|
||||||
|
_ = c.writeParams(dir, params)
|
||||||
|
case string(statusReady), string(statusError):
|
||||||
|
_ = c.PromoteToReady(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
in, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer in.Close()
|
||||||
|
out, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
_, err = io.Copy(out, in)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidTaskID(id string) bool {
|
||||||
|
return id != "" && !strings.Contains(id, "..") && !strings.ContainsAny(id, `/\`)
|
||||||
|
}
|
||||||
127
api/cache_test.go
Normal file
127
api/cache_test.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDiskCache_params_promote_list(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
c, err := NewDiskCache(root)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := "319d72c7-301d-44fd-935f-3526dfb70f9f"
|
||||||
|
dir := c.waitingDir(id)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
params := TaskParams{
|
||||||
|
ID: id,
|
||||||
|
Created: "2026-03-31 21:37:46",
|
||||||
|
Status: string(statusReady),
|
||||||
|
Model: "ggml-small",
|
||||||
|
Text: "test",
|
||||||
|
}
|
||||||
|
if err := c.writeParams(dir, params); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := c.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if list[id]["status"] != string(statusReady) {
|
||||||
|
t.Fatalf("list: %v", list[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.PromoteToReady(id); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, phase, err := c.LoadParams(id)
|
||||||
|
if err != nil || phase != cacheReady {
|
||||||
|
t.Fatalf("phase=%s err=%v", phase, err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(c.readyDir(id), fileParams)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidTaskID(t *testing.T) {
|
||||||
|
if !isValidTaskID("319d72c7-301d-44fd-935f-3526dfb70f9f") {
|
||||||
|
t.Fatal("uuid should be valid")
|
||||||
|
}
|
||||||
|
if isValidTaskID("../etc") {
|
||||||
|
t.Fatal("path traversal must be rejected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCacheRoot_absolute(t *testing.T) {
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
abs, err := resolveCacheRoot("./cache")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := filepath.Join(cwd, "cache")
|
||||||
|
if abs != want {
|
||||||
|
t.Fatalf("got %q want %q", abs, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiskCache_RecoverInterrupted(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
c, err := NewDiskCache(root)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
id := "task-1"
|
||||||
|
dir := c.waitingDir(id)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.writeParams(dir, TaskParams{
|
||||||
|
ID: id, Created: "2026-01-01 00:00:00", Status: string(statusProcessing), Model: "m",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.RecoverInterrupted(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p, _, err := c.LoadParams(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if p.Status != string(statusWaiting) {
|
||||||
|
t.Fatalf("got %q", p.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiskCache_RecoverInterrupted_promotesReady(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
c, err := NewDiskCache(root)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
id := "done-task"
|
||||||
|
dir := c.waitingDir(id)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.writeParams(dir, TaskParams{
|
||||||
|
ID: id, Created: "2026-01-01 00:00:00", Status: string(statusReady), Model: "m", Text: "hi",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.RecoverInterrupted(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, phase, err := c.LoadParams(id)
|
||||||
|
if err != nil || phase != cacheReady {
|
||||||
|
t.Fatalf("phase=%s err=%v", phase, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
26
api/garbage.go
Normal file
26
api/garbage.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go-whisper-api/garbage"
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func applyGarbage(r whisper.TranscriptResult, patterns []string) whisper.TranscriptResult {
|
||||||
|
if len(patterns) == 0 {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.Text = garbage.FilterText(r.Text, patterns)
|
||||||
|
if len(r.Words) == 0 {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
gw := make([]garbage.Word, len(r.Words))
|
||||||
|
for i, w := range r.Words {
|
||||||
|
gw[i] = garbage.Word{Word: w.Word, Start: w.Start, Stop: w.Stop}
|
||||||
|
}
|
||||||
|
gw = garbage.FilterWords(gw, patterns)
|
||||||
|
r.Words = make([]whisper.Word, len(gw))
|
||||||
|
for i, w := range gw {
|
||||||
|
r.Words[i] = whisper.Word{Word: w.Word, Start: w.Start, Stop: w.Stop}
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
181
api/models.go
Normal file
181
api/models.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Subdirectories under models_dir that hold non-Whisper assets (VAD, punctuation, etc.).
|
||||||
|
var reservedModelSubdirs = map[string]struct{}{
|
||||||
|
"vad": {},
|
||||||
|
"punctuation": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Top-level .bin files that are not Whisper STT models.
|
||||||
|
var excludedWhisperModelFiles = map[string]struct{}{
|
||||||
|
"vad.bin": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
type Registry struct {
|
||||||
|
dir string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRegistry(dir string) *Registry {
|
||||||
|
return &Registry{dir: dir}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReservedModelSubdir(name string) bool {
|
||||||
|
_, ok := reservedModelSubdirs[strings.ToLower(name)]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWhisperModelFile(name string) bool {
|
||||||
|
if !strings.HasSuffix(strings.ToLower(name), ".bin") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, excluded := excludedWhisperModelFiles[strings.ToLower(name)]
|
||||||
|
return !excluded
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) List() ([]string, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
entries, err := os.ReadDir(r.dir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var models []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := e.Name()
|
||||||
|
if !isWhisperModelFile(name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
models = append(models, strings.TrimSuffix(name, filepath.Ext(name)))
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) Path(id string) (string, error) {
|
||||||
|
if id == "" {
|
||||||
|
return "", fmt.Errorf("model id is required")
|
||||||
|
}
|
||||||
|
if strings.Contains(id, "/") || strings.Contains(id, "..") {
|
||||||
|
return "", fmt.Errorf("invalid model id")
|
||||||
|
}
|
||||||
|
if isReservedModelSubdir(id) {
|
||||||
|
return "", fmt.Errorf("model %q not found", id)
|
||||||
|
}
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
candidates := []string{
|
||||||
|
filepath.Join(r.dir, id+".bin"),
|
||||||
|
filepath.Join(r.dir, id),
|
||||||
|
filepath.Join(r.dir, "ggml-"+id+".bin"),
|
||||||
|
}
|
||||||
|
for _, p := range candidates {
|
||||||
|
if st, err := os.Stat(p); err == nil && !st.IsDir() && isWhisperModelFile(filepath.Base(p)) {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("model %q not found", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) Delete(id string) error {
|
||||||
|
p, err := r.Path(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
return os.Remove(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) Import(id string, src io.Reader) error {
|
||||||
|
if id == "" {
|
||||||
|
return fmt.Errorf("model id is required")
|
||||||
|
}
|
||||||
|
if strings.Contains(id, "/") || strings.Contains(id, "..") {
|
||||||
|
return fmt.Errorf("invalid model id")
|
||||||
|
}
|
||||||
|
if isReservedModelSubdir(id) {
|
||||||
|
return fmt.Errorf("invalid model id")
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if err := os.MkdirAll(r.dir, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst := filepath.Join(r.dir, id+".bin")
|
||||||
|
f, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(f, src); err != nil {
|
||||||
|
os.Remove(dst)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) Open(id string) (*os.File, error) {
|
||||||
|
p, err := r.Path(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return os.Open(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve maps an OpenAI-style model name to a local whisper model id.
|
||||||
|
func (r *Registry) Resolve(id, defaultModel string) (string, error) {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" {
|
||||||
|
id = strings.TrimSpace(defaultModel)
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
models, err := r.List()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
return "", fmt.Errorf("model is required")
|
||||||
|
}
|
||||||
|
return models[0], nil
|
||||||
|
}
|
||||||
|
if id == "whisper-1" {
|
||||||
|
if dm := strings.TrimSpace(defaultModel); dm != "" {
|
||||||
|
if _, err := r.Path(dm); err == nil {
|
||||||
|
return dm, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models, err := r.List()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
return "", fmt.Errorf("no whisper models installed")
|
||||||
|
}
|
||||||
|
return models[0], nil
|
||||||
|
}
|
||||||
|
if _, err := r.Path(id); err == nil {
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
if alt := strings.TrimPrefix(id, "whisper-"); alt != id {
|
||||||
|
if _, err := r.Path(alt); err == nil {
|
||||||
|
return alt, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("model %q not found", id)
|
||||||
|
}
|
||||||
55
api/models_test.go
Normal file
55
api/models_test.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegistry_List_excludesAuxiliaryDirs(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
for _, name := range []string{"common.bin", "vad/vad.bin", "punctuation/model.onnx"} {
|
||||||
|
p := filepath.Join(root, name)
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(p, []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := NewRegistry(root)
|
||||||
|
list, err := r.List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(list) != 1 || list[0] != "common" {
|
||||||
|
t.Fatalf("list=%v want [common]", list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_List_excludesVADAtRoot(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(root, "vad.bin"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(root, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
list, err := NewRegistry(root).List()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(list) != 1 || list[0] != "ggml-small" {
|
||||||
|
t.Fatalf("list=%v", list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Path_rejectsReservedIDs(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
r := NewRegistry(root)
|
||||||
|
for _, id := range []string{"vad", "punctuation"} {
|
||||||
|
if _, err := r.Path(id); err == nil {
|
||||||
|
t.Fatalf("expected error for %q", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
121
api/openai.go
Normal file
121
api/openai.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) handleOpenAITranscriptions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.ParseMultipartForm(128 << 20); err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelID, err := s.models.Resolve(r.FormValue("model"), s.cfg.DefaultModel)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelPath, err := s.models.Path(modelID)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
audioPath, cleanup, err := s.saveUploadedOpenAI(r)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
stt := s.parseOpenAISTTOptions(r)
|
||||||
|
result, err := s.transcribe(r.Context(), modelPath, audioPath, stt)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(strings.TrimSpace(r.FormValue("response_format"))) {
|
||||||
|
case "text":
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(result.Text))
|
||||||
|
default:
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"text": result.Text})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleOpenAIModels(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeOpenAIError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ids, err := s.models.List()
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now().Unix()
|
||||||
|
data := make([]map[string]any, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
data = append(data, map[string]any{
|
||||||
|
"id": id,
|
||||||
|
"object": "model",
|
||||||
|
"created": now,
|
||||||
|
"owned_by": "go-whisper-api",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseOpenAISTTOptions(r *http.Request) sttOptions {
|
||||||
|
lang := strings.TrimSpace(r.FormValue("language"))
|
||||||
|
if lang == "" {
|
||||||
|
lang = s.cfg.Language
|
||||||
|
}
|
||||||
|
return sttOptions{
|
||||||
|
language: lang,
|
||||||
|
punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) saveUploadedOpenAI(r *http.Request) (path string, cleanup func(), err error) {
|
||||||
|
if r.MultipartForm == nil {
|
||||||
|
if err := r.ParseMultipartForm(128 << 20); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return saveUploadedRawFields(r, []string{"file", "audio", "wav"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenAIError(w http.ResponseWriter, code int, msg string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.WriteHeader(code)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"message": msg,
|
||||||
|
"type": openAIErrorType(code),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIErrorType(code int) string {
|
||||||
|
switch code {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
return "invalid_request_error"
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "authentication_error"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return "not_found_error"
|
||||||
|
default:
|
||||||
|
return "server_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
90
api/openai_test.go
Normal file
90
api/openai_test.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegistryResolve(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "ggml-large-v3-turbo.bin"), []byte("y"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
reg := NewRegistry(dir)
|
||||||
|
|
||||||
|
id, err := reg.Resolve("whisper-1", "ggml-large-v3-turbo")
|
||||||
|
if err != nil || id != "ggml-large-v3-turbo" {
|
||||||
|
t.Fatalf("whisper-1: id=%q err=%v", id, err)
|
||||||
|
}
|
||||||
|
id, err = reg.Resolve("whisper-large-v3-turbo", "")
|
||||||
|
if err != nil || id != "large-v3-turbo" {
|
||||||
|
t.Fatalf("whisper-large-v3-turbo: id=%q err=%v", id, err)
|
||||||
|
}
|
||||||
|
id, err = reg.Resolve("ggml-small", "")
|
||||||
|
if err != nil || id != "ggml-small" {
|
||||||
|
t.Fatalf("ggml-small: id=%q err=%v", id, err)
|
||||||
|
}
|
||||||
|
id, err = reg.Resolve("", "ggml-small")
|
||||||
|
if err != nil || id != "ggml-small" {
|
||||||
|
t.Fatalf("empty with default: id=%q err=%v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleOpenAITranscriptionsMissingFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
srv := &Server{
|
||||||
|
cfg: config.API{ModelsDir: dir, Language: "ru"},
|
||||||
|
models: NewRegistry(dir),
|
||||||
|
}
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
w := multipart.NewWriter(body)
|
||||||
|
_ = w.WriteField("model", "ggml-small")
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/audio/transcriptions", body)
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
srv.handleOpenAITranscriptions(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "file") {
|
||||||
|
t.Fatalf("expected file error, got %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleOpenAIModels(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "ggml-small.bin"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
srv := &Server{
|
||||||
|
cfg: config.API{ModelsDir: dir, Language: "ru"},
|
||||||
|
models: NewRegistry(dir),
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
srv.handleOpenAIModels(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "ggml-small") {
|
||||||
|
t.Fatalf("expected model list, got %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
98
api/queue_worker.go
Normal file
98
api/queue_worker.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) StartWorker(ctx context.Context) {
|
||||||
|
go s.queueWorker(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) queueWorker(ctx context.Context) {
|
||||||
|
const idlePoll = 2 * time.Second
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id, ok, err := s.cache.NextWaiting()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("cache queue scan")
|
||||||
|
if !sleepOrWake(ctx, s.queueWake, time.Second) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
s.processCacheTask(ctx, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !sleepOrWake(ctx, s.queueWake, idlePoll) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepOrWake(ctx context.Context, wake <-chan struct{}, d time.Duration) bool {
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-wake:
|
||||||
|
return true
|
||||||
|
case <-timer.C:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) processCacheTask(ctx context.Context, id string) {
|
||||||
|
params, _, err := s.cache.LoadParams(id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelPath, err := s.models.Path(params.Model)
|
||||||
|
if err != nil {
|
||||||
|
s.completeCacheTask(id, whisper.TranscriptResult{}, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.cache.SetStatus(id, statusProcessing, nil); err != nil {
|
||||||
|
log.Error().Err(err).Str("task", id).Msg("set processing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
audioPath, ok := s.cache.AudioPath(id)
|
||||||
|
if !ok {
|
||||||
|
s.completeCacheTask(id, whisper.TranscriptResult{}, "audio file missing")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stt := sttOptions{
|
||||||
|
language: params.Language,
|
||||||
|
punctuate: params.Punctuation,
|
||||||
|
speakers: params.Speakers,
|
||||||
|
numClusters: params.NumClusters,
|
||||||
|
}
|
||||||
|
if stt.language == "" {
|
||||||
|
stt.language = s.cfg.Language
|
||||||
|
}
|
||||||
|
result, err := s.transcribe(ctx, modelPath, audioPath, stt)
|
||||||
|
if err != nil {
|
||||||
|
s.completeCacheTask(id, result, err.Error())
|
||||||
|
log.Error().Err(err).Str("task", id).Msg("async transcribe")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.completeCacheTask(id, result, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) completeCacheTask(id string, result whisper.TranscriptResult, errMsg string) {
|
||||||
|
if err := s.cache.FinishWaiting(id, result, errMsg, nil); err != nil {
|
||||||
|
log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("finish task")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.cache.PromoteToReady(id); err != nil {
|
||||||
|
log.Error().Err(err).Str("task", id).Str("cache", s.cache.Root()).Msg("promote to ready")
|
||||||
|
}
|
||||||
|
}
|
||||||
28
api/result.go
Normal file
28
api/result.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import "go-whisper-api/whisper"
|
||||||
|
|
||||||
|
// sprResultReady builds GET /spr/result/{taskID} body for completed tasks (SPR-compatible).
|
||||||
|
func sprResultReady(params TaskParams) map[string]any {
|
||||||
|
words := params.Words
|
||||||
|
if words == nil {
|
||||||
|
words = []whisper.Word{}
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"model": params.Model,
|
||||||
|
"text": params.Text,
|
||||||
|
"words": words,
|
||||||
|
"toxicity": map[string]float64{
|
||||||
|
"insult": 0,
|
||||||
|
"obscenity": 0,
|
||||||
|
"threat": 0,
|
||||||
|
"politeness": 0,
|
||||||
|
},
|
||||||
|
"emotion": map[string]any{},
|
||||||
|
"voice_analysis": map[string]any{},
|
||||||
|
"status": "ready",
|
||||||
|
"taskID": params.ID,
|
||||||
|
"created": params.Created,
|
||||||
|
"processed": params.Processed,
|
||||||
|
}
|
||||||
|
}
|
||||||
26
api/result_test.go
Normal file
26
api/result_test.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSprResultReady(t *testing.T) {
|
||||||
|
body := sprResultReady(TaskParams{
|
||||||
|
ID: "701b3e84-5815-4baf-97a2-933ff820f16d",
|
||||||
|
Created: "2026-06-03 10:06:35",
|
||||||
|
Processed: "2026-06-03 10:07:48",
|
||||||
|
Status: "ready",
|
||||||
|
Model: "common",
|
||||||
|
Text: "hello",
|
||||||
|
})
|
||||||
|
for _, key := range []string{"model", "text", "words", "toxicity", "emotion", "voice_analysis", "status", "taskID", "created", "processed"} {
|
||||||
|
if body[key] == nil {
|
||||||
|
t.Fatalf("missing %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if body["status"] != "ready" {
|
||||||
|
t.Fatalf("status=%v", body["status"])
|
||||||
|
}
|
||||||
|
tox, ok := body["toxicity"].(map[string]float64)
|
||||||
|
if !ok || len(tox) != 4 {
|
||||||
|
t.Fatalf("toxicity=%v", body["toxicity"])
|
||||||
|
}
|
||||||
|
}
|
||||||
640
api/server.go
Normal file
640
api/server.go
Normal file
@ -0,0 +1,640 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
"go-whisper-api/diarization"
|
||||||
|
"go-whisper-api/punctuation"
|
||||||
|
"go-whisper-api/transcode"
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
cfg config.API
|
||||||
|
punctCfg config.Punctuation
|
||||||
|
diarCfg config.Diarization
|
||||||
|
transcode *transcode.Engine
|
||||||
|
modelPool *whisper.ModelPool
|
||||||
|
punct punctuation.Restorer
|
||||||
|
diarizer diarization.Engine
|
||||||
|
models *Registry
|
||||||
|
cache *DiskCache
|
||||||
|
mux *http.ServeMux
|
||||||
|
queueWake chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) (*Server, error) {
|
||||||
|
cfg = cfg.WithDefaults()
|
||||||
|
tc = tc.WithDefaults()
|
||||||
|
pc = pc.WithDefaults()
|
||||||
|
dc = dc.WithDefaults()
|
||||||
|
restorer, err := punctuation.New(pc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if pc.Active() && !restorer.Active() {
|
||||||
|
return nil, fmt.Errorf("punctuation is enabled but engine %q is not available", pc.Engine)
|
||||||
|
}
|
||||||
|
if cfg.ModelsDir == "" {
|
||||||
|
cfg.ModelsDir = "./models"
|
||||||
|
}
|
||||||
|
if cfg.Addr == "" {
|
||||||
|
cfg.Addr = ":8080"
|
||||||
|
}
|
||||||
|
if cfg.Threads == 0 {
|
||||||
|
cfg.Threads = uint(runtime.NumCPU())
|
||||||
|
}
|
||||||
|
cfg = cfg.WithDefaults()
|
||||||
|
if cfg.MaxContext == 0 {
|
||||||
|
cfg.MaxContext = 32
|
||||||
|
}
|
||||||
|
if cfg.BeamSize == 0 {
|
||||||
|
cfg.BeamSize = 5
|
||||||
|
}
|
||||||
|
if cfg.EntropyThold == 0 {
|
||||||
|
cfg.EntropyThold = 2.4
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(cfg.ModelsDir, 0o755); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cacheDir := cfg.CacheDir
|
||||||
|
if cacheDir == "" {
|
||||||
|
cacheDir = "./cache"
|
||||||
|
}
|
||||||
|
cache, err := NewDiskCache(cacheDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg.CacheDir = cache.Root()
|
||||||
|
if err := cache.RecoverInterrupted(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
diar, err := diarization.New(dc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s := &Server{
|
||||||
|
cfg: cfg,
|
||||||
|
punctCfg: pc,
|
||||||
|
diarCfg: dc,
|
||||||
|
transcode: transcode.NewEngine(tc.FFmpegPath),
|
||||||
|
modelPool: whisper.NewModelPool(),
|
||||||
|
punct: restorer,
|
||||||
|
diarizer: diar,
|
||||||
|
models: NewRegistry(cfg.ModelsDir),
|
||||||
|
cache: cache,
|
||||||
|
mux: http.NewServeMux(),
|
||||||
|
queueWake: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
s.routes()
|
||||||
|
go s.warmModels()
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) warmModels() {
|
||||||
|
ids, err := s.models.List()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, id := range ids {
|
||||||
|
path, err := s.models.Path(id)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := s.modelPool.WithModel(path, func(whisper.Model) error { return nil }); err != nil {
|
||||||
|
log.Warn().Err(err).Str("model", id).Msg("preload whisper model")
|
||||||
|
} else {
|
||||||
|
log.Info().Str("model", id).Msg("whisper model loaded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) routes() {
|
||||||
|
s.mux.HandleFunc("/", s.handleSwaggerUI)
|
||||||
|
s.mux.HandleFunc("/swagger.json", s.handleSwaggerJSON)
|
||||||
|
s.mux.HandleFunc("/spr/models", s.handleModels)
|
||||||
|
s.mux.HandleFunc("/spr/hostname", s.handleHostname)
|
||||||
|
s.mux.HandleFunc("/spr/queue", s.handleQueue)
|
||||||
|
s.mux.HandleFunc("/spr/stt/", s.handleSTT)
|
||||||
|
s.mux.HandleFunc("/spr/result/", s.handleResult)
|
||||||
|
s.mux.HandleFunc("/spr/queue/", s.handleQueueItem)
|
||||||
|
s.mux.HandleFunc("/spr/audio/", s.handleAudio)
|
||||||
|
s.mux.HandleFunc("/spr/waveform/", s.handleWaveform)
|
||||||
|
s.mux.HandleFunc("/spr/delete/", s.handleDeleteModel)
|
||||||
|
s.mux.HandleFunc("/spr/export/", s.handleExportModel)
|
||||||
|
s.mux.HandleFunc("/spr/import/", s.handleImportModel)
|
||||||
|
s.mux.HandleFunc("/v1/audio/transcriptions", s.handleOpenAITranscriptions)
|
||||||
|
s.mux.HandleFunc("/v1/audio/transcriptions/", s.handleOpenAITranscriptions)
|
||||||
|
s.mux.HandleFunc("/v1/models", s.handleOpenAIModels)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ListenAndServe() error {
|
||||||
|
log.Info().
|
||||||
|
Str("addr", s.cfg.Addr).
|
||||||
|
Str("models", s.cfg.ModelsDir).
|
||||||
|
Str("cache", s.cache.Root()).
|
||||||
|
Msg("starting API server")
|
||||||
|
return http.ListenAndServe(s.cfg.Addr, s.mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
models, err := s.models.List()
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"models": models})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleHostname(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
host, _ := os.Hostname()
|
||||||
|
cwd, _ := os.Getwd()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"error": 0,
|
||||||
|
"message": "Success",
|
||||||
|
"hostname": host,
|
||||||
|
"version": "go-whisper-api",
|
||||||
|
"cwd": cwd,
|
||||||
|
"models": s.cfg.ModelsDir,
|
||||||
|
"cache": s.cache.Root(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleQueue(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list, err := s.cache.List()
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleQueueItem(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/queue/")
|
||||||
|
if id == "" || !isValidTaskID(id) {
|
||||||
|
writeError(w, http.StatusBadRequest, "task id required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
s.handleQueueGet(w, id)
|
||||||
|
case http.MethodDelete:
|
||||||
|
if !s.cache.Delete(id) {
|
||||||
|
writeAPIError(w, http.StatusNotFound, "TaskNotFound")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"})
|
||||||
|
default:
|
||||||
|
methodNotAllowed(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleQueueGet(w http.ResponseWriter, id string) {
|
||||||
|
params, phase, err := s.cache.LoadParams(id)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusNotFound, "task not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch params.Status {
|
||||||
|
case string(statusReady):
|
||||||
|
if phase == cacheWaiting {
|
||||||
|
if err := s.cache.PromoteToReady(id); err != nil {
|
||||||
|
writeAPIError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "message": "Success"})
|
||||||
|
case string(statusError):
|
||||||
|
msg := params.Error
|
||||||
|
if msg == "" {
|
||||||
|
msg = "transcription failed"
|
||||||
|
}
|
||||||
|
writeAPIError(w, http.StatusNotFound, msg)
|
||||||
|
default:
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"error": 0,
|
||||||
|
"message": params.Status,
|
||||||
|
"status": params.Status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSTT(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelID := strings.TrimPrefix(r.URL.Path, "/spr/stt/")
|
||||||
|
if modelID == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "model id required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelPath, err := s.models.Path(modelID)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
audioPath, cleanup, err := s.saveUploadedWav(r)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stt, err := s.parseSTTOptions(r)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if queryAsync(r, s.cfg.DefaultAsync) {
|
||||||
|
taskID, err := s.enqueueAsync(r, modelID, audioPath, stt)
|
||||||
|
cleanup()
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"taskID": taskID})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
result, err := s.transcribe(r.Context(), modelPath, audioPath, stt)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusMethodNotAllowed, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"model": modelID,
|
||||||
|
"text": result.Text,
|
||||||
|
"words": result.Words,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleResult(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/result/")
|
||||||
|
if !isValidTaskID(id) {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, "task id required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
params, _, err := s.cache.LoadParams(id)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusNotFound, "TaskNotFound")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch params.Status {
|
||||||
|
case string(statusWaiting), string(statusProcessing):
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": params.Status})
|
||||||
|
case string(statusError):
|
||||||
|
msg := params.Error
|
||||||
|
if msg == "" {
|
||||||
|
msg = "TaskNotFound"
|
||||||
|
}
|
||||||
|
writeAPIError(w, http.StatusNotFound, msg)
|
||||||
|
case string(statusReady):
|
||||||
|
writeJSON(w, http.StatusOK, sprResultReady(params))
|
||||||
|
default:
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": params.Status})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAudio(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/audio/")
|
||||||
|
path, ok := s.cache.AudioPath(id)
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusNotFound, "task not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.ServeFile(w, r, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleWaveform(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/waveform/")
|
||||||
|
if !isValidTaskID(id) {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, "task id required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wf, err := s.cache.Waveform(id)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"error": 0, "waveform": wf})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodDelete {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/delete/")
|
||||||
|
if err := s.models.Delete(id); err != nil {
|
||||||
|
writeAPIError(w, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleExportModel(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/export/")
|
||||||
|
f, err := s.models.Open(id)
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusNotFound, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q.bin", id))
|
||||||
|
w.Header().Set("Content-Type", "application/octet-stream")
|
||||||
|
io.Copy(w, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleImportModel(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := strings.TrimPrefix(r.URL.Path, "/spr/import/")
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, header, err := r.FormFile("zip-model")
|
||||||
|
if err != nil {
|
||||||
|
file, header, err = r.FormFile("model")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, "model file required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
src := io.Reader(file)
|
||||||
|
if strings.HasSuffix(strings.ToLower(header.Filename), ".zip") {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, "zip import is not supported; upload .bin model file as zip-model field")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.models.Import(id, src); err != nil {
|
||||||
|
writeAPIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) enqueueAsync(r *http.Request, modelID, audioWavPath string, stt sttOptions) (string, error) {
|
||||||
|
id := uuid.New().String()
|
||||||
|
params := TaskParams{
|
||||||
|
ID: id,
|
||||||
|
Created: time.Now().Format("2006-01-02 15:04:05"),
|
||||||
|
Status: string(statusWaiting),
|
||||||
|
Model: modelID,
|
||||||
|
Language: stt.language,
|
||||||
|
Punctuation: stt.punctuate,
|
||||||
|
Speakers: stt.speakers,
|
||||||
|
NumClusters: stt.numClusters,
|
||||||
|
}
|
||||||
|
if err := s.cache.Enqueue(id, params, audioWavPath); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
s.notifyQueue()
|
||||||
|
log.Info().
|
||||||
|
Str("task", id).
|
||||||
|
Str("model", modelID).
|
||||||
|
Str("cache", s.cache.Root()).
|
||||||
|
Msg("enqueued async task")
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) transcribe(ctx context.Context, modelPath, audioPath string, stt sttOptions) (whisper.TranscriptResult, error) {
|
||||||
|
turns, err := s.runDiarization(ctx, audioPath, stt)
|
||||||
|
if err != nil {
|
||||||
|
return whisper.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
vad := s.cfg.VAD
|
||||||
|
if vad.Enabled {
|
||||||
|
vad.Model = vad.ResolveModelPath(s.cfg.ModelsDir)
|
||||||
|
}
|
||||||
|
cfg := &config.Whisper{
|
||||||
|
Model: modelPath,
|
||||||
|
AudioPath: audioPath,
|
||||||
|
Threads: s.cfg.Threads,
|
||||||
|
Language: stt.language,
|
||||||
|
Debug: s.cfg.Debug,
|
||||||
|
SpeedUp: s.cfg.SpeedUp,
|
||||||
|
Translate: s.cfg.Translate,
|
||||||
|
Prompt: s.cfg.Prompt,
|
||||||
|
MaxContext: s.cfg.MaxContext,
|
||||||
|
BeamSize: s.cfg.BeamSize,
|
||||||
|
EntropyThold: s.cfg.EntropyThold,
|
||||||
|
VAD: vad,
|
||||||
|
PrintProgress: false,
|
||||||
|
PrintSegment: false,
|
||||||
|
}
|
||||||
|
runOpts := s.whisperRunOpts(stt, turns)
|
||||||
|
if stt.punctuate && s.punct.Active() {
|
||||||
|
runOpts.PunctuateRestore = func(text string) (string, error) {
|
||||||
|
return punctuation.Apply(ctx, s.punct, true, text, stt.language)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result, err := whisper.TranscribeWithPool(s.modelPool, cfg, runOpts)
|
||||||
|
if err != nil {
|
||||||
|
return whisper.TranscriptResult{}, err
|
||||||
|
}
|
||||||
|
return applyGarbage(result, s.cfg.GarbagePatterns()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) runDiarization(ctx context.Context, audioPath string, stt sttOptions) ([]whisper.Turn, error) {
|
||||||
|
if !stt.speakers {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
samples, err := whisper.LoadPCM16Mono(audioPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("diarization audio: %w", err)
|
||||||
|
}
|
||||||
|
return s.diarizer.Process(ctx, samples, stt.numClusters)
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveUploadedRaw(r *http.Request) (path string, cleanup func(), err error) {
|
||||||
|
return saveUploadedRawFields(r, []string{"audio", "wav", "file"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveUploadedRawFields(r *http.Request, fieldNames []string) (path string, cleanup func(), err error) {
|
||||||
|
if r.MultipartForm == nil {
|
||||||
|
if err := r.ParseMultipartForm(128 << 20); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
file multipart.File
|
||||||
|
header *multipart.FileHeader
|
||||||
|
found bool
|
||||||
|
)
|
||||||
|
for _, name := range fieldNames {
|
||||||
|
file, header, err = r.FormFile(name)
|
||||||
|
if err == nil {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return "", nil, fmt.Errorf("audio file required (form field: %s)", strings.Join(fieldNames, ", "))
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
dir, err := config.MkdirTemp("go-whisper-api-upload-*")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
cleanup = func() { os.RemoveAll(dir) }
|
||||||
|
base := "input"
|
||||||
|
if header != nil {
|
||||||
|
if ext := filepath.Ext(header.Filename); ext != "" {
|
||||||
|
base += ext
|
||||||
|
}
|
||||||
|
}
|
||||||
|
raw := filepath.Join(dir, base)
|
||||||
|
out, err := os.Create(raw)
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(out, file); err != nil {
|
||||||
|
out.Close()
|
||||||
|
cleanup()
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
out.Close()
|
||||||
|
return raw, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) saveUploadedWav(r *http.Request) (path string, cleanup func(), err error) {
|
||||||
|
raw, cleanup, err := saveUploadedRaw(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
dst := filepath.Join(filepath.Dir(raw), "audio.wav")
|
||||||
|
if err := s.transcode.Transcode(r.Context(), raw, dst, transcode.WhisperOptions()); err != nil {
|
||||||
|
cleanup()
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return dst, cleanup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryBoolDefault(r *http.Request, key string, def bool) bool {
|
||||||
|
v := r.URL.Query().Get(key)
|
||||||
|
if v == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
b, err := strconv.ParseBool(v)
|
||||||
|
if err != nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryAsync(r *http.Request, defaultAsync bool) bool {
|
||||||
|
v := r.URL.Query().Get("async")
|
||||||
|
if v == "" {
|
||||||
|
return defaultAsync
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
return defaultAsync
|
||||||
|
}
|
||||||
|
return n == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryInt(r *http.Request, key string, def int) int {
|
||||||
|
v := r.URL.Query().Get(key)
|
||||||
|
if v == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, code int, v any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.WriteHeader(code)
|
||||||
|
_ = json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(w http.ResponseWriter, code int, msg string) {
|
||||||
|
http.Error(w, msg, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeAPIError(w http.ResponseWriter, code int, msg string) {
|
||||||
|
writeJSON(w, code, map[string]any{"error": 1, "message": msg})
|
||||||
|
}
|
||||||
|
|
||||||
|
func methodNotAllowed(w http.ResponseWriter) {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) notifyQueue() {
|
||||||
|
select {
|
||||||
|
case s.queueWake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Run(ctx context.Context, cfg config.API, tc config.Transcode, pc config.Punctuation, dc config.Diarization) error {
|
||||||
|
srv, err := NewServer(cfg, tc, pc, dc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer srv.modelPool.Close()
|
||||||
|
srv.StartWorker(ctx)
|
||||||
|
hs := &http.Server{Addr: cfg.Addr, Handler: srv.mux}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = hs.Shutdown(shutdownCtx)
|
||||||
|
}()
|
||||||
|
if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
punctuation.Close(srv.punct)
|
||||||
|
srv.diarizer.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
23
api/swagger-ui.html
Normal file
23
api/swagger-ui.html
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>go-whisper-api</title>
|
||||||
|
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui"></div>
|
||||||
|
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||||
|
<script>
|
||||||
|
window.onload = function () {
|
||||||
|
window.ui = SwaggerUIBundle({
|
||||||
|
url: '/swagger.json',
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
presets: [SwaggerUIBundle.presets.apis],
|
||||||
|
docExpansion: 'list',
|
||||||
|
displayRequestDuration: true,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
34
api/swagger.go
Normal file
34
api/swagger.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed swagger.json
|
||||||
|
var swaggerSpec []byte
|
||||||
|
|
||||||
|
//go:embed swagger-ui.html
|
||||||
|
var swaggerUI []byte
|
||||||
|
|
||||||
|
func (s *Server) handleSwaggerJSON(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
w.Write(swaggerSpec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSwaggerUI(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
methodNotAllowed(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.Write(swaggerUI)
|
||||||
|
}
|
||||||
457
api/swagger.json
Normal file
457
api/swagger.json
Normal file
@ -0,0 +1,457 @@
|
|||||||
|
{
|
||||||
|
"swagger": "2.0",
|
||||||
|
"basePath": "/",
|
||||||
|
"paths": {
|
||||||
|
"/spr/audio/{taskID}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "taskID",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_audio_stt",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/delete/{id}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"delete": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "delete_model_delete",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/export/{id}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_model_export",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/hostname": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_hostname_class",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/import/{id}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "post_model_import",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "zip-model",
|
||||||
|
"in": "formData",
|
||||||
|
"type": "file",
|
||||||
|
"required": true,
|
||||||
|
"description": "prepared model zip file"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"consumes": [
|
||||||
|
"multipart/form-data"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/models": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/modelList"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_model_list",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/queue": {
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_queue_stt",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/queue/{taskID}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "taskID",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"delete": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"error": { "type": "integer" },
|
||||||
|
"message": { "type": "string" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "TaskNotFound"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "delete_queue_del_stt",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/result/{taskID}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "taskID",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"404": {
|
||||||
|
"description": "Not found",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/error"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"200": {
|
||||||
|
"description": "Success",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/resultTTS"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_result_stt",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/stt/{id}": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"405": {
|
||||||
|
"description": "Error",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/error"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Not found",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/error"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"200": {
|
||||||
|
"description": "Success",
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/definitions/modelTTS"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "post_model_test",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string",
|
||||||
|
"description": "NN Model ID"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "wav",
|
||||||
|
"in": "formData",
|
||||||
|
"type": "file",
|
||||||
|
"description": "file to recognize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "async",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "async mode (default 1: enqueue to cache/waiting; 0: sync text response)",
|
||||||
|
"default": 1,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "speakers",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "find speakers",
|
||||||
|
"default": 0,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "speaker_counter",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "number of speakers. 0 for autodetect. -1 disable speaker detection.",
|
||||||
|
"default": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "normalization",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "normalize text",
|
||||||
|
"default": 1,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "punctuation",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "punctuate text",
|
||||||
|
"default": 1,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "toxicity",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "toxicity analyzer",
|
||||||
|
"default": 1,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "emotion",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "emotion analyzer",
|
||||||
|
"default": 0,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "voice_analyzer",
|
||||||
|
"in": "query",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "voice analyzer",
|
||||||
|
"default": 1,
|
||||||
|
"enum": [
|
||||||
|
0,
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "vad",
|
||||||
|
"in": "query",
|
||||||
|
"type": "string",
|
||||||
|
"description": "VAD type",
|
||||||
|
"default": "webrtc",
|
||||||
|
"enum": [
|
||||||
|
"webrtc"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "classifiers",
|
||||||
|
"in": "query",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON with classification models to process each sentence"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "webhook",
|
||||||
|
"in": "query",
|
||||||
|
"type": "string",
|
||||||
|
"description": "webhook url to send stt async result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"consumes": [
|
||||||
|
"multipart/form-data"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/spr/waveform/{taskID}": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "taskID",
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"get": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Success"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_audioarray_stt",
|
||||||
|
"tags": [
|
||||||
|
"spr"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"info": {
|
||||||
|
"title": "go-whisper-api",
|
||||||
|
"version": "5.008 release",
|
||||||
|
"description": "Whisper speech-to-text API (SPR-compatible)"
|
||||||
|
},
|
||||||
|
"produces": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"consumes": [
|
||||||
|
"application/json"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
{
|
||||||
|
"name": "spr",
|
||||||
|
"description": "NN Model operations"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"definitions": {
|
||||||
|
"modelList": {
|
||||||
|
"properties": {
|
||||||
|
"models": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "NN Model ID"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"properties": {
|
||||||
|
"error": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Error flag"
|
||||||
|
},
|
||||||
|
"message": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Error description"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
"modelTTS": {
|
||||||
|
"required": [
|
||||||
|
"text"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Recognized text"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
"resultTTS": {
|
||||||
|
"properties": {
|
||||||
|
"model": { "type": "string" },
|
||||||
|
"text": { "type": "string" },
|
||||||
|
"words": { "type": "array" },
|
||||||
|
"toxicity": { "type": "object" },
|
||||||
|
"emotion": { "type": "object" },
|
||||||
|
"voice_analysis": { "type": "object" },
|
||||||
|
"status": { "type": "string" },
|
||||||
|
"taskID": { "type": "string" },
|
||||||
|
"created": { "type": "string" },
|
||||||
|
"processed": { "type": "string" }
|
||||||
|
},
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"ParseError": {
|
||||||
|
"description": "When a mask can't be parsed"
|
||||||
|
},
|
||||||
|
"MaskError": {
|
||||||
|
"description": "When any error occurs on mask"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
10
api/tasks.go
Normal file
10
api/tasks.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
type taskStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
statusWaiting taskStatus = "waiting"
|
||||||
|
statusProcessing taskStatus = "processing"
|
||||||
|
statusReady taskStatus = "ready"
|
||||||
|
statusError taskStatus = "error"
|
||||||
|
)
|
||||||
77
api/transcribe_opts.go
Normal file
77
api/transcribe_opts.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sttOptions struct {
|
||||||
|
language string
|
||||||
|
punctuate bool
|
||||||
|
speakers bool
|
||||||
|
numClusters int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseSTTOptions(r *http.Request) (sttOptions, error) {
|
||||||
|
opts := sttOptions{
|
||||||
|
language: resolveLanguage(r, s.cfg.Language),
|
||||||
|
punctuate: s.punctCfg.ShouldApplyAPI(r, s.cfg.DefaultPunctuation),
|
||||||
|
}
|
||||||
|
sp, clusters, err := querySpeakers(r, s.cfg.DefaultSpeakers, s.diarCfg, s.diarizer.Active())
|
||||||
|
if err != nil {
|
||||||
|
return opts, err
|
||||||
|
}
|
||||||
|
opts.speakers = sp
|
||||||
|
opts.numClusters = clusters
|
||||||
|
return opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveLanguage(r *http.Request, defaultLang string) string {
|
||||||
|
if v := strings.TrimSpace(r.URL.Query().Get("language")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(defaultLang)
|
||||||
|
}
|
||||||
|
|
||||||
|
func querySpeakers(r *http.Request, defaultOn bool, dc config.Diarization, diarizerActive bool) (enabled bool, numClusters int, err error) {
|
||||||
|
counter := queryInt(r, "speaker_counter", -999)
|
||||||
|
if counter == -1 {
|
||||||
|
return false, 0, nil
|
||||||
|
}
|
||||||
|
speakersQ := r.URL.Query().Get("speakers")
|
||||||
|
enabled = defaultOn
|
||||||
|
if speakersQ != "" {
|
||||||
|
enabled = queryInt(r, "speakers", 0) == 1
|
||||||
|
}
|
||||||
|
if !enabled {
|
||||||
|
return false, 0, nil
|
||||||
|
}
|
||||||
|
if !dc.Active() {
|
||||||
|
return false, 0, fmt.Errorf("speaker diarization is disabled in config (diarization.enabled: true)")
|
||||||
|
}
|
||||||
|
if !diarizerActive {
|
||||||
|
return false, 0, fmt.Errorf("speaker diarization requires server built with -tags sherpa (make build-sherpa) and models (make download-diarization-models)")
|
||||||
|
}
|
||||||
|
if counter > 0 {
|
||||||
|
numClusters = counter
|
||||||
|
} else if dc.NumClusters > 0 {
|
||||||
|
numClusters = dc.NumClusters
|
||||||
|
}
|
||||||
|
return true, numClusters, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) whisperRunOpts(stt sttOptions, turns []whisper.Turn) whisper.RunOptions {
|
||||||
|
t := s.cfg.Transcript.WithDefaults()
|
||||||
|
return whisper.RunOptions{
|
||||||
|
Turns: turns,
|
||||||
|
Format: whisper.FormatOptions{
|
||||||
|
PauseGap: t.PauseGapDuration(),
|
||||||
|
SpeakerLabel: t.SpeakerLabel,
|
||||||
|
UseSpeakers: stt.speakers && len(turns) > 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
48
api/waveform.go
Normal file
48
api/waveform.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func waveformFromWav(path string, buckets int) ([]float64, error) {
|
||||||
|
if buckets <= 0 {
|
||||||
|
buckets = 256
|
||||||
|
}
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
dec := wav.NewDecoder(f)
|
||||||
|
buf, err := dec.FullPCMBuffer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
samples := buf.AsFloat32Buffer().Data
|
||||||
|
if len(samples) == 0 {
|
||||||
|
return make([]float64, buckets), nil
|
||||||
|
}
|
||||||
|
chunk := len(samples) / buckets
|
||||||
|
if chunk < 1 {
|
||||||
|
chunk = 1
|
||||||
|
}
|
||||||
|
out := make([]float64, 0, buckets)
|
||||||
|
for i := 0; i < len(samples) && len(out) < buckets; i += chunk {
|
||||||
|
end := i + chunk
|
||||||
|
if end > len(samples) {
|
||||||
|
end = len(samples)
|
||||||
|
}
|
||||||
|
peak := 0.0
|
||||||
|
for _, s := range samples[i:end] {
|
||||||
|
v := math.Abs(float64(s))
|
||||||
|
if v > peak {
|
||||||
|
peak = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, math.Round(peak*1000)/1000)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
55
config.yaml.example
Normal file
55
config.yaml.example
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
api:
|
||||||
|
addr: "0.0.0.0:6183"
|
||||||
|
models_dir: "./models"
|
||||||
|
cache_dir: "./cache"
|
||||||
|
# Open WebUI / OpenAI STT: model id when client sends whisper-1 (e.g. ggml-large-v3-turbo)
|
||||||
|
default_model: ggml-large-v3-turbo
|
||||||
|
threads: 16
|
||||||
|
language: ru
|
||||||
|
transcript:
|
||||||
|
pause_gap_sec: 1.5
|
||||||
|
speaker_label: "Спикер"
|
||||||
|
default_speakers: false
|
||||||
|
debug: false
|
||||||
|
speedup: false
|
||||||
|
translate: false
|
||||||
|
prompt: ""
|
||||||
|
max_context: 32
|
||||||
|
beam_size: 5
|
||||||
|
entropy_thold: 2.4
|
||||||
|
vad:
|
||||||
|
enabled: false
|
||||||
|
model: vad/vad.bin
|
||||||
|
threshold: 0.5
|
||||||
|
min_speech_duration_ms: 250
|
||||||
|
min_silence_duration_ms: 100
|
||||||
|
speech_pad_ms: 30
|
||||||
|
samples_overlap: 0.1
|
||||||
|
default_punctuation: true
|
||||||
|
default_async: true
|
||||||
|
garbage:
|
||||||
|
- "*выбая*"
|
||||||
|
|
||||||
|
# transcode: pure Go (wav, mp3, flac, ogg, m4a, mp4, aac) — no ffmpeg required
|
||||||
|
transcode: {}
|
||||||
|
|
||||||
|
diarization:
|
||||||
|
enabled: false
|
||||||
|
model_dir: ./models/diarization
|
||||||
|
segmentation_model: pyannote-segmentation-3-0/model.onnx
|
||||||
|
embedding_model: 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||||
|
num_threads: 2
|
||||||
|
num_clusters: 0
|
||||||
|
clustering_threshold: 0.5
|
||||||
|
|
||||||
|
punctuation:
|
||||||
|
enabled: true
|
||||||
|
default_on: true
|
||||||
|
# off | heuristic | xlm | sherpa | sherpa-online | http
|
||||||
|
engine: xlm
|
||||||
|
model_dir: ./models/punctuation/xlm-roberta
|
||||||
|
model_file: model.onnx
|
||||||
|
sp_model: sp.model
|
||||||
|
config_file: config.yaml
|
||||||
|
apply_sbd: true
|
||||||
|
num_threads: 2
|
||||||
42
config/api.go
Normal file
42
config/api.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type API struct {
|
||||||
|
Addr string `yaml:"addr"`
|
||||||
|
ModelsDir string `yaml:"models_dir"`
|
||||||
|
CacheDir string `yaml:"cache_dir"`
|
||||||
|
// DefaultModel: whisper model id for OpenAI /v1/audio/transcriptions (maps whisper-1).
|
||||||
|
DefaultModel string `yaml:"default_model"`
|
||||||
|
Language string `yaml:"language"`
|
||||||
|
Transcript Transcript `yaml:"transcript"`
|
||||||
|
DefaultSpeakers bool `yaml:"default_speakers"`
|
||||||
|
Prompt string `yaml:"prompt"`
|
||||||
|
Threads uint `yaml:"threads"`
|
||||||
|
MaxContext uint `yaml:"max_context"`
|
||||||
|
BeamSize uint `yaml:"beam_size"`
|
||||||
|
EntropyThold float64 `yaml:"entropy_thold"`
|
||||||
|
VAD VAD `yaml:"vad"`
|
||||||
|
Debug bool `yaml:"debug"`
|
||||||
|
SpeedUp bool `yaml:"speedup"`
|
||||||
|
Translate bool `yaml:"translate"`
|
||||||
|
DefaultPunctuation bool `yaml:"default_punctuation"`
|
||||||
|
// DefaultAsync: STT via API enqueues to cache/waiting and returns taskID (use ?async=0 for sync).
|
||||||
|
DefaultAsync bool `yaml:"default_async"`
|
||||||
|
// Garbage: artifact substrings removed from transcript text and words (default includes *выбая*).
|
||||||
|
Garbage []string `yaml:"garbage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a API) WithDefaults() API {
|
||||||
|
a.VAD = a.VAD.WithDefaults()
|
||||||
|
a.Transcript = a.Transcript.WithDefaults()
|
||||||
|
if strings.TrimSpace(a.Language) == "" {
|
||||||
|
a.Language = "ru"
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// GarbagePatterns returns garbage filter list (never empty unless explicitly set to [] in YAML).
|
||||||
|
func (a API) GarbagePatterns() []string {
|
||||||
|
return a.garbagePatterns()
|
||||||
|
}
|
||||||
72
config/diarization.go
Normal file
72
config/diarization.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Diarization struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
ModelDir string `yaml:"model_dir"`
|
||||||
|
SegmentationModel string `yaml:"segmentation_model"`
|
||||||
|
EmbeddingModel string `yaml:"embedding_model"`
|
||||||
|
NumThreads int `yaml:"num_threads"`
|
||||||
|
NumClusters int `yaml:"num_clusters"`
|
||||||
|
ClusteringThreshold float32 `yaml:"clustering_threshold"`
|
||||||
|
MinDurationOn float32 `yaml:"min_duration_on"`
|
||||||
|
MinDurationOff float32 `yaml:"min_duration_off"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Diarization) WithDefaults() Diarization {
|
||||||
|
if d.ModelDir == "" {
|
||||||
|
d.ModelDir = "./models/diarization"
|
||||||
|
}
|
||||||
|
if d.SegmentationModel == "" {
|
||||||
|
d.SegmentationModel = "pyannote-segmentation-3-0/model.onnx"
|
||||||
|
}
|
||||||
|
if d.EmbeddingModel == "" {
|
||||||
|
d.EmbeddingModel = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
|
||||||
|
}
|
||||||
|
if d.NumThreads <= 0 {
|
||||||
|
d.NumThreads = 2
|
||||||
|
}
|
||||||
|
if d.ClusteringThreshold <= 0 {
|
||||||
|
d.ClusteringThreshold = 0.5
|
||||||
|
}
|
||||||
|
if d.MinDurationOn <= 0 {
|
||||||
|
d.MinDurationOn = 0.3
|
||||||
|
}
|
||||||
|
if d.MinDurationOff <= 0 {
|
||||||
|
d.MinDurationOff = 0.5
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Diarization) SegmentationPath() string {
|
||||||
|
return resolveModelPath(d.ModelDir, d.SegmentationModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Diarization) EmbeddingPath() string {
|
||||||
|
return resolveModelPath(d.ModelDir, d.EmbeddingModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Diarization) ModelsPresent() bool {
|
||||||
|
if _, err := os.Stat(d.SegmentationPath()); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(d.EmbeddingPath()); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveModelPath(dir, name string) string {
|
||||||
|
if filepath.IsAbs(name) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return filepath.Join(dir, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Diarization) Active() bool {
|
||||||
|
return d.Enabled
|
||||||
|
}
|
||||||
53
config/file.go
Normal file
53
config/file.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
API API `yaml:"api"`
|
||||||
|
Transcode Transcode `yaml:"transcode"`
|
||||||
|
Punctuation Punctuation `yaml:"punctuation"`
|
||||||
|
Diarization Diarization `yaml:"diarization"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultFile() File {
|
||||||
|
return File{
|
||||||
|
API: API{
|
||||||
|
Addr: ":8080",
|
||||||
|
ModelsDir: "./models",
|
||||||
|
CacheDir: "./cache",
|
||||||
|
Threads: uint(runtime.NumCPU()),
|
||||||
|
Language: "ru",
|
||||||
|
Transcript: Transcript{}.WithDefaults(),
|
||||||
|
MaxContext: 32,
|
||||||
|
BeamSize: 5,
|
||||||
|
EntropyThold: 2.4,
|
||||||
|
DefaultAsync: true,
|
||||||
|
Garbage: DefaultGarbage(),
|
||||||
|
},
|
||||||
|
Diarization: Diarization{}.WithDefaults(),
|
||||||
|
Transcode: Transcode{}.WithDefaults(),
|
||||||
|
Punctuation: Punctuation{Enabled: false, Engine: "off"}.WithDefaults(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadFile(path string) (File, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return File{}, err
|
||||||
|
}
|
||||||
|
cfg := DefaultFile()
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return File{}, fmt.Errorf("parse config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f File) APIConfig() API {
|
||||||
|
return f.API
|
||||||
|
}
|
||||||
35
config/file_test.go
Normal file
35
config/file_test.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadFile(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "config.yaml")
|
||||||
|
err := os.WriteFile(path, []byte(`
|
||||||
|
api:
|
||||||
|
addr: ":9090"
|
||||||
|
models_dir: "/data/models"
|
||||||
|
language: ru
|
||||||
|
`), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := LoadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if cfg.API.Addr != ":9090" {
|
||||||
|
t.Fatalf("addr: got %q", cfg.API.Addr)
|
||||||
|
}
|
||||||
|
if cfg.API.ModelsDir != "/data/models" {
|
||||||
|
t.Fatalf("models_dir: got %q", cfg.API.ModelsDir)
|
||||||
|
}
|
||||||
|
if cfg.API.Language != "ru" {
|
||||||
|
t.Fatalf("language: got %q", cfg.API.Language)
|
||||||
|
}
|
||||||
|
}
|
||||||
13
config/garbage.go
Normal file
13
config/garbage.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
// DefaultGarbage returns STT artifact tokens filtered from API/CLI transcript output.
|
||||||
|
func DefaultGarbage() []string {
|
||||||
|
return []string{"*выбая*"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a API) garbagePatterns() []string {
|
||||||
|
if a.Garbage == nil {
|
||||||
|
return DefaultGarbage()
|
||||||
|
}
|
||||||
|
return a.Garbage
|
||||||
|
}
|
||||||
23
config/garbage_test.go
Normal file
23
config/garbage_test.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDefaultGarbage(t *testing.T) {
|
||||||
|
if len(DefaultGarbage()) == 0 || DefaultGarbage()[0] != "*выбая*" {
|
||||||
|
t.Fatalf("got %v", DefaultGarbage())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPI_GarbagePatterns_default(t *testing.T) {
|
||||||
|
p := (API{}).GarbagePatterns()
|
||||||
|
if len(p) != 1 || p[0] != "*выбая*" {
|
||||||
|
t.Fatalf("got %v", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPI_GarbagePatterns_explicitEmpty(t *testing.T) {
|
||||||
|
p := (API{Garbage: []string{}}).GarbagePatterns()
|
||||||
|
if len(p) != 0 {
|
||||||
|
t.Fatalf("got %v", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
136
config/merge.go
Normal file
136
config/merge.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadResolved(path string) (File, error) {
|
||||||
|
if path == "" {
|
||||||
|
if _, err := os.Stat("config.yaml"); err == nil {
|
||||||
|
path = "config.yaml"
|
||||||
|
} else {
|
||||||
|
return DefaultFile(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return LoadFile(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeVAD(c *cli.Context, v VAD) VAD {
|
||||||
|
if c.IsSet("vad") {
|
||||||
|
v.Enabled = c.Bool("vad")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-model") {
|
||||||
|
v.Model = c.String("vad-model")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-threshold") {
|
||||||
|
v.Threshold = c.Float64("vad-threshold")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-min-speech-ms") {
|
||||||
|
v.MinSpeechMs = c.Int("vad-min-speech-ms")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-min-silence-ms") {
|
||||||
|
v.MinSilenceMs = c.Int("vad-min-silence-ms")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-max-speech-sec") {
|
||||||
|
v.MaxSpeechSec = c.Float64("vad-max-speech-sec")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-speech-pad-ms") {
|
||||||
|
v.SpeechPadMs = c.Int("vad-speech-pad-ms")
|
||||||
|
}
|
||||||
|
if c.IsSet("vad-samples-overlap") {
|
||||||
|
v.SamplesOverlap = c.Float64("vad-samples-overlap")
|
||||||
|
}
|
||||||
|
return v.WithDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeAPI(c *cli.Context, a API) API {
|
||||||
|
if c.IsSet("addr") {
|
||||||
|
a.Addr = c.String("addr")
|
||||||
|
}
|
||||||
|
if c.IsSet("models-dir") {
|
||||||
|
a.ModelsDir = c.String("models-dir")
|
||||||
|
}
|
||||||
|
if c.IsSet("cache-dir") {
|
||||||
|
a.CacheDir = c.String("cache-dir")
|
||||||
|
}
|
||||||
|
if c.IsSet("threads") {
|
||||||
|
a.Threads = c.Uint("threads")
|
||||||
|
}
|
||||||
|
if c.IsSet("language") {
|
||||||
|
a.Language = c.String("language")
|
||||||
|
}
|
||||||
|
if c.IsSet("debug") {
|
||||||
|
a.Debug = c.Bool("debug")
|
||||||
|
}
|
||||||
|
if c.IsSet("speedup") {
|
||||||
|
a.SpeedUp = c.Bool("speedup")
|
||||||
|
}
|
||||||
|
if c.IsSet("translate") {
|
||||||
|
a.Translate = c.Bool("translate")
|
||||||
|
}
|
||||||
|
if c.IsSet("prompt") {
|
||||||
|
a.Prompt = c.String("prompt")
|
||||||
|
}
|
||||||
|
if c.IsSet("max-context") {
|
||||||
|
a.MaxContext = c.Uint("max-context")
|
||||||
|
}
|
||||||
|
if c.IsSet("beam-size") {
|
||||||
|
a.BeamSize = c.Uint("beam-size")
|
||||||
|
}
|
||||||
|
if c.IsSet("entropy-thold") {
|
||||||
|
a.EntropyThold = c.Float64("entropy-thold")
|
||||||
|
}
|
||||||
|
a.VAD = mergeVAD(c, a.VAD)
|
||||||
|
if c.IsSet("default-punctuation") {
|
||||||
|
a.DefaultPunctuation = c.Bool("default-punctuation")
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func APIFromCLI(c *cli.Context) (API, error) {
|
||||||
|
file, err := LoadResolved(c.String("config"))
|
||||||
|
if err != nil {
|
||||||
|
return API{}, err
|
||||||
|
}
|
||||||
|
return mergeAPI(c, file.API), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TranscodeFromCLI(c *cli.Context) (Transcode, error) {
|
||||||
|
file, err := LoadResolved(c.String("config"))
|
||||||
|
if err != nil {
|
||||||
|
return Transcode{}, err
|
||||||
|
}
|
||||||
|
return file.Transcode.WithDefaults(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergePunctuation(c *cli.Context, p Punctuation) Punctuation {
|
||||||
|
if c.IsSet("punctuation-enabled") {
|
||||||
|
p.Enabled = c.Bool("punctuation-enabled")
|
||||||
|
}
|
||||||
|
if c.IsSet("punctuation-engine") {
|
||||||
|
p.Engine = c.String("punctuation-engine")
|
||||||
|
}
|
||||||
|
if c.IsSet("punctuation-default-on") {
|
||||||
|
p.DefaultOn = c.Bool("punctuation-default-on")
|
||||||
|
}
|
||||||
|
return p.WithDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
func PunctuationFromCLI(c *cli.Context) (Punctuation, error) {
|
||||||
|
file, err := LoadResolved(c.String("config"))
|
||||||
|
if err != nil {
|
||||||
|
return Punctuation{}, err
|
||||||
|
}
|
||||||
|
return mergePunctuation(c, file.Punctuation), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DiarizationFromCLI(c *cli.Context) (Diarization, error) {
|
||||||
|
file, err := LoadResolved(c.String("config"))
|
||||||
|
if err != nil {
|
||||||
|
return Diarization{}, err
|
||||||
|
}
|
||||||
|
return file.Diarization.WithDefaults(), nil
|
||||||
|
}
|
||||||
|
|
||||||
28
config/merge_test.go
Normal file
28
config/merge_test.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultFile_apiDefaults(t *testing.T) {
|
||||||
|
f := DefaultFile()
|
||||||
|
if f.API.Addr == "" {
|
||||||
|
t.Fatal("expected API listen addr")
|
||||||
|
}
|
||||||
|
if f.API.ModelsDir == "" {
|
||||||
|
t.Fatal("expected models_dir")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMkdirTemp_usesTmpRoot(t *testing.T) {
|
||||||
|
dir, err := MkdirTemp("go-whisper-api-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() { _ = os.RemoveAll(dir) }()
|
||||||
|
if !strings.HasPrefix(dir, TempRoot+"/") {
|
||||||
|
t.Fatalf("temp dir should be under %s, got %s", TempRoot, dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
125
config/punctuation.go
Normal file
125
config/punctuation.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p Punctuation) XLMJoinSentences() bool {
|
||||||
|
return p.ApplySBD
|
||||||
|
}
|
||||||
|
|
||||||
|
type Punctuation struct {
|
||||||
|
Command []string `yaml:"command"`
|
||||||
|
NumThreads int `yaml:"num_threads"`
|
||||||
|
TimeoutSec int `yaml:"timeout_sec"`
|
||||||
|
Engine string `yaml:"engine"`
|
||||||
|
ModelDir string `yaml:"model_dir"`
|
||||||
|
ModelFile string `yaml:"model_file"`
|
||||||
|
SPModel string `yaml:"sp_model"`
|
||||||
|
ConfigFile string `yaml:"config_file"`
|
||||||
|
BpeVocab string `yaml:"bpe_vocab"`
|
||||||
|
HTTPURL string `yaml:"http_url"`
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
DefaultOn bool `yaml:"default_on"`
|
||||||
|
ApplySBD bool `yaml:"apply_sbd"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) WithDefaults() Punctuation {
|
||||||
|
if p.Engine == "" {
|
||||||
|
p.Engine = "heuristic"
|
||||||
|
}
|
||||||
|
engine := strings.ToLower(strings.TrimSpace(p.Engine))
|
||||||
|
if p.ModelDir == "" {
|
||||||
|
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
|
||||||
|
p.ModelDir = "./models/punctuation/xlm-roberta"
|
||||||
|
} else {
|
||||||
|
p.ModelDir = "./models/punctuation/ct-transformer-zh-en-int8"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.ModelFile == "" {
|
||||||
|
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
|
||||||
|
p.ModelFile = "model.onnx"
|
||||||
|
} else {
|
||||||
|
p.ModelFile = "model.int8.onnx"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.NumThreads <= 0 {
|
||||||
|
p.NumThreads = 2
|
||||||
|
}
|
||||||
|
if p.TimeoutSec <= 0 {
|
||||||
|
p.TimeoutSec = 120
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) Timeout() time.Duration {
|
||||||
|
p = p.WithDefaults()
|
||||||
|
return time.Duration(p.TimeoutSec) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) Active() bool {
|
||||||
|
p = p.WithDefaults()
|
||||||
|
return p.Enabled && p.Engine != "" && !strings.EqualFold(p.Engine, "off")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) ModelPath() string {
|
||||||
|
p = p.WithDefaults()
|
||||||
|
return filepath.Join(p.ModelDir, p.ModelFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) SPModelPath() string {
|
||||||
|
p = p.WithDefaults()
|
||||||
|
if p.SPModel == "" {
|
||||||
|
return filepath.Join(p.ModelDir, "sp.model")
|
||||||
|
}
|
||||||
|
return filepath.Join(p.ModelDir, p.SPModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) XLMConfigPath() string {
|
||||||
|
p = p.WithDefaults()
|
||||||
|
if p.ConfigFile == "" {
|
||||||
|
return filepath.Join(p.ModelDir, "config.yaml")
|
||||||
|
}
|
||||||
|
return filepath.Join(p.ModelDir, p.ConfigFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) BpeVocabPath() string {
|
||||||
|
p = p.WithDefaults()
|
||||||
|
if p.BpeVocab == "" {
|
||||||
|
return filepath.Join(p.ModelDir, "bpe.vocab")
|
||||||
|
}
|
||||||
|
return filepath.Join(p.ModelDir, p.BpeVocab)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Punctuation) ShouldApplyAPI(r *http.Request, apiDefault bool) bool {
|
||||||
|
if !p.Active() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
q := strings.TrimSpace(r.URL.Query().Get("punctuation"))
|
||||||
|
if q != "" {
|
||||||
|
return parsePunctuationQuery(q, true)
|
||||||
|
}
|
||||||
|
return apiDefault || p.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePunctuationQuery(raw string, def bool) bool {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
switch strings.ToLower(raw) {
|
||||||
|
case "1", "true", "yes", "on":
|
||||||
|
return true
|
||||||
|
case "0", "false", "no", "off":
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
b, err := strconv.ParseBool(raw)
|
||||||
|
if err != nil {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
58
config/punctuation_test.go
Normal file
58
config/punctuation_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPunctuation_Active(t *testing.T) {
|
||||||
|
if (Punctuation{Enabled: false, Engine: "heuristic"}).Active() {
|
||||||
|
t.Fatal("disabled should be inactive")
|
||||||
|
}
|
||||||
|
if (Punctuation{Enabled: true, Engine: "off"}).Active() {
|
||||||
|
t.Fatal("engine off should be inactive")
|
||||||
|
}
|
||||||
|
if !(Punctuation{Enabled: true, Engine: "heuristic"}).Active() {
|
||||||
|
t.Fatal("enabled heuristic should be active")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPunctuation_ShouldApplyAPI(t *testing.T) {
|
||||||
|
p := Punctuation{Enabled: true, Engine: "heuristic"}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/?punctuation=1", nil)
|
||||||
|
if !p.ShouldApplyAPI(req, false) {
|
||||||
|
t.Fatal("query 1 should enable")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "/?punctuation=0", nil)
|
||||||
|
if p.ShouldApplyAPI(req, true) {
|
||||||
|
t.Fatal("query 0 should disable")
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "/", nil)
|
||||||
|
if !p.ShouldApplyAPI(req, false) {
|
||||||
|
t.Fatal("enabled in config => apply when query omitted")
|
||||||
|
}
|
||||||
|
if !p.ShouldApplyAPI(req, true) {
|
||||||
|
t.Fatal("api default true")
|
||||||
|
}
|
||||||
|
|
||||||
|
off := Punctuation{Enabled: false, Engine: "heuristic"}
|
||||||
|
if off.ShouldApplyAPI(req, true) {
|
||||||
|
t.Fatal("master disabled must never apply")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePunctuationQuery(t *testing.T) {
|
||||||
|
if !parsePunctuationQuery("1", false) {
|
||||||
|
t.Fatal("1 => true")
|
||||||
|
}
|
||||||
|
if parsePunctuationQuery("0", true) {
|
||||||
|
t.Fatal("0 => false")
|
||||||
|
}
|
||||||
|
if !parsePunctuationQuery("", true) {
|
||||||
|
t.Fatal("empty => default")
|
||||||
|
}
|
||||||
|
}
|
||||||
9
config/tmp.go
Normal file
9
config/tmp.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
const TempRoot = "/tmp"
|
||||||
|
|
||||||
|
func MkdirTemp(prefix string) (string, error) {
|
||||||
|
return os.MkdirTemp(TempRoot, prefix)
|
||||||
|
}
|
||||||
11
config/transcode.go
Normal file
11
config/transcode.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
// Transcode holds audio normalization settings (pure Go decoders; no external ffmpeg).
|
||||||
|
type Transcode struct {
|
||||||
|
// FFmpegPath is deprecated and ignored; kept for backward-compatible YAML.
|
||||||
|
FFmpegPath string `yaml:"ffmpeg_path,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Transcode) WithDefaults() Transcode {
|
||||||
|
return t
|
||||||
|
}
|
||||||
26
config/transcript.go
Normal file
26
config/transcript.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transcript controls how STT segments are joined into one text field (with embedded newlines).
|
||||||
|
type Transcript struct {
|
||||||
|
PauseGapSec float64 `yaml:"pause_gap_sec"`
|
||||||
|
SpeakerLabel string `yaml:"speaker_label"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Transcript) WithDefaults() Transcript {
|
||||||
|
if t.PauseGapSec <= 0 {
|
||||||
|
t.PauseGapSec = 1.5
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(t.SpeakerLabel) == "" {
|
||||||
|
t.SpeakerLabel = "Спикер"
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Transcript) PauseGapDuration() time.Duration {
|
||||||
|
return time.Duration(t.PauseGapSec * float64(time.Second))
|
||||||
|
}
|
||||||
88
config/vad.go
Normal file
88
config/vad.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VAD struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Model string `yaml:"model"`
|
||||||
|
Threshold float64 `yaml:"threshold"`
|
||||||
|
MinSpeechMs int `yaml:"min_speech_duration_ms"`
|
||||||
|
MinSilenceMs int `yaml:"min_silence_duration_ms"`
|
||||||
|
MaxSpeechSec float64 `yaml:"max_speech_duration_s"`
|
||||||
|
SpeechPadMs int `yaml:"speech_pad_ms"`
|
||||||
|
SamplesOverlap float64 `yaml:"samples_overlap"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultVAD() VAD {
|
||||||
|
return VAD{
|
||||||
|
Threshold: 0.5,
|
||||||
|
MinSpeechMs: 250,
|
||||||
|
MinSilenceMs: 100,
|
||||||
|
MaxSpeechSec: 0,
|
||||||
|
SpeechPadMs: 30,
|
||||||
|
SamplesOverlap: 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VAD) WithDefaults() VAD {
|
||||||
|
d := DefaultVAD()
|
||||||
|
if v.Threshold <= 0 {
|
||||||
|
v.Threshold = d.Threshold
|
||||||
|
}
|
||||||
|
if v.MinSpeechMs <= 0 {
|
||||||
|
v.MinSpeechMs = d.MinSpeechMs
|
||||||
|
}
|
||||||
|
if v.MinSilenceMs <= 0 {
|
||||||
|
v.MinSilenceMs = d.MinSilenceMs
|
||||||
|
}
|
||||||
|
if v.SpeechPadMs <= 0 {
|
||||||
|
v.SpeechPadMs = d.SpeechPadMs
|
||||||
|
}
|
||||||
|
if v.SamplesOverlap <= 0 {
|
||||||
|
v.SamplesOverlap = d.SamplesOverlap
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VAD) ResolveModelPath(modelsDir string) string {
|
||||||
|
if v.Model == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(v.Model) {
|
||||||
|
return v.Model
|
||||||
|
}
|
||||||
|
if modelsDir == "" {
|
||||||
|
return v.Model
|
||||||
|
}
|
||||||
|
direct := filepath.Join(modelsDir, v.Model)
|
||||||
|
if _, err := os.Stat(direct); err == nil {
|
||||||
|
return direct
|
||||||
|
}
|
||||||
|
// VAD weights often live under models_dir/vad/ (not listed in /spr/models).
|
||||||
|
base := filepath.Base(v.Model)
|
||||||
|
inVAD := filepath.Join(modelsDir, "vad", base)
|
||||||
|
if _, err := os.Stat(inVAD); err == nil {
|
||||||
|
return inVAD
|
||||||
|
}
|
||||||
|
return direct
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v VAD) Validate() error {
|
||||||
|
if !v.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if v.Model == "" {
|
||||||
|
return fmt.Errorf("vad.model is required when vad.enabled is true")
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(v.Model); err != nil {
|
||||||
|
return fmt.Errorf("vad model %q: %w", v.Model, err)
|
||||||
|
}
|
||||||
|
if v.Threshold < 0 || v.Threshold > 1 {
|
||||||
|
return fmt.Errorf("vad.threshold must be between 0 and 1")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
65
config/vad_test.go
Normal file
65
config/vad_test.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVAD_Validate_disabled(t *testing.T) {
|
||||||
|
if err := (VAD{}).Validate(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVAD_Validate_requiresModel(t *testing.T) {
|
||||||
|
err := (VAD{Enabled: true}).Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVAD_Validate_modelExists(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
model := filepath.Join(dir, "ggml-silero.bin")
|
||||||
|
if err := os.WriteFile(model, []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
v := VAD{Enabled: true, Model: model, Threshold: 0.5}
|
||||||
|
if err := v.Validate(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVAD_ResolveModelPath(t *testing.T) {
|
||||||
|
v := VAD{Model: "ggml-silero-v6.2.0.bin"}
|
||||||
|
got := v.ResolveModelPath("/data/models")
|
||||||
|
want := filepath.Join("/data/models", "ggml-silero-v6.2.0.bin")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVAD_ResolveModelPath_vadSubdir(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
vadDir := filepath.Join(dir, "vad")
|
||||||
|
if err := os.MkdirAll(vadDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(vadDir, "vad.bin"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
v := VAD{Model: "vad.bin"}
|
||||||
|
got := v.ResolveModelPath(dir)
|
||||||
|
want := filepath.Join(dir, "vad", "vad.bin")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVAD_WithDefaults(t *testing.T) {
|
||||||
|
v := VAD{Enabled: true}.WithDefaults()
|
||||||
|
if v.Threshold != 0.5 || v.MinSpeechMs != 250 || v.MinSilenceMs != 100 {
|
||||||
|
t.Fatalf("unexpected defaults: %+v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
43
config/whisper.go
Normal file
43
config/whisper.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Whisper struct {
|
||||||
|
OutputFormat []string `yaml:"output_format"`
|
||||||
|
Model string `yaml:"model"`
|
||||||
|
AudioPath string `yaml:"audio_path"`
|
||||||
|
Language string `yaml:"language"`
|
||||||
|
Prompt string `yaml:"prompt"`
|
||||||
|
OutputFolder string `yaml:"output_folder"`
|
||||||
|
OutputFilename string `yaml:"output_filename"`
|
||||||
|
Threads uint `yaml:"threads"`
|
||||||
|
MaxContext uint `yaml:"max_context"`
|
||||||
|
BeamSize uint `yaml:"beam_size"`
|
||||||
|
EntropyThold float64 `yaml:"entropy_thold"`
|
||||||
|
VAD VAD `yaml:"vad"`
|
||||||
|
Debug bool `yaml:"debug"`
|
||||||
|
SpeedUp bool `yaml:"speedup"`
|
||||||
|
Translate bool `yaml:"translate"`
|
||||||
|
PrintProgress bool `yaml:"print_progress"`
|
||||||
|
PrintSegment bool `yaml:"print_segment"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Whisper) ValidateModel() error {
|
||||||
|
if c.Model == "" {
|
||||||
|
return fmt.Errorf("model is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Whisper) Validate() error {
|
||||||
|
if err := c.ValidateModel(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.AudioPath == "" {
|
||||||
|
return fmt.Errorf("audio path is required")
|
||||||
|
}
|
||||||
|
if err := c.VAD.WithDefaults().Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
81
config/xlm-roberta-model.yaml
Normal file
81
config/xlm-roberta-model.yaml
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Metadata for Salama1429/xlm-roberta_punctuation_fullstop_truecase (ONNX punctuation).
|
||||||
|
# Install into the model directory:
|
||||||
|
# cp config/xlm-roberta-model.yaml models/punctuation/xlm-roberta/config.yaml
|
||||||
|
# or: make install-xlm-punctuation-config
|
||||||
|
|
||||||
|
languages:
|
||||||
|
- af
|
||||||
|
- am
|
||||||
|
- ar
|
||||||
|
- bg
|
||||||
|
- bn
|
||||||
|
- de
|
||||||
|
- el
|
||||||
|
- en
|
||||||
|
- es
|
||||||
|
- et
|
||||||
|
- fa
|
||||||
|
- fi
|
||||||
|
- fr
|
||||||
|
- gu
|
||||||
|
- hi
|
||||||
|
- hr
|
||||||
|
- hu
|
||||||
|
- id
|
||||||
|
- is
|
||||||
|
- it
|
||||||
|
- ja
|
||||||
|
- kk
|
||||||
|
- kn
|
||||||
|
- ko
|
||||||
|
- ky
|
||||||
|
- lt
|
||||||
|
- lv
|
||||||
|
- mk
|
||||||
|
- ml
|
||||||
|
- mr
|
||||||
|
- nl
|
||||||
|
- or
|
||||||
|
- pa
|
||||||
|
- pl
|
||||||
|
- ps
|
||||||
|
- pt
|
||||||
|
- ro
|
||||||
|
- ru
|
||||||
|
- rw
|
||||||
|
- so
|
||||||
|
- sr
|
||||||
|
- sw
|
||||||
|
- ta
|
||||||
|
- te
|
||||||
|
- tr
|
||||||
|
- uk
|
||||||
|
- zh
|
||||||
|
|
||||||
|
max_length: 256
|
||||||
|
|
||||||
|
pre_labels:
|
||||||
|
- "<NULL>"
|
||||||
|
- "¿"
|
||||||
|
|
||||||
|
post_labels:
|
||||||
|
- "<NULL>"
|
||||||
|
- "<ACRONYM>"
|
||||||
|
- "."
|
||||||
|
- ","
|
||||||
|
- "?"
|
||||||
|
- "?"
|
||||||
|
- ","
|
||||||
|
- "。"
|
||||||
|
- "、"
|
||||||
|
- "・"
|
||||||
|
- "।"
|
||||||
|
- "؟"
|
||||||
|
- "،"
|
||||||
|
- ";"
|
||||||
|
- "።"
|
||||||
|
- "፣"
|
||||||
|
- "፧"
|
||||||
|
|
||||||
|
null_token: "<NULL>"
|
||||||
|
acronym_token: "<ACRONYM>"
|
||||||
19
diarization/diarization.go
Normal file
19
diarization/diarization.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package diarization
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Engine runs offline speaker diarization when built with -tags sherpa.
|
||||||
|
type Engine interface {
|
||||||
|
Active() bool
|
||||||
|
Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg config.Diarization) (Engine, error) {
|
||||||
|
return newEngine(cfg)
|
||||||
|
}
|
||||||
107
diarization/sherpa.go
Normal file
107
diarization/sherpa.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
//go:build sherpa
|
||||||
|
|
||||||
|
package diarization
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
|
||||||
|
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sherpaEngine struct {
|
||||||
|
cfg config.Diarization
|
||||||
|
sd *sherpa.OfflineSpeakerDiarization
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEngine(cfg config.Diarization) (Engine, error) {
|
||||||
|
cfg = cfg.WithDefaults()
|
||||||
|
if !cfg.Active() {
|
||||||
|
return &noopEngine{}, nil
|
||||||
|
}
|
||||||
|
if !cfg.ModelsPresent() {
|
||||||
|
return nil, fmt.Errorf("diarization models missing (run: make download-diarization-models)")
|
||||||
|
}
|
||||||
|
conf := &sherpa.OfflineSpeakerDiarizationConfig{
|
||||||
|
Segmentation: sherpa.OfflineSpeakerSegmentationModelConfig{
|
||||||
|
Pyannote: sherpa.OfflineSpeakerSegmentationPyannoteModelConfig{
|
||||||
|
Model: cfg.SegmentationPath(),
|
||||||
|
},
|
||||||
|
NumThreads: cfg.NumThreads,
|
||||||
|
Debug: 0,
|
||||||
|
Provider: "cpu",
|
||||||
|
},
|
||||||
|
Embedding: sherpa.SpeakerEmbeddingExtractorConfig{
|
||||||
|
Model: cfg.EmbeddingPath(),
|
||||||
|
NumThreads: cfg.NumThreads,
|
||||||
|
Debug: 0,
|
||||||
|
Provider: "cpu",
|
||||||
|
},
|
||||||
|
Clustering: sherpa.FastClusteringConfig{
|
||||||
|
NumClusters: cfg.NumClusters,
|
||||||
|
Threshold: cfg.ClusteringThreshold,
|
||||||
|
},
|
||||||
|
MinDurationOn: cfg.MinDurationOn,
|
||||||
|
MinDurationOff: cfg.MinDurationOff,
|
||||||
|
}
|
||||||
|
sd := sherpa.NewOfflineSpeakerDiarization(conf)
|
||||||
|
if sd == nil {
|
||||||
|
return nil, fmt.Errorf("failed to create sherpa speaker diarization")
|
||||||
|
}
|
||||||
|
return &sherpaEngine{cfg: cfg, sd: sd}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopEngine struct{}
|
||||||
|
|
||||||
|
func (noopEngine) Active() bool { return false }
|
||||||
|
func (noopEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) { return nil, nil }
|
||||||
|
func (noopEngine) Close() {}
|
||||||
|
|
||||||
|
func (e *sherpaEngine) Active() bool {
|
||||||
|
return e.sd != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *sherpaEngine) Process(ctx context.Context, samples []float32, numClusters int) ([]whisper.Turn, error) {
|
||||||
|
_ = ctx
|
||||||
|
if len(samples) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty audio for diarization")
|
||||||
|
}
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
if _, err := os.Stat(e.cfg.SegmentationPath()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clusters := numClusters
|
||||||
|
if clusters <= 0 {
|
||||||
|
clusters = e.cfg.NumClusters
|
||||||
|
}
|
||||||
|
e.sd.SetConfig(&sherpa.OfflineSpeakerDiarizationConfig{
|
||||||
|
Clustering: sherpa.FastClusteringConfig{
|
||||||
|
NumClusters: clusters,
|
||||||
|
Threshold: e.cfg.ClusteringThreshold,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
segments := e.sd.Process(samples)
|
||||||
|
out := make([]whisper.Turn, len(segments))
|
||||||
|
for i, s := range segments {
|
||||||
|
out[i] = whisper.Turn{
|
||||||
|
Start: s.Start,
|
||||||
|
End: s.End,
|
||||||
|
Speaker: s.Speaker,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *sherpaEngine) Close() {
|
||||||
|
if e.sd != nil {
|
||||||
|
sherpa.DeleteOfflineSpeakerDiarization(e.sd)
|
||||||
|
e.sd = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
29
diarization/stub.go
Normal file
29
diarization/stub.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
//go:build !sherpa
|
||||||
|
|
||||||
|
package diarization
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
"go-whisper-api/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubEngine struct {
|
||||||
|
cfg config.Diarization
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEngine(cfg config.Diarization) (Engine, error) {
|
||||||
|
return &stubEngine{cfg: cfg.WithDefaults()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubEngine) Active() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubEngine) Process(context.Context, []float32, int) ([]whisper.Turn, error) {
|
||||||
|
return nil, fmt.Errorf("speaker diarization requires build with -tags sherpa and models (make build-sherpa, make download-diarization-models)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubEngine) Close() {}
|
||||||
55
docker/Dockerfile
Normal file
55
docker/Dockerfile
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
ARG UBUNTU_VERSION=22.04
|
||||||
|
# This needs to generally match the container host's environment.
|
||||||
|
ARG CUDA_VERSION=12.0.0
|
||||||
|
# Target the CUDA build image
|
||||||
|
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||||
|
# Target the CUDA runtime image
|
||||||
|
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||||
|
|
||||||
|
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS build
|
||||||
|
WORKDIR /app
|
||||||
|
# Unless otherwise specified, we make a fat build.
|
||||||
|
ARG CUDA_DOCKER_ARCH=all
|
||||||
|
# Set nvcc architecture
|
||||||
|
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||||
|
# Enable cuBLAS
|
||||||
|
ENV WHISPER_CUBLAS=1
|
||||||
|
|
||||||
|
#apt-get
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends build-essential git gcc g++ wget \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||||
|
|
||||||
|
# install golang
|
||||||
|
RUN wget --progress=dot:giga https://go.dev/dl/go1.22.10.linux-amd64.tar.gz
|
||||||
|
RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz
|
||||||
|
ENV PATH ${PATH}:/usr/local/go/bin
|
||||||
|
|
||||||
|
# Ref: https://stackoverflow.com/a/53464012
|
||||||
|
ENV CUDA_MAIN_VERSION=12.0
|
||||||
|
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
COPY ./ .
|
||||||
|
RUN make dependency && env && make build && \
|
||||||
|
mv bin/go-whisper-api /bin/ && \
|
||||||
|
rm -rf bin
|
||||||
|
|
||||||
|
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} AS runtime
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
LABEL maintainer="Bo-Yi Wu <appleboy.tw@gmail.com>" \
|
||||||
|
org.label-schema.name="Speech-to-Text" \
|
||||||
|
org.label-schema.vendor="Bo-Yi Wu" \
|
||||||
|
org.label-schema.schema-version="1.0"
|
||||||
|
|
||||||
|
LABEL org.opencontainers.image.source=https://github.com/appleboy/go-whisper-api
|
||||||
|
LABEL org.opencontainers.image.description="Speech-to-Text."
|
||||||
|
LABEL org.opencontainers.image.licenses=MIT
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||||
|
|
||||||
|
COPY --from=build /bin/go-whisper-api /bin/go-whisper-api
|
||||||
|
EXPOSE 8080
|
||||||
|
ENTRYPOINT ["/bin/go-whisper-api"]
|
||||||
27
docker/Dockerfile.ci
Normal file
27
docker/Dockerfile.ci
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# CPU image for Gitea CI and hosts without NVIDIA GPU.
|
||||||
|
# Production GPU image: docker/Dockerfile
|
||||||
|
|
||||||
|
FROM golang:1.23-bookworm AS build
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential git cmake pkg-config libsentencepiece-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN make dependency && make build-xlm
|
||||||
|
|
||||||
|
FROM debian:bookworm-slim AS runtime
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY --from=build /app/bin/go-whisper-api /bin/go-whisper-api
|
||||||
|
COPY --from=build /app/third_party/whisper.cpp/build/src/libwhisper.so* /usr/local/lib/
|
||||||
|
COPY --from=build /app/third_party/whisper.cpp/build/ggml/src/libggml*.so* /usr/local/lib/
|
||||||
|
ENV LD_LIBRARY_PATH=/usr/local/lib
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
ENTRYPOINT ["/bin/go-whisper-api"]
|
||||||
56
garbage/filter.go
Normal file
56
garbage/filter.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package garbage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var spaceCollapse = regexp.MustCompile(`\s+`)
|
||||||
|
|
||||||
|
// Word is a timed token for garbage filtering (mirrors whisper.Word JSON shape).
|
||||||
|
type Word struct {
|
||||||
|
Word string `json:"word"`
|
||||||
|
Start int `json:"start"`
|
||||||
|
Stop int `json:"stop"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterText removes configured artifact substrings and normalizes whitespace.
|
||||||
|
func FilterText(text string, patterns []string) string {
|
||||||
|
for _, p := range patterns {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text = strings.ReplaceAll(text, p, " ")
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(spaceCollapse.ReplaceAllString(text, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterWords drops tokens that match any garbage pattern.
|
||||||
|
func FilterWords(words []Word, patterns []string) []Word {
|
||||||
|
if len(words) == 0 {
|
||||||
|
return words
|
||||||
|
}
|
||||||
|
out := make([]Word, 0, len(words))
|
||||||
|
for _, w := range words {
|
||||||
|
if matchesGarbage(w.Word, patterns) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, w)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchesGarbage(word string, patterns []string) bool {
|
||||||
|
word = strings.TrimSpace(word)
|
||||||
|
for _, p := range patterns {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if word == p || strings.Contains(word, p) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
30
garbage/filter_test.go
Normal file
30
garbage/filter_test.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package garbage
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestFilterText(t *testing.T) {
|
||||||
|
in := "Привет *выбая* мир *выбая*"
|
||||||
|
got := FilterText(in, []string{"*выбая*"})
|
||||||
|
want := "Привет мир"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterWords(t *testing.T) {
|
||||||
|
words := []Word{
|
||||||
|
{Word: "Что", Start: 0, Stop: 100},
|
||||||
|
{Word: "*выбая*", Start: 100, Stop: 200},
|
||||||
|
{Word: "мир", Start: 200, Stop: 300},
|
||||||
|
}
|
||||||
|
got := FilterWords(words, []string{"*выбая*"})
|
||||||
|
if len(got) != 2 || got[1].Word != "мир" {
|
||||||
|
t.Fatalf("got %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterText_emptyPatterns(t *testing.T) {
|
||||||
|
if got := FilterText("a b", nil); got != "a b" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
48
go.mod
Normal file
48
go.mod
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
module go-whisper-api
|
||||||
|
|
||||||
|
go 1.26
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/Eyevinn/mp4ff v0.51.0
|
||||||
|
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27
|
||||||
|
github.com/go-audio/audio v1.0.0
|
||||||
|
github.com/go-audio/wav v1.1.0
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/gopxl/beep v1.4.1
|
||||||
|
github.com/joho/godotenv v1.5.1
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go v1.13.2
|
||||||
|
github.com/mattn/go-isatty v0.0.20
|
||||||
|
github.com/olivier-w/climp-aac-decoder v0.1.0
|
||||||
|
github.com/pion/opus v0.0.0-20260601214817-71d58474cec8
|
||||||
|
github.com/rs/zerolog v1.35.0
|
||||||
|
github.com/skrashevich/go-aac v0.1.0
|
||||||
|
github.com/urfave/cli/v2 v2.27.7
|
||||||
|
github.com/yalue/onnxruntime_go v1.24.0
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/go-audio/riff v1.0.0 // indirect
|
||||||
|
github.com/hajimehoshi/go-mp3 v0.3.4 // indirect
|
||||||
|
github.com/icza/bitio v1.1.0 // indirect
|
||||||
|
github.com/jfreymuth/oggvorbis v1.0.5 // indirect
|
||||||
|
github.com/jfreymuth/vorbis v1.0.2 // indirect
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 // indirect
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 // indirect
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 // indirect
|
||||||
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
github.com/mewkiz/flac v1.0.8 // indirect
|
||||||
|
github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||||
|
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect
|
||||||
|
golang.org/x/sys v0.45.0 // indirect
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
|
)
|
||||||
|
|
||||||
|
replace github.com/ggerganov/whisper.cpp/bindings/go => ./third_party/whisper.cpp/bindings/go
|
||||||
127
go.sum
Normal file
127
go.sum
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M=
|
||||||
|
github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg=
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo=
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
|
github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/ebitengine/oto/v3 v3.1.0 h1:9tChG6rizyeR2w3vsygTTTVVJ9QMMyu00m2yBOCch6U=
|
||||||
|
github.com/ebitengine/oto/v3 v3.1.0/go.mod h1:IK1QTnlfZK2GIB6ziyECm433hAdTaPpOsGMLhEyEGTg=
|
||||||
|
github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA=
|
||||||
|
github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ=
|
||||||
|
github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
|
||||||
|
github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
|
||||||
|
github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
|
||||||
|
github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
|
||||||
|
github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
|
||||||
|
github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
|
||||||
|
github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg=
|
||||||
|
github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gopxl/beep v1.4.1 h1:WqNs9RsDAhG9M3khMyc1FaVY50dTdxG/6S6a3qsUHqE=
|
||||||
|
github.com/gopxl/beep v1.4.1/go.mod h1:A1dmiUkuY8kxsvcNJNUBIEcchmiP6eUyCHSxpXl0YO0=
|
||||||
|
github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68=
|
||||||
|
github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo=
|
||||||
|
github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo=
|
||||||
|
github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
|
||||||
|
github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
|
||||||
|
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
|
||||||
|
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA=
|
||||||
|
github.com/jfreymuth/oggvorbis v1.0.5 h1:u+Ck+R0eLSRhgq8WTmffYnrVtSztJcYrl588DM4e3kQ=
|
||||||
|
github.com/jfreymuth/oggvorbis v1.0.5/go.mod h1:1U4pqWmghcoVsCJJ4fRBKv9peUJMBHixthRlBeD6uII=
|
||||||
|
github.com/jfreymuth/vorbis v1.0.2 h1:m1xH6+ZI4thH927pgKD8JOH4eaGRm18rEE9/0WKjvNE=
|
||||||
|
github.com/jfreymuth/vorbis v1.0.2/go.mod h1:DoftRo4AznKnShRl1GxiTFCseHr4zR9BN3TWXyuzrqQ=
|
||||||
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
|
github.com/jszwec/csvutil v1.5.1/go.mod h1:Rpu7Uu9giO9subDyMCIQfHVDuLrcaC36UA4YcJjGBkg=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go v1.13.2 h1:1orlwwdJcVk37OoV1gi1L6fIdeZHKJcLQjm6S5h/WWs=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go v1.13.2/go.mod h1:NyA3OsF/hmcvk4FpbXpBO9Rl/I3dlnrOB9Ny+TnhvIc=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2 h1:w/RHaU9liD/ovA9Q5iI2lCoVjVEd4M8+uTrf54rjQPY=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-linux v1.13.2/go.mod h1:NXEH2rsBgTdqY59YpPq6CtSBlBAXy/8a9FmpLERU97I=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2 h1:oIIdSfU3NEMT7oq7yxoH7Rk37kekkbmr/b+7e4er1WE=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-macos v1.13.2/go.mod h1:ZOhUAXC62Unj0ZNfu6zxSFKcW96aXf7P3BsqiUyOBbE=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2 h1:tkfwXnmJRsChv59ZzLsycphwpvJpSkyd29aMG9kUoEE=
|
||||||
|
github.com/k2-fsa/sherpa-onnx-go-windows v1.13.2/go.mod h1:5AX7TU8+P/gInjglY1ijtWUM2b8iyR0QX4yEngzMe64=
|
||||||
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||||
|
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mewkiz/flac v1.0.8 h1:cophRjvafteDGmqsfXRK28YAX6l8wy19QxTHruEEg1s=
|
||||||
|
github.com/mewkiz/flac v1.0.8/go.mod h1:l7dt5uFY724eKVkHQtAJAQSkhpC3helU3RDxN0ESAqo=
|
||||||
|
github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14 h1:tnAPMExbRERsyEYkmR1YjhTgDM0iqyiBYf8ojRXxdbA=
|
||||||
|
github.com/mewkiz/pkg v0.0.0-20230226050401-4010bf0fec14/go.mod h1:QYCFBiH5q6XTHEbWhR0uhR3M9qNPoD2CSQzr0g75kE4=
|
||||||
|
github.com/olivier-w/climp-aac-decoder v0.1.0 h1:jdwYeBlnfQlnm6KA/fG20s22HrlBlVVaaWQTj/tnb3A=
|
||||||
|
github.com/olivier-w/climp-aac-decoder v0.1.0/go.mod h1:Hpqi8cI4NeN4JEcgKxiAK0zEwx+hO+lWdg9Q3ruHouI=
|
||||||
|
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw=
|
||||||
|
github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0=
|
||||||
|
github.com/pion/opus v0.0.0-20260601214817-71d58474cec8 h1:THWkLaX+FzbTbHIo1irDpbBMXbNrqV6q2Xs+00p4d+0=
|
||||||
|
github.com/pion/opus v0.0.0-20260601214817-71d58474cec8/go.mod h1:t5Xog2n682JnawoykACE6nKVmupFvmJvkpM7x6bTv6g=
|
||||||
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||||
|
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
|
||||||
|
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
github.com/skrashevich/go-aac v0.1.0 h1:7oHNj1ADmgfjAHvi3wAIFbmbCpQBrcjZEVTLlRtAS1A=
|
||||||
|
github.com/skrashevich/go-aac v0.1.0/go.mod h1:Mj7r//4LDL4FC0ezORj+MnmQ+nDEkJhTOy2aMC8dzww=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU=
|
||||||
|
github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4=
|
||||||
|
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg=
|
||||||
|
github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
|
||||||
|
github.com/yalue/onnxruntime_go v1.24.0 h1:IdgJLxxyotlsUTmL1UnHZgBzXJGgY51LZ4vQ5rZeOXU=
|
||||||
|
github.com/yalue/onnxruntime_go v1.24.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
||||||
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
|
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
|
||||||
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
|
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||||
|
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
251
main.go
Normal file
251
main.go
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go-whisper-api/api"
|
||||||
|
"go-whisper-api/config"
|
||||||
|
|
||||||
|
_ "github.com/joho/godotenv/autoload"
|
||||||
|
"github.com/mattn/go-isatty"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
Version string
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
isTerm := isatty.IsTerminal(os.Stdout.Fd())
|
||||||
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
|
log.Logger = log.Output(
|
||||||
|
zerolog.ConsoleWriter{
|
||||||
|
Out: os.Stderr,
|
||||||
|
NoColor: !isTerm,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string {
|
||||||
|
short := file
|
||||||
|
for i := len(file) - 1; i > 0; i-- {
|
||||||
|
if file[i] == '/' {
|
||||||
|
short = file[i+1:]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file = short
|
||||||
|
return file + ":" + strconv.Itoa(line)
|
||||||
|
}
|
||||||
|
app := cli.NewApp()
|
||||||
|
app.Name = "go-whisper-api"
|
||||||
|
app.Usage = "HTTP API for speech-to-text (SPR + OpenAI-compatible)."
|
||||||
|
app.Copyright = "Copyright (c) " + strconv.Itoa(time.Now().Year()) + " Bo-Yi Wu"
|
||||||
|
app.Authors = []*cli.Author{
|
||||||
|
{
|
||||||
|
Name: "Bo-Yi Wu",
|
||||||
|
Email: "appleboy.tw@gmail.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
app.Version = Version
|
||||||
|
app.Commands = []*cli.Command{
|
||||||
|
{
|
||||||
|
Name: "serve",
|
||||||
|
Aliases: []string{"s", "api"},
|
||||||
|
Usage: "start HTTP API server",
|
||||||
|
Flags: serveFlags(),
|
||||||
|
Action: runServe,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
app.Flags = append(append([]cli.Flag{configFlag()}, punctuationFlags()...), serveFlags()...)
|
||||||
|
app.Action = runServe
|
||||||
|
if err := app.Run(os.Args); err != nil {
|
||||||
|
log.Fatal().Err(err).Msg("can't run app")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configFlag() cli.Flag {
|
||||||
|
return &cli.StringFlag{
|
||||||
|
Name: "config",
|
||||||
|
Usage: "path to YAML config file (default: config.yaml if present)",
|
||||||
|
EnvVars: []string{"CONFIG_PATH", "GO_WHISPER_CONFIG"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func punctuationFlags() []cli.Flag {
|
||||||
|
return []cli.Flag{
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "punctuation-enabled",
|
||||||
|
Usage: "master switch: enable or disable punctuation (YAML: punctuation.enabled)",
|
||||||
|
EnvVars: []string{"PUNCTUATION_ENABLED"},
|
||||||
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "punctuation-default-on",
|
||||||
|
Usage: "apply punctuation by default when query/flags omitted (YAML: punctuation.default_on)",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "punctuation-engine",
|
||||||
|
Usage: "punctuation engine: off, heuristic, sherpa, sherpa-online, http, xlm",
|
||||||
|
EnvVars: []string{"PUNCTUATION_ENGINE"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func vadFlags() []cli.Flag {
|
||||||
|
return []cli.Flag{
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "vad",
|
||||||
|
Usage: "enable voice activity detection (Silero VAD)",
|
||||||
|
EnvVars: []string{"API_VAD"},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "vad-model",
|
||||||
|
Usage: "path to ggml Silero VAD model (e.g. models/ggml-silero-v6.2.0.bin)",
|
||||||
|
EnvVars: []string{"API_VAD_MODEL"},
|
||||||
|
},
|
||||||
|
&cli.Float64Flag{
|
||||||
|
Name: "vad-threshold",
|
||||||
|
Usage: "VAD speech threshold (0.0-1.0)",
|
||||||
|
EnvVars: []string{"API_VAD_THRESHOLD"},
|
||||||
|
},
|
||||||
|
&cli.IntFlag{
|
||||||
|
Name: "vad-min-speech-ms",
|
||||||
|
Usage: "minimum speech duration in ms",
|
||||||
|
EnvVars: []string{"API_VAD_MIN_SPEECH_MS"},
|
||||||
|
},
|
||||||
|
&cli.IntFlag{
|
||||||
|
Name: "vad-min-silence-ms",
|
||||||
|
Usage: "minimum silence between segments in ms",
|
||||||
|
EnvVars: []string{"API_VAD_MIN_SILENCE_MS"},
|
||||||
|
},
|
||||||
|
&cli.Float64Flag{
|
||||||
|
Name: "vad-max-speech-sec",
|
||||||
|
Usage: "maximum speech segment length in seconds (0 = unlimited)",
|
||||||
|
EnvVars: []string{"API_VAD_MAX_SPEECH_SEC"},
|
||||||
|
},
|
||||||
|
&cli.IntFlag{
|
||||||
|
Name: "vad-speech-pad-ms",
|
||||||
|
Usage: "padding around speech segments in ms",
|
||||||
|
EnvVars: []string{"API_VAD_SPEECH_PAD_MS"},
|
||||||
|
},
|
||||||
|
&cli.Float64Flag{
|
||||||
|
Name: "vad-samples-overlap",
|
||||||
|
Usage: "overlap between VAD segments in seconds",
|
||||||
|
EnvVars: []string{"API_VAD_SAMPLES_OVERLAP"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveFlags() []cli.Flag {
|
||||||
|
flags := []cli.Flag{
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "addr",
|
||||||
|
Usage: "HTTP listen address",
|
||||||
|
EnvVars: []string{"API_ADDR"},
|
||||||
|
Value: ":8080",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "models-dir",
|
||||||
|
Usage: "directory with ggml *.bin whisper models",
|
||||||
|
EnvVars: []string{"API_MODELS_DIR"},
|
||||||
|
Value: "./models",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "cache-dir",
|
||||||
|
Usage: "directory for async task cache (waiting/ready)",
|
||||||
|
EnvVars: []string{"API_CACHE_DIR"},
|
||||||
|
Value: "./cache",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "language",
|
||||||
|
Usage: "default language for speech recognition",
|
||||||
|
EnvVars: []string{"API_LANGUAGE"},
|
||||||
|
Value: "auto",
|
||||||
|
},
|
||||||
|
&cli.UintFlag{
|
||||||
|
Name: "threads",
|
||||||
|
Usage: "number of threads for whisper",
|
||||||
|
EnvVars: []string{"API_THREADS"},
|
||||||
|
Value: uint(runtime.NumCPU()),
|
||||||
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "debug",
|
||||||
|
Usage: "enable debug mode",
|
||||||
|
EnvVars: []string{"API_DEBUG"},
|
||||||
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "speedup",
|
||||||
|
Usage: "speed up audio by x2",
|
||||||
|
EnvVars: []string{"API_SPEEDUP"},
|
||||||
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "translate",
|
||||||
|
Usage: "translate to english",
|
||||||
|
EnvVars: []string{"API_TRANSLATE"},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "prompt",
|
||||||
|
Usage: "initial prompt",
|
||||||
|
EnvVars: []string{"API_PROMPT"},
|
||||||
|
},
|
||||||
|
&cli.UintFlag{
|
||||||
|
Name: "max-context",
|
||||||
|
Usage: "maximum text context tokens",
|
||||||
|
EnvVars: []string{"API_MAX_CONTEXT"},
|
||||||
|
Value: 32,
|
||||||
|
},
|
||||||
|
&cli.UintFlag{
|
||||||
|
Name: "beam-size",
|
||||||
|
Usage: "beam size for beam search",
|
||||||
|
EnvVars: []string{"API_BEAM_SIZE"},
|
||||||
|
Value: 5,
|
||||||
|
},
|
||||||
|
&cli.Float64Flag{
|
||||||
|
Name: "entropy-thold",
|
||||||
|
Usage: "entropy threshold",
|
||||||
|
EnvVars: []string{"API_ENTROPY_THOLD"},
|
||||||
|
Value: 2.4,
|
||||||
|
},
|
||||||
|
&cli.BoolFlag{
|
||||||
|
Name: "default-punctuation",
|
||||||
|
Usage: "enable punctuation on STT by default",
|
||||||
|
EnvVars: []string{"API_DEFAULT_PUNCTUATION"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return append(flags, vadFlags()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServe(c *cli.Context) error {
|
||||||
|
apiCfg, err := config.APIFromCLI(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if apiCfg.Debug {
|
||||||
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
|
log.Logger = log.With().Caller().Logger()
|
||||||
|
}
|
||||||
|
vad := apiCfg.VAD.WithDefaults()
|
||||||
|
if vad.Enabled {
|
||||||
|
vad.Model = vad.ResolveModelPath(apiCfg.ModelsDir)
|
||||||
|
if err := vad.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
apiCfg.VAD = vad
|
||||||
|
}
|
||||||
|
tc, err := config.TranscodeFromCLI(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pc, err := config.PunctuationFromCLI(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dc, err := config.DiarizationFromCLI(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return api.Run(c.Context, apiCfg, tc, pc, dc)
|
||||||
|
}
|
||||||
0
models/.gitkeep
Normal file
0
models/.gitkeep
Normal file
0
models/diarization/.gitkeep
Normal file
0
models/diarization/.gitkeep
Normal file
0
models/punctuation/.gitkeep
Normal file
0
models/punctuation/.gitkeep
Normal file
0
models/punctuation/xlm-roberta/.gitkeep
Normal file
0
models/punctuation/xlm-roberta/.gitkeep
Normal file
0
models/vad/.gitkeep
Normal file
0
models/vad/.gitkeep
Normal file
86
punctuation/heuristic.go
Normal file
86
punctuation/heuristic.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Heuristic struct{}
|
||||||
|
|
||||||
|
func (Heuristic) Active() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Heuristic) Restore(ctx context.Context, text, language string) (string, error) {
|
||||||
|
_ = ctx
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
text = normalizeSpaces(text)
|
||||||
|
text = capitalizeFirst(text)
|
||||||
|
lang := strings.ToLower(strings.TrimSpace(language))
|
||||||
|
if lang == "ru" || lang == "rus" || lang == "russian" || lang == "auto" {
|
||||||
|
text = heuristicRU(text)
|
||||||
|
} else {
|
||||||
|
text = heuristicEN(text)
|
||||||
|
}
|
||||||
|
return ensureTerminalPunct(text), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSpaces(s string) string {
|
||||||
|
return strings.Join(strings.Fields(s), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func capitalizeFirst(s string) string {
|
||||||
|
r, size := utf8.DecodeRuneInString(s)
|
||||||
|
if r == utf8.RuneError {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return string(unicode.ToUpper(r)) + s[size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
reQuestionRU = regexp.MustCompile(`(?i)(^|.*\s)(как|что|где|когда|почему|зачем|кто|чей|какой|какая|какое|какие|сколько|зачем|откуда|куда|ли)(\s+[^.?!]+)$`)
|
||||||
|
reQuestionEN = regexp.MustCompile(`(?i)^(who|what|when|where|why|how|which|whose|whom|is|are|am|was|were|do|does|did|can|could|would|will|shall|should)\b`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func heuristicRU(s string) string {
|
||||||
|
if reQuestionRU.MatchString(s) && !strings.HasSuffix(s, "?") {
|
||||||
|
return s + "?"
|
||||||
|
}
|
||||||
|
if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 {
|
||||||
|
return s + "."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func heuristicEN(s string) string {
|
||||||
|
lower := strings.ToLower(s)
|
||||||
|
if reQuestionEN.MatchString(lower) && !strings.HasSuffix(s, "?") {
|
||||||
|
return s + "?"
|
||||||
|
}
|
||||||
|
if !hasTerminalPunct(s) && len(strings.Fields(s)) <= 24 {
|
||||||
|
return s + "."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasTerminalPunct(s string) bool {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
r, _ := utf8.DecodeLastRuneInString(s)
|
||||||
|
return r == '.' || r == '?' || r == '!' || r == '…'
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureTerminalPunct(s string) string {
|
||||||
|
if hasTerminalPunct(s) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s + "."
|
||||||
|
}
|
||||||
32
punctuation/heuristic_test.go
Normal file
32
punctuation/heuristic_test.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHeuristicRU_question(t *testing.T) {
|
||||||
|
h := Heuristic{}
|
||||||
|
out, err := h.Restore(context.Background(), "как дела", "ru")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !stringsHasSuffix(out, "?") {
|
||||||
|
t.Fatalf("expected question mark, got %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeuristicEN_period(t *testing.T) {
|
||||||
|
h := Heuristic{}
|
||||||
|
out, err := h.Restore(context.Background(), "hello world", "en")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !stringsHasSuffix(out, ".") {
|
||||||
|
t.Fatalf("expected period, got %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsHasSuffix(s, suffix string) bool {
|
||||||
|
return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
|
||||||
|
}
|
||||||
92
punctuation/internal/spwrap/sp.go
Normal file
92
punctuation/internal/spwrap/sp.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
//go:build xlm
|
||||||
|
|
||||||
|
package spwrap
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: -std=c++17
|
||||||
|
#cgo LDFLAGS: -lsentencepiece
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include "sp_wrap.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Processor struct {
|
||||||
|
p *C.SPProcessor
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load(path string) (*Processor, error) {
|
||||||
|
cpath := C.CString(path)
|
||||||
|
defer C.free(unsafe.Pointer(cpath))
|
||||||
|
var errMsg *C.char
|
||||||
|
p := C.sp_load(cpath, &errMsg)
|
||||||
|
if p == nil {
|
||||||
|
if errMsg != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errMsg))
|
||||||
|
return nil, fmt.Errorf("sentencepiece: %s", C.GoString(errMsg))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("sentencepiece: failed to load %s", path)
|
||||||
|
}
|
||||||
|
return &Processor{p: p}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (proc *Processor) Close() {
|
||||||
|
if proc.p != nil {
|
||||||
|
C.sp_free(proc.p)
|
||||||
|
proc.p = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (proc *Processor) BOSID() int {
|
||||||
|
return int(C.sp_bos_id(proc.p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (proc *Processor) EOSID() int {
|
||||||
|
return int(C.sp_eos_id(proc.p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (proc *Processor) PadID() int {
|
||||||
|
return int(C.sp_pad_id(proc.p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (proc *Processor) EncodeAsIDs(text string) ([]int, error) {
|
||||||
|
ctext := C.CString(text)
|
||||||
|
defer C.free(unsafe.Pointer(ctext))
|
||||||
|
var ids *C.int
|
||||||
|
var n C.int
|
||||||
|
var errMsg *C.char
|
||||||
|
if C.sp_encode(proc.p, ctext, &ids, &n, &errMsg) == 0 {
|
||||||
|
if errMsg != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errMsg))
|
||||||
|
return nil, fmt.Errorf("sentencepiece encode: %s", C.GoString(errMsg))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("sentencepiece encode failed")
|
||||||
|
}
|
||||||
|
if ids == nil || n == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
defer C.free(unsafe.Pointer(ids))
|
||||||
|
out := make([]int, int(n))
|
||||||
|
slice := unsafe.Slice(ids, int(n))
|
||||||
|
for i := range out {
|
||||||
|
out[i] = int(slice[i])
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (proc *Processor) IDToPiece(id int) (string, error) {
|
||||||
|
var errMsg *C.char
|
||||||
|
piece := C.sp_id_to_piece(proc.p, C.int(id), &errMsg)
|
||||||
|
if piece == nil {
|
||||||
|
if errMsg != nil {
|
||||||
|
defer C.free(unsafe.Pointer(errMsg))
|
||||||
|
return "", fmt.Errorf("sentencepiece id to piece: %s", C.GoString(errMsg))
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("sentencepiece id to piece failed")
|
||||||
|
}
|
||||||
|
defer C.free(unsafe.Pointer(piece))
|
||||||
|
return C.GoString(piece), nil
|
||||||
|
}
|
||||||
88
punctuation/internal/spwrap/sp_wrap.cc
Normal file
88
punctuation/internal/spwrap/sp_wrap.cc
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
#include "sp_wrap.h"
|
||||||
|
|
||||||
|
#include <sentencepiece_processor.h>
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct SPProcessor {
|
||||||
|
sentencepiece::SentencePieceProcessor proc;
|
||||||
|
};
|
||||||
|
|
||||||
|
static char *copy_err(const std::string &msg) {
|
||||||
|
char *out = static_cast<char *>(std::malloc(msg.size() + 1));
|
||||||
|
if (out != nullptr) {
|
||||||
|
std::memcpy(out, msg.c_str(), msg.size() + 1);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
SPProcessor *sp_load(const char *path, char **err) {
|
||||||
|
if (err != nullptr) {
|
||||||
|
*err = nullptr;
|
||||||
|
}
|
||||||
|
auto *p = new SPProcessor();
|
||||||
|
const auto status = p->proc.Load(path);
|
||||||
|
if (!status.ok()) {
|
||||||
|
if (err != nullptr) {
|
||||||
|
*err = copy_err(status.ToString());
|
||||||
|
}
|
||||||
|
delete p;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sp_free(SPProcessor *p) { delete p; }
|
||||||
|
|
||||||
|
int sp_bos_id(const SPProcessor *p) { return p->proc.bos_id(); }
|
||||||
|
|
||||||
|
int sp_eos_id(const SPProcessor *p) { return p->proc.eos_id(); }
|
||||||
|
|
||||||
|
int sp_pad_id(const SPProcessor *p) { return p->proc.pad_id(); }
|
||||||
|
|
||||||
|
int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err) {
|
||||||
|
if (err != nullptr) {
|
||||||
|
*err = nullptr;
|
||||||
|
}
|
||||||
|
if (out_ids != nullptr) {
|
||||||
|
*out_ids = nullptr;
|
||||||
|
}
|
||||||
|
if (out_len != nullptr) {
|
||||||
|
*out_len = 0;
|
||||||
|
}
|
||||||
|
std::vector<int> ids;
|
||||||
|
const auto status = p->proc.Encode(text, &ids);
|
||||||
|
if (!status.ok()) {
|
||||||
|
if (err != nullptr) {
|
||||||
|
*err = copy_err(status.ToString());
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (ids.empty()) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
int *buf = static_cast<int *>(std::malloc(sizeof(int) * ids.size()));
|
||||||
|
if (buf == nullptr) {
|
||||||
|
if (err != nullptr) {
|
||||||
|
*err = copy_err("malloc failed");
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < ids.size(); i++) {
|
||||||
|
buf[i] = ids[i];
|
||||||
|
}
|
||||||
|
*out_ids = buf;
|
||||||
|
*out_len = static_cast<int>(ids.size());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
char *sp_id_to_piece(const SPProcessor *p, int id, char **err) {
|
||||||
|
if (err != nullptr) {
|
||||||
|
*err = nullptr;
|
||||||
|
}
|
||||||
|
const std::string piece = p->proc.IdToPiece(id);
|
||||||
|
return copy_err(piece);
|
||||||
|
}
|
||||||
21
punctuation/internal/spwrap/sp_wrap.h
Normal file
21
punctuation/internal/spwrap/sp_wrap.h
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef struct SPProcessor SPProcessor;
|
||||||
|
|
||||||
|
SPProcessor *sp_load(const char *path, char **err);
|
||||||
|
void sp_free(SPProcessor *p);
|
||||||
|
|
||||||
|
int sp_bos_id(const SPProcessor *p);
|
||||||
|
int sp_eos_id(const SPProcessor *p);
|
||||||
|
int sp_pad_id(const SPProcessor *p);
|
||||||
|
|
||||||
|
int sp_encode(const SPProcessor *p, const char *text, int **out_ids, int *out_len, char **err);
|
||||||
|
char *sp_id_to_piece(const SPProcessor *p, int id, char **err);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
177
punctuation/ort_env.go
Normal file
177
punctuation/ort_env.go
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
//go:build xlm
|
||||||
|
|
||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
ort "github.com/yalue/onnxruntime_go"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ortOnce sync.Once
|
||||||
|
ortErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
func ensureORT() error {
|
||||||
|
ortOnce.Do(func() {
|
||||||
|
if p := resolveONNXRuntimeLib(); p != "" {
|
||||||
|
ort.SetSharedLibraryPath(p)
|
||||||
|
}
|
||||||
|
ortErr = ort.InitializeEnvironment()
|
||||||
|
})
|
||||||
|
return ortErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveONNXRuntimeLib() string {
|
||||||
|
if p := strings.TrimSpace(os.Getenv("ONNXRUNTIME_SHARED_LIBRARY_PATH")); p != "" {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
for _, p := range onnxRuntimeCandidates() {
|
||||||
|
if st, err := os.Stat(p); err == nil && !st.IsDir() {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func onnxRuntimeCandidates() []string {
|
||||||
|
arch := sherpaLibArch()
|
||||||
|
ver := sherpaLinuxModuleVersion()
|
||||||
|
var out []string
|
||||||
|
for _, root := range goModCacheRoots() {
|
||||||
|
if ver != "" {
|
||||||
|
out = append(out, filepath.Join(root,
|
||||||
|
"github.com/k2-fsa/sherpa-onnx-go-linux@"+ver,
|
||||||
|
"lib", arch, "libonnxruntime.so"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if exe, err := os.Executable(); err == nil {
|
||||||
|
exeDir := filepath.Dir(exe)
|
||||||
|
out = append(out,
|
||||||
|
filepath.Join(exeDir, "libonnxruntime.so"),
|
||||||
|
filepath.Join(exeDir, "lib", "libonnxruntime.so"),
|
||||||
|
filepath.Join(exeDir, "..", "lib", "libonnxruntime.so"),
|
||||||
|
)
|
||||||
|
if modRoot := findModuleRoot(exeDir); modRoot != "" {
|
||||||
|
out = append(out, filepath.Join(modRoot, "lib", "libonnxruntime.so"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func goModCacheRoots() []string {
|
||||||
|
var roots []string
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
add := func(p string) {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[p]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[p] = struct{}{}
|
||||||
|
roots = append(roots, p)
|
||||||
|
}
|
||||||
|
add(os.Getenv("GOMODCACHE"))
|
||||||
|
if gopath := os.Getenv("GOPATH"); gopath != "" {
|
||||||
|
for _, gp := range filepath.SplitList(gopath) {
|
||||||
|
add(filepath.Join(gp, "pkg", "mod"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if home, err := os.UserHomeDir(); err == nil {
|
||||||
|
add(filepath.Join(home, "go", "pkg", "mod"))
|
||||||
|
}
|
||||||
|
return roots
|
||||||
|
}
|
||||||
|
|
||||||
|
func findModuleRoot(start string) string {
|
||||||
|
dir := start
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
parent := filepath.Dir(dir)
|
||||||
|
if parent == dir {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
dir = parent
|
||||||
|
}
|
||||||
|
if cwd, err := os.Getwd(); err == nil {
|
||||||
|
return findModuleRootFrom(cwd)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func findModuleRootFrom(dir string) string {
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
parent := filepath.Dir(dir)
|
||||||
|
if parent == dir {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
dir = parent
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func sherpaLinuxModuleVersion() string {
|
||||||
|
for _, dir := range []string{findModuleRootFrom(mustCwd()), ""} {
|
||||||
|
if dir == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if v := readSherpaVersion(filepath.Join(dir, "go.mod")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if exe, err := os.Executable(); err == nil {
|
||||||
|
if root := findModuleRoot(filepath.Dir(exe)); root != "" {
|
||||||
|
return readSherpaVersion(filepath.Join(root, "go.mod"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSherpaVersion(goModPath string) string {
|
||||||
|
data, err := os.ReadFile(goModPath)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, line := range strings.Split(string(data), "\n") {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if !strings.Contains(line, "github.com/k2-fsa/sherpa-onnx-go-linux") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) >= 2 {
|
||||||
|
return fields[1] // e.g. v1.13.2 — must match pkg/mod path @v1.13.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func sherpaLibArch() string {
|
||||||
|
switch runtime.GOARCH {
|
||||||
|
case "arm64":
|
||||||
|
return "aarch64-unknown-linux-gnu"
|
||||||
|
case "arm":
|
||||||
|
return "arm-unknown-linux-gnueabihf"
|
||||||
|
default:
|
||||||
|
return "x86_64-unknown-linux-gnu"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCwd() string {
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return cwd
|
||||||
|
}
|
||||||
154
punctuation/punctuation.go
Normal file
154
punctuation/punctuation.go
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Restorer interface {
|
||||||
|
Active() bool
|
||||||
|
Restore(ctx context.Context, text, language string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Closer interface {
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg config.Punctuation) (Restorer, error) {
|
||||||
|
cfg = cfg.WithDefaults()
|
||||||
|
if !cfg.Active() {
|
||||||
|
return Nop{}, nil
|
||||||
|
}
|
||||||
|
engine := strings.ToLower(strings.TrimSpace(cfg.Engine))
|
||||||
|
switch engine {
|
||||||
|
case "heuristic":
|
||||||
|
return Heuristic{}, nil
|
||||||
|
case "sherpa", "sherpa-offline", "offline":
|
||||||
|
cfg.Engine = "sherpa"
|
||||||
|
return newSherpaRestorer(cfg)
|
||||||
|
case "sherpa-online", "online":
|
||||||
|
cfg.Engine = "sherpa-online"
|
||||||
|
return newSherpaRestorer(cfg)
|
||||||
|
case "xlm", "xlm-roberta", "roberta":
|
||||||
|
return newXLM(cfg)
|
||||||
|
case "http":
|
||||||
|
if cfg.HTTPURL == "" {
|
||||||
|
return nil, fmt.Errorf("punctuation.http_url is required when engine=http")
|
||||||
|
}
|
||||||
|
return HTTP{cfg: cfg}, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported punctuation engine %q (use: off, heuristic, xlm, sherpa, sherpa-online, http)", engine)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Nop struct{}
|
||||||
|
|
||||||
|
func (Nop) Active() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Nop) Restore(ctx context.Context, text, language string) (string, error) {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTP struct {
|
||||||
|
cfg config.Punctuation
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h HTTP) Active() bool { return true }
|
||||||
|
|
||||||
|
func (h HTTP) Restore(ctx context.Context, text, language string) (string, error) {
|
||||||
|
body, err := json.Marshal(map[string]string{
|
||||||
|
"text": text,
|
||||||
|
"language": language,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.cfg.HTTPURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := h.client
|
||||||
|
if client == nil {
|
||||||
|
client = &http.Client{Timeout: h.cfg.Timeout()}
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
return "", fmt.Errorf("punctuation http %s: %s", resp.Status, strings.TrimSpace(string(raw)))
|
||||||
|
}
|
||||||
|
var out struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return strings.TrimSpace(string(raw)), nil
|
||||||
|
}
|
||||||
|
if out.Text == "" {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
return out.Text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Apply(ctx context.Context, r Restorer, enabled bool, text, language string) (string, error) {
|
||||||
|
if !enabled || r == nil || !r.Active() {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
return r.Restore(ctx, text, language)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Close(r Restorer) {
|
||||||
|
if c, ok := r.(Closer); ok {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AutoSelect(cfg config.Punctuation) (Restorer, error) {
|
||||||
|
cfg = cfg.WithDefaults()
|
||||||
|
if !cfg.Active() {
|
||||||
|
return Nop{}, nil
|
||||||
|
}
|
||||||
|
engine := strings.ToLower(cfg.Engine)
|
||||||
|
if engine == "heuristic" {
|
||||||
|
return Heuristic{}, nil
|
||||||
|
}
|
||||||
|
if engine == "http" {
|
||||||
|
return New(cfg)
|
||||||
|
}
|
||||||
|
if engine == "xlm" || engine == "xlm-roberta" || engine == "roberta" {
|
||||||
|
return newXLM(cfg)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(cfg.ModelPath()); err == nil {
|
||||||
|
cfg.Engine = engine
|
||||||
|
if engine == "sherpa" || engine == "sherpa-offline" || engine == "offline" || engine == "" {
|
||||||
|
cfg.Engine = "sherpa"
|
||||||
|
}
|
||||||
|
return newSherpaRestorer(cfg)
|
||||||
|
}
|
||||||
|
if engine == "sherpa" || engine == "sherpa-online" || engine == "online" {
|
||||||
|
return nil, fmt.Errorf("punctuation model not found at %s (run: make download-punctuation-model)", cfg.ModelPath())
|
||||||
|
}
|
||||||
|
return Heuristic{}, nil
|
||||||
|
}
|
||||||
40
punctuation/punctuation_test.go
Normal file
40
punctuation/punctuation_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNop(t *testing.T) {
|
||||||
|
r, err := New(config.Punctuation{Engine: "off"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
out, err := Apply(context.Background(), r, true, "hello world", "en")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if out != "hello world" {
|
||||||
|
t.Fatalf("got %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApply_disabled(t *testing.T) {
|
||||||
|
r := Heuristic{}
|
||||||
|
out, err := Apply(context.Background(), r, false, "hello", "en")
|
||||||
|
if err != nil || out != "hello" {
|
||||||
|
t.Fatalf("got %q err=%v", out, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_heuristic(t *testing.T) {
|
||||||
|
r, err := New(config.Punctuation{Enabled: true, Engine: "heuristic"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !r.Active() {
|
||||||
|
t.Fatal("expected active")
|
||||||
|
}
|
||||||
|
}
|
||||||
107
punctuation/sherpa.go
Normal file
107
punctuation/sherpa.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
//go:build sherpa
|
||||||
|
|
||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
|
||||||
|
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sherpa struct {
|
||||||
|
offline *sherpa.OfflinePunctuation
|
||||||
|
online *sherpa.OnlinePunctuation
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) {
|
||||||
|
return newSherpa(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSherpa(cfg config.Punctuation) (*Sherpa, error) {
|
||||||
|
cfg = cfg.WithDefaults()
|
||||||
|
engine := strings.ToLower(cfg.Engine)
|
||||||
|
s := &Sherpa{}
|
||||||
|
switch engine {
|
||||||
|
case "sherpa-online", "online":
|
||||||
|
modelPath := cfg.ModelPath()
|
||||||
|
vocabPath := cfg.BpeVocabPath()
|
||||||
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("sherpa online punctuation model %q: %w", modelPath, err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(vocabPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("sherpa bpe vocab %q: %w", vocabPath, err)
|
||||||
|
}
|
||||||
|
conf := sherpa.OnlinePunctuationConfig{}
|
||||||
|
conf.Model.CnnBilstm = modelPath
|
||||||
|
conf.Model.BpeVocab = vocabPath
|
||||||
|
conf.Model.NumThreads = cfg.NumThreads
|
||||||
|
conf.Model.Provider = "cpu"
|
||||||
|
s.online = sherpa.NewOnlinePunctuation(&conf)
|
||||||
|
if s.online == nil {
|
||||||
|
return nil, fmt.Errorf("failed to create sherpa online punctuation")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
modelPath := cfg.ModelPath()
|
||||||
|
if _, err := os.Stat(modelPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("sherpa offline punctuation model %q: %w (run: make download-punctuation-model)", modelPath, err)
|
||||||
|
}
|
||||||
|
conf := sherpa.OfflinePunctuationConfig{}
|
||||||
|
conf.Model.CtTransformer = modelPath
|
||||||
|
conf.Model.NumThreads = cfg.NumThreads
|
||||||
|
conf.Model.Provider = "cpu"
|
||||||
|
s.offline = sherpa.NewOfflinePunctuation(&conf)
|
||||||
|
if s.offline == nil {
|
||||||
|
return nil, fmt.Errorf("failed to create sherpa offline punctuation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sherpa) Active() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sherpa) Restore(ctx context.Context, text, language string) (string, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = language
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var out string
|
||||||
|
switch {
|
||||||
|
case s.offline != nil:
|
||||||
|
out = s.offline.AddPunct(text)
|
||||||
|
case s.online != nil:
|
||||||
|
out = s.online.AddPunct(text)
|
||||||
|
default:
|
||||||
|
return text, fmt.Errorf("sherpa punctuation not initialized")
|
||||||
|
}
|
||||||
|
out = strings.TrimSpace(out)
|
||||||
|
if out == "" {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sherpa) Close() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.offline != nil {
|
||||||
|
sherpa.DeleteOfflinePunc(s.offline)
|
||||||
|
s.offline = nil
|
||||||
|
}
|
||||||
|
if s.online != nil {
|
||||||
|
sherpa.DeleteOnlinePunctuation(s.online)
|
||||||
|
s.online = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
13
punctuation/sherpa_stub.go
Normal file
13
punctuation/sherpa_stub.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
//go:build !sherpa
|
||||||
|
|
||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newSherpaRestorer(cfg config.Punctuation) (Restorer, error) {
|
||||||
|
return nil, fmt.Errorf("punctuation engine %q requires build tag sherpa (use: make build)", cfg.Engine)
|
||||||
|
}
|
||||||
319
punctuation/xlm.go
Normal file
319
punctuation/xlm.go
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
//go:build xlm
|
||||||
|
|
||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
"go-whisper-api/punctuation/internal/spwrap"
|
||||||
|
|
||||||
|
ort "github.com/yalue/onnxruntime_go"
|
||||||
|
)
|
||||||
|
|
||||||
|
type XLM struct {
|
||||||
|
cfg config.Punctuation
|
||||||
|
modelCfg xlmModelConfig
|
||||||
|
sp *spwrap.Processor
|
||||||
|
session *ort.DynamicAdvancedSession
|
||||||
|
inputName string
|
||||||
|
outputNames []string
|
||||||
|
joinSBD bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newXLM(cfg config.Punctuation) (*XLM, error) {
|
||||||
|
if err := ensureORT(); err != nil {
|
||||||
|
return nil, fmt.Errorf("onnxruntime: %w (set ONNXRUNTIME_SHARED_LIBRARY_PATH or install sherpa-onnx libs)", err)
|
||||||
|
}
|
||||||
|
onnxPath := cfg.ModelPath()
|
||||||
|
if _, err := os.Stat(onnxPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("xlm onnx model not found at %s (run: make download-xlm-punctuation-model)", onnxPath)
|
||||||
|
}
|
||||||
|
spPath := cfg.SPModelPath()
|
||||||
|
if _, err := os.Stat(spPath); err != nil {
|
||||||
|
return nil, fmt.Errorf("xlm sp.model not found at %s (run: make download-xlm-punctuation-model)", spPath)
|
||||||
|
}
|
||||||
|
cfgPath := cfg.XLMConfigPath()
|
||||||
|
modelCfg, err := loadXLMConfig(cfgPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sp, err := spwrap.Load(spPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inputs, outputs, err := ort.GetInputOutputInfo(onnxPath)
|
||||||
|
if err != nil {
|
||||||
|
sp.Close()
|
||||||
|
return nil, fmt.Errorf("xlm model io: %w", err)
|
||||||
|
}
|
||||||
|
if len(inputs) == 0 || len(outputs) < 4 {
|
||||||
|
sp.Close()
|
||||||
|
return nil, fmt.Errorf("xlm model: unexpected inputs/outputs")
|
||||||
|
}
|
||||||
|
inNames := make([]string, len(inputs))
|
||||||
|
for i, in := range inputs {
|
||||||
|
inNames[i] = in.Name
|
||||||
|
}
|
||||||
|
outNames := make([]string, len(outputs))
|
||||||
|
for i, out := range outputs {
|
||||||
|
outNames[i] = out.Name
|
||||||
|
}
|
||||||
|
session, err := ort.NewDynamicAdvancedSession(onnxPath, inNames, outNames, nil)
|
||||||
|
if err != nil {
|
||||||
|
sp.Close()
|
||||||
|
return nil, fmt.Errorf("xlm onnx session: %w", err)
|
||||||
|
}
|
||||||
|
return &XLM{
|
||||||
|
cfg: cfg,
|
||||||
|
modelCfg: modelCfg,
|
||||||
|
sp: sp,
|
||||||
|
session: session,
|
||||||
|
inputName: inNames[0],
|
||||||
|
outputNames: outNames,
|
||||||
|
joinSBD: cfg.XLMJoinSentences(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *XLM) Active() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *XLM) Close() {
|
||||||
|
if x.session != nil {
|
||||||
|
_ = x.session.Destroy()
|
||||||
|
x.session = nil
|
||||||
|
}
|
||||||
|
if x.sp != nil {
|
||||||
|
x.sp.Close()
|
||||||
|
x.sp = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *XLM) Restore(ctx context.Context, text, language string) (string, error) {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
text = strings.TrimSpace(normalizeXLMSpaces(text))
|
||||||
|
if text == "" {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ids, err := x.sp.EncodeAsIDs(text)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
full := make([]int, 0, len(ids)+2)
|
||||||
|
full = append(full, x.sp.BOSID())
|
||||||
|
full = append(full, ids...)
|
||||||
|
full = append(full, x.sp.EOSID())
|
||||||
|
maxLen := x.modelCfg.MaxLength
|
||||||
|
if maxLen <= 2 {
|
||||||
|
maxLen = 256
|
||||||
|
}
|
||||||
|
if len(full) <= maxLen {
|
||||||
|
return x.inferIDs(full)
|
||||||
|
}
|
||||||
|
var parts []string
|
||||||
|
content := full[1 : len(full)-1]
|
||||||
|
step := maxLen - 2
|
||||||
|
for start := 0; start < len(content); {
|
||||||
|
end := start + step
|
||||||
|
if end > len(content) {
|
||||||
|
end = len(content)
|
||||||
|
}
|
||||||
|
chunk := make([]int, 0, end-start+2)
|
||||||
|
chunk = append(chunk, x.sp.BOSID())
|
||||||
|
chunk = append(chunk, content[start:end]...)
|
||||||
|
chunk = append(chunk, x.sp.EOSID())
|
||||||
|
out, err := x.inferIDs(chunk)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if out != "" {
|
||||||
|
parts = append(parts, out)
|
||||||
|
}
|
||||||
|
if end >= len(content) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(strings.Join(parts, " ")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *XLM) inferIDs(inputIDs []int) (string, error) {
|
||||||
|
data := make([]int64, len(inputIDs))
|
||||||
|
for i, id := range inputIDs {
|
||||||
|
data[i] = int64(id)
|
||||||
|
}
|
||||||
|
inputTensor, err := ort.NewTensor(ort.NewShape(1, int64(len(inputIDs))), data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer inputTensor.Destroy()
|
||||||
|
outputs := make([]ort.Value, len(x.outputNames))
|
||||||
|
if err := x.session.Run([]ort.Value{inputTensor}, outputs); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer destroyValues(outputs)
|
||||||
|
pre, err := int64Row(outputs[0], len(inputIDs))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
post, err := int64Row(outputs[1], len(inputIDs))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cap, err := boolMatrix(outputs[2], len(inputIDs))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sbd, err := boolRow(outputs[3], len(inputIDs))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return decodeXLMSegment(x.sp, x.modelCfg, inputIDs, pre, post, cap, sbd, x.joinSBD)
|
||||||
|
}
|
||||||
|
|
||||||
|
func destroyValues(vals []ort.Value) {
|
||||||
|
for _, v := range vals {
|
||||||
|
if v != nil {
|
||||||
|
_ = v.Destroy()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func int64Row(v ort.Value, wantLen int) ([]int64, error) {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case *ort.Tensor[int64]:
|
||||||
|
d := t.GetData()
|
||||||
|
if len(d) == wantLen {
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
if len(d) > wantLen {
|
||||||
|
return d[len(d)-wantLen:], nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("int64 output short: %d < %d", len(d), wantLen)
|
||||||
|
case *ort.Tensor[int32]:
|
||||||
|
d := t.GetData()
|
||||||
|
if len(d) > wantLen {
|
||||||
|
d = d[len(d)-wantLen:]
|
||||||
|
}
|
||||||
|
out := make([]int64, wantLen)
|
||||||
|
for i := 0; i < wantLen && i < len(d); i++ {
|
||||||
|
out[i] = int64(d[i])
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected int output type %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolRow(v ort.Value, wantLen int) ([]bool, error) {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case *ort.Tensor[bool]:
|
||||||
|
d := t.GetData()
|
||||||
|
if len(d) == wantLen {
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
if len(d) > wantLen {
|
||||||
|
return d[len(d)-wantLen:], nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("bool output short")
|
||||||
|
case *ort.Tensor[float32]:
|
||||||
|
d := t.GetData()
|
||||||
|
out := make([]bool, wantLen)
|
||||||
|
for i := 0; i < wantLen && i < len(d); i++ {
|
||||||
|
out[i] = d[i] > 0.5
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected bool output type %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolMatrix(v ort.Value, seqLen int) ([][]bool, error) {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case *ort.Tensor[bool]:
|
||||||
|
shape := t.GetShape()
|
||||||
|
d := t.GetData()
|
||||||
|
if len(shape) == 3 {
|
||||||
|
_, sl, width := shape[0], shape[1], shape[2]
|
||||||
|
out := make([][]bool, sl)
|
||||||
|
for i := 0; i < int(sl); i++ {
|
||||||
|
row := make([]bool, width)
|
||||||
|
base := int(i) * int(width)
|
||||||
|
copy(row, d[base:base+int(width)])
|
||||||
|
out[i] = row
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
width := len(d) / seqLen
|
||||||
|
if width < 1 {
|
||||||
|
width = 1
|
||||||
|
}
|
||||||
|
out := make([][]bool, seqLen)
|
||||||
|
for i := 0; i < seqLen; i++ {
|
||||||
|
row := make([]bool, width)
|
||||||
|
base := i * width
|
||||||
|
if base+width <= len(d) {
|
||||||
|
copy(row, d[base:base+width])
|
||||||
|
}
|
||||||
|
out[i] = row
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
case *ort.Tensor[float32]:
|
||||||
|
shape := t.GetShape()
|
||||||
|
d := t.GetData()
|
||||||
|
if len(shape) == 3 {
|
||||||
|
_, sl, width := shape[0], shape[1], shape[2]
|
||||||
|
out := make([][]bool, sl)
|
||||||
|
for i := 0; i < int(sl); i++ {
|
||||||
|
row := make([]bool, width)
|
||||||
|
base := int(i) * int(width)
|
||||||
|
for j := 0; j < int(width); j++ {
|
||||||
|
row[j] = d[base+j] > 0.5
|
||||||
|
}
|
||||||
|
out[i] = row
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
width := len(d) / seqLen
|
||||||
|
if width < 1 {
|
||||||
|
width = 1
|
||||||
|
}
|
||||||
|
out := make([][]bool, seqLen)
|
||||||
|
for i := 0; i < seqLen; i++ {
|
||||||
|
row := make([]bool, width)
|
||||||
|
base := i * width
|
||||||
|
for j := 0; j < width && base+j < len(d); j++ {
|
||||||
|
row[j] = d[base+j] > 0.5
|
||||||
|
}
|
||||||
|
out[i] = row
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected cap output type %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeXLMSpaces(s string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
prevSpace := false
|
||||||
|
for _, r := range s {
|
||||||
|
if unicode.IsSpace(r) {
|
||||||
|
if !prevSpace {
|
||||||
|
b.WriteRune(' ')
|
||||||
|
prevSpace = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prevSpace = false
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
43
punctuation/xlm_config.go
Normal file
43
punctuation/xlm_config.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type xlmModelConfig struct {
|
||||||
|
Languages []string `yaml:"languages"`
|
||||||
|
MaxLength int `yaml:"max_length"`
|
||||||
|
PreLabels []string `yaml:"pre_labels"`
|
||||||
|
PostLabels []string `yaml:"post_labels"`
|
||||||
|
NullToken string `yaml:"null_token"`
|
||||||
|
Acronym string `yaml:"acronym_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadXLMConfig(path string) (xlmModelConfig, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return xlmModelConfig{}, err
|
||||||
|
}
|
||||||
|
var cfg xlmModelConfig
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return xlmModelConfig{}, fmt.Errorf("parse xlm config %s: %w", path, err)
|
||||||
|
}
|
||||||
|
if cfg.MaxLength <= 0 {
|
||||||
|
cfg.MaxLength = 256
|
||||||
|
}
|
||||||
|
if cfg.NullToken == "" {
|
||||||
|
cfg.NullToken = "<NULL>"
|
||||||
|
}
|
||||||
|
if cfg.Acronym == "" {
|
||||||
|
cfg.Acronym = "<ACRONYM>"
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultXLMConfigPath(modelDir string) string {
|
||||||
|
return filepath.Join(modelDir, "config.yaml")
|
||||||
|
}
|
||||||
97
punctuation/xlm_decode.go
Normal file
97
punctuation/xlm_decode.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
//go:build xlm
|
||||||
|
|
||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"go-whisper-api/punctuation/internal/spwrap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func decodeXLMSegment(
|
||||||
|
sp *spwrap.Processor,
|
||||||
|
cfg xlmModelConfig,
|
||||||
|
inputIDs []int,
|
||||||
|
prePred, postPred []int64,
|
||||||
|
capPred [][]bool,
|
||||||
|
sbdPred []bool,
|
||||||
|
joinSentences bool,
|
||||||
|
) (string, error) {
|
||||||
|
var outputTexts []string
|
||||||
|
current := make([]string, 0, len(inputIDs)*4)
|
||||||
|
for tokenIdx := 1; tokenIdx < len(inputIDs)-1; tokenIdx++ {
|
||||||
|
piece, err := sp.IDToPiece(inputIDs[tokenIdx])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(piece, "▁") && len(current) > 0 {
|
||||||
|
current = append(current, " ")
|
||||||
|
}
|
||||||
|
preLabel := labelAt(cfg.PreLabels, prePred, tokenIdx)
|
||||||
|
postLabel := labelAt(cfg.PostLabels, postPred, tokenIdx)
|
||||||
|
if preLabel != cfg.NullToken {
|
||||||
|
current = append(current, preLabel)
|
||||||
|
}
|
||||||
|
charStart := 0
|
||||||
|
if strings.HasPrefix(piece, "▁") {
|
||||||
|
charStart = 1
|
||||||
|
}
|
||||||
|
runes := []rune(piece)
|
||||||
|
for tokenCharIdx := charStart; tokenCharIdx < len(runes); tokenCharIdx++ {
|
||||||
|
ch := string(runes[tokenCharIdx])
|
||||||
|
if capAt(capPred, tokenIdx, tokenCharIdx) {
|
||||||
|
ch = strings.ToUpper(ch)
|
||||||
|
}
|
||||||
|
current = append(current, ch)
|
||||||
|
if postLabel == cfg.Acronym {
|
||||||
|
current = append(current, ".")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if postLabel != cfg.NullToken && postLabel != cfg.Acronym {
|
||||||
|
current = append(current, postLabel)
|
||||||
|
}
|
||||||
|
if sbdAt(sbdPred, tokenIdx) {
|
||||||
|
outputTexts = append(outputTexts, strings.Join(current, ""))
|
||||||
|
current = current[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(current) > 0 {
|
||||||
|
outputTexts = append(outputTexts, strings.Join(current, ""))
|
||||||
|
}
|
||||||
|
if len(outputTexts) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if joinSentences {
|
||||||
|
return strings.Join(outputTexts, " "), nil
|
||||||
|
}
|
||||||
|
return outputTexts[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func labelAt(labels []string, preds []int64, idx int) string {
|
||||||
|
if idx < 0 || idx >= len(preds) {
|
||||||
|
return labels[0]
|
||||||
|
}
|
||||||
|
pi := int(preds[idx])
|
||||||
|
if pi < 0 || pi >= len(labels) {
|
||||||
|
return labels[0]
|
||||||
|
}
|
||||||
|
return labels[pi]
|
||||||
|
}
|
||||||
|
|
||||||
|
func capAt(capPred [][]bool, tokenIdx, charIdx int) bool {
|
||||||
|
if tokenIdx < 0 || tokenIdx >= len(capPred) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
row := capPred[tokenIdx]
|
||||||
|
if charIdx < 0 || charIdx >= len(row) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return row[charIdx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func sbdAt(sbd []bool, tokenIdx int) bool {
|
||||||
|
if tokenIdx < 0 || tokenIdx >= len(sbd) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sbd[tokenIdx]
|
||||||
|
}
|
||||||
13
punctuation/xlm_stub.go
Normal file
13
punctuation/xlm_stub.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
//go:build !xlm
|
||||||
|
|
||||||
|
package punctuation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newXLM(cfg config.Punctuation) (Restorer, error) {
|
||||||
|
return nil, fmt.Errorf("punctuation engine %q requires build tag xlm (run: make build-xlm)", cfg.Engine)
|
||||||
|
}
|
||||||
132
transcode/aac_decode.go
Normal file
132
transcode/aac_decode.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/olivier-w/climp-aac-decoder/aacfile"
|
||||||
|
)
|
||||||
|
|
||||||
|
func decodeAACPath(path, ext string) ([]float64, int, int, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
st, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
// aacfile picks the parser from the *name* extension, not file content (.mp4 → .m4a).
|
||||||
|
name := aacOpenName(path, ext)
|
||||||
|
size := st.Size()
|
||||||
|
r, err := aacfile.Open(f, size, name)
|
||||||
|
if err != nil && isMP4SampleDeltaError(err) {
|
||||||
|
return decodeMP4AACRelaxed(f, size)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
sr := r.SampleRate()
|
||||||
|
ch := r.ChannelCount()
|
||||||
|
pcm, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, fmt.Errorf("read aac pcm: %w", err)
|
||||||
|
}
|
||||||
|
samples := pcm16LEToFloat(pcm, ch)
|
||||||
|
if ch > 1 {
|
||||||
|
samples = interleavedToMono(samples, ch)
|
||||||
|
ch = 1
|
||||||
|
}
|
||||||
|
return samples, sr, ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// aacContainerExt maps file extensions to a container name understood by aacfile.
|
||||||
|
func aacContainerExt(ext string) string {
|
||||||
|
switch strings.ToLower(ext) {
|
||||||
|
case ".mp4", ".m4v", ".mov", ".3gp", ".3g2":
|
||||||
|
return ".m4a"
|
||||||
|
case ".aac", ".m4a", ".m4b":
|
||||||
|
return ext
|
||||||
|
default:
|
||||||
|
return ".m4a"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func aacOpenName(path, ext string) string {
|
||||||
|
containerExt := aacContainerExt(ext)
|
||||||
|
if containerExt == "" {
|
||||||
|
containerExt = ".m4a"
|
||||||
|
}
|
||||||
|
base := filepath.Base(path)
|
||||||
|
if e := strings.ToLower(filepath.Ext(base)); e == containerExt {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
stem := strings.TrimSuffix(base, filepath.Ext(base))
|
||||||
|
if stem == "" || stem == base {
|
||||||
|
stem = "audio"
|
||||||
|
}
|
||||||
|
return stem + containerExt
|
||||||
|
}
|
||||||
|
|
||||||
|
func pcm16LEToFloat(pcm []byte, channels int) []float64 {
|
||||||
|
if channels <= 0 {
|
||||||
|
channels = 1
|
||||||
|
}
|
||||||
|
frameBytes := 2 * channels
|
||||||
|
nFrames := len(pcm) / frameBytes
|
||||||
|
out := make([]float64, nFrames*channels)
|
||||||
|
for i := 0; i < nFrames*channels; i++ {
|
||||||
|
off := i * 2
|
||||||
|
if off+1 >= len(pcm) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
v := int16(pcm[off]) | int16(pcm[off+1])<<8
|
||||||
|
out[i] = float64(v) / 32768.0
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func interleavedToMono(samples []float64, channels int) []float64 {
|
||||||
|
if channels <= 1 {
|
||||||
|
return samples
|
||||||
|
}
|
||||||
|
nFrames := len(samples) / channels
|
||||||
|
out := make([]float64, nFrames)
|
||||||
|
for i := 0; i < nFrames; i++ {
|
||||||
|
var sum float64
|
||||||
|
for c := 0; c < channels; c++ {
|
||||||
|
sum += samples[i*channels+c]
|
||||||
|
}
|
||||||
|
out[i] = sum / float64(channels)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func resampleLinear(samples []float64, fromRate, toRate int) []float64 {
|
||||||
|
if fromRate <= 0 || toRate <= 0 || fromRate == toRate || len(samples) == 0 {
|
||||||
|
return samples
|
||||||
|
}
|
||||||
|
ratio := float64(fromRate) / float64(toRate)
|
||||||
|
outLen := int(math.Round(float64(len(samples)) / ratio))
|
||||||
|
if outLen < 1 {
|
||||||
|
outLen = 1
|
||||||
|
}
|
||||||
|
out := make([]float64, outLen)
|
||||||
|
for i := 0; i < outLen; i++ {
|
||||||
|
src := float64(i) * ratio
|
||||||
|
j := int(src)
|
||||||
|
if j >= len(samples)-1 {
|
||||||
|
out[i] = samples[len(samples)-1]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
frac := src - float64(j)
|
||||||
|
out[i] = samples[j]*(1-frac) + samples[j+1]*frac
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
31
transcode/aac_decode_test.go
Normal file
31
transcode/aac_decode_test.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestAacOpenName_mp4(t *testing.T) {
|
||||||
|
got := aacOpenName("/tmp/cache/input.mp4", ".mp4")
|
||||||
|
if got != "input.m4a" {
|
||||||
|
t.Fatalf("got %q want input.m4a", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAacOpenName_noExt(t *testing.T) {
|
||||||
|
got := aacOpenName("/tmp/input", ".m4a")
|
||||||
|
if got != "audio.m4a" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAacContainerExt(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
".mp4": ".m4a",
|
||||||
|
".mov": ".m4a",
|
||||||
|
".aac": ".aac",
|
||||||
|
".m4a": ".m4a",
|
||||||
|
}
|
||||||
|
for in, want := range cases {
|
||||||
|
if got := aacContainerExt(in); got != want {
|
||||||
|
t.Fatalf("%s: got %q want %q", in, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
127
transcode/decode.go
Normal file
127
transcode/decode.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gopxl/beep"
|
||||||
|
"github.com/gopxl/beep/flac"
|
||||||
|
"github.com/gopxl/beep/mp3"
|
||||||
|
beepwav "github.com/gopxl/beep/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
var probeFormats = []string{
|
||||||
|
".wav", ".wave", ".mp3", ".flac", ".ogg", ".opus",
|
||||||
|
".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".3gp", ".3g2", ".aac",
|
||||||
|
}
|
||||||
|
|
||||||
|
func supportedFormatsMessage() string {
|
||||||
|
return strings.Join(probeFormats, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveFormat(path string) (string, error) {
|
||||||
|
ext := strings.ToLower(filepath.Ext(path))
|
||||||
|
if ext != "" {
|
||||||
|
return ext, nil
|
||||||
|
}
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
ext = sniffFormat(f)
|
||||||
|
if ext == "" {
|
||||||
|
return "", fmt.Errorf("could not detect audio format (supported: %s)", supportedFormatsMessage())
|
||||||
|
}
|
||||||
|
return ext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openDecoder(path string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
ext, err := resolveFormat(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, beep.Format{}, nil, err
|
||||||
|
}
|
||||||
|
streamer, format, closer, err := decodeByExt(path, ext)
|
||||||
|
if err == nil {
|
||||||
|
return streamer, format, closer, nil
|
||||||
|
}
|
||||||
|
for _, try := range probeFormats {
|
||||||
|
if try == ext {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
streamer, format, closer, tryErr := decodeByExt(path, try)
|
||||||
|
if tryErr == nil {
|
||||||
|
return streamer, format, closer, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("unsupported audio format %q (supported: %s): %w", ext, supportedFormatsMessage(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeByExt(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
switch ext {
|
||||||
|
case ".wav", ".wave":
|
||||||
|
return decodeBeepFile(path, ext)
|
||||||
|
case ".mp3":
|
||||||
|
return decodeBeepFile(path, ext)
|
||||||
|
case ".flac":
|
||||||
|
return decodeBeepFile(path, ext)
|
||||||
|
case ".ogg", ".opus":
|
||||||
|
return decodeOggFile(path)
|
||||||
|
case ".m4a", ".m4b", ".mp4", ".mov", ".m4v", ".aac":
|
||||||
|
return decodeAACAsStreamer(path, ext)
|
||||||
|
case ".webm":
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("webm is not supported yet")
|
||||||
|
default:
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("unsupported extension %q", ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBeepFile(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, beep.Format{}, nil, err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
streamer beep.StreamSeekCloser
|
||||||
|
format beep.Format
|
||||||
|
decErr error
|
||||||
|
)
|
||||||
|
switch ext {
|
||||||
|
case ".wav", ".wave":
|
||||||
|
streamer, format, decErr = beepwav.Decode(f)
|
||||||
|
case ".mp3":
|
||||||
|
streamer, format, decErr = mp3.Decode(f)
|
||||||
|
case ".flac":
|
||||||
|
streamer, format, decErr = flac.Decode(f)
|
||||||
|
default:
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("internal: beep decode for %q", ext)
|
||||||
|
}
|
||||||
|
if decErr != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, decErr
|
||||||
|
}
|
||||||
|
return streamer, format, f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeAACAsStreamer(path, ext string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
samples, sr, ch, err := decodeAACPath(path, ext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, beep.Format{}, nil, err
|
||||||
|
}
|
||||||
|
if ch <= 0 {
|
||||||
|
ch = 1
|
||||||
|
}
|
||||||
|
return newSamplesStreamer(samples, sr), beep.Format{
|
||||||
|
SampleRate: beep.SampleRate(sr),
|
||||||
|
NumChannels: ch,
|
||||||
|
Precision: 2,
|
||||||
|
}, noopCloser{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopCloser struct{}
|
||||||
|
|
||||||
|
func (noopCloser) Close() error { return nil }
|
||||||
119
transcode/engine.go
Normal file
119
transcode/engine.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Engine converts input audio to PCM WAV using pure Go decoders (no ffmpeg).
|
||||||
|
type Engine struct{}
|
||||||
|
|
||||||
|
// NewEngine creates a transcoder. The ffmpegPath argument is ignored (kept for config compatibility).
|
||||||
|
func NewEngine(_ string) *Engine {
|
||||||
|
return &Engine{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) Available() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) Transcode(ctx context.Context, src, dst string, opt Options) error {
|
||||||
|
if err := opt.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
spec, err := ResolveFormat(opt.Format)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst, err = OutputPath(dst, spec.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
streamer, format, closer, err := openDecoder(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer closer.Close()
|
||||||
|
s, format := buildPipeline(streamer, format, opt)
|
||||||
|
samples, err := drainSamples(ctx, s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ch := format.NumChannels
|
||||||
|
if opt.Channels > 0 {
|
||||||
|
ch = opt.Channels
|
||||||
|
}
|
||||||
|
sr := int(format.SampleRate)
|
||||||
|
if opt.SampleRate > 0 {
|
||||||
|
sr = opt.SampleRate
|
||||||
|
}
|
||||||
|
if err := writePCM16WAV(dst, sr, ch, samples); err != nil {
|
||||||
|
return fmt.Errorf("write wav: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Transcode(ctx context.Context, src, dst string, opt Options) error {
|
||||||
|
return NewEngine("").Transcode(ctx, src, dst, opt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToWhisperWAV(ctx context.Context, src, dst string) error {
|
||||||
|
return Transcode(ctx, src, dst, WhisperOptions())
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportedInputFormats lists file extensions decoded without external tools.
|
||||||
|
func SupportedInputFormats() []string {
|
||||||
|
return append([]string(nil), probeFormats...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) Probe(ctx context.Context, path string) (MediaInfo, error) {
|
||||||
|
_ = ctx
|
||||||
|
ext, err := resolveFormat(path)
|
||||||
|
if err != nil {
|
||||||
|
return MediaInfo{}, err
|
||||||
|
}
|
||||||
|
streamer, format, closer, err := openDecoder(path)
|
||||||
|
if err != nil {
|
||||||
|
return MediaInfo{}, err
|
||||||
|
}
|
||||||
|
defer closer.Close()
|
||||||
|
info := MediaInfo{
|
||||||
|
Format: ext,
|
||||||
|
Streams: []StreamInfo{{
|
||||||
|
Index: 0,
|
||||||
|
Codec: extTrim(ext),
|
||||||
|
Type: "audio",
|
||||||
|
SampleRate: int(format.SampleRate),
|
||||||
|
Channels: format.NumChannels,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
if st, err := os.Stat(path); err == nil {
|
||||||
|
info.BitRate = st.Size() * 8
|
||||||
|
}
|
||||||
|
_ = streamer
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extTrim(ext string) string {
|
||||||
|
if len(ext) > 0 && ext[0] == '.' {
|
||||||
|
return ext[1:]
|
||||||
|
}
|
||||||
|
return ext
|
||||||
|
}
|
||||||
|
|
||||||
|
// MediaInfo describes decoded input (for optional diagnostics).
|
||||||
|
type MediaInfo struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
Duration float64 `json:"duration_seconds"`
|
||||||
|
BitRate int64 `json:"bit_rate"`
|
||||||
|
Streams []StreamInfo `json:"streams"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamInfo struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Codec string `json:"codec"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
SampleRate int `json:"sample_rate,omitempty"`
|
||||||
|
Channels int `json:"channels,omitempty"`
|
||||||
|
}
|
||||||
85
transcode/engine_test.go
Normal file
85
transcode/engine_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func samplePath(t *testing.T, name string) string {
|
||||||
|
t.Helper()
|
||||||
|
p := filepath.Join("..", "third_party", "whisper.cpp", "samples", name)
|
||||||
|
if _, err := os.Stat(p); err != nil {
|
||||||
|
t.Skip("sample not found:", p)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToWhisperWAV_mp3(t *testing.T) {
|
||||||
|
src := samplePath(t, "jfk.mp3")
|
||||||
|
dst := filepath.Join(t.TempDir(), "out.wav")
|
||||||
|
if err := ToWhisperWAV(context.Background(), src, dst); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertWhisperWAV(t, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToWhisperWAV_wav(t *testing.T) {
|
||||||
|
src := samplePath(t, "jfk.wav")
|
||||||
|
dst := filepath.Join(t.TempDir(), "out.wav")
|
||||||
|
if err := ToWhisperWAV(context.Background(), src, dst); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
assertWhisperWAV(t, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveFormat_noExtension_mp3(t *testing.T) {
|
||||||
|
src := samplePath(t, "jfk.mp3")
|
||||||
|
data, err := os.ReadFile(src)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
path := filepath.Join(t.TempDir(), "upload")
|
||||||
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ext, err := resolveFormat(path)
|
||||||
|
if err != nil || ext != ".mp3" {
|
||||||
|
t.Fatalf("ext=%q err=%v", ext, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Available(t *testing.T) {
|
||||||
|
if err := NewEngine("").Available(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWhisperWAV(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
dec := wav.NewDecoder(f)
|
||||||
|
if !dec.IsValidFile() {
|
||||||
|
t.Fatal("invalid wav")
|
||||||
|
}
|
||||||
|
buf, err := dec.FullPCMBuffer()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.SampleRate != 16000 {
|
||||||
|
t.Fatalf("sample rate %d", dec.SampleRate)
|
||||||
|
}
|
||||||
|
if dec.NumChans != 1 {
|
||||||
|
t.Fatalf("channels %d", dec.NumChans)
|
||||||
|
}
|
||||||
|
if len(buf.Data) == 0 {
|
||||||
|
t.Fatal("empty audio")
|
||||||
|
}
|
||||||
|
}
|
||||||
47
transcode/format.go
Normal file
47
transcode/format.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
FormatWAV = "wav"
|
||||||
|
FormatPCM = "pcm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FormatSpec struct {
|
||||||
|
ID string
|
||||||
|
Extension string
|
||||||
|
Codec string
|
||||||
|
}
|
||||||
|
|
||||||
|
var formats = map[string]FormatSpec{
|
||||||
|
FormatWAV: {ID: FormatWAV, Extension: ".wav", Codec: "pcm_s16le"},
|
||||||
|
FormatPCM: {ID: FormatPCM, Extension: ".wav", Codec: "pcm_s16le"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveFormat(name string) (FormatSpec, error) {
|
||||||
|
name = strings.ToLower(strings.TrimSpace(name))
|
||||||
|
if name == "" {
|
||||||
|
name = FormatWAV
|
||||||
|
}
|
||||||
|
spec, ok := formats[name]
|
||||||
|
if !ok {
|
||||||
|
return FormatSpec{}, fmt.Errorf("unsupported format %q (supported: wav)", name)
|
||||||
|
}
|
||||||
|
return spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OutputPath(dst, format string) (string, error) {
|
||||||
|
spec, err := ResolveFormat(format)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
ext := filepath.Ext(dst)
|
||||||
|
if ext == "" {
|
||||||
|
return dst + spec.Extension, nil
|
||||||
|
}
|
||||||
|
return dst, nil
|
||||||
|
}
|
||||||
288
transcode/mp4_aac_decode.go
Normal file
288
transcode/mp4_aac_decode.go
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/Eyevinn/mp4ff/mp4"
|
||||||
|
"github.com/olivier-w/climp-aac-decoder/aacfile"
|
||||||
|
aacdec "github.com/skrashevich/go-aac/pkg/decoder"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mp4AACSample struct {
|
||||||
|
offset int64
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMP4SampleDeltaError(err error) bool {
|
||||||
|
var uf *aacfile.UnsupportedFeatureError
|
||||||
|
if !errors.As(err, &uf) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return uf.Feature == "MP4 sample delta" || uf.Feature == "MP4 sample delta layout"
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeMP4AACRelaxed demuxes MP4/M4A with mp4ff (ignoring stts sample deltas) and
|
||||||
|
// decodes raw AAC frames with go-aac. Used when climp-aac-decoder rejects stts
|
||||||
|
// entries whose delta is not exactly 1024 (common in ffmpeg/phone muxers).
|
||||||
|
func decodeMP4AACRelaxed(r io.ReaderAt, size int64) ([]float64, int, int, error) {
|
||||||
|
asc, samples, leading, err := demuxMP4AAC(r, size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, err
|
||||||
|
}
|
||||||
|
dec := aacdec.New()
|
||||||
|
if err := dec.SetASC(asc); err != nil {
|
||||||
|
return nil, 0, 0, fmt.Errorf("aac config: %w", err)
|
||||||
|
}
|
||||||
|
ch := dec.Config.ChanConfig
|
||||||
|
if ch < 1 {
|
||||||
|
return nil, 0, 0, fmt.Errorf("aac config: invalid channel count %d", ch)
|
||||||
|
}
|
||||||
|
sr := dec.Config.SampleRate
|
||||||
|
if sr <= 0 {
|
||||||
|
return nil, 0, 0, fmt.Errorf("aac config: invalid sample rate %d", sr)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxSize := 0
|
||||||
|
for _, s := range samples {
|
||||||
|
if s.size > maxSize {
|
||||||
|
maxSize = s.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf := make([]byte, maxSize)
|
||||||
|
var pcm []float32
|
||||||
|
for i, s := range samples {
|
||||||
|
if cap(buf) < s.size {
|
||||||
|
buf = make([]byte, s.size)
|
||||||
|
}
|
||||||
|
frame := buf[:s.size]
|
||||||
|
if _, err := r.ReadAt(frame, s.offset); err != nil {
|
||||||
|
return nil, 0, 0, fmt.Errorf("read mp4 aac sample %d: %w", i, err)
|
||||||
|
}
|
||||||
|
out, err := dec.DecodeFrame(frame)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, 0, fmt.Errorf("decode mp4 aac sample %d: %w", i, err)
|
||||||
|
}
|
||||||
|
pcm = append(pcm, out...)
|
||||||
|
}
|
||||||
|
|
||||||
|
skipSamples := leading * ch
|
||||||
|
if skipSamples > len(pcm) {
|
||||||
|
skipSamples = len(pcm)
|
||||||
|
}
|
||||||
|
pcm = pcm[skipSamples:]
|
||||||
|
|
||||||
|
samplesF64 := make([]float64, len(pcm))
|
||||||
|
for i, v := range pcm {
|
||||||
|
samplesF64[i] = float64(v)
|
||||||
|
}
|
||||||
|
if ch > 1 {
|
||||||
|
samplesF64 = float32InterleavedToMono(samplesF64, ch)
|
||||||
|
ch = 1
|
||||||
|
}
|
||||||
|
return samplesF64, sr, ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func float32InterleavedToMono(samples []float64, channels int) []float64 {
|
||||||
|
if channels <= 1 {
|
||||||
|
return samples
|
||||||
|
}
|
||||||
|
nFrames := len(samples) / channels
|
||||||
|
out := make([]float64, nFrames)
|
||||||
|
for i := 0; i < nFrames; i++ {
|
||||||
|
var sum float64
|
||||||
|
for c := 0; c < channels; c++ {
|
||||||
|
sum += samples[i*channels+c]
|
||||||
|
}
|
||||||
|
out[i] = sum / float64(channels)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func demuxMP4AAC(r io.ReaderAt, size int64) (asc []byte, samples []mp4AACSample, leading int, err error) {
|
||||||
|
file, err := mp4.DecodeFile(io.NewSectionReader(r, 0, size), mp4.WithDecodeMode(mp4.DecModeLazyMdat))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4 decode: %w", err)
|
||||||
|
}
|
||||||
|
if file.IsFragmented() {
|
||||||
|
return nil, nil, 0, fmt.Errorf("fragmented mp4 is not supported")
|
||||||
|
}
|
||||||
|
if file.Moov == nil {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: missing moov")
|
||||||
|
}
|
||||||
|
|
||||||
|
var audioTracks []*mp4.TrakBox
|
||||||
|
for _, trak := range file.Moov.Traks {
|
||||||
|
if trak != nil && trak.Mdia != nil && trak.Mdia.Hdlr != nil && trak.Mdia.Hdlr.HandlerType == "soun" {
|
||||||
|
audioTracks = append(audioTracks, trak)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(audioTracks) != 1 {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: expected one audio track, found %d", len(audioTracks))
|
||||||
|
}
|
||||||
|
trak := audioTracks[0]
|
||||||
|
if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil || trak.Mdia.Minf.Stbl.Stsd == nil {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: incomplete audio track")
|
||||||
|
}
|
||||||
|
stsd := trak.Mdia.Minf.Stbl.Stsd
|
||||||
|
if len(stsd.Children) != 1 {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: multiple sample descriptions")
|
||||||
|
}
|
||||||
|
if stsd.Enca != nil {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: encrypted audio")
|
||||||
|
}
|
||||||
|
sampleEntry := stsd.Mp4a
|
||||||
|
if sampleEntry == nil {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: unsupported audio sample entry %s", stsd.Children[0].Type())
|
||||||
|
}
|
||||||
|
if sampleEntry.Sinf != nil {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: encrypted audio")
|
||||||
|
}
|
||||||
|
if sampleEntry.Esds == nil ||
|
||||||
|
sampleEntry.Esds.DecConfigDescriptor == nil ||
|
||||||
|
sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo == nil ||
|
||||||
|
len(sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig) == 0 {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: missing AudioSpecificConfig")
|
||||||
|
}
|
||||||
|
asc = append([]byte(nil), sampleEntry.Esds.DecConfigDescriptor.DecSpecificInfo.DecConfig...)
|
||||||
|
|
||||||
|
leading, _ = mp4LeadingTrimRelaxed(trak)
|
||||||
|
|
||||||
|
samples, err = buildMP4AACSamples(trak, size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, 0, err
|
||||||
|
}
|
||||||
|
if len(samples) == 0 {
|
||||||
|
return nil, nil, 0, fmt.Errorf("mp4: no audio samples")
|
||||||
|
}
|
||||||
|
return asc, samples, leading, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mp4LeadingTrimRelaxed(trak *mp4.TrakBox) (int, error) {
|
||||||
|
if trak.Edts == nil || len(trak.Edts.Elst) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if len(trak.Edts.Elst) != 1 || len(trak.Edts.Elst[0].Entries) != 1 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
entry := trak.Edts.Elst[0].Entries[0]
|
||||||
|
if entry.MediaRateInteger != 1 || entry.MediaRateFraction != 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if entry.MediaTime < 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return int(entry.MediaTime), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMP4AACSamples(trak *mp4.TrakBox, size int64) ([]mp4AACSample, error) {
|
||||||
|
if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil {
|
||||||
|
return nil, fmt.Errorf("mp4: incomplete sample table")
|
||||||
|
}
|
||||||
|
stbl := trak.Mdia.Minf.Stbl
|
||||||
|
if stbl.Stsc == nil || stbl.Stsz == nil {
|
||||||
|
return nil, fmt.Errorf("mp4: incomplete sample table")
|
||||||
|
}
|
||||||
|
if stbl.Stco == nil && stbl.Co64 == nil {
|
||||||
|
return nil, fmt.Errorf("mp4: missing chunk offsets")
|
||||||
|
}
|
||||||
|
if len(stbl.Stsc.Entries) == 0 {
|
||||||
|
return nil, fmt.Errorf("mp4: empty chunk map")
|
||||||
|
}
|
||||||
|
|
||||||
|
totalSamples := int(trak.GetNrSamples())
|
||||||
|
if totalSamples <= 0 {
|
||||||
|
return nil, fmt.Errorf("mp4: empty sample table")
|
||||||
|
}
|
||||||
|
|
||||||
|
sampleSizes, err := mp4AACSampleSizes(stbl.Stsz, totalSamples)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
chunkOffsets, err := mp4AACChunkOffsets(stbl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]mp4AACSample, 0, totalSamples)
|
||||||
|
sampleIndex := 0
|
||||||
|
entryIndex := 0
|
||||||
|
entry := stbl.Stsc.Entries[entryIndex]
|
||||||
|
|
||||||
|
for chunkIndex := 0; chunkIndex < len(chunkOffsets) && sampleIndex < totalSamples; chunkIndex++ {
|
||||||
|
chunkNr := uint32(chunkIndex + 1)
|
||||||
|
for entryIndex+1 < len(stbl.Stsc.Entries) && chunkNr >= stbl.Stsc.Entries[entryIndex+1].FirstChunk {
|
||||||
|
entryIndex++
|
||||||
|
entry = stbl.Stsc.Entries[entryIndex]
|
||||||
|
}
|
||||||
|
if entry.SamplesPerChunk == 0 {
|
||||||
|
return nil, fmt.Errorf("mp4: zero samples per chunk")
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := chunkOffsets[chunkIndex]
|
||||||
|
samplesPerChunk := int(entry.SamplesPerChunk)
|
||||||
|
for i := 0; i < samplesPerChunk && sampleIndex < totalSamples; i++ {
|
||||||
|
sampleSize := sampleSizes[sampleIndex]
|
||||||
|
end := offset + int64(sampleSize)
|
||||||
|
if offset < 0 || end < offset || end > size {
|
||||||
|
return nil, fmt.Errorf("mp4: invalid sample bounds at sample %d", sampleIndex+1)
|
||||||
|
}
|
||||||
|
out = append(out, mp4AACSample{offset: offset, size: sampleSize})
|
||||||
|
offset = end
|
||||||
|
sampleIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sampleIndex != totalSamples {
|
||||||
|
return nil, fmt.Errorf("mp4: sample table mismatch")
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mp4AACSampleSizes(stsz *mp4.StszBox, totalSamples int) ([]int, error) {
|
||||||
|
if stsz == nil {
|
||||||
|
return nil, fmt.Errorf("mp4: missing sample sizes")
|
||||||
|
}
|
||||||
|
if int(stsz.GetNrSamples()) != totalSamples {
|
||||||
|
return nil, fmt.Errorf("mp4: sample size count mismatch")
|
||||||
|
}
|
||||||
|
sizes := make([]int, totalSamples)
|
||||||
|
if stsz.SampleUniformSize != 0 {
|
||||||
|
sz := int(stsz.SampleUniformSize)
|
||||||
|
for i := range sizes {
|
||||||
|
sizes[i] = sz
|
||||||
|
}
|
||||||
|
return sizes, nil
|
||||||
|
}
|
||||||
|
if len(stsz.SampleSize) != totalSamples {
|
||||||
|
return nil, fmt.Errorf("mp4: sample size table mismatch")
|
||||||
|
}
|
||||||
|
for i, sz := range stsz.SampleSize {
|
||||||
|
sizes[i] = int(sz)
|
||||||
|
}
|
||||||
|
return sizes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mp4AACChunkOffsets(stbl *mp4.StblBox) ([]int64, error) {
|
||||||
|
switch {
|
||||||
|
case stbl == nil:
|
||||||
|
return nil, fmt.Errorf("mp4: incomplete sample table")
|
||||||
|
case stbl.Stco != nil:
|
||||||
|
offsets := make([]int64, len(stbl.Stco.ChunkOffset))
|
||||||
|
for i, off := range stbl.Stco.ChunkOffset {
|
||||||
|
offsets[i] = int64(off)
|
||||||
|
}
|
||||||
|
return offsets, nil
|
||||||
|
case stbl.Co64 != nil:
|
||||||
|
offsets := make([]int64, len(stbl.Co64.ChunkOffset))
|
||||||
|
for i, off := range stbl.Co64.ChunkOffset {
|
||||||
|
if off > uint64(^uint64(0)>>1) {
|
||||||
|
return nil, fmt.Errorf("mp4: invalid chunk offset")
|
||||||
|
}
|
||||||
|
offsets[i] = int64(off)
|
||||||
|
}
|
||||||
|
return offsets, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("mp4: missing chunk offsets")
|
||||||
|
}
|
||||||
|
}
|
||||||
154
transcode/ogg_decode.go
Normal file
154
transcode/ogg_decode.go
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gopxl/beep"
|
||||||
|
"github.com/gopxl/beep/vorbis"
|
||||||
|
"github.com/pion/opus"
|
||||||
|
"github.com/pion/opus/pkg/oggreader"
|
||||||
|
)
|
||||||
|
|
||||||
|
func decodeOggFile(path string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
switch sniffOggCodec(path) {
|
||||||
|
case "opus":
|
||||||
|
return decodeOggOpus(path)
|
||||||
|
case "vorbis":
|
||||||
|
return decodeOggVorbis(path)
|
||||||
|
default:
|
||||||
|
streamer, format, closer, err := decodeOggVorbis(path)
|
||||||
|
if err == nil {
|
||||||
|
return streamer, format, closer, nil
|
||||||
|
}
|
||||||
|
if isVorbisInvalidHeader(err) {
|
||||||
|
return decodeOggOpus(path)
|
||||||
|
}
|
||||||
|
return nil, beep.Format{}, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sniffOggCodec(path string) string {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
buf := make([]byte, 8192)
|
||||||
|
n, _ := io.ReadFull(f, buf)
|
||||||
|
buf = buf[:n]
|
||||||
|
if len(buf) < 4 || !bytes.HasPrefix(buf, []byte("OggS")) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if bytes.Contains(buf, []byte("OpusHead")) {
|
||||||
|
return "opus"
|
||||||
|
}
|
||||||
|
// Vorbis ID packet: 0x01 + "vorbis"
|
||||||
|
if bytes.Contains(buf, []byte{0x01, 'v', 'o', 'r', 'b', 'i', 's'}) {
|
||||||
|
return "vorbis"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isVorbisInvalidHeader(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "invalid header") || strings.Contains(msg, "vorbis:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeOggVorbis(path string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, beep.Format{}, nil, err
|
||||||
|
}
|
||||||
|
streamer, format, decErr := vorbis.Decode(f)
|
||||||
|
if decErr != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("ogg/vorbis: %w", decErr)
|
||||||
|
}
|
||||||
|
return streamer, format, f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeOggOpus(path string) (beep.Streamer, beep.Format, io.Closer, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, beep.Format{}, nil, err
|
||||||
|
}
|
||||||
|
ogg, header, err := oggreader.NewWith(f)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sr := int(header.SampleRate)
|
||||||
|
if sr <= 0 {
|
||||||
|
sr = 48000
|
||||||
|
}
|
||||||
|
ch := int(header.Channels)
|
||||||
|
if ch <= 0 {
|
||||||
|
ch = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, err := opus.NewDecoderWithOutput(sr, ch)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decoder: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxFrameSamples = 5760
|
||||||
|
pcmBuf := make([]float32, maxFrameSamples*ch)
|
||||||
|
var samples []float64
|
||||||
|
|
||||||
|
for {
|
||||||
|
pkt, _, err := ogg.ParseNextPacket()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus read: %w", err)
|
||||||
|
}
|
||||||
|
if len(pkt) == 0 || bytes.HasPrefix(pkt, []byte("OpusHead")) || bytes.HasPrefix(pkt, []byte("OpusTags")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dec.DecodeToFloat32(pkt, pcmBuf)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus decode: %w", err)
|
||||||
|
}
|
||||||
|
if n <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
total := n * ch
|
||||||
|
if total > len(pcmBuf) {
|
||||||
|
total = len(pcmBuf)
|
||||||
|
}
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
samples = append(samples, float64(pcmBuf[i]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(samples) == 0 {
|
||||||
|
f.Close()
|
||||||
|
return nil, beep.Format{}, nil, fmt.Errorf("ogg/opus: no audio samples")
|
||||||
|
}
|
||||||
|
|
||||||
|
outCh := ch
|
||||||
|
if outCh > 1 {
|
||||||
|
samples = interleavedToMono(samples, outCh)
|
||||||
|
outCh = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return newSamplesStreamer(samples, sr), beep.Format{
|
||||||
|
SampleRate: beep.SampleRate(sr),
|
||||||
|
NumChannels: outCh,
|
||||||
|
Precision: 2,
|
||||||
|
}, f, nil
|
||||||
|
}
|
||||||
52
transcode/ogg_decode_test.go
Normal file
52
transcode/ogg_decode_test.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSniffOggCodec_opus(t *testing.T) {
|
||||||
|
path := pionTinyOggPath(t)
|
||||||
|
if got := sniffOggCodec(path); got != "opus" {
|
||||||
|
t.Fatalf("sniffOggCodec() = %q want opus", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeOggOpus_pionTiny(t *testing.T) {
|
||||||
|
path := pionTinyOggPath(t)
|
||||||
|
streamer, format, closer, err := decodeOggOpus(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer closer.Close()
|
||||||
|
if format.SampleRate == 0 {
|
||||||
|
t.Fatal("zero sample rate")
|
||||||
|
}
|
||||||
|
buf := make([][2]float64, 4096)
|
||||||
|
n, ok := streamer.Stream(buf)
|
||||||
|
if !ok || n == 0 {
|
||||||
|
t.Fatal("expected pcm samples")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pionTinyOggPath(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
_, file, _, ok := runtime.Caller(0)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("runtime.Caller failed")
|
||||||
|
}
|
||||||
|
modRoot := filepath.Clean(filepath.Join(filepath.Dir(file), ".."))
|
||||||
|
cache := os.Getenv("GOMODCACHE")
|
||||||
|
if cache == "" {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
cache = filepath.Join(home, "go", "pkg", "mod")
|
||||||
|
}
|
||||||
|
path := filepath.Join(cache, "github.com/pion/opus@v0.0.0-20260601214817-71d58474cec8/testdata/tiny.ogg")
|
||||||
|
if _, err := os.Stat(path); err != nil {
|
||||||
|
_ = modRoot
|
||||||
|
t.Skipf("pion testdata not in module cache: %v", err)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
43
transcode/options.go
Normal file
43
transcode/options.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Format string
|
||||||
|
SampleRate int
|
||||||
|
Channels int
|
||||||
|
Codec string
|
||||||
|
}
|
||||||
|
|
||||||
|
func WhisperOptions() Options {
|
||||||
|
return Options{
|
||||||
|
Format: FormatWAV,
|
||||||
|
SampleRate: 16000,
|
||||||
|
Channels: 1,
|
||||||
|
Codec: "pcm_s16le",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) ApplyDefaults() error {
|
||||||
|
spec, err := ResolveFormat(o.Format)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if o.Codec == "" {
|
||||||
|
o.Codec = spec.Codec
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) Validate() error {
|
||||||
|
if err := o.ApplyDefaults(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if o.SampleRate < 0 || o.Channels < 0 {
|
||||||
|
return fmt.Errorf("sample_rate and channels must be >= 0")
|
||||||
|
}
|
||||||
|
if o.Channels > 8 {
|
||||||
|
return fmt.Errorf("channels must be <= 8")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
20
transcode/options_test.go
Normal file
20
transcode/options_test.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestWhisperOptions(t *testing.T) {
|
||||||
|
o := WhisperOptions()
|
||||||
|
if err := o.Validate(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if o.SampleRate != 16000 || o.Channels != 1 {
|
||||||
|
t.Fatalf("unexpected whisper opts: %+v", o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveFormat_unknown(t *testing.T) {
|
||||||
|
_, err := ResolveFormat("xyz")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
38
transcode/samples_stream.go
Normal file
38
transcode/samples_stream.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import "github.com/gopxl/beep"
|
||||||
|
|
||||||
|
type samplesStreamer struct {
|
||||||
|
samples []float64
|
||||||
|
pos int
|
||||||
|
sampleRate beep.SampleRate
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSamplesStreamer(samples []float64, sampleRate int) *samplesStreamer {
|
||||||
|
return &samplesStreamer{
|
||||||
|
samples: samples,
|
||||||
|
sampleRate: beep.SampleRate(sampleRate),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *samplesStreamer) Stream(buf [][2]float64) (int, bool) {
|
||||||
|
if s.pos >= len(s.samples) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
n := 0
|
||||||
|
for i := range buf {
|
||||||
|
if s.pos >= len(s.samples) {
|
||||||
|
return n, n > 0
|
||||||
|
}
|
||||||
|
v := s.samples[s.pos]
|
||||||
|
buf[i][0] = v
|
||||||
|
buf[i][1] = v
|
||||||
|
s.pos++
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
return n, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *samplesStreamer) Err() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
44
transcode/sniff.go
Normal file
44
transcode/sniff.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sniffFormat detects container/codec from file header when the path has no extension.
|
||||||
|
func sniffFormat(r io.Reader) string {
|
||||||
|
head := make([]byte, 32)
|
||||||
|
n, _ := io.ReadFull(r, head)
|
||||||
|
head = head[:n]
|
||||||
|
if len(head) < 4 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte("RIFF")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("WAVE")) {
|
||||||
|
return ".wav"
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte("ID3")) {
|
||||||
|
return ".mp3"
|
||||||
|
}
|
||||||
|
if len(head) >= 2 && head[0] == 0xFF && (head[1]&0xE0) == 0xE0 {
|
||||||
|
return ".mp3"
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte("fLaC")) {
|
||||||
|
return ".flac"
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte("OggS")) {
|
||||||
|
return ".ogg"
|
||||||
|
}
|
||||||
|
if len(head) >= 8 && bytes.Equal(head[4:8], []byte("ftyp")) {
|
||||||
|
return ".m4a"
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte{0xFF, 0xF1}) || bytes.HasPrefix(head, []byte{0xFF, 0xF9}) {
|
||||||
|
return ".aac"
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte("FORM")) && len(head) >= 12 && bytes.Equal(head[8:12], []byte("AIFF")) {
|
||||||
|
return ".aiff"
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(head, []byte{0x1A, 0x45, 0xDF, 0xA3}) {
|
||||||
|
return ".webm"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
42
transcode/stream.go
Normal file
42
transcode/stream.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/gopxl/beep"
|
||||||
|
"github.com/gopxl/beep/effects"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildPipeline(streamer beep.Streamer, format beep.Format, opt Options) (beep.Streamer, beep.Format) {
|
||||||
|
out := streamer
|
||||||
|
if opt.SampleRate > 0 && format.SampleRate != beep.SampleRate(opt.SampleRate) {
|
||||||
|
out = beep.Resample(4, format.SampleRate, beep.SampleRate(opt.SampleRate), out)
|
||||||
|
format.SampleRate = beep.SampleRate(opt.SampleRate)
|
||||||
|
}
|
||||||
|
if opt.Channels == 1 {
|
||||||
|
out = effects.Mono(out)
|
||||||
|
format.NumChannels = 1
|
||||||
|
}
|
||||||
|
return out, format
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainSamples(ctx context.Context, s beep.Streamer) ([]float64, error) {
|
||||||
|
buf := make([][2]float64, 4096)
|
||||||
|
var samples []float64
|
||||||
|
for {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
n, ok := s.Stream(buf)
|
||||||
|
if !ok {
|
||||||
|
if err := s.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
samples = append(samples, buf[i][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return samples, nil
|
||||||
|
}
|
||||||
53
transcode/wav_out.go
Normal file
53
transcode/wav_out.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package transcode
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/go-audio/audio"
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writePCM16WAV(path string, sampleRate int, channels int, samples []float64) error {
|
||||||
|
if channels <= 0 {
|
||||||
|
channels = 1
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(dirOf(path), 0o755); err != nil && dirOf(path) != "." {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
enc := wav.NewEncoder(f, sampleRate, 16, channels, 1)
|
||||||
|
data := make([]int, len(samples))
|
||||||
|
for i, s := range samples {
|
||||||
|
data[i] = floatToInt16(s)
|
||||||
|
}
|
||||||
|
if err := enc.Write(&audio.IntBuffer{
|
||||||
|
Format: &audio.Format{SampleRate: sampleRate, NumChannels: channels},
|
||||||
|
Data: data,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return enc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func floatToInt16(f float64) int {
|
||||||
|
if f > 1 {
|
||||||
|
f = 1
|
||||||
|
}
|
||||||
|
if f < -1 {
|
||||||
|
f = -1
|
||||||
|
}
|
||||||
|
return int(f * 32767)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dirOf(path string) string {
|
||||||
|
for i := len(path) - 1; i >= 0; i-- {
|
||||||
|
if path[i] == '/' || path[i] == '\\' {
|
||||||
|
return path[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "."
|
||||||
|
}
|
||||||
11
whisper/audio.go
Normal file
11
whisper/audio.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"go-whisper-api/transcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AudioToWav(src, dst string) error {
|
||||||
|
return transcode.ToWhisperWAV(context.Background(), src, dst)
|
||||||
|
}
|
||||||
60
whisper/audio_load.go
Normal file
60
whisper/audio_load.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"go-whisper-api/config"
|
||||||
|
|
||||||
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
"github.com/go-audio/wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadPCM16Mono reads 16 kHz mono WAV into float32 samples (for diarization).
|
||||||
|
func LoadPCM16Mono(path string) ([]float32, error) {
|
||||||
|
return loadPCM16Mono(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadPCM16Mono(path string) ([]float32, error) {
|
||||||
|
fh, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer fh.Close()
|
||||||
|
dec := wav.NewDecoder(fh)
|
||||||
|
buf, err := dec.FullPCMBuffer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if dec.SampleRate != whisper.SampleRate {
|
||||||
|
return nil, fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
|
||||||
|
}
|
||||||
|
if dec.NumChans != 1 {
|
||||||
|
return nil, fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
|
||||||
|
}
|
||||||
|
return buf.AsFloat32Buffer().Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareAudioPCM(sourcePath string) (data []float32, cleanup func(), err error) {
|
||||||
|
cleanup = func() {}
|
||||||
|
if data, err = loadPCM16Mono(sourcePath); err == nil {
|
||||||
|
return data, cleanup, nil
|
||||||
|
}
|
||||||
|
dir, err := config.MkdirTemp("go-whisper-api-whisper-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
cleanup = func() { os.RemoveAll(dir) }
|
||||||
|
converted := filepath.Join(dir, "converted.wav")
|
||||||
|
if err := AudioToWav(sourcePath, converted); err != nil {
|
||||||
|
cleanup()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
data, err = loadPCM16Mono(converted)
|
||||||
|
if err != nil {
|
||||||
|
cleanup()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return data, cleanup, nil
|
||||||
|
}
|
||||||
154
whisper/format.go
Normal file
154
whisper/format.go
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Turn is a speaker-active time range from diarization (seconds).
|
||||||
|
type Turn struct {
|
||||||
|
Start float32
|
||||||
|
End float32
|
||||||
|
Speaker int
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatOptions controls joining Whisper segments into one string with newlines.
|
||||||
|
type FormatOptions struct {
|
||||||
|
PauseGap time.Duration
|
||||||
|
SpeakerLabel string
|
||||||
|
UseSpeakers bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatSegments joins segment texts with \n on long pauses and optional speaker labels.
|
||||||
|
func FormatSegments(segments []wpkg.Segment, turns []Turn, opts FormatOptions) string {
|
||||||
|
if len(segments) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if opts.PauseGap <= 0 {
|
||||||
|
opts.PauseGap = 1500 * time.Millisecond
|
||||||
|
}
|
||||||
|
label := strings.TrimSpace(opts.SpeakerLabel)
|
||||||
|
if label == "" {
|
||||||
|
label = "Спикер"
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := make([]segmentLine, len(segments))
|
||||||
|
for i, seg := range segments {
|
||||||
|
lines[i] = segmentLine{
|
||||||
|
Text: strings.TrimSpace(seg.Text),
|
||||||
|
Start: seg.Start,
|
||||||
|
End: seg.End,
|
||||||
|
Speaker: -1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if opts.UseSpeakers && len(turns) > 0 {
|
||||||
|
assignSpeakers(lines, turns)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
prevSpeaker := -2
|
||||||
|
prevIdx := -1
|
||||||
|
for i, line := range lines {
|
||||||
|
if line.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if b.Len() > 0 && prevIdx >= 0 {
|
||||||
|
speakerBreak := opts.UseSpeakers && line.Speaker >= 0 && line.Speaker != prevSpeaker
|
||||||
|
pauseBreak := line.Start-lines[prevIdx].End >= opts.PauseGap
|
||||||
|
switch {
|
||||||
|
case speakerBreak:
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1)
|
||||||
|
case pauseBreak:
|
||||||
|
b.WriteString("\n")
|
||||||
|
default:
|
||||||
|
if !strings.HasSuffix(b.String(), " ") && !strings.HasSuffix(b.String(), "\n") {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if opts.UseSpeakers && line.Speaker >= 0 {
|
||||||
|
fmt.Fprintf(&b, "%s %d: ", label, line.Speaker+1)
|
||||||
|
}
|
||||||
|
b.WriteString(line.Text)
|
||||||
|
if line.Speaker >= 0 {
|
||||||
|
prevSpeaker = line.Speaker
|
||||||
|
}
|
||||||
|
prevIdx = i
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
type segmentLine struct {
|
||||||
|
Text string
|
||||||
|
Start time.Duration
|
||||||
|
End time.Duration
|
||||||
|
Speaker int
|
||||||
|
}
|
||||||
|
|
||||||
|
func assignSpeakers(lines []segmentLine, turns []Turn) {
|
||||||
|
for i := range lines {
|
||||||
|
mid := lines[i].Start + (lines[i].End-lines[i].Start)/2
|
||||||
|
lines[i].Speaker = speakerAt(mid, turns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func speakerAt(t time.Duration, turns []Turn) int {
|
||||||
|
sec := float32(t.Seconds())
|
||||||
|
bestSpeaker := -1
|
||||||
|
bestOverlap := float32(0)
|
||||||
|
for _, tr := range turns {
|
||||||
|
if sec >= tr.Start && sec < tr.End {
|
||||||
|
return tr.Speaker
|
||||||
|
}
|
||||||
|
overlap := intervalOverlap(sec, sec, tr.Start, tr.End)
|
||||||
|
if overlap > bestOverlap {
|
||||||
|
bestOverlap = overlap
|
||||||
|
bestSpeaker = tr.Speaker
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestSpeaker
|
||||||
|
}
|
||||||
|
|
||||||
|
func intervalOverlap(a0, a1, b0, b1 float32) float32 {
|
||||||
|
start := max32(a0, b0)
|
||||||
|
end := min32(a1, b1)
|
||||||
|
if end <= start {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return end - start
|
||||||
|
}
|
||||||
|
|
||||||
|
func max32(a, b float32) float32 {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func min32(a, b float32) float32 {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// PunctuateSegments runs punctuation on each segment separately (preserves line breaks).
|
||||||
|
func PunctuateSegments(segments []wpkg.Segment, restore func(text string) (string, error)) ([]wpkg.Segment, error) {
|
||||||
|
out := make([]wpkg.Segment, len(segments))
|
||||||
|
copy(out, segments)
|
||||||
|
for i := range out {
|
||||||
|
t := strings.TrimSpace(out[i].Text)
|
||||||
|
if t == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p, err := restore(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out[i].Text = " " + strings.TrimSpace(p)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
40
whisper/format_test.go
Normal file
40
whisper/format_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wpkg "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatSegments_pauseAndSpeaker(t *testing.T) {
|
||||||
|
segments := []wpkg.Segment{
|
||||||
|
{Text: " привет", Start: 0, End: 2 * time.Second},
|
||||||
|
{Text: " мир", Start: 4 * time.Second, End: 5 * time.Second},
|
||||||
|
{Text: " ответ", Start: 6 * time.Second, End: 8 * time.Second},
|
||||||
|
}
|
||||||
|
turns := []Turn{
|
||||||
|
{Start: 0, End: 5.5, Speaker: 0},
|
||||||
|
{Start: 5.5, End: 10, Speaker: 1},
|
||||||
|
}
|
||||||
|
got := FormatSegments(segments, turns, FormatOptions{
|
||||||
|
PauseGap: 1500 * time.Millisecond,
|
||||||
|
SpeakerLabel: "Спикер",
|
||||||
|
UseSpeakers: true,
|
||||||
|
})
|
||||||
|
want := "Спикер 1: привет\nмир\n\nСпикер 2: ответ"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q\nwant %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSegments_noSpeakers(t *testing.T) {
|
||||||
|
segments := []wpkg.Segment{
|
||||||
|
{Text: "a", Start: 0, End: time.Second},
|
||||||
|
{Text: "b", Start: 3 * time.Second, End: 4 * time.Second},
|
||||||
|
}
|
||||||
|
got := FormatSegments(segments, nil, FormatOptions{PauseGap: time.Second, UseSpeakers: false})
|
||||||
|
if got != "a\nb" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
15
whisper/helper.go
Normal file
15
whisper/helper.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func srtTimestamp(t time.Duration) string {
|
||||||
|
return fmt.Sprintf("%02d:%02d:%02d,%03d",
|
||||||
|
t/time.Hour,
|
||||||
|
(t%time.Hour)/time.Minute,
|
||||||
|
(t%time.Minute)/time.Second,
|
||||||
|
(t%time.Second)/time.Millisecond,
|
||||||
|
)
|
||||||
|
}
|
||||||
39
whisper/helper_test.go
Normal file
39
whisper/helper_test.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package whisper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSrtTimestamp(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
t time.Duration
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "test 1",
|
||||||
|
args: args{
|
||||||
|
t: time.Duration(1*time.Hour + 2*time.Minute + 3*time.Second + 4*time.Millisecond),
|
||||||
|
},
|
||||||
|
want: "01:02:03,004",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test 2",
|
||||||
|
args: args{
|
||||||
|
t: time.Duration(10*time.Hour + 20*time.Minute + 30*time.Second + 40*time.Millisecond),
|
||||||
|
},
|
||||||
|
want: "10:20:30,040",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := srtTimestamp(tt.args.t); got != tt.want {
|
||||||
|
t.Errorf("srtTimestamp() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user