111 lines
3.0 KiB
Go
111 lines
3.0 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os/exec"
|
|
|
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
|
)
|
|
|
|
type MCPClient struct {
|
|
name string
|
|
client *sdkmcp.Client
|
|
session *sdkmcp.ClientSession
|
|
cmd *exec.Cmd
|
|
}
|
|
|
|
func Connect(ctx context.Context, name, command string, args []string, env []string, transport, url string) (*MCPClient, error) {
|
|
client := sdkmcp.NewClient(
|
|
&sdkmcp.Implementation{Name: "ai-agent", Version: "0.2.0"},
|
|
nil,
|
|
)
|
|
var t sdkmcp.Transport
|
|
switch transport {
|
|
case "sse":
|
|
if url == "" {
|
|
return nil, fmt.Errorf("sse transport requires url for %s", name)
|
|
}
|
|
t = &sdkmcp.SSEClientTransport{Endpoint: url}
|
|
case "streamable-http":
|
|
if url == "" {
|
|
return nil, fmt.Errorf("streamable-http transport requires url for %s", name)
|
|
}
|
|
t = &sdkmcp.StreamableClientTransport{Endpoint: url}
|
|
default:
|
|
if command == "" {
|
|
return nil, fmt.Errorf("stdio transport requires command for %s", name)
|
|
}
|
|
cmd := exec.Command(command, args...)
|
|
if len(env) > 0 {
|
|
cmd.Env = append(cmd.Environ(), env...)
|
|
}
|
|
t = &sdkmcp.CommandTransport{Command: cmd}
|
|
}
|
|
session, err := client.Connect(ctx, t, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("connect to %s: %w", name, err)
|
|
}
|
|
return &MCPClient{
|
|
name: name,
|
|
client: client,
|
|
session: session,
|
|
}, nil
|
|
}
|
|
|
|
func (c *MCPClient) Name() string { return c.name }
|
|
|
|
func (c *MCPClient) ListTools(ctx context.Context) ([]*sdkmcp.Tool, error) {
|
|
caps := c.session.InitializeResult()
|
|
if caps == nil || caps.Capabilities.Tools == nil {
|
|
return nil, nil
|
|
}
|
|
var tools []*sdkmcp.Tool
|
|
for tool, err := range c.session.Tools(ctx, nil) {
|
|
if err != nil {
|
|
return tools, fmt.Errorf("list tools from %s: %w", c.name, err)
|
|
}
|
|
tools = append(tools, tool)
|
|
}
|
|
return tools, nil
|
|
}
|
|
|
|
func (c *MCPClient) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
|
|
result, err := c.session.CallTool(ctx, &sdkmcp.CallToolParams{
|
|
Name: name,
|
|
Arguments: args,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("call tool %s on %s: %w", name, c.name, err)
|
|
}
|
|
var text string
|
|
for _, ct := range result.Content {
|
|
if tc, ok := ct.(*sdkmcp.TextContent); ok {
|
|
if text != "" {
|
|
text += "\n"
|
|
}
|
|
text += tc.Text
|
|
}
|
|
}
|
|
return &ToolResult{Content: text, IsError: result.IsError}, nil
|
|
}
|
|
|
|
func (c *MCPClient) Close() error {
|
|
if c.session != nil {
|
|
return c.session.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *MCPClient) IsConnected() bool {
|
|
return c.session != nil
|
|
}
|
|
|
|
func (c *MCPClient) Ping(ctx context.Context) error {
|
|
if c.session == nil {
|
|
return fmt.Errorf("no session")
|
|
}
|
|
_, err := c.ListTools(ctx)
|
|
return err
|
|
}
|