ai-agent/internal/mcp/registry.go
admin 8dc496b626
Some checks failed
CI / test (push) Has been cancelled
Release / release (push) Failing after 4m36s
first commit
2026-03-08 15:40:34 +07:00

242 lines
6.5 KiB
Go

package mcp
import (
"context"
"fmt"
"sync"
"time"
"ai-agent/internal/config"
"ai-agent/internal/llm"
)
type FailedServer struct {
Name string
Reason string
}
type ServerStatus struct {
Name string
Connected bool
LastError string
LastPing time.Time
}
type Registry struct {
mu sync.RWMutex
clients []*MCPClient
toolMap map[string]*MCPClient
toolDefs []llm.ToolDef
failedServers []FailedServer
serverConfigs map[string]config.ServerConfig
}
func NewRegistry() *Registry {
return &Registry{toolMap: make(map[string]*MCPClient), serverConfigs: make(map[string]config.ServerConfig)}
}
const connectTimeout = 5 * time.Second
func (r *Registry) ConnectServer(ctx context.Context, srv config.ServerConfig) (int, error) {
connCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
client, err := Connect(connCtx, srv.Name, srv.Command, srv.Args, srv.Env, srv.Transport, srv.URL)
if err != nil {
r.mu.Lock()
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
r.mu.Unlock()
return 0, fmt.Errorf("connect to %s: %w", srv.Name, err)
}
tools, err := client.ListTools(connCtx)
if err != nil {
client.Close()
r.mu.Lock()
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
r.mu.Unlock()
return 0, fmt.Errorf("%s tools: %w", srv.Name, err)
}
r.mu.Lock()
r.clients = append(r.clients, client)
for _, tool := range tools {
r.toolMap[tool.Name] = client
r.toolDefs = append(r.toolDefs, ToLLMToolDef(tool.Name, tool.Description, tool.InputSchema))
}
r.serverConfigs[srv.Name] = srv
r.mu.Unlock()
return len(tools), nil
}
func (r *Registry) ConnectAll(ctx context.Context, servers []config.ServerConfig, logFn func(string)) {
for _, srv := range servers {
toolCount, err := r.ConnectServer(ctx, srv)
if err != nil {
logFn(fmt.Sprintf("skip %s: %v", srv.Name, err))
continue
}
logFn(fmt.Sprintf("connected %s (%d tools)", srv.Name, toolCount))
}
}
func (r *Registry) Tools() []llm.ToolDef {
r.mu.RLock()
defer r.mu.RUnlock()
return r.toolDefs
}
func (r *Registry) ToolCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.toolDefs)
}
func (r *Registry) ServerCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.clients)
}
func (r *Registry) ServerNames() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, len(r.clients))
for i, c := range r.clients {
names[i] = c.Name()
}
return names
}
func (r *Registry) FailedServers() []FailedServer {
r.mu.RLock()
defer r.mu.RUnlock()
return r.failedServers
}
func (r *Registry) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
r.mu.RLock()
client, ok := r.toolMap[name]
r.mu.RUnlock()
if !ok {
return &ToolResult{
Content: fmt.Sprintf("unknown tool: %s", name),
IsError: true,
}, nil
}
return client.CallTool(ctx, name, args)
}
func (r *Registry) Close() {
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.clients {
c.Close()
}
r.clients = nil
r.toolMap = make(map[string]*MCPClient)
r.toolDefs = nil
}
func (r *Registry) HealthCheck(ctx context.Context) []ServerStatus {
r.mu.RLock()
defer r.mu.RUnlock()
var results []ServerStatus
for _, client := range r.clients {
status := ServerStatus{Name: client.Name()}
if client.IsConnected() {
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := client.Ping(pingCtx)
cancel()
status.Connected = err == nil
if err != nil {
status.LastError = err.Error()
}
status.LastPing = time.Now()
}
results = append(results, status)
}
for _, failed := range r.failedServers {
results = append(results, ServerStatus{
Name: failed.Name,
Connected: false,
LastError: failed.Reason,
})
}
return results
}
func (r *Registry) ReconnectServer(ctx context.Context, name string) (int, error) {
r.mu.RLock()
srv, ok := r.serverConfigs[name]
r.mu.RUnlock()
if !ok {
return 0, fmt.Errorf("no config found for server: %s", name)
}
r.mu.Lock()
var remainingFailed []FailedServer
for _, f := range r.failedServers {
if f.Name != name {
remainingFailed = append(remainingFailed, f)
}
}
r.failedServers = remainingFailed
r.mu.Unlock()
return r.ConnectServer(ctx, srv)
}
type MonitorConfig struct {
Interval time.Duration
MaxRetries int
BackoffBase time.Duration
}
var defaultMonitorConfig = MonitorConfig{
Interval: 30 * time.Second,
MaxRetries: 3,
BackoffBase: 5 * time.Second,
}
func (r *Registry) StartHealthMonitor(ctx context.Context, cfg MonitorConfig, logFn func(string)) context.CancelFunc {
if cfg.Interval == 0 {
cfg = defaultMonitorConfig
}
monitorCtx, cancel := context.WithCancel(ctx)
go func() {
ticker := time.NewTicker(cfg.Interval)
defer ticker.Stop()
for {
select {
case <-monitorCtx.Done():
return
case <-ticker.C:
r.healthCheckRound(monitorCtx, cfg, logFn)
}
}
}()
return cancel
}
func (r *Registry) healthCheckRound(ctx context.Context, cfg MonitorConfig, logFn func(string)) {
statuses := r.HealthCheck(ctx)
for _, status := range statuses {
if status.Connected {
continue
}
logFn(fmt.Sprintf("server %s unhealthy, attempting reconnect...", status.Name))
for attempt := 1; attempt <= cfg.MaxRetries; attempt++ {
backoff := cfg.BackoffBase * time.Duration(attempt)
select {
case <-ctx.Done():
return
case <-time.After(backoff):
}
_, err := r.ReconnectServer(ctx, status.Name)
if err == nil {
logFn(fmt.Sprintf("server %s reconnected", status.Name))
break
}
if attempt == cfg.MaxRetries {
logFn(fmt.Sprintf("server %s reconnection failed after %d attempts: %v", status.Name, cfg.MaxRetries, err))
}
}
}
}