299 lines
11 KiB
Go
299 lines
11 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"ai-agent/internal/llm"
|
|
permissionPkg "ai-agent/internal/permission"
|
|
)
|
|
|
|
func (a *Agent) Run(ctx context.Context, out Output) {
|
|
var tools []llm.ToolDef
|
|
if a.toolsEnabled {
|
|
tools = a.registry.Tools()
|
|
if a.memoryStore != nil {
|
|
tools = append(tools, a.memoryBuiltinToolDefs()...)
|
|
}
|
|
tools = append(tools, a.toolsBuiltinToolDefs()...)
|
|
}
|
|
var iceContext string
|
|
a.mu.RLock()
|
|
hasMessages := len(a.messages) > 0
|
|
var lastMsg llm.Message
|
|
if hasMessages {
|
|
lastMsg = a.messages[len(a.messages)-1]
|
|
}
|
|
a.mu.RUnlock()
|
|
if a.iceEngine != nil && hasMessages {
|
|
if lastMsg.Role == "user" {
|
|
if err := a.iceEngine.IndexMessage(ctx, "user", lastMsg.Content); err != nil {
|
|
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
|
|
}
|
|
if assembled, err := a.iceEngine.AssembleContext(ctx, lastMsg.Content); err == nil {
|
|
iceContext = assembled
|
|
}
|
|
}
|
|
}
|
|
system := buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
|
|
const maxRetries = 2
|
|
var lastPromptTokens int
|
|
var retryCount int
|
|
maxIters := a.MaxIterations()
|
|
for i := 0; i < maxIters; i++ {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
var textBuf strings.Builder
|
|
var toolCalls []llm.ToolCall
|
|
err := a.llmClient.ChatStream(ctx, llm.ChatOptions{
|
|
Messages: a.messages,
|
|
Tools: tools,
|
|
System: system,
|
|
}, func(chunk llm.StreamChunk) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
if chunk.Text != "" {
|
|
textBuf.WriteString(chunk.Text)
|
|
out.StreamText(chunk.Text)
|
|
}
|
|
if len(chunk.ToolCalls) > 0 {
|
|
toolCalls = append(toolCalls, chunk.ToolCalls...)
|
|
}
|
|
if chunk.Done {
|
|
lastPromptTokens = chunk.PromptEvalCount
|
|
out.StreamDone(chunk.EvalCount, chunk.PromptEvalCount)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
if retryCount < maxRetries && isRetryableError(err) {
|
|
retryCount++
|
|
if isConnectionError(err) {
|
|
out.Error(fmt.Sprintf("Connection lost, retrying (%d/%d) in 2s...", retryCount, maxRetries))
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(2 * time.Second):
|
|
}
|
|
} else {
|
|
out.Error(fmt.Sprintf("LLM produced malformed output, retrying (%d/%d)...", retryCount, maxRetries))
|
|
}
|
|
textBuf.Reset()
|
|
toolCalls = nil
|
|
continue
|
|
}
|
|
out.Error(fmt.Sprintf("LLM error: %v", err))
|
|
out.SystemMessage(fmt.Sprintf("⚠️ Model response failed: %v\n\nYou can try:\n- Checking if Ollama is running (`ollama ps`)\n- Switching to a different model (F6)\n- Reducing context size (num_ctx in config)\n\nTool results are still available above.", err))
|
|
return
|
|
}
|
|
retryCount = 0
|
|
assistantMsg := llm.Message{
|
|
Role: "assistant",
|
|
Content: textBuf.String(),
|
|
ToolCalls: toolCalls,
|
|
}
|
|
a.AppendMessage(assistantMsg)
|
|
if a.iceEngine != nil && assistantMsg.Content != "" {
|
|
if err := a.iceEngine.IndexMessage(ctx, "assistant", assistantMsg.Content); err != nil {
|
|
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
|
|
}
|
|
}
|
|
if len(toolCalls) == 0 {
|
|
a.mu.RLock()
|
|
hasEnoughMessages := len(a.messages) >= 2
|
|
var userContent string
|
|
if hasEnoughMessages {
|
|
for idx := len(a.messages) - 2; idx >= 0; idx-- {
|
|
if a.messages[idx].Role == "user" {
|
|
userContent = a.messages[idx].Content
|
|
break
|
|
}
|
|
}
|
|
}
|
|
a.mu.RUnlock()
|
|
if a.iceEngine != nil && hasEnoughMessages && userContent != "" {
|
|
a.iceEngine.DetectAutoMemory(ctx, userContent, assistantMsg.Content)
|
|
}
|
|
return
|
|
}
|
|
type pendingTool struct {
|
|
tc llm.ToolCall
|
|
isMemoryTool bool
|
|
isMCPTool bool
|
|
}
|
|
var pending []pendingTool
|
|
for _, tc := range toolCalls {
|
|
if a.memoryStore != nil && a.isMemoryTool(tc.Name) {
|
|
pending = append(pending, pendingTool{tc: tc, isMemoryTool: true})
|
|
continue
|
|
}
|
|
if a.isToolsTool(tc.Name) {
|
|
out.ToolCallStart(tc.Name, tc.Arguments)
|
|
startTime := time.Now()
|
|
result, isErr := a.handleToolsTool(tc)
|
|
duration := time.Since(startTime)
|
|
out.ToolCallResult(tc.Name, result, isErr, duration)
|
|
a.AppendMessage(llm.Message{
|
|
Role: "tool",
|
|
Content: result,
|
|
ToolName: tc.Name,
|
|
ToolCallID: tc.ID,
|
|
})
|
|
continue
|
|
}
|
|
if a.permChecker != nil {
|
|
switch a.permChecker.ToCheckResult(tc.Name) {
|
|
case permissionPkg.CheckDeny:
|
|
errMsg := "tool call blocked by permission policy"
|
|
out.ToolCallStart(tc.Name, tc.Arguments)
|
|
out.ToolCallResult(tc.Name, errMsg, true, 0)
|
|
a.AppendMessage(llm.Message{
|
|
Role: "tool",
|
|
Content: errMsg,
|
|
ToolName: tc.Name,
|
|
ToolCallID: tc.ID,
|
|
})
|
|
continue
|
|
case permissionPkg.CheckAsk:
|
|
if a.approvalCallback != nil {
|
|
allowed, always := permissionPkg.RequestApproval(tc.Name, tc.Arguments, a.approvalCallback)
|
|
if always {
|
|
a.permChecker.SetPolicy(tc.Name, permissionPkg.PolicyAllow)
|
|
}
|
|
if !allowed {
|
|
errMsg := "tool call denied by user"
|
|
out.ToolCallStart(tc.Name, tc.Arguments)
|
|
out.ToolCallResult(tc.Name, errMsg, true, 0)
|
|
a.AppendMessage(llm.Message{
|
|
Role: "tool",
|
|
Content: errMsg,
|
|
ToolName: tc.Name,
|
|
ToolCallID: tc.ID,
|
|
})
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
pending = append(pending, pendingTool{tc: tc, isMCPTool: true})
|
|
}
|
|
if len(pending) > 0 {
|
|
var wg sync.WaitGroup
|
|
mu := sync.Mutex{}
|
|
results := make([]llm.Message, len(pending))
|
|
for i, p := range pending {
|
|
wg.Add(1)
|
|
go func(idx int, tool pendingTool) {
|
|
defer wg.Done()
|
|
tc := tool.tc
|
|
out.ToolCallStart(tc.Name, tc.Arguments)
|
|
startTime := time.Now()
|
|
var result string
|
|
var isErr bool
|
|
if tool.isMemoryTool {
|
|
result, isErr = a.handleMemoryTool(tc)
|
|
} else if tool.isMCPTool {
|
|
toolResult, err := a.registry.CallTool(ctx, tc.Name, tc.Arguments)
|
|
if err != nil {
|
|
result = fmt.Sprintf("ERROR: Tool '%s' failed: %v\nThis tool call failed but you can still complete the task with other available information.", tc.Name, err)
|
|
isErr = true
|
|
} else {
|
|
result = toolResult.Content
|
|
isErr = toolResult.IsError
|
|
}
|
|
}
|
|
duration := time.Since(startTime)
|
|
out.ToolCallResult(tc.Name, result, isErr, duration)
|
|
mu.Lock()
|
|
results[idx] = llm.Message{
|
|
Role: "tool",
|
|
Content: result,
|
|
ToolName: tc.Name,
|
|
ToolCallID: tc.ID,
|
|
}
|
|
mu.Unlock()
|
|
}(i, p)
|
|
}
|
|
wg.Wait()
|
|
for _, msg := range results {
|
|
if msg.ToolName != "" {
|
|
a.AppendMessage(msg)
|
|
}
|
|
}
|
|
}
|
|
if a.shouldCompact(lastPromptTokens) {
|
|
if a.compact(ctx, out) {
|
|
system = buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
|
|
}
|
|
}
|
|
if i == maxIters-2 {
|
|
out.Error(fmt.Sprintf("approaching iteration limit (%d/%d)", i+2, maxIters))
|
|
}
|
|
}
|
|
out.Error(fmt.Sprintf("reached max iterations (%d)", maxIters))
|
|
}
|
|
|
|
func isRetryableError(err error) bool {
|
|
msg := err.Error()
|
|
if strings.Contains(msg, "parse JSON") || strings.Contains(msg, "unexpected end of JSON") {
|
|
return true
|
|
}
|
|
return isConnectionError(err)
|
|
}
|
|
|
|
// isConnectionError reports transient connection failures (EOF, reset, refused) that may succeed on retry.
|
|
func isConnectionError(err error) bool {
|
|
msg := err.Error()
|
|
return strings.Contains(msg, "EOF") ||
|
|
strings.Contains(msg, "connection reset") ||
|
|
strings.Contains(msg, "connection refused") ||
|
|
strings.Contains(msg, "broken pipe")
|
|
}
|
|
|
|
func FormatToolArgs(args map[string]any) string {
|
|
if len(args) == 0 {
|
|
return ""
|
|
}
|
|
var parts []string
|
|
for key, value := range args {
|
|
var valStr string
|
|
switch v := value.(type) {
|
|
case string:
|
|
if len(v) > 47 {
|
|
valStr = `"` + v[:44] + `..."`
|
|
} else {
|
|
valStr = `"` + v + `"`
|
|
}
|
|
case int, float64, bool:
|
|
valStr = fmt.Sprintf("%v", v)
|
|
case []any:
|
|
valStr = fmt.Sprintf("[%d items]", len(v))
|
|
case map[string]any:
|
|
valStr = fmt.Sprintf("{%d fields}", len(v))
|
|
default:
|
|
valStr = fmt.Sprintf("%v", v)
|
|
}
|
|
parts = append(parts, fmt.Sprintf("%s=%s", key, valStr))
|
|
}
|
|
sort.Strings(parts)
|
|
result := strings.Join(parts, " ")
|
|
if len(result) > 60 {
|
|
return result[:57] + "..."
|
|
}
|
|
return result
|
|
}
|