ai-agent/internal/llm/manager.go
admin 8dc496b626
Some checks failed
CI / test (push) Has been cancelled
Release / release (push) Failing after 4m36s
first commit
2026-03-08 15:40:34 +07:00

164 lines
3.6 KiB
Go

package llm
import (
"context"
"fmt"
"sync"
)
type ModelManager struct {
baseURL string
numCtx int
clients map[string]*OllamaClient
currentModel string
mu sync.RWMutex
}
var _ Client = (*ModelManager)(nil)
func NewModelManager(baseURL string, numCtx int) *ModelManager {
return &ModelManager{
baseURL: baseURL,
numCtx: numCtx,
clients: make(map[string]*OllamaClient),
}
}
func (m *ModelManager) GetClient(modelName string) (*OllamaClient, error) {
m.mu.RLock()
client, exists := m.clients[modelName]
m.mu.RUnlock()
if exists {
return client, nil
}
m.mu.Lock()
defer m.mu.Unlock()
if client, exists := m.clients[modelName]; exists {
return client, nil
}
client, err := NewOllamaClient(m.baseURL, modelName, m.numCtx)
if err != nil {
return nil, fmt.Errorf("create client for %s: %w", modelName, err)
}
m.clients[modelName] = client
return client, nil
}
func (m *ModelManager) SetCurrentModel(model string) error {
m.mu.Lock()
defer m.mu.Unlock()
client, err := NewOllamaClient(m.baseURL, model, m.numCtx)
if err != nil {
return fmt.Errorf("create client for %s: %w", model, err)
}
m.clients[model] = client
m.currentModel = model
return nil
}
func (m *ModelManager) CurrentModel() string {
m.mu.RLock()
defer m.mu.RUnlock()
return m.currentModel
}
func (m *ModelManager) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error {
m.mu.RLock()
model := m.currentModel
m.mu.RUnlock()
if model == "" {
return fmt.Errorf("no model selected")
}
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.ChatStream(ctx, opts, fn)
}
func (m *ModelManager) ChatStreamForModel(ctx context.Context, model string, opts ChatOptions, fn func(StreamChunk) error) error {
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.ChatStream(ctx, opts, fn)
}
func (m *ModelManager) Ping() error {
m.mu.RLock()
model := m.currentModel
m.mu.RUnlock()
if model == "" {
return fmt.Errorf("no model selected")
}
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.Ping()
}
func (m *ModelManager) PingModel(model string) error {
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.Ping()
}
func (m *ModelManager) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) {
client, err := m.GetClient(model)
if err != nil {
return nil, err
}
return client.Embed(ctx, model, texts)
}
func (m *ModelManager) EmbedWithCurrentModel(ctx context.Context, texts []string) ([][]float32, error) {
m.mu.RLock()
model := m.currentModel
m.mu.RUnlock()
if model == "" {
return nil, fmt.Errorf("no model selected")
}
return m.Embed(ctx, model, texts)
}
func (m *ModelManager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
for range m.clients {
}
m.clients = make(map[string]*OllamaClient)
}
func (m *ModelManager) BaseURL() string {
return m.baseURL
}
func (m *ModelManager) NumCtx() int {
return m.numCtx
}
func (m *ModelManager) Model() string {
return m.CurrentModel()
}
// ListModels returns model names available in Ollama at the manager's base URL.
func (m *ModelManager) ListModels(ctx context.Context) ([]string, error) {
return ListModels(ctx, m.baseURL)
}