admin 8dc496b626
Some checks failed
CI / test (push) Has been cancelled
Release / release (push) Failing after 4m36s
first commit
2026-03-08 15:40:34 +07:00

689 lines
21 KiB
Go

package agent
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"ai-agent/internal/llm"
"ai-agent/internal/tools"
)
const (
maxTimeout = 120 * time.Second
)
func (a *Agent) toolsBuiltinToolDefs() []llm.ToolDef {
return tools.AllToolDefs()
}
func (a *Agent) isToolsTool(name string) bool {
return tools.IsBuiltinTool(name)
}
func (a *Agent) handleToolsTool(tc llm.ToolCall) (string, bool) {
switch tc.Name {
case "grep":
return a.handleGrep(tc.Arguments)
case "read":
return a.handleRead(tc.Arguments)
case "write":
return a.handleWrite(tc.Arguments)
case "glob":
return a.handleGlob(tc.Arguments)
case "bash":
return a.handleBash(tc.Arguments)
case "ls":
return a.handleLs(tc.Arguments)
case "find":
return a.handleFind(tc.Arguments)
case "diff":
return a.handleDiff(tc.Arguments)
case "edit":
return a.handleEdit(tc.Arguments)
case "mkdir":
return a.handleMkdir(tc.Arguments)
case "remove":
return a.handleRemove(tc.Arguments)
case "copy":
return a.handleCopy(tc.Arguments)
case "move":
return a.handleMove(tc.Arguments)
case "exists":
return a.handleExists(tc.Arguments)
default:
return fmt.Sprintf("unknown tool: %s", tc.Name), true
}
}
func (a *Agent) handleGrep(args map[string]any) (string, bool) {
pattern, _ := args["pattern"].(string)
if pattern == "" {
return "error: pattern is required", true
}
path := a.getArgString(args, "path", a.workDir)
include := a.getArgString(args, "include", "")
context := a.getArgInt(args, "context", 3)
maxResults := a.MaxGrepResults()
if _, err := os.Stat(path); err != nil {
return fmt.Sprintf("error: path does not exist: %s", path), true
}
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Sprintf("error: invalid regex pattern: %v", err), true
}
var results []string
err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
if shouldSkipDir(info.Name()) {
return filepath.SkipDir
}
return nil
}
if include != "" {
matched, err := filepath.Match(include, info.Name())
if err != nil || !matched {
return nil
}
}
if strings.HasPrefix(info.Name(), ".") {
return nil
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil
}
lines := strings.Split(string(content), "\n")
for i, line := range lines {
if re.MatchString(line) {
relPath, _ := filepath.Rel(path, filePath)
ctxStart := i - context
if ctxStart < 0 {
ctxStart = 0
}
ctxEnd := i + context + 1
if ctxEnd > len(lines) {
ctxEnd = len(lines)
}
results = append(results, fmt.Sprintf("%s:%d: %s", relPath, i+1, line))
if context > 0 && ctxStart < i {
for j := ctxStart; j < i; j++ {
if len(results) < maxResults {
results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j]))
}
}
}
if context > 0 && i+1 < ctxEnd {
for j := i + 1; j < ctxEnd; j++ {
if len(results) < maxResults {
results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j]))
}
}
}
if len(results) >= maxResults {
results = append(results, fmt.Sprintf("\n... (truncated, max %d results)", maxResults))
return filepath.SkipAll
}
}
}
return nil
})
if err != nil {
return fmt.Sprintf("error walking directory: %v", err), true
}
if len(results) == 0 {
return fmt.Sprintf("No matches found for pattern: %s", pattern), false
}
return strings.Join(results, "\n"), false
}
func (a *Agent) handleRead(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
data, err := os.ReadFile(path)
if err != nil {
return fmt.Sprintf("error reading file: %v", err), true
}
lines := strings.Split(string(data), "\n")
offset := a.getArgInt(args, "offset", 1)
limit := a.getArgInt(args, "limit", 0)
if offset > len(lines) {
return "error: offset beyond file length", true
}
if offset > 1 {
lines = lines[offset-1:]
}
if limit > 0 && len(lines) > limit {
lines = lines[:limit]
content := strings.Join(lines, "\n")
content += fmt.Sprintf("\n\n... (%d more lines)", len(lines)-limit)
return content, false
}
return strings.Join(lines, "\n"), false
}
func (a *Agent) handleWrite(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
content, _ := args["content"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Sprintf("error creating directory: %v", err), true
}
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
return fmt.Sprintf("error writing file: %v", err), true
}
return fmt.Sprintf("Written to %s (%d bytes)", path, len(content)), false
}
func (a *Agent) handleGlob(args map[string]any) (string, bool) {
pattern, _ := args["pattern"].(string)
if pattern == "" {
return "error: pattern is required", true
}
path := a.getArgString(args, "path", a.workDir)
if _, err := os.Stat(path); err != nil {
return fmt.Sprintf("error: path does not exist: %s", path), true
}
basePattern := filepath.Join(path, pattern)
matches, err := filepath.Glob(basePattern)
if err != nil {
return fmt.Sprintf("error: invalid pattern: %v", err), true
}
if len(matches) == 0 {
return fmt.Sprintf("No files match pattern: %s", pattern), false
}
relMatches := make([]string, 0, len(matches))
for _, m := range matches {
rel, err := filepath.Rel(path, m)
if err != nil {
continue
}
relMatches = append(relMatches, rel)
}
return strings.Join(relMatches, "\n"), false
}
func (a *Agent) handleBash(args map[string]any) (string, bool) {
command, _ := args["command"].(string)
if command == "" {
return "error: command is required", true
}
timeout := a.getArgInt(args, "timeout", int(a.ToolTimeout().Seconds()))
maxTimeoutSecs := int(a.ToolTimeout().Seconds())
if maxTimeoutSecs > 120 {
maxTimeoutSecs = 120
}
if timeout > maxTimeoutSecs {
timeout = maxTimeoutSecs
}
if timeout < 1 {
timeout = 1
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", command)
cmd.Dir = a.workDir
cmd.Env = os.Environ()
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
output := stdout.String()
if stderr.Len() > 0 {
if output != "" {
output += "\n"
}
output += "STDERR:\n" + stderr.String()
}
if ctx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("error: command timed out after %d seconds", timeout), true
}
if err != nil {
if output == "" {
return fmt.Sprintf("error: %v", err), true
}
return fmt.Sprintf("Command exited with error:\n%s", output), true
}
if output == "" {
return "Command completed successfully (no output)", false
}
return output, false
}
func (a *Agent) handleLs(args map[string]any) (string, bool) {
path := a.getArgString(args, "path", a.workDir)
path = a.resolvePath(path)
entries, err := os.ReadDir(path)
if err != nil {
return fmt.Sprintf("error reading directory: %v", err), true
}
if len(entries) == 0 {
return "Directory is empty", false
}
var dirs []string
var files []string
for _, e := range entries {
name := e.Name()
if e.IsDir() {
dirs = append(dirs, name+"/")
} else {
files = append(files, name)
}
}
var result strings.Builder
for _, d := range dirs {
result.WriteString(d + "\n")
}
for _, f := range files {
result.WriteString(f + "\n")
}
return result.String(), false
}
func (a *Agent) handleFind(args map[string]any) (string, bool) {
name, _ := args["name"].(string)
if name == "" {
return "error: name is required", true
}
path := a.getArgString(args, "path", a.workDir)
fileType := a.getArgString(args, "type", "")
if _, err := os.Stat(path); err != nil {
return fmt.Sprintf("error: path does not exist: %s", path), true
}
re, err := regexp.Compile("^" + strings.ReplaceAll(name, "*", ".*") + "$")
if err != nil {
return fmt.Sprintf("error: invalid name pattern: %v", err), true
}
var results []string
err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if shouldSkipDir(info.Name()) && filePath != path {
if info.IsDir() {
return filepath.SkipDir
}
return nil
}
isDir := info.IsDir()
if fileType == "f" && isDir {
return nil
}
if fileType == "d" && !isDir {
return nil
}
if re.MatchString(info.Name()) {
relPath, _ := filepath.Rel(path, filePath)
if relPath != "." {
if isDir {
results = append(results, relPath+"/")
} else {
results = append(results, relPath)
}
}
}
return nil
})
if err != nil {
return fmt.Sprintf("error walking directory: %v", err), true
}
if len(results) == 0 {
return fmt.Sprintf("No files/directories found matching: %s", name), false
}
return strings.Join(results, "\n"), false
}
func (a *Agent) getArgString(args map[string]any, key, defaultValue string) string {
if v, ok := args[key].(string); ok && v != "" {
return v
}
return defaultValue
}
func (a *Agent) getArgInt(args map[string]any, key string, defaultValue int) int {
if v, ok := args[key]; ok {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
case string:
if n == "" {
return defaultValue
}
if i, err := strconv.Atoi(n); err == nil {
return i
}
}
}
return defaultValue
}
func (a *Agent) resolvePath(path string) string {
if filepath.IsAbs(path) {
return path
}
return filepath.Join(a.workDir, path)
}
func shouldSkipDir(name string) bool {
switch name {
case "node_modules", ".git", "__pycache__", ".venv", "venv",
"dist", "build", "target", ".cache", ".npm",
".svn", "CVS", ".hg", ".bzr":
return true
}
return strings.HasPrefix(name, ".")
}
func (a *Agent) handleDiff(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
newContent, _ := args["new_content"].(string)
if path == "" {
return "error: path is required", true
}
if newContent == "" {
return "error: new_content is required", true
}
path = a.resolvePath(path)
oldContent, err := os.ReadFile(path)
if err != nil {
return fmt.Sprintf("error reading file: %v", err), true
}
oldLines := strings.Split(string(oldContent), "\n")
newLines := strings.Split(newContent, "\n")
diff := computeDiff(oldLines, newLines)
if diff == "" {
return "No changes (files are identical)", false
}
return diff, false
}
func computeDiff(oldLines, newLines []string) string {
var result strings.Builder
oldLen := len(oldLines)
newLen := len(newLines)
lcs := longestCommonSubsequence(oldLines, newLines)
oldIdx := 0
newIdx := 0
lcsIdx := 0
for oldIdx < oldLen || newIdx < newLen {
if lcsIdx < len(lcs) {
for oldIdx < oldLen && oldLines[oldIdx] != lcs[lcsIdx] {
result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx]))
oldIdx++
}
for newIdx < newLen && newLines[newIdx] != lcs[lcsIdx] {
result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx]))
newIdx++
}
if oldIdx < oldLen && newIdx < newLen {
result.WriteString(fmt.Sprintf(" %s\n", lcs[lcsIdx]))
oldIdx++
newIdx++
lcsIdx++
}
} else {
for oldIdx < oldLen {
result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx]))
oldIdx++
}
for newIdx < newLen {
result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx]))
newIdx++
}
}
}
return result.String()
}
func longestCommonSubsequence(a, b []string) []string {
m, n := len(a), len(b)
dp := make([][]int, m+1)
for i := range dp {
dp[i] = make([]int, n+1)
}
for i := 1; i <= m; i++ {
for j := 1; j <= n; j++ {
if a[i-1] == b[j-1] {
dp[i][j] = dp[i-1][j-1] + 1
} else {
if dp[i-1][j] > dp[i][j-1] {
dp[i][j] = dp[i-1][j]
} else {
dp[i][j] = dp[i][j-1]
}
}
}
}
var lcs []string
i, j := m, n
for i > 0 && j > 0 {
if a[i-1] == b[j-1] {
lcs = append([]string{a[i-1]}, lcs...)
i--
j--
} else if dp[i-1][j] > dp[i][j-1] {
i--
} else {
j--
}
}
return lcs
}
func (a *Agent) handleEdit(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
patch, _ := args["patch"].(string)
if path == "" {
return "error: path is required", true
}
if patch == "" {
return "error: patch is required", true
}
path = a.resolvePath(path)
oldContent, err := os.ReadFile(path)
if err != nil {
return fmt.Sprintf("error reading file: %v", err), true
}
newContent, err := applyPatch(string(oldContent), patch)
if err != nil {
return fmt.Sprintf("error applying patch: %v", err), true
}
if err := os.WriteFile(path, []byte(newContent), 0644); err != nil {
return fmt.Sprintf("error writing file: %v", err), true
}
return fmt.Sprintf("Applied patch to %s (%d bytes)", path, len(newContent)), false
}
func (a *Agent) handleMkdir(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Sprintf("error creating directory: %v", err), true
}
return fmt.Sprintf("Created directory: %s", path), false
}
func (a *Agent) handleRemove(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
recursive := a.getArgBool(args, "recursive", false)
force := a.getArgBool(args, "force", false)
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
if force {
return "Removed (ignored nonexistent)", false
}
return fmt.Sprintf("error: path does not exist: %s", path), true
}
return fmt.Sprintf("error: %v", err), true
}
if info.IsDir() {
if recursive {
err = os.RemoveAll(path)
} else {
err = os.Remove(path)
}
} else {
err = os.Remove(path)
}
if err != nil {
return fmt.Sprintf("error removing: %v", err), true
}
return fmt.Sprintf("Removed: %s", path), false
}
func (a *Agent) handleCopy(args map[string]any) (string, bool) {
source, _ := args["source"].(string)
destination, _ := args["destination"].(string)
if source == "" || destination == "" {
return "error: source and destination are required", true
}
source = a.resolvePath(source)
destination = a.resolvePath(destination)
info, err := os.Stat(source)
if err != nil {
return fmt.Sprintf("error: %v", err), true
}
if info.IsDir() {
return "error: copying directories not supported (use bash with cp -r)", true
}
srcData, err := os.ReadFile(source)
if err != nil {
return fmt.Sprintf("error reading source: %v", err), true
}
dir := filepath.Dir(destination)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Sprintf("error creating destination directory: %v", err), true
}
err = os.WriteFile(destination, srcData, info.Mode())
if err != nil {
return fmt.Sprintf("error writing destination: %v", err), true
}
return fmt.Sprintf("Copied: %s -> %s", source, destination), false
}
func (a *Agent) handleMove(args map[string]any) (string, bool) {
source, _ := args["source"].(string)
destination, _ := args["destination"].(string)
if source == "" || destination == "" {
return "error: source and destination are required", true
}
source = a.resolvePath(source)
destination = a.resolvePath(destination)
dir := filepath.Dir(destination)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Sprintf("error creating destination directory: %v", err), true
}
err := os.Rename(source, destination)
if err != nil {
return fmt.Sprintf("error moving: %v", err), true
}
return fmt.Sprintf("Moved: %s -> %s", source, destination), false
}
func (a *Agent) handleExists(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
info, err := os.Stat(path)
if os.IsNotExist(err) {
return fmt.Sprintf("false: %s does not exist", path), false
}
if err != nil {
return fmt.Sprintf("error: %v", err), true
}
if info.IsDir() {
return fmt.Sprintf("true: %s (directory)", path), false
}
return fmt.Sprintf("true: %s (file, %d bytes)", path, info.Size()), false
}
func (a *Agent) getArgBool(args map[string]any, key string, defaultValue bool) bool {
if v, ok := args[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return defaultValue
}
func applyPatch(content, patch string) (string, error) {
lines := strings.Split(content, "\n")
patchLines := strings.Split(patch, "\n")
var result []string
i := 0
for i < len(patchLines) {
line := patchLines[i]
if strings.HasPrefix(line, "@@") {
parts := strings.Fields(line)
if len(parts) < 4 {
return "", fmt.Errorf("invalid hunk header: %s", line)
}
oldSpec := strings.TrimPrefix(parts[1], "-")
oldParts := strings.Split(oldSpec, ",")
oldStart, _ := strconv.Atoi(oldParts[0])
newSpec := strings.TrimPrefix(parts[2], "+")
newParts := strings.Split(newSpec, ",")
newStart, _ := strconv.Atoi(newParts[0])
oldIdx := oldStart - 1
newIdx := newStart - 1
i++
for i < len(patchLines) && !strings.HasPrefix(patchLines[i], "@@") {
patchLine := patchLines[i]
if strings.HasPrefix(patchLine, "-") {
if oldIdx < len(lines) {
_ = lines[oldIdx]
oldIdx++
}
} else if strings.HasPrefix(patchLine, "+") {
content := strings.TrimPrefix(patchLine, "+")
result = append(result, content)
newIdx++
} else if strings.HasPrefix(patchLine, " ") || patchLine == "" {
if oldIdx < len(lines) {
result = append(result, lines[oldIdx])
oldIdx++
}
} else {
result = append(result, patchLine)
}
i++
}
continue
}
i++
}
if len(result) == 0 {
return content, nil
}
return strings.Join(result, "\n"), nil
}