admin 8dc496b626
Some checks failed
CI / test (push) Has been cancelled
Release / release (push) Failing after 4m36s
first commit
2026-03-08 15:40:34 +07:00

278 lines
10 KiB
Go

package agent
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"ai-agent/internal/llm"
permissionPkg "ai-agent/internal/permission"
)
func (a *Agent) Run(ctx context.Context, out Output) {
var tools []llm.ToolDef
if a.toolsEnabled {
tools = a.registry.Tools()
if a.memoryStore != nil {
tools = append(tools, a.memoryBuiltinToolDefs()...)
}
tools = append(tools, a.toolsBuiltinToolDefs()...)
}
var iceContext string
a.mu.RLock()
hasMessages := len(a.messages) > 0
var lastMsg llm.Message
if hasMessages {
lastMsg = a.messages[len(a.messages)-1]
}
a.mu.RUnlock()
if a.iceEngine != nil && hasMessages {
if lastMsg.Role == "user" {
if err := a.iceEngine.IndexMessage(ctx, "user", lastMsg.Content); err != nil {
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
}
if assembled, err := a.iceEngine.AssembleContext(ctx, lastMsg.Content); err == nil {
iceContext = assembled
}
}
}
system := buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
const maxRetries = 2
var lastPromptTokens int
var retryCount int
maxIters := a.MaxIterations()
for i := 0; i < maxIters; i++ {
select {
case <-ctx.Done():
return
default:
}
var textBuf strings.Builder
var toolCalls []llm.ToolCall
err := a.llmClient.ChatStream(ctx, llm.ChatOptions{
Messages: a.messages,
Tools: tools,
System: system,
}, func(chunk llm.StreamChunk) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if chunk.Text != "" {
textBuf.WriteString(chunk.Text)
out.StreamText(chunk.Text)
}
if len(chunk.ToolCalls) > 0 {
toolCalls = append(toolCalls, chunk.ToolCalls...)
}
if chunk.Done {
lastPromptTokens = chunk.PromptEvalCount
out.StreamDone(chunk.EvalCount, chunk.PromptEvalCount)
}
return nil
})
if err != nil {
if ctx.Err() != nil {
return
}
if retryCount < maxRetries && isRetryableError(err) {
retryCount++
out.Error(fmt.Sprintf("LLM produced malformed output, retrying (%d/%d)...", retryCount, maxRetries))
textBuf.Reset()
toolCalls = nil
continue
}
out.Error(fmt.Sprintf("LLM error: %v", err))
out.SystemMessage(fmt.Sprintf("⚠️ Model response failed: %v\n\nYou can try:\n- Checking if Ollama is running (`ollama ps`)\n- Switching to a different model (ctrl+m)\n- Reducing context size\n\nTool results are still available above.", err))
return
}
retryCount = 0
assistantMsg := llm.Message{
Role: "assistant",
Content: textBuf.String(),
ToolCalls: toolCalls,
}
a.AppendMessage(assistantMsg)
if a.iceEngine != nil && assistantMsg.Content != "" {
if err := a.iceEngine.IndexMessage(ctx, "assistant", assistantMsg.Content); err != nil {
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
}
}
if len(toolCalls) == 0 {
a.mu.RLock()
hasEnoughMessages := len(a.messages) >= 2
var userContent string
if hasEnoughMessages {
for idx := len(a.messages) - 2; idx >= 0; idx-- {
if a.messages[idx].Role == "user" {
userContent = a.messages[idx].Content
break
}
}
}
a.mu.RUnlock()
if a.iceEngine != nil && hasEnoughMessages && userContent != "" {
a.iceEngine.DetectAutoMemory(ctx, userContent, assistantMsg.Content)
}
return
}
type pendingTool struct {
tc llm.ToolCall
isMemoryTool bool
isMCPTool bool
}
var pending []pendingTool
for _, tc := range toolCalls {
if a.memoryStore != nil && a.isMemoryTool(tc.Name) {
pending = append(pending, pendingTool{tc: tc, isMemoryTool: true})
continue
}
if a.isToolsTool(tc.Name) {
out.ToolCallStart(tc.Name, tc.Arguments)
startTime := time.Now()
result, isErr := a.handleToolsTool(tc)
duration := time.Since(startTime)
out.ToolCallResult(tc.Name, result, isErr, duration)
a.AppendMessage(llm.Message{
Role: "tool",
Content: result,
ToolName: tc.Name,
ToolCallID: tc.ID,
})
continue
}
if a.permChecker != nil {
switch a.permChecker.ToCheckResult(tc.Name) {
case permissionPkg.CheckDeny:
errMsg := "tool call blocked by permission policy"
out.ToolCallStart(tc.Name, tc.Arguments)
out.ToolCallResult(tc.Name, errMsg, true, 0)
a.AppendMessage(llm.Message{
Role: "tool",
Content: errMsg,
ToolName: tc.Name,
ToolCallID: tc.ID,
})
continue
case permissionPkg.CheckAsk:
if a.approvalCallback != nil {
allowed, always := permissionPkg.RequestApproval(tc.Name, tc.Arguments, a.approvalCallback)
if always {
a.permChecker.SetPolicy(tc.Name, permissionPkg.PolicyAllow)
}
if !allowed {
errMsg := "tool call denied by user"
out.ToolCallStart(tc.Name, tc.Arguments)
out.ToolCallResult(tc.Name, errMsg, true, 0)
a.AppendMessage(llm.Message{
Role: "tool",
Content: errMsg,
ToolName: tc.Name,
ToolCallID: tc.ID,
})
continue
}
}
}
}
pending = append(pending, pendingTool{tc: tc, isMCPTool: true})
}
if len(pending) > 0 {
var wg sync.WaitGroup
mu := sync.Mutex{}
results := make([]llm.Message, len(pending))
for i, p := range pending {
wg.Add(1)
go func(idx int, tool pendingTool) {
defer wg.Done()
tc := tool.tc
out.ToolCallStart(tc.Name, tc.Arguments)
startTime := time.Now()
var result string
var isErr bool
if tool.isMemoryTool {
result, isErr = a.handleMemoryTool(tc)
} else if tool.isMCPTool {
toolResult, err := a.registry.CallTool(ctx, tc.Name, tc.Arguments)
if err != nil {
result = fmt.Sprintf("ERROR: Tool '%s' failed: %v\nThis tool call failed but you can still complete the task with other available information.", tc.Name, err)
isErr = true
} else {
result = toolResult.Content
isErr = toolResult.IsError
}
}
duration := time.Since(startTime)
out.ToolCallResult(tc.Name, result, isErr, duration)
mu.Lock()
results[idx] = llm.Message{
Role: "tool",
Content: result,
ToolName: tc.Name,
ToolCallID: tc.ID,
}
mu.Unlock()
}(i, p)
}
wg.Wait()
for _, msg := range results {
if msg.ToolName != "" {
a.AppendMessage(msg)
}
}
}
if a.shouldCompact(lastPromptTokens) {
if a.compact(ctx, out) {
system = buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
}
}
if i == maxIters-2 {
out.Error(fmt.Sprintf("approaching iteration limit (%d/%d)", i+2, maxIters))
}
}
out.Error(fmt.Sprintf("reached max iterations (%d)", maxIters))
}
func isRetryableError(err error) bool {
msg := err.Error()
return strings.Contains(msg, "parse JSON") || strings.Contains(msg, "unexpected end of JSON")
}
func FormatToolArgs(args map[string]any) string {
if len(args) == 0 {
return ""
}
var parts []string
for key, value := range args {
var valStr string
switch v := value.(type) {
case string:
if len(v) > 47 {
valStr = `"` + v[:44] + `..."`
} else {
valStr = `"` + v + `"`
}
case int, float64, bool:
valStr = fmt.Sprintf("%v", v)
case []any:
valStr = fmt.Sprintf("[%d items]", len(v))
case map[string]any:
valStr = fmt.Sprintf("{%d fields}", len(v))
default:
valStr = fmt.Sprintf("%v", v)
}
parts = append(parts, fmt.Sprintf("%s=%s", key, valStr))
}
sort.Strings(parts)
result := strings.Join(parts, " ")
if len(result) > 60 {
return result[:57] + "..."
}
return result
}