go-whisper-api/api/models.go
admin b5c083e06f
Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
first commit
2026-06-04 18:10:52 +07:00

182 lines
4.4 KiB
Go

package api
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
)
// Subdirectories under models_dir that hold non-Whisper assets (VAD, punctuation, etc.).
var reservedModelSubdirs = map[string]struct{}{
"vad": {},
"punctuation": {},
}
// Top-level .bin files that are not Whisper STT models.
var excludedWhisperModelFiles = map[string]struct{}{
"vad.bin": {},
}
type Registry struct {
dir string
mu sync.RWMutex
}
func NewRegistry(dir string) *Registry {
return &Registry{dir: dir}
}
func isReservedModelSubdir(name string) bool {
_, ok := reservedModelSubdirs[strings.ToLower(name)]
return ok
}
func isWhisperModelFile(name string) bool {
if !strings.HasSuffix(strings.ToLower(name), ".bin") {
return false
}
_, excluded := excludedWhisperModelFiles[strings.ToLower(name)]
return !excluded
}
func (r *Registry) List() ([]string, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entries, err := os.ReadDir(r.dir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var models []string
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !isWhisperModelFile(name) {
continue
}
models = append(models, strings.TrimSuffix(name, filepath.Ext(name)))
}
return models, nil
}
func (r *Registry) Path(id string) (string, error) {
if id == "" {
return "", fmt.Errorf("model id is required")
}
if strings.Contains(id, "/") || strings.Contains(id, "..") {
return "", fmt.Errorf("invalid model id")
}
if isReservedModelSubdir(id) {
return "", fmt.Errorf("model %q not found", id)
}
r.mu.RLock()
defer r.mu.RUnlock()
candidates := []string{
filepath.Join(r.dir, id+".bin"),
filepath.Join(r.dir, id),
filepath.Join(r.dir, "ggml-"+id+".bin"),
}
for _, p := range candidates {
if st, err := os.Stat(p); err == nil && !st.IsDir() && isWhisperModelFile(filepath.Base(p)) {
return p, nil
}
}
return "", fmt.Errorf("model %q not found", id)
}
func (r *Registry) Delete(id string) error {
p, err := r.Path(id)
if err != nil {
return err
}
r.mu.Lock()
defer r.mu.Unlock()
return os.Remove(p)
}
func (r *Registry) Import(id string, src io.Reader) error {
if id == "" {
return fmt.Errorf("model id is required")
}
if strings.Contains(id, "/") || strings.Contains(id, "..") {
return fmt.Errorf("invalid model id")
}
if isReservedModelSubdir(id) {
return fmt.Errorf("invalid model id")
}
r.mu.Lock()
defer r.mu.Unlock()
if err := os.MkdirAll(r.dir, 0o755); err != nil {
return err
}
dst := filepath.Join(r.dir, id+".bin")
f, err := os.Create(dst)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, src); err != nil {
os.Remove(dst)
return err
}
return nil
}
func (r *Registry) Open(id string) (*os.File, error) {
p, err := r.Path(id)
if err != nil {
return nil, err
}
return os.Open(p)
}
// Resolve maps an OpenAI-style model name to a local whisper model id.
func (r *Registry) Resolve(id, defaultModel string) (string, error) {
id = strings.TrimSpace(id)
if id == "" {
id = strings.TrimSpace(defaultModel)
}
if id == "" {
models, err := r.List()
if err != nil {
return "", err
}
if len(models) == 0 {
return "", fmt.Errorf("model is required")
}
return models[0], nil
}
if id == "whisper-1" {
if dm := strings.TrimSpace(defaultModel); dm != "" {
if _, err := r.Path(dm); err == nil {
return dm, nil
}
}
models, err := r.List()
if err != nil {
return "", err
}
if len(models) == 0 {
return "", fmt.Errorf("no whisper models installed")
}
return models[0], nil
}
if _, err := r.Path(id); err == nil {
return id, nil
}
if alt := strings.TrimPrefix(id, "whisper-"); alt != id {
if _, err := r.Path(alt); err == nil {
return alt, nil
}
}
return "", fmt.Errorf("model %q not found", id)
}