package mcp import ( "context" "fmt" "sync" "time" "ai-agent/internal/config" "ai-agent/internal/llm" ) type FailedServer struct { Name string Reason string } type ServerStatus struct { Name string Connected bool LastError string LastPing time.Time } type Registry struct { mu sync.RWMutex clients []*MCPClient toolMap map[string]*MCPClient toolDefs []llm.ToolDef failedServers []FailedServer serverConfigs map[string]config.ServerConfig } func NewRegistry() *Registry { return &Registry{toolMap: make(map[string]*MCPClient), serverConfigs: make(map[string]config.ServerConfig)} } const connectTimeout = 5 * time.Second func (r *Registry) ConnectServer(ctx context.Context, srv config.ServerConfig) (int, error) { connCtx, cancel := context.WithTimeout(ctx, connectTimeout) defer cancel() client, err := Connect(connCtx, srv.Name, srv.Command, srv.Args, srv.Env, srv.Transport, srv.URL) if err != nil { r.mu.Lock() r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()}) r.mu.Unlock() return 0, fmt.Errorf("connect to %s: %w", srv.Name, err) } tools, err := client.ListTools(connCtx) if err != nil { client.Close() r.mu.Lock() r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()}) r.mu.Unlock() return 0, fmt.Errorf("%s tools: %w", srv.Name, err) } r.mu.Lock() r.clients = append(r.clients, client) for _, tool := range tools { r.toolMap[tool.Name] = client r.toolDefs = append(r.toolDefs, ToLLMToolDef(tool.Name, tool.Description, tool.InputSchema)) } r.serverConfigs[srv.Name] = srv r.mu.Unlock() return len(tools), nil } func (r *Registry) ConnectAll(ctx context.Context, servers []config.ServerConfig, logFn func(string)) { for _, srv := range servers { toolCount, err := r.ConnectServer(ctx, srv) if err != nil { logFn(fmt.Sprintf("skip %s: %v", srv.Name, err)) continue } logFn(fmt.Sprintf("connected %s (%d tools)", srv.Name, toolCount)) } } func (r *Registry) Tools() []llm.ToolDef { r.mu.RLock() defer r.mu.RUnlock() return r.toolDefs } func (r *Registry) ToolCount() int { r.mu.RLock() defer r.mu.RUnlock() return len(r.toolDefs) } func (r *Registry) ServerCount() int { r.mu.RLock() defer r.mu.RUnlock() return len(r.clients) } func (r *Registry) ServerNames() []string { r.mu.RLock() defer r.mu.RUnlock() names := make([]string, len(r.clients)) for i, c := range r.clients { names[i] = c.Name() } return names } func (r *Registry) FailedServers() []FailedServer { r.mu.RLock() defer r.mu.RUnlock() return r.failedServers } func (r *Registry) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) { r.mu.RLock() client, ok := r.toolMap[name] r.mu.RUnlock() if !ok { return &ToolResult{ Content: fmt.Sprintf("unknown tool: %s", name), IsError: true, }, nil } return client.CallTool(ctx, name, args) } func (r *Registry) Close() { r.mu.Lock() defer r.mu.Unlock() for _, c := range r.clients { c.Close() } r.clients = nil r.toolMap = make(map[string]*MCPClient) r.toolDefs = nil } func (r *Registry) HealthCheck(ctx context.Context) []ServerStatus { r.mu.RLock() defer r.mu.RUnlock() var results []ServerStatus for _, client := range r.clients { status := ServerStatus{Name: client.Name()} if client.IsConnected() { pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) err := client.Ping(pingCtx) cancel() status.Connected = err == nil if err != nil { status.LastError = err.Error() } status.LastPing = time.Now() } results = append(results, status) } for _, failed := range r.failedServers { results = append(results, ServerStatus{ Name: failed.Name, Connected: false, LastError: failed.Reason, }) } return results } func (r *Registry) ReconnectServer(ctx context.Context, name string) (int, error) { r.mu.RLock() srv, ok := r.serverConfigs[name] r.mu.RUnlock() if !ok { return 0, fmt.Errorf("no config found for server: %s", name) } r.mu.Lock() var remainingFailed []FailedServer for _, f := range r.failedServers { if f.Name != name { remainingFailed = append(remainingFailed, f) } } r.failedServers = remainingFailed r.mu.Unlock() return r.ConnectServer(ctx, srv) } type MonitorConfig struct { Interval time.Duration MaxRetries int BackoffBase time.Duration } var defaultMonitorConfig = MonitorConfig{ Interval: 30 * time.Second, MaxRetries: 3, BackoffBase: 5 * time.Second, } func (r *Registry) StartHealthMonitor(ctx context.Context, cfg MonitorConfig, logFn func(string)) context.CancelFunc { if cfg.Interval == 0 { cfg = defaultMonitorConfig } monitorCtx, cancel := context.WithCancel(ctx) go func() { ticker := time.NewTicker(cfg.Interval) defer ticker.Stop() for { select { case <-monitorCtx.Done(): return case <-ticker.C: r.healthCheckRound(monitorCtx, cfg, logFn) } } }() return cancel } func (r *Registry) healthCheckRound(ctx context.Context, cfg MonitorConfig, logFn func(string)) { statuses := r.HealthCheck(ctx) for _, status := range statuses { if status.Connected { continue } logFn(fmt.Sprintf("server %s unhealthy, attempting reconnect...", status.Name)) for attempt := 1; attempt <= cfg.MaxRetries; attempt++ { backoff := cfg.BackoffBase * time.Duration(attempt) select { case <-ctx.Done(): return case <-time.After(backoff): } _, err := r.ReconnectServer(ctx, status.Name) if err == nil { logFn(fmt.Sprintf("server %s reconnected", status.Name)) break } if attempt == cfg.MaxRetries { logFn(fmt.Sprintf("server %s reconnection failed after %d attempts: %v", status.Name, cfg.MaxRetries, err)) } } } }