Some checks failed
CodeQL / Analyze (go) (push) Successful in 6m28s
Docker Image / build-docker (push) Failing after 13m26s
Lint and Testing / lint (push) Successful in 11m17s
Lint and Testing / test (push) Successful in 11m17s
Lint and Testing / golangci (push) Successful in 2m40s
182 lines
4.4 KiB
Go
182 lines
4.4 KiB
Go
package api
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// Subdirectories under models_dir that hold non-Whisper assets (VAD, punctuation, etc.).
|
|
var reservedModelSubdirs = map[string]struct{}{
|
|
"vad": {},
|
|
"punctuation": {},
|
|
}
|
|
|
|
// Top-level .bin files that are not Whisper STT models.
|
|
var excludedWhisperModelFiles = map[string]struct{}{
|
|
"vad.bin": {},
|
|
}
|
|
|
|
type Registry struct {
|
|
dir string
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewRegistry(dir string) *Registry {
|
|
return &Registry{dir: dir}
|
|
}
|
|
|
|
func isReservedModelSubdir(name string) bool {
|
|
_, ok := reservedModelSubdirs[strings.ToLower(name)]
|
|
return ok
|
|
}
|
|
|
|
func isWhisperModelFile(name string) bool {
|
|
if !strings.HasSuffix(strings.ToLower(name), ".bin") {
|
|
return false
|
|
}
|
|
_, excluded := excludedWhisperModelFiles[strings.ToLower(name)]
|
|
return !excluded
|
|
}
|
|
|
|
func (r *Registry) List() ([]string, error) {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
entries, err := os.ReadDir(r.dir)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
var models []string
|
|
for _, e := range entries {
|
|
if e.IsDir() {
|
|
continue
|
|
}
|
|
name := e.Name()
|
|
if !isWhisperModelFile(name) {
|
|
continue
|
|
}
|
|
models = append(models, strings.TrimSuffix(name, filepath.Ext(name)))
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func (r *Registry) Path(id string) (string, error) {
|
|
if id == "" {
|
|
return "", fmt.Errorf("model id is required")
|
|
}
|
|
if strings.Contains(id, "/") || strings.Contains(id, "..") {
|
|
return "", fmt.Errorf("invalid model id")
|
|
}
|
|
if isReservedModelSubdir(id) {
|
|
return "", fmt.Errorf("model %q not found", id)
|
|
}
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
candidates := []string{
|
|
filepath.Join(r.dir, id+".bin"),
|
|
filepath.Join(r.dir, id),
|
|
filepath.Join(r.dir, "ggml-"+id+".bin"),
|
|
}
|
|
for _, p := range candidates {
|
|
if st, err := os.Stat(p); err == nil && !st.IsDir() && isWhisperModelFile(filepath.Base(p)) {
|
|
return p, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("model %q not found", id)
|
|
}
|
|
|
|
func (r *Registry) Delete(id string) error {
|
|
p, err := r.Path(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
return os.Remove(p)
|
|
}
|
|
|
|
func (r *Registry) Import(id string, src io.Reader) error {
|
|
if id == "" {
|
|
return fmt.Errorf("model id is required")
|
|
}
|
|
if strings.Contains(id, "/") || strings.Contains(id, "..") {
|
|
return fmt.Errorf("invalid model id")
|
|
}
|
|
if isReservedModelSubdir(id) {
|
|
return fmt.Errorf("invalid model id")
|
|
}
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if err := os.MkdirAll(r.dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
dst := filepath.Join(r.dir, id+".bin")
|
|
f, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
if _, err := io.Copy(f, src); err != nil {
|
|
os.Remove(dst)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *Registry) Open(id string) (*os.File, error) {
|
|
p, err := r.Path(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return os.Open(p)
|
|
}
|
|
|
|
// Resolve maps an OpenAI-style model name to a local whisper model id.
|
|
func (r *Registry) Resolve(id, defaultModel string) (string, error) {
|
|
id = strings.TrimSpace(id)
|
|
if id == "" {
|
|
id = strings.TrimSpace(defaultModel)
|
|
}
|
|
if id == "" {
|
|
models, err := r.List()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(models) == 0 {
|
|
return "", fmt.Errorf("model is required")
|
|
}
|
|
return models[0], nil
|
|
}
|
|
if id == "whisper-1" {
|
|
if dm := strings.TrimSpace(defaultModel); dm != "" {
|
|
if _, err := r.Path(dm); err == nil {
|
|
return dm, nil
|
|
}
|
|
}
|
|
models, err := r.List()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(models) == 0 {
|
|
return "", fmt.Errorf("no whisper models installed")
|
|
}
|
|
return models[0], nil
|
|
}
|
|
if _, err := r.Path(id); err == nil {
|
|
return id, nil
|
|
}
|
|
if alt := strings.TrimPrefix(id, "whisper-"); alt != id {
|
|
if _, err := r.Path(alt); err == nil {
|
|
return alt, nil
|
|
}
|
|
}
|
|
return "", fmt.Errorf("model %q not found", id)
|
|
}
|