242 lines
6.5 KiB
Go
242 lines
6.5 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"ai-agent/internal/config"
|
|
"ai-agent/internal/llm"
|
|
)
|
|
|
|
type FailedServer struct {
|
|
Name string
|
|
Reason string
|
|
}
|
|
|
|
type ServerStatus struct {
|
|
Name string
|
|
Connected bool
|
|
LastError string
|
|
LastPing time.Time
|
|
}
|
|
|
|
type Registry struct {
|
|
mu sync.RWMutex
|
|
clients []*MCPClient
|
|
toolMap map[string]*MCPClient
|
|
toolDefs []llm.ToolDef
|
|
failedServers []FailedServer
|
|
serverConfigs map[string]config.ServerConfig
|
|
}
|
|
|
|
func NewRegistry() *Registry {
|
|
return &Registry{toolMap: make(map[string]*MCPClient), serverConfigs: make(map[string]config.ServerConfig)}
|
|
}
|
|
|
|
const connectTimeout = 5 * time.Second
|
|
|
|
func (r *Registry) ConnectServer(ctx context.Context, srv config.ServerConfig) (int, error) {
|
|
connCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
|
defer cancel()
|
|
client, err := Connect(connCtx, srv.Name, srv.Command, srv.Args, srv.Env, srv.Transport, srv.URL)
|
|
if err != nil {
|
|
r.mu.Lock()
|
|
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
|
|
r.mu.Unlock()
|
|
return 0, fmt.Errorf("connect to %s: %w", srv.Name, err)
|
|
}
|
|
tools, err := client.ListTools(connCtx)
|
|
if err != nil {
|
|
client.Close()
|
|
r.mu.Lock()
|
|
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
|
|
r.mu.Unlock()
|
|
return 0, fmt.Errorf("%s tools: %w", srv.Name, err)
|
|
}
|
|
r.mu.Lock()
|
|
r.clients = append(r.clients, client)
|
|
for _, tool := range tools {
|
|
r.toolMap[tool.Name] = client
|
|
r.toolDefs = append(r.toolDefs, ToLLMToolDef(tool.Name, tool.Description, tool.InputSchema))
|
|
}
|
|
r.serverConfigs[srv.Name] = srv
|
|
r.mu.Unlock()
|
|
return len(tools), nil
|
|
}
|
|
|
|
func (r *Registry) ConnectAll(ctx context.Context, servers []config.ServerConfig, logFn func(string)) {
|
|
for _, srv := range servers {
|
|
toolCount, err := r.ConnectServer(ctx, srv)
|
|
if err != nil {
|
|
logFn(fmt.Sprintf("skip %s: %v", srv.Name, err))
|
|
continue
|
|
}
|
|
logFn(fmt.Sprintf("connected %s (%d tools)", srv.Name, toolCount))
|
|
}
|
|
}
|
|
|
|
func (r *Registry) Tools() []llm.ToolDef {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return r.toolDefs
|
|
}
|
|
|
|
func (r *Registry) ToolCount() int {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return len(r.toolDefs)
|
|
}
|
|
|
|
func (r *Registry) ServerCount() int {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return len(r.clients)
|
|
}
|
|
|
|
func (r *Registry) ServerNames() []string {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
names := make([]string, len(r.clients))
|
|
for i, c := range r.clients {
|
|
names[i] = c.Name()
|
|
}
|
|
return names
|
|
}
|
|
|
|
func (r *Registry) FailedServers() []FailedServer {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return r.failedServers
|
|
}
|
|
|
|
func (r *Registry) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
|
|
r.mu.RLock()
|
|
client, ok := r.toolMap[name]
|
|
r.mu.RUnlock()
|
|
if !ok {
|
|
return &ToolResult{
|
|
Content: fmt.Sprintf("unknown tool: %s", name),
|
|
IsError: true,
|
|
}, nil
|
|
}
|
|
return client.CallTool(ctx, name, args)
|
|
}
|
|
|
|
func (r *Registry) Close() {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
for _, c := range r.clients {
|
|
c.Close()
|
|
}
|
|
r.clients = nil
|
|
r.toolMap = make(map[string]*MCPClient)
|
|
r.toolDefs = nil
|
|
}
|
|
|
|
func (r *Registry) HealthCheck(ctx context.Context) []ServerStatus {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
var results []ServerStatus
|
|
for _, client := range r.clients {
|
|
status := ServerStatus{Name: client.Name()}
|
|
if client.IsConnected() {
|
|
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
err := client.Ping(pingCtx)
|
|
cancel()
|
|
status.Connected = err == nil
|
|
if err != nil {
|
|
status.LastError = err.Error()
|
|
}
|
|
status.LastPing = time.Now()
|
|
}
|
|
results = append(results, status)
|
|
}
|
|
for _, failed := range r.failedServers {
|
|
results = append(results, ServerStatus{
|
|
Name: failed.Name,
|
|
Connected: false,
|
|
LastError: failed.Reason,
|
|
})
|
|
}
|
|
return results
|
|
}
|
|
|
|
func (r *Registry) ReconnectServer(ctx context.Context, name string) (int, error) {
|
|
r.mu.RLock()
|
|
srv, ok := r.serverConfigs[name]
|
|
r.mu.RUnlock()
|
|
if !ok {
|
|
return 0, fmt.Errorf("no config found for server: %s", name)
|
|
}
|
|
r.mu.Lock()
|
|
var remainingFailed []FailedServer
|
|
for _, f := range r.failedServers {
|
|
if f.Name != name {
|
|
remainingFailed = append(remainingFailed, f)
|
|
}
|
|
}
|
|
r.failedServers = remainingFailed
|
|
r.mu.Unlock()
|
|
return r.ConnectServer(ctx, srv)
|
|
}
|
|
|
|
type MonitorConfig struct {
|
|
Interval time.Duration
|
|
MaxRetries int
|
|
BackoffBase time.Duration
|
|
}
|
|
|
|
var defaultMonitorConfig = MonitorConfig{
|
|
Interval: 30 * time.Second,
|
|
MaxRetries: 3,
|
|
BackoffBase: 5 * time.Second,
|
|
}
|
|
|
|
func (r *Registry) StartHealthMonitor(ctx context.Context, cfg MonitorConfig, logFn func(string)) context.CancelFunc {
|
|
if cfg.Interval == 0 {
|
|
cfg = defaultMonitorConfig
|
|
}
|
|
monitorCtx, cancel := context.WithCancel(ctx)
|
|
go func() {
|
|
ticker := time.NewTicker(cfg.Interval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-monitorCtx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
r.healthCheckRound(monitorCtx, cfg, logFn)
|
|
}
|
|
}
|
|
}()
|
|
return cancel
|
|
}
|
|
|
|
func (r *Registry) healthCheckRound(ctx context.Context, cfg MonitorConfig, logFn func(string)) {
|
|
statuses := r.HealthCheck(ctx)
|
|
for _, status := range statuses {
|
|
if status.Connected {
|
|
continue
|
|
}
|
|
logFn(fmt.Sprintf("server %s unhealthy, attempting reconnect...", status.Name))
|
|
for attempt := 1; attempt <= cfg.MaxRetries; attempt++ {
|
|
backoff := cfg.BackoffBase * time.Duration(attempt)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(backoff):
|
|
}
|
|
_, err := r.ReconnectServer(ctx, status.Name)
|
|
if err == nil {
|
|
logFn(fmt.Sprintf("server %s reconnected", status.Name))
|
|
break
|
|
}
|
|
if attempt == cfg.MaxRetries {
|
|
logFn(fmt.Sprintf("server %s reconnection failed after %d attempts: %v", status.Name, cfg.MaxRetries, err))
|
|
}
|
|
}
|
|
}
|
|
}
|