first commit
This commit is contained in:
commit
8dc496b626
31
.gitea/workflows/ci.yml
Normal file
31
.gitea/workflows/ci.yml
Normal file
@ -0,0 +1,31 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.25"
|
||||
cache: true
|
||||
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Run tests
|
||||
run: go test ./... -count=1 -race
|
||||
37
.gitea/workflows/release.yml
Normal file
37
.gitea/workflows/release.yml
Normal file
@ -0,0 +1,37 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.25"
|
||||
cache: true
|
||||
|
||||
- name: Run tests
|
||||
run: go test ./... -count=1
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
distribution: goreleaser
|
||||
version: latest
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GIT_TOKEN }}
|
||||
HOMEBREW_TAP_TOKEN: ${{ secrets.HOMEBREW_TAP_TOKEN }}
|
||||
32
.gitignore
vendored
Normal file
32
.gitignore
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
# Binaries
|
||||
bin/
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
*.test
|
||||
*.out
|
||||
ai-agent
|
||||
|
||||
# Vector embeddings
|
||||
.vecgrep/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
71
.goreleaser.yaml
Normal file
71
.goreleaser.yaml
Normal file
@ -0,0 +1,71 @@
|
||||
version: 2
|
||||
|
||||
project_name: ai-agent
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy
|
||||
|
||||
builds:
|
||||
- id: ai-agent
|
||||
main: ./cmd/ai-agent
|
||||
binary: ai-agent
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
ldflags:
|
||||
- -s -w -X main.version={{.Version}}
|
||||
|
||||
archives:
|
||||
- id: default
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
formats: [zip]
|
||||
|
||||
checksum:
|
||||
name_template: "checksums.txt"
|
||||
|
||||
snapshot:
|
||||
version_template: "{{ incpatch .Version }}-next"
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
- "^ci:"
|
||||
- "^chore:"
|
||||
- Merge pull request
|
||||
- Merge branch
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: abdul-hamid-achik
|
||||
name: ai-agent
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "v{{.Version}}"
|
||||
|
||||
brews:
|
||||
- name: ai-agent
|
||||
homepage: https://github.com/abdul-hamid-achik/ai-agent
|
||||
description: "Local AI agent with TUI, powered by Ollama and MCP servers"
|
||||
license: MIT
|
||||
directory: Formula
|
||||
repository:
|
||||
owner: abdul-hamid-achik
|
||||
name: homebrew-tap
|
||||
token: "{{ .Env.HOMEBREW_TAP_TOKEN }}"
|
||||
install: |
|
||||
bin.install "ai-agent"
|
||||
test: |
|
||||
system "#{bin}/ai-agent", "--version"
|
||||
skip_upload: "{{ if .Env.HOMEBREW_TAP_TOKEN }}false{{ else }}true{{ end }}"
|
||||
375
README.md
Normal file
375
README.md
Normal file
@ -0,0 +1,375 @@
|
||||
# ai-agent
|
||||
|
||||
A fully local AI coding agent for the terminal -- powered by Ollama and small models, with intelligent routing, cross-session memory, and MCP tool integration.
|
||||
|
||||
```
|
||||
╭──────────────────────────────────────────╮
|
||||
│ ai-agent │
|
||||
│ 100% local. Your data never leaves. │
|
||||
│ │
|
||||
│ ASK -- PLAN -- BUILD │
|
||||
│ 0.8B 4B 9B │
|
||||
╰──────────────────────────────────────────╯
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What is ai-agent?
|
||||
|
||||
- **100% local** -- runs entirely on your machine via Ollama. No API keys, no cloud, no data leaving your device.
|
||||
- **Small model optimized** -- intelligent routing across Qwen 3.5 variants (0.8B / 2B / 4B / 9B) based on task complexity.
|
||||
- **Three operational modes** -- ASK for quick answers, PLAN for design and reasoning, BUILD for full execution with tools.
|
||||
- **MCP native** -- first-class Model Context Protocol support (STDIO, SSE, Streamable HTTP) for extensible tool integration.
|
||||
- **Beautiful TUI** -- built with Charm's BubbleTea v2, Lip Gloss v2, and Glamour for rich markdown rendering in the terminal.
|
||||
- **Infinite Context Engine (ICE)** -- cross-session vector retrieval that surfaces relevant past conversations automatically.
|
||||
- **Auto-Memory Detection** -- the LLM extracts facts, decisions, preferences, and TODOs from conversations and persists them.
|
||||
- **Thinking/CoT extraction** -- chain-of-thought reasoning is captured and displayed in collapsible blocks.
|
||||
- **Skills system** -- load `.md` skill files with YAML frontmatter to inject domain-specific instructions into the system prompt.
|
||||
- **Agent profiles** -- configure per-project agents with custom system prompts, skills, and MCP servers.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [Go 1.25+](https://go.dev/dl/)
|
||||
- [Ollama](https://ollama.ai/) running locally
|
||||
- [Task](https://taskfile.dev/) (optional, for build commands)
|
||||
|
||||
### Install
|
||||
|
||||
Pull the required model, then install:
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.5:2b
|
||||
|
||||
go install github.com/abdul-hamid-achik/ai-agent/cmd/ai-agent@latest
|
||||
```
|
||||
|
||||
For the full model routing suite (optional):
|
||||
|
||||
```bash
|
||||
ollama pull qwen3.5:0.8b
|
||||
ollama pull qwen3.5:4b
|
||||
ollama pull qwen3.5:9b
|
||||
ollama pull nomic-embed-text # for ICE vector embeddings
|
||||
```
|
||||
|
||||
### Configure
|
||||
|
||||
Create a config file (optional -- defaults work out of the box):
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.config/ai-agent
|
||||
cp config.example.yaml ~/.config/ai-agent/config.yaml
|
||||
```
|
||||
|
||||
### Run
|
||||
|
||||
```bash
|
||||
ai-agent
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
task dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
### Model Routing
|
||||
|
||||
ai-agent automatically selects the right model size for the task at hand. Simple questions go to the fast 2B model; complex multi-step reasoning escalates to the 9B model. The router analyzes query complexity using keyword heuristics and word count.
|
||||
|
||||
| Complexity | Model | Speed | Use Cases |
|
||||
|------------|---------------|--------|----------------------------------------------|
|
||||
| Simple | qwen3.5:2b | 2.5x | Quick answers, simple tool use, single edits |
|
||||
| Medium | qwen3.5:4b | 1.5x | Code completion, refactoring, explanations |
|
||||
| Complex | qwen3.5:9b | 1.0x | Multi-step reasoning, debugging, code review |
|
||||
|
||||
The fallback chain ensures graceful degradation if a model is not available: `2b -> 4b -> 9b`.
|
||||
|
||||
### Three Modes: ASK / PLAN / BUILD
|
||||
|
||||
Cycle between modes with `shift+tab`. Each mode configures a different system prompt and preferred model tier.
|
||||
|
||||
- **ASK** -- Direct, concise answers. Routes to the fastest available model. Tools available for file reads and searches.
|
||||
- **PLAN** -- Design and planning. Breaks tasks into steps. Reads and explores with tools but does not modify files.
|
||||
- **BUILD** -- Full execution mode. Uses the most capable model. All tools enabled including writes and modifications.
|
||||
|
||||
### MCP Tool Integration
|
||||
|
||||
Connect any MCP-compatible tool server. Supports all three transport protocols:
|
||||
|
||||
- **STDIO** -- Launch tools as subprocesses (default).
|
||||
- **SSE** -- Connect to Server-Sent Events endpoints.
|
||||
- **Streamable HTTP** -- Connect to HTTP-based MCP servers.
|
||||
|
||||
Tool calls execute in parallel when possible. The registry handles graceful failure if a server becomes unavailable.
|
||||
|
||||
### Infinite Context Engine (ICE)
|
||||
|
||||
ICE embeds each conversation turn using `nomic-embed-text` and stores them persistently. On every new message, it retrieves the most relevant past conversations via cosine similarity and injects them into the system prompt -- giving the agent memory that spans across sessions.
|
||||
|
||||
### Auto-Memory Detection
|
||||
|
||||
After each conversation turn, a background process analyzes the exchange and extracts structured memories:
|
||||
|
||||
- **FACT** -- objective information the user shared
|
||||
- **DECISION** -- choices made during the conversation
|
||||
- **PREFERENCE** -- user preferences and working styles
|
||||
- **TODO** -- action items and follow-ups
|
||||
|
||||
Memories are stored in `~/.config/ai-agent/memories.json` with tag-weighted search scoring (tags weighted 3x over content).
|
||||
|
||||
### Thinking/CoT Display
|
||||
|
||||
When the model produces chain-of-thought reasoning, ai-agent captures it and renders it in collapsible blocks. Toggle the display with `ctrl+t`.
|
||||
|
||||
### Skills System
|
||||
|
||||
Drop `.md` files with YAML frontmatter into the skills directory to inject domain-specific instructions:
|
||||
|
||||
```
|
||||
~/.config/ai-agent/skills/
|
||||
```
|
||||
|
||||
Manage active skills with `/skill list`, `/skill activate <name>`, and `/skill deactivate <name>`.
|
||||
|
||||
### Agent Profiles
|
||||
|
||||
Create per-project or per-domain agent profiles:
|
||||
|
||||
```
|
||||
~/.agents/<name>/
|
||||
AGENT.md # System prompt additions
|
||||
SKILL.md # Agent-specific skills
|
||||
mcp.yaml # Agent-specific MCP servers
|
||||
```
|
||||
|
||||
Switch profiles with `/agent <name>` or `/agent list`.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### File Locations
|
||||
|
||||
Config is searched in order (first match wins):
|
||||
|
||||
1. `./ai-agent.yaml` (project-local)
|
||||
2. `~/.config/ai-agent/config.yaml` (user-global)
|
||||
|
||||
### Annotated Example
|
||||
|
||||
```yaml
|
||||
ollama:
|
||||
model: "qwen3.5:2b" # Default model
|
||||
base_url: "http://localhost:11434" # Ollama API endpoint
|
||||
num_ctx: 262144 # Context window size
|
||||
|
||||
# Skills directory (default: ~/.config/ai-agent/skills/)
|
||||
# skills_dir: "/path/to/custom/skills"
|
||||
|
||||
# MCP tool servers
|
||||
servers:
|
||||
# STDIO transport (default)
|
||||
- name: noted
|
||||
command: noted
|
||||
args: [mcp]
|
||||
|
||||
# SSE transport
|
||||
# - name: remote-server
|
||||
# transport: sse
|
||||
# url: "http://localhost:8811"
|
||||
|
||||
# Streamable HTTP transport
|
||||
# - name: streamable-server
|
||||
# transport: streamable-http
|
||||
# url: "http://localhost:8812/mcp"
|
||||
|
||||
# ICE configuration
|
||||
# ice:
|
||||
# enabled: true
|
||||
# embed_model: "nomic-embed-text"
|
||||
# store_path: "~/.config/ai-agent/conversations.json"
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Overrides |
|
||||
|--------------------------|------------------------------|----------------------|
|
||||
| `OLLAMA_HOST` | Ollama API base URL | `ollama.base_url` |
|
||||
| `LOCAL_AGENT_MODEL` | Default model name | `ollama.model` |
|
||||
| `LOCAL_AGENT_AGENTS_DIR` | Path to agents directory | `agents.dir` |
|
||||
|
||||
---
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
### Input
|
||||
|
||||
| Key | Action |
|
||||
|-----------------|-------------------------------|
|
||||
| `enter` | Send message |
|
||||
| `shift+enter` | Insert new line |
|
||||
| `up` / `down` | Browse input history |
|
||||
| `shift+tab` | Cycle mode (ASK/PLAN/BUILD) |
|
||||
| `ctrl+m` | Quick model switch |
|
||||
|
||||
### Navigation
|
||||
|
||||
| Key | Action |
|
||||
|------------------|------------------------------|
|
||||
| `pgup` / `pgdown`| Scroll conversation |
|
||||
| `ctrl+u` | Half-page scroll up |
|
||||
| `ctrl+d` | Half-page scroll down |
|
||||
|
||||
### Display
|
||||
|
||||
| Key | Action |
|
||||
|-----------------|-------------------------------|
|
||||
| `?` | Toggle help overlay |
|
||||
| `t` | Expand/collapse tool calls |
|
||||
| `space` | Toggle last tool details |
|
||||
| `ctrl+t` | Toggle thinking/CoT display |
|
||||
| `ctrl+y` | Copy last response |
|
||||
|
||||
### Control
|
||||
|
||||
| Key | Action |
|
||||
|-----------------|-------------------------------|
|
||||
| `esc` | Cancel streaming / close overlay |
|
||||
| `ctrl+c` | Quit |
|
||||
| `ctrl+l` | Clear screen |
|
||||
| `ctrl+n` | New conversation |
|
||||
|
||||
---
|
||||
|
||||
## Slash Commands
|
||||
|
||||
| Command | Description |
|
||||
|--------------------------------------|-----------------------------------|
|
||||
| `/help` | Show help overlay |
|
||||
| `/clear` | Clear conversation history |
|
||||
| `/new` | Start a fresh conversation |
|
||||
| `/model [name\|list\|fast\|smart]` | Show or switch models |
|
||||
| `/models` | Open model picker |
|
||||
| `/agent [name\|list]` | Show or switch agent profile |
|
||||
| `/load <path>` | Load markdown file as context |
|
||||
| `/unload` | Remove loaded context |
|
||||
| `/skill [list\|activate\|deactivate] [name]` | Manage skills |
|
||||
| `/servers` | List connected MCP servers |
|
||||
| `/ice` | Show ICE engine status |
|
||||
| `/sessions` | Browse saved sessions |
|
||||
| `/exit` | Quit |
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
cmd/ai-agent/ Entry point
|
||||
internal/
|
||||
agent/ ReAct loop orchestration
|
||||
llm/ LLM abstraction (OllamaClient, ModelManager)
|
||||
mcp/ MCP server registry
|
||||
config/ YAML config, env overrides, Router
|
||||
ice/ Infinite Context Engine
|
||||
memory/ Persistent key-value store
|
||||
skill/ Skill file loader
|
||||
command/ Slash command registry
|
||||
tui/ BubbleTea v2 terminal UI
|
||||
logging/ Structured logging
|
||||
```
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
User Input
|
||||
|
|
||||
v
|
||||
agent.AddUserMessage()
|
||||
|
|
||||
v
|
||||
ICE embeds message, retrieves relevant past context
|
||||
|
|
||||
v
|
||||
System prompt assembled (tools + skills + context + ICE + memory)
|
||||
|
|
||||
v
|
||||
Router selects model based on task complexity
|
||||
|
|
||||
v
|
||||
LLM streams response via ChatStream()
|
||||
|
|
||||
v
|
||||
Tool calls routed through MCP registry (parallel execution)
|
||||
|
|
||||
v
|
||||
ReAct loop continues (up to 10 iterations) until final text
|
||||
|
|
||||
v
|
||||
Conversation compacted if token budget exceeded
|
||||
Auto-memory detection runs in background
|
||||
```
|
||||
|
||||
### Key Interfaces
|
||||
|
||||
- `llm.Client` -- pluggable LLM provider (`ChatStream`, `Ping`, `Embed`)
|
||||
- `agent.Output` -- streaming callbacks for TUI rendering
|
||||
- `command.Registry` -- extensible slash command dispatch
|
||||
|
||||
### Concurrency
|
||||
|
||||
`sync.RWMutex` protects shared state in `ModelManager`, `mcp.Registry`, and `memory.Store`. Auto-memory detection and MCP connections run as background goroutines. Tool calls execute in parallel when independent.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | ai-agent | opencode | crush |
|
||||
|----------------------------------|:-----------:|:--------:|:-----:|
|
||||
| 100% local (no API keys) | Yes | No | Yes |
|
||||
| Model routing by task complexity | Yes | No | No |
|
||||
| Operational modes (ASK/PLAN/BUILD)| Yes | No | No |
|
||||
| Cross-session memory (ICE) | Yes | No | No |
|
||||
| Auto-memory detection | Yes | No | No |
|
||||
| Thinking/CoT extraction | Yes | Yes | No |
|
||||
| MCP tool support | Yes | Yes | Yes |
|
||||
| Skills system | Yes | No | No |
|
||||
| Plan form overlay | Yes | No | No |
|
||||
| Small model optimized | Yes | No | No |
|
||||
| TUI chat interface | Yes | Yes | Yes |
|
||||
| Language | Go | TypeScript| Go |
|
||||
|
||||
---
|
||||
|
||||
## Building
|
||||
|
||||
This project uses [Task](https://taskfile.dev/) as its build tool.
|
||||
|
||||
```bash
|
||||
task build # Compile to bin/ai-agent
|
||||
task run # Build and run
|
||||
task dev # Quick run via go run ./cmd/ai-agent
|
||||
task test # Run all tests: go test ./...
|
||||
task lint # Run golangci-lint run ./...
|
||||
task clean # Remove bin/ directory
|
||||
```
|
||||
|
||||
Run a single test:
|
||||
|
||||
```bash
|
||||
go test ./internal/agent/ -run TestFunctionName
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
33
Taskfile.yml
Normal file
33
Taskfile.yml
Normal file
@ -0,0 +1,33 @@
|
||||
version: '3'
|
||||
|
||||
tasks:
|
||||
build:
|
||||
desc: Build the binary
|
||||
cmds:
|
||||
- go build -o bin/ai-agent ./cmd/ai-agent
|
||||
|
||||
run:
|
||||
desc: Build and run
|
||||
deps: [build]
|
||||
cmds:
|
||||
- ./bin/ai-agent
|
||||
|
||||
dev:
|
||||
desc: Run with go run
|
||||
cmds:
|
||||
- go run ./cmd/ai-agent
|
||||
|
||||
lint:
|
||||
desc: Run linter
|
||||
cmds:
|
||||
- golangci-lint run ./...
|
||||
|
||||
test:
|
||||
desc: Run tests
|
||||
cmds:
|
||||
- go test ./...
|
||||
|
||||
clean:
|
||||
desc: Clean build artifacts
|
||||
cmds:
|
||||
- rm -rf bin/
|
||||
70
config.example.yaml
Normal file
70
config.example.yaml
Normal file
@ -0,0 +1,70 @@
|
||||
# ai-agent configuration
|
||||
# Place at ~/.config/ai-agent/config.yaml or ./ai-agent.yaml
|
||||
|
||||
ollama:
|
||||
model: "qwen3.5:2b"
|
||||
base_url: "http://localhost:11434"
|
||||
# num_ctx: контекст в токенах. Большие значения (например 262144) требуют много RAM/VRAM;
|
||||
# при ошибке "requires more system memory" уменьшите до 32768 или 8192.
|
||||
num_ctx: 262144
|
||||
|
||||
# Model routing suite (optional - for automatic model tier selection)
|
||||
# models:
|
||||
# - name: "qwen3.5:0.8b"
|
||||
# size: "0.8B"
|
||||
# capability: "simple"
|
||||
# - name: "qwen3.5:2b"
|
||||
# size: "2B"
|
||||
# capability: "medium"
|
||||
# - name: "qwen3.5:4b"
|
||||
# size: "4B"
|
||||
# capability: "complex"
|
||||
# - name: "qwen3.5:9b"
|
||||
# size: "9B"
|
||||
# capability: "advanced"
|
||||
|
||||
# Embedding model for ICE (Infinite Context Engine)
|
||||
# embed_model: "nomic-embed-text"
|
||||
|
||||
# UI Configuration (optional)
|
||||
# ui:
|
||||
# # Theme: "dark", "light", or "auto" (default: auto - detects system theme)
|
||||
# # The Nord theme will be applied automatically:
|
||||
# # - Dark mode: Nord Polar Night (dark blues) + Frost (light text) + Aurora (colorful accents)
|
||||
# # - Light mode: Nord Aurora Light (white background + same colorful accents)
|
||||
# theme: "auto"
|
||||
# # Syntax highlighting theme for code blocks (via Chroma)
|
||||
# # Options: monokai, github, dracula, one-dark, solarized-dark, etc.
|
||||
# # Default: monokai (dark), github (light)
|
||||
# code_theme: "monokai"
|
||||
# # Show line numbers in code blocks (default: false)
|
||||
# code_line_numbers: false
|
||||
# # Compact mode threshold (terminal width < this value triggers compact mode)
|
||||
# compact_threshold: 80
|
||||
|
||||
# MCP tool servers
|
||||
servers:
|
||||
# STDIO transport (default)
|
||||
- name: noted
|
||||
command: noted
|
||||
args: [mcp]
|
||||
|
||||
# tinyvault — requires `tvault init` before first use
|
||||
# - name: tinyvault
|
||||
# command: tvault
|
||||
# args: [mcp-server]
|
||||
|
||||
# Docker MCP Gateway (requires Docker Desktop with MCP enabled)
|
||||
# - name: docker-gateway
|
||||
# command: docker
|
||||
# args: ["mcp", "gateway", "run"]
|
||||
|
||||
# SSE transport
|
||||
# - name: remote-server
|
||||
# transport: sse
|
||||
# url: "http://localhost:8811"
|
||||
|
||||
# Streamable HTTP transport
|
||||
# - name: streamable-server
|
||||
# transport: streamable-http
|
||||
# url: "http://localhost:8812/mcp"
|
||||
65
config.yaml
Normal file
65
config.yaml
Normal file
@ -0,0 +1,65 @@
|
||||
ollama:
|
||||
model: "qwen3.5:27b"
|
||||
base_url: "http://10.2.18.188:11434"
|
||||
# 262144 требует ~22.6 GiB; при 20.4 GiB доступно — уменьшаем контекст (32768 укладывается в память)
|
||||
num_ctx: 32768
|
||||
|
||||
# Model routing suite (optional - for automatic model tier selection)
|
||||
# models:
|
||||
# - name: "qwen3.5:0.8b"
|
||||
# size: "0.8B"
|
||||
# capability: "simple"
|
||||
# - name: "qwen3.5:2b"
|
||||
# size: "2B"
|
||||
# capability: "medium"
|
||||
# - name: "qwen3.5:4b"
|
||||
# size: "4B"
|
||||
# capability: "complex"
|
||||
# - name: "qwen3.5:9b"
|
||||
# size: "9B"
|
||||
# capability: "advanced"
|
||||
|
||||
# Embedding model for ICE (Infinite Context Engine)
|
||||
# embed_model: "nomic-embed-text"
|
||||
|
||||
# UI Configuration (optional)
|
||||
# ui:
|
||||
# # Theme: "dark", "light", or "auto" (default: auto - detects system theme)
|
||||
# # The Nord theme will be applied automatically:
|
||||
# # - Dark mode: Nord Polar Night (dark blues) + Frost (light text) + Aurora (colorful accents)
|
||||
# # - Light mode: Nord Aurora Light (white background + same colorful accents)
|
||||
# theme: "auto"
|
||||
# # Syntax highlighting theme for code blocks (via Chroma)
|
||||
# # Options: monokai, github, dracula, one-dark, solarized-dark, etc.
|
||||
# # Default: monokai (dark), github (light)
|
||||
# code_theme: "monokai"
|
||||
# # Show line numbers in code blocks (default: false)
|
||||
# code_line_numbers: false
|
||||
# # Compact mode threshold (terminal width < this value triggers compact mode)
|
||||
# compact_threshold: 80
|
||||
|
||||
# MCP tool servers
|
||||
servers:
|
||||
- name: noted
|
||||
command: noted
|
||||
args: [mcp]
|
||||
|
||||
# tinyvault — requires `tvault init` before first use
|
||||
# - name: tinyvault
|
||||
# command: tvault
|
||||
# args: [mcp-server]
|
||||
|
||||
# Docker MCP Gateway (requires Docker Desktop with MCP enabled)
|
||||
# - name: docker-gateway
|
||||
# command: docker
|
||||
# args: ["mcp", "gateway", "run"]
|
||||
|
||||
# SSE transport
|
||||
# - name: remote-server
|
||||
# transport: sse
|
||||
# url: "http://localhost:8811"
|
||||
|
||||
# Streamable HTTP transport
|
||||
# - name: streamable-server
|
||||
# transport: streamable-http
|
||||
# url: "http://localhost:8812/mcp"
|
||||
72
go.mod
Normal file
72
go.mod
Normal file
@ -0,0 +1,72 @@
|
||||
module ai-agent
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.1
|
||||
charm.land/lipgloss/v2 v2.0.0
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/charmbracelet/glamour v0.10.0
|
||||
github.com/charmbracelet/log v0.4.2
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.1
|
||||
github.com/ollama/ollama v0.17.4
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.2 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.20 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
github.com/segmentio/encoding v0.5.3 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.5 // indirect
|
||||
golang.org/x/crypto v0.43.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/term v0.36.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.46.1 // indirect
|
||||
)
|
||||
164
go.sum
Normal file
164
go.sum
Normal file
@ -0,0 +1,164 @@
|
||||
charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.0 h1:p0d6CtWyJXJ9GfzMpUUqbP/XUUhhlk06+vCKWmox1wQ=
|
||||
charm.land/bubbletea/v2 v2.0.0/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/bubbletea/v2 v2.0.1 h1:B8e9zzK7x9JJ+XvHGF4xnYu9Xa0E0y0MyggY6dbaCfQ=
|
||||
charm.land/bubbletea/v2 v2.0.1/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/lipgloss/v2 v2.0.0 h1:sd8N/B3x892oiOjFfBQdXBQp3cAkvjGaU5TvVZC3ivo=
|
||||
charm.land/lipgloss/v2 v2.0.0/go.mod h1:w6SnmsBFBmEFBodiEDurGS/sdUY/u1+v72DqUzc6J14=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
|
||||
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
|
||||
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.0 h1:TKnLPh7IbnizJIBKFWa9mKayRUBQ9Kh1BPCk6w2PnYM=
|
||||
github.com/aymanbagabas/go-udiff v0.4.0/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/charmbracelet/colorprofile v0.4.2 h1:BdSNuMjRbotnxHSfxy+PCSa4xAmz7szw70ktAtWRYrY=
|
||||
github.com/charmbracelet/colorprofile v0.4.2/go.mod h1:0rTi81QpwDElInthtrQ6Ni7cG0sDtwAd4C4le060fT8=
|
||||
github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY=
|
||||
github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk=
|
||||
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
|
||||
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
|
||||
github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
|
||||
github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 h1:eyFRbAmexyt43hVfeyBofiGSEmJ7krjLOYt/9CF5NKA=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8/go.mod h1:SQpCTRNBtzJkwku5ye4S3HEuthAlGy2n9VXZnWkEW98=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY=
|
||||
github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo=
|
||||
github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM=
|
||||
github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k=
|
||||
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
|
||||
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjcQQaQ=
|
||||
github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/ollama/ollama v0.17.4 h1:X3KNm9x4BlHqk/AXMGtC7pBAFd46nmJmy8yDaBZLo9s=
|
||||
github.com/ollama/ollama v0.17.4/go.mod h1:tCX4IMV8DHjl3zY0THxuEkpWDZSOchJpzTuLACpMwFw=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
|
||||
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
|
||||
github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w=
|
||||
github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic=
|
||||
github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk=
|
||||
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
|
||||
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||
194
internal/agent/agent.go
Normal file
194
internal/agent/agent.go
Normal file
@ -0,0 +1,194 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/config"
|
||||
"ai-agent/internal/ice"
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/mcp"
|
||||
"ai-agent/internal/memory"
|
||||
"ai-agent/internal/permission"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
mu sync.RWMutex
|
||||
llmClient llm.Client
|
||||
registry *mcp.Registry
|
||||
messages []llm.Message
|
||||
skillContent string
|
||||
loadedCtx string
|
||||
numCtx int
|
||||
memoryStore *memory.Store
|
||||
iceEngine *ice.Engine
|
||||
router *config.Router
|
||||
modePrefix string
|
||||
toolsEnabled bool
|
||||
workDir string
|
||||
ignoreContent string
|
||||
permChecker *permission.Checker
|
||||
approvalCallback func(permission.ApprovalRequest)
|
||||
toolsConfig config.ToolsConfig
|
||||
}
|
||||
|
||||
func New(llmClient llm.Client, registry *mcp.Registry, numCtx int) *Agent {
|
||||
return &Agent{
|
||||
llmClient: llmClient,
|
||||
registry: registry,
|
||||
numCtx: numCtx,
|
||||
toolsEnabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) SetRouter(router *config.Router) {
|
||||
a.router = router
|
||||
}
|
||||
|
||||
func (a *Agent) SetModeContext(prefix string, allowTools bool) {
|
||||
a.modePrefix = prefix
|
||||
a.toolsEnabled = allowTools
|
||||
}
|
||||
|
||||
func (a *Agent) AppendLoadedContext(content string) {
|
||||
if a.loadedCtx == "" {
|
||||
a.loadedCtx = content
|
||||
} else {
|
||||
a.loadedCtx += content
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) Router() *config.Router {
|
||||
return a.router
|
||||
}
|
||||
|
||||
func (a *Agent) NumCtx() int {
|
||||
return a.numCtx
|
||||
}
|
||||
|
||||
func (a *Agent) SetMemoryStore(store *memory.Store) {
|
||||
a.memoryStore = store
|
||||
}
|
||||
|
||||
func (a *Agent) AddUserMessage(content string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.messages = append(a.messages, llm.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Agent) Messages() []llm.Message {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return a.messages
|
||||
}
|
||||
|
||||
func (a *Agent) ClearHistory() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.messages = nil
|
||||
}
|
||||
|
||||
func (a *Agent) AppendMessage(msg llm.Message) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.messages = append(a.messages, msg)
|
||||
}
|
||||
|
||||
func (a *Agent) ReplaceMessages(msgs []llm.Message) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.messages = msgs
|
||||
}
|
||||
|
||||
func (a *Agent) SetSkillContent(content string) {
|
||||
a.skillContent = content
|
||||
}
|
||||
|
||||
func (a *Agent) SetLoadedContext(content string) {
|
||||
a.loadedCtx = content
|
||||
}
|
||||
|
||||
func (a *Agent) Model() string {
|
||||
return a.llmClient.Model()
|
||||
}
|
||||
|
||||
func (a *Agent) LLMClient() llm.Client {
|
||||
return a.llmClient
|
||||
}
|
||||
|
||||
func (a *Agent) ToolCount() int {
|
||||
count := a.registry.ToolCount()
|
||||
if a.memoryStore != nil {
|
||||
count += 2
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (a *Agent) ServerCount() int {
|
||||
return a.registry.ServerCount()
|
||||
}
|
||||
|
||||
func (a *Agent) ServerNames() []string {
|
||||
return a.registry.ServerNames()
|
||||
}
|
||||
|
||||
func (a *Agent) SetWorkDir(dir string) {
|
||||
a.workDir = dir
|
||||
}
|
||||
|
||||
func (a *Agent) SetIgnoreContent(content string) {
|
||||
a.ignoreContent = content
|
||||
}
|
||||
|
||||
func (a *Agent) SetPermissionChecker(checker *permission.Checker) {
|
||||
a.permChecker = checker
|
||||
}
|
||||
|
||||
func (a *Agent) SetApprovalCallback(cb func(permission.ApprovalRequest)) {
|
||||
a.approvalCallback = cb
|
||||
}
|
||||
|
||||
func (a *Agent) SetICEEngine(engine *ice.Engine) {
|
||||
a.iceEngine = engine
|
||||
}
|
||||
|
||||
func (a *Agent) ICEEngine() *ice.Engine {
|
||||
return a.iceEngine
|
||||
}
|
||||
|
||||
func (a *Agent) SetToolsConfig(cfg config.ToolsConfig) {
|
||||
a.toolsConfig = cfg
|
||||
}
|
||||
|
||||
func (a *Agent) MaxIterations() int {
|
||||
if a.toolsConfig.MaxIterations > 0 {
|
||||
return a.toolsConfig.MaxIterations
|
||||
}
|
||||
return 10
|
||||
}
|
||||
|
||||
func (a *Agent) ToolTimeout() time.Duration {
|
||||
if a.toolsConfig.Timeout != "" {
|
||||
if d, err := time.ParseDuration(a.toolsConfig.Timeout); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return 30 * time.Second
|
||||
}
|
||||
|
||||
func (a *Agent) MaxGrepResults() int {
|
||||
if a.toolsConfig.MaxGrepResults > 0 {
|
||||
return a.toolsConfig.MaxGrepResults
|
||||
}
|
||||
return 500
|
||||
}
|
||||
|
||||
func (a *Agent) Close() {
|
||||
if a.iceEngine != nil {
|
||||
_ = a.iceEngine.Flush()
|
||||
}
|
||||
a.registry.Close()
|
||||
}
|
||||
95
internal/agent/compact.go
Normal file
95
internal/agent/compact.go
Normal file
@ -0,0 +1,95 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
const compactThreshold = 0.75
|
||||
const keepMessages = 4
|
||||
|
||||
func (a *Agent) shouldCompact(promptTokens int) bool {
|
||||
if a.numCtx <= 0 || promptTokens <= 0 {
|
||||
return false
|
||||
}
|
||||
return float64(promptTokens) > float64(a.numCtx)*compactThreshold
|
||||
}
|
||||
|
||||
func (a *Agent) compact(ctx context.Context, out Output) bool {
|
||||
a.mu.RLock()
|
||||
msgCount := len(a.messages)
|
||||
a.mu.RUnlock()
|
||||
if msgCount <= keepMessages+1 {
|
||||
return false
|
||||
}
|
||||
a.mu.RLock()
|
||||
splitAt := msgCount - keepMessages
|
||||
older := make([]llm.Message, splitAt)
|
||||
copy(older, a.messages[:splitAt])
|
||||
recent := make([]llm.Message, keepMessages)
|
||||
copy(recent, a.messages[splitAt:])
|
||||
a.mu.RUnlock()
|
||||
summary := summarizeMessages(older)
|
||||
var summaryBuf strings.Builder
|
||||
err := a.llmClient.ChatStream(ctx, llm.ChatOptions{
|
||||
Messages: []llm.Message{
|
||||
{Role: "user", Content: summary},
|
||||
},
|
||||
System: "You are a conversation summarizer. Produce a concise summary of the conversation so far, capturing all key facts, decisions, tool results, and user requests. Keep it under 500 words. Output only the summary, no preamble.",
|
||||
}, func(chunk llm.StreamChunk) error {
|
||||
if chunk.Text != "" {
|
||||
summaryBuf.WriteString(chunk.Text)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
out.Error(fmt.Sprintf("compaction failed: %v", err))
|
||||
return false
|
||||
}
|
||||
summaryText := summaryBuf.String()
|
||||
if summaryText == "" {
|
||||
return false
|
||||
}
|
||||
if a.iceEngine != nil {
|
||||
if err := a.iceEngine.IndexSummary(ctx, summaryText); err != nil {
|
||||
out.Error(fmt.Sprintf("ICE summary indexing failed: %v", err))
|
||||
}
|
||||
}
|
||||
compacted := make([]llm.Message, 0, 1+len(recent))
|
||||
compacted = append(compacted, llm.Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("[Conversation summary: %s]", summaryText),
|
||||
})
|
||||
compacted = append(compacted, recent...)
|
||||
a.ReplaceMessages(compacted)
|
||||
out.SystemMessage(fmt.Sprintf("Context compacted: %d messages summarized, %d kept", len(older), len(recent)))
|
||||
return true
|
||||
}
|
||||
|
||||
func summarizeMessages(msgs []llm.Message) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("Summarize this conversation:\n\n")
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
fmt.Fprintf(&b, "User: %s\n", msg.Content)
|
||||
case "assistant":
|
||||
if msg.Content != "" {
|
||||
fmt.Fprintf(&b, "Assistant: %s\n", msg.Content)
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
fmt.Fprintf(&b, "Assistant called tool %s(%s)\n", tc.Name, FormatToolArgs(tc.Arguments))
|
||||
}
|
||||
case "tool":
|
||||
content := msg.Content
|
||||
if len(content) > 300 {
|
||||
content = content[:297] + "..."
|
||||
}
|
||||
fmt.Fprintf(&b, "Tool %s result: %s\n", msg.ToolName, content)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
143
internal/agent/compact_test.go
Normal file
143
internal/agent/compact_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/mcp"
|
||||
)
|
||||
|
||||
type mockOutput struct {
|
||||
texts []string
|
||||
errors []string
|
||||
sysMsgs []string
|
||||
}
|
||||
|
||||
func (m *mockOutput) StreamText(text string) {
|
||||
m.texts = append(m.texts, text)
|
||||
}
|
||||
|
||||
func (m *mockOutput) StreamDone(_, _ int) {}
|
||||
|
||||
func (m *mockOutput) ToolCallStart(_ string, _ map[string]any) {}
|
||||
|
||||
func (m *mockOutput) ToolCallResult(_ string, _ string, _ bool, _ time.Duration) {}
|
||||
|
||||
func (m *mockOutput) SystemMessage(msg string) {
|
||||
m.sysMsgs = append(m.sysMsgs, msg)
|
||||
}
|
||||
|
||||
func (m *mockOutput) Error(msg string) {
|
||||
m.errors = append(m.errors, msg)
|
||||
}
|
||||
|
||||
func TestShouldCompact(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int
|
||||
promptTokens int
|
||||
want bool
|
||||
}{
|
||||
{"below 75%", 1000, 749, false},
|
||||
{"above 75%", 1000, 751, true},
|
||||
{"exactly 75% (strict >)", 1000, 750, false},
|
||||
{"numCtx zero", 0, 500, false},
|
||||
{"promptTokens zero", 1000, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ag := &Agent{
|
||||
numCtx: tt.numCtx,
|
||||
registry: mcp.NewRegistry(),
|
||||
}
|
||||
got := ag.shouldCompact(tt.promptTokens)
|
||||
if got != tt.want {
|
||||
t.Errorf("shouldCompact(%d) with numCtx=%d = %v, want %v",
|
||||
tt.promptTokens, tt.numCtx, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []llm.Message
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "user message",
|
||||
msgs: []llm.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
},
|
||||
contains: []string{"User: hello"},
|
||||
},
|
||||
{
|
||||
name: "assistant message",
|
||||
msgs: []llm.Message{
|
||||
{Role: "assistant", Content: "hi there"},
|
||||
},
|
||||
contains: []string{"Assistant: hi there"},
|
||||
},
|
||||
{
|
||||
name: "tool message",
|
||||
msgs: []llm.Message{
|
||||
{Role: "tool", Content: "result data", ToolName: "read_file"},
|
||||
},
|
||||
contains: []string{"Tool read_file result: result data"},
|
||||
},
|
||||
{
|
||||
name: "tool content truncation at 300 chars",
|
||||
msgs: []llm.Message{
|
||||
{Role: "tool", Content: strings.Repeat("x", 400), ToolName: "big_tool"},
|
||||
},
|
||||
contains: []string{"Tool big_tool result: " + strings.Repeat("x", 297) + "..."},
|
||||
},
|
||||
{
|
||||
name: "empty slice",
|
||||
msgs: []llm.Message{},
|
||||
contains: []string{"Summarize this conversation:"},
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
msgs: []llm.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []llm.ToolCall{
|
||||
{Name: "search", Arguments: map[string]any{"q": "test"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
contains: []string{"Assistant called tool search("},
|
||||
},
|
||||
{
|
||||
name: "mixed messages",
|
||||
msgs: []llm.Message{
|
||||
{Role: "user", Content: "find files"},
|
||||
{Role: "assistant", Content: "", ToolCalls: []llm.ToolCall{
|
||||
{Name: "glob", Arguments: map[string]any{"pattern": "*.go"}},
|
||||
}},
|
||||
{Role: "tool", Content: "file1.go\nfile2.go", ToolName: "glob"},
|
||||
{Role: "assistant", Content: "Found 2 files"},
|
||||
},
|
||||
contains: []string{
|
||||
"User: find files",
|
||||
"Assistant called tool glob(",
|
||||
"Tool glob result:",
|
||||
"Assistant: Found 2 files",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := summarizeMessages(tt.msgs)
|
||||
for _, want := range tt.contains {
|
||||
if !strings.Contains(result, want) {
|
||||
t.Errorf("summarizeMessages() missing %q in:\n%s", want, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
71
internal/agent/headless_output.go
Normal file
71
internal/agent/headless_output.go
Normal file
@ -0,0 +1,71 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HeadlessOutput implements the Output interface for non-interactive / pipe mode.
|
||||
// Text is written to stdout; tool calls, system messages, and errors go to stderr.
|
||||
type HeadlessOutput struct {
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
}
|
||||
|
||||
// NewHeadlessOutput creates a HeadlessOutput that writes text to os.Stdout
|
||||
// and diagnostics to os.Stderr.
|
||||
func NewHeadlessOutput() *HeadlessOutput {
|
||||
return &HeadlessOutput{
|
||||
stdout: os.Stdout,
|
||||
stderr: os.Stderr,
|
||||
}
|
||||
}
|
||||
|
||||
// newHeadlessOutput creates a HeadlessOutput with custom writers (for testing).
|
||||
func newHeadlessOutput(stdout, stderr io.Writer) *HeadlessOutput {
|
||||
return &HeadlessOutput{
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
}
|
||||
}
|
||||
|
||||
// StreamText writes incremental text content to stdout.
|
||||
func (h *HeadlessOutput) StreamText(text string) {
|
||||
fmt.Fprint(h.stdout, text)
|
||||
}
|
||||
|
||||
// StreamDone writes a trailing newline to ensure output is terminated.
|
||||
func (h *HeadlessOutput) StreamDone(evalCount, promptTokens int) {
|
||||
fmt.Fprintln(h.stdout)
|
||||
}
|
||||
|
||||
// ToolCallStart writes a brief tool invocation notice to stderr.
|
||||
func (h *HeadlessOutput) ToolCallStart(name string, args map[string]any) {
|
||||
fmt.Fprintf(h.stderr, "→ %s %s\n", name, FormatToolArgs(args))
|
||||
}
|
||||
|
||||
// ToolCallResult writes the tool result summary to stderr.
|
||||
func (h *HeadlessOutput) ToolCallResult(name string, result string, isError bool, duration time.Duration) {
|
||||
status := "ok"
|
||||
if isError {
|
||||
status = "ERROR"
|
||||
}
|
||||
// Truncate long results for stderr display.
|
||||
display := result
|
||||
if len(display) > 200 {
|
||||
display = display[:197] + "..."
|
||||
}
|
||||
fmt.Fprintf(h.stderr, "← %s [%s %s] %s\n", name, status, duration.Round(time.Millisecond), display)
|
||||
}
|
||||
|
||||
// SystemMessage writes a system message to stderr.
|
||||
func (h *HeadlessOutput) SystemMessage(msg string) {
|
||||
fmt.Fprintf(h.stderr, "[system] %s\n", msg)
|
||||
}
|
||||
|
||||
// Error writes an error message to stderr.
|
||||
func (h *HeadlessOutput) Error(msg string) {
|
||||
fmt.Fprintf(h.stderr, "[error] %s\n", msg)
|
||||
}
|
||||
154
internal/agent/headless_output_test.go
Normal file
154
internal/agent/headless_output_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Verify HeadlessOutput satisfies the Output interface at compile time.
|
||||
var _ Output = (*HeadlessOutput)(nil)
|
||||
|
||||
func TestHeadlessOutput_StreamText(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.StreamText("hello ")
|
||||
out.StreamText("world")
|
||||
|
||||
if got := stdout.String(); got != "hello world" {
|
||||
t.Errorf("StreamText: stdout = %q, want %q", got, "hello world")
|
||||
}
|
||||
if stderr.Len() != 0 {
|
||||
t.Errorf("StreamText: unexpected stderr output: %q", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_StreamDone(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.StreamText("response")
|
||||
out.StreamDone(100, 50)
|
||||
|
||||
if got := stdout.String(); got != "response\n" {
|
||||
t.Errorf("StreamDone: stdout = %q, want %q", got, "response\n")
|
||||
}
|
||||
if stderr.Len() != 0 {
|
||||
t.Errorf("StreamDone: unexpected stderr output: %q", stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_ToolCallStart(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.ToolCallStart("read_file", map[string]any{"path": "/tmp/test.go"})
|
||||
|
||||
if stdout.Len() != 0 {
|
||||
t.Errorf("ToolCallStart: unexpected stdout output: %q", stdout.String())
|
||||
}
|
||||
got := stderr.String()
|
||||
if !strings.Contains(got, "read_file") {
|
||||
t.Errorf("ToolCallStart: stderr = %q, missing tool name", got)
|
||||
}
|
||||
if !strings.HasPrefix(got, "→ ") {
|
||||
t.Errorf("ToolCallStart: stderr = %q, missing arrow prefix", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_ToolCallResult(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.ToolCallResult("read_file", "file contents here", false, 150*time.Millisecond)
|
||||
|
||||
if stdout.Len() != 0 {
|
||||
t.Errorf("ToolCallResult: unexpected stdout output: %q", stdout.String())
|
||||
}
|
||||
got := stderr.String()
|
||||
if !strings.Contains(got, "read_file") {
|
||||
t.Errorf("ToolCallResult: stderr = %q, missing tool name", got)
|
||||
}
|
||||
if !strings.Contains(got, "ok") {
|
||||
t.Errorf("ToolCallResult: stderr = %q, missing ok status", got)
|
||||
}
|
||||
if !strings.Contains(got, "file contents here") {
|
||||
t.Errorf("ToolCallResult: stderr = %q, missing result content", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_ToolCallResult_Error(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.ToolCallResult("write_file", "permission denied", true, 50*time.Millisecond)
|
||||
|
||||
got := stderr.String()
|
||||
if !strings.Contains(got, "ERROR") {
|
||||
t.Errorf("ToolCallResult error: stderr = %q, missing ERROR status", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_ToolCallResult_LongResult(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
longResult := strings.Repeat("x", 300)
|
||||
out.ToolCallResult("search", longResult, false, 100*time.Millisecond)
|
||||
|
||||
got := stderr.String()
|
||||
if strings.Contains(got, strings.Repeat("x", 300)) {
|
||||
t.Error("ToolCallResult: long result should be truncated")
|
||||
}
|
||||
if !strings.Contains(got, "...") {
|
||||
t.Error("ToolCallResult: truncated result should end with ...")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_SystemMessage(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.SystemMessage("compacting conversation")
|
||||
|
||||
if stdout.Len() != 0 {
|
||||
t.Errorf("SystemMessage: unexpected stdout output: %q", stdout.String())
|
||||
}
|
||||
got := stderr.String()
|
||||
if !strings.Contains(got, "[system]") {
|
||||
t.Errorf("SystemMessage: stderr = %q, missing [system] prefix", got)
|
||||
}
|
||||
if !strings.Contains(got, "compacting conversation") {
|
||||
t.Errorf("SystemMessage: stderr = %q, missing message", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadlessOutput_Error(t *testing.T) {
|
||||
var stdout, stderr bytes.Buffer
|
||||
out := newHeadlessOutput(&stdout, &stderr)
|
||||
|
||||
out.Error("something went wrong")
|
||||
|
||||
if stdout.Len() != 0 {
|
||||
t.Errorf("Error: unexpected stdout output: %q", stdout.String())
|
||||
}
|
||||
got := stderr.String()
|
||||
if !strings.Contains(got, "[error]") {
|
||||
t.Errorf("Error: stderr = %q, missing [error] prefix", got)
|
||||
}
|
||||
if !strings.Contains(got, "something went wrong") {
|
||||
t.Errorf("Error: stderr = %q, missing message", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHeadlessOutput(t *testing.T) {
|
||||
out := NewHeadlessOutput()
|
||||
if out == nil {
|
||||
t.Fatal("NewHeadlessOutput returned nil")
|
||||
}
|
||||
if out.stdout == nil || out.stderr == nil {
|
||||
t.Error("NewHeadlessOutput: writers should not be nil")
|
||||
}
|
||||
}
|
||||
277
internal/agent/loop.go
Normal file
277
internal/agent/loop.go
Normal file
@ -0,0 +1,277 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
permissionPkg "ai-agent/internal/permission"
|
||||
)
|
||||
|
||||
func (a *Agent) Run(ctx context.Context, out Output) {
|
||||
var tools []llm.ToolDef
|
||||
if a.toolsEnabled {
|
||||
tools = a.registry.Tools()
|
||||
if a.memoryStore != nil {
|
||||
tools = append(tools, a.memoryBuiltinToolDefs()...)
|
||||
}
|
||||
tools = append(tools, a.toolsBuiltinToolDefs()...)
|
||||
}
|
||||
var iceContext string
|
||||
a.mu.RLock()
|
||||
hasMessages := len(a.messages) > 0
|
||||
var lastMsg llm.Message
|
||||
if hasMessages {
|
||||
lastMsg = a.messages[len(a.messages)-1]
|
||||
}
|
||||
a.mu.RUnlock()
|
||||
if a.iceEngine != nil && hasMessages {
|
||||
if lastMsg.Role == "user" {
|
||||
if err := a.iceEngine.IndexMessage(ctx, "user", lastMsg.Content); err != nil {
|
||||
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
|
||||
}
|
||||
if assembled, err := a.iceEngine.AssembleContext(ctx, lastMsg.Content); err == nil {
|
||||
iceContext = assembled
|
||||
}
|
||||
}
|
||||
}
|
||||
system := buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
|
||||
const maxRetries = 2
|
||||
var lastPromptTokens int
|
||||
var retryCount int
|
||||
maxIters := a.MaxIterations()
|
||||
for i := 0; i < maxIters; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
var textBuf strings.Builder
|
||||
var toolCalls []llm.ToolCall
|
||||
err := a.llmClient.ChatStream(ctx, llm.ChatOptions{
|
||||
Messages: a.messages,
|
||||
Tools: tools,
|
||||
System: system,
|
||||
}, func(chunk llm.StreamChunk) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if chunk.Text != "" {
|
||||
textBuf.WriteString(chunk.Text)
|
||||
out.StreamText(chunk.Text)
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolCalls = append(toolCalls, chunk.ToolCalls...)
|
||||
}
|
||||
if chunk.Done {
|
||||
lastPromptTokens = chunk.PromptEvalCount
|
||||
out.StreamDone(chunk.EvalCount, chunk.PromptEvalCount)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if retryCount < maxRetries && isRetryableError(err) {
|
||||
retryCount++
|
||||
out.Error(fmt.Sprintf("LLM produced malformed output, retrying (%d/%d)...", retryCount, maxRetries))
|
||||
textBuf.Reset()
|
||||
toolCalls = nil
|
||||
continue
|
||||
}
|
||||
out.Error(fmt.Sprintf("LLM error: %v", err))
|
||||
out.SystemMessage(fmt.Sprintf("⚠️ Model response failed: %v\n\nYou can try:\n- Checking if Ollama is running (`ollama ps`)\n- Switching to a different model (ctrl+m)\n- Reducing context size\n\nTool results are still available above.", err))
|
||||
return
|
||||
}
|
||||
retryCount = 0
|
||||
assistantMsg := llm.Message{
|
||||
Role: "assistant",
|
||||
Content: textBuf.String(),
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
a.AppendMessage(assistantMsg)
|
||||
if a.iceEngine != nil && assistantMsg.Content != "" {
|
||||
if err := a.iceEngine.IndexMessage(ctx, "assistant", assistantMsg.Content); err != nil {
|
||||
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
|
||||
}
|
||||
}
|
||||
if len(toolCalls) == 0 {
|
||||
a.mu.RLock()
|
||||
hasEnoughMessages := len(a.messages) >= 2
|
||||
var userContent string
|
||||
if hasEnoughMessages {
|
||||
for idx := len(a.messages) - 2; idx >= 0; idx-- {
|
||||
if a.messages[idx].Role == "user" {
|
||||
userContent = a.messages[idx].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
a.mu.RUnlock()
|
||||
if a.iceEngine != nil && hasEnoughMessages && userContent != "" {
|
||||
a.iceEngine.DetectAutoMemory(ctx, userContent, assistantMsg.Content)
|
||||
}
|
||||
return
|
||||
}
|
||||
type pendingTool struct {
|
||||
tc llm.ToolCall
|
||||
isMemoryTool bool
|
||||
isMCPTool bool
|
||||
}
|
||||
var pending []pendingTool
|
||||
for _, tc := range toolCalls {
|
||||
if a.memoryStore != nil && a.isMemoryTool(tc.Name) {
|
||||
pending = append(pending, pendingTool{tc: tc, isMemoryTool: true})
|
||||
continue
|
||||
}
|
||||
if a.isToolsTool(tc.Name) {
|
||||
out.ToolCallStart(tc.Name, tc.Arguments)
|
||||
startTime := time.Now()
|
||||
result, isErr := a.handleToolsTool(tc)
|
||||
duration := time.Since(startTime)
|
||||
out.ToolCallResult(tc.Name, result, isErr, duration)
|
||||
a.AppendMessage(llm.Message{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
ToolName: tc.Name,
|
||||
ToolCallID: tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if a.permChecker != nil {
|
||||
switch a.permChecker.ToCheckResult(tc.Name) {
|
||||
case permissionPkg.CheckDeny:
|
||||
errMsg := "tool call blocked by permission policy"
|
||||
out.ToolCallStart(tc.Name, tc.Arguments)
|
||||
out.ToolCallResult(tc.Name, errMsg, true, 0)
|
||||
a.AppendMessage(llm.Message{
|
||||
Role: "tool",
|
||||
Content: errMsg,
|
||||
ToolName: tc.Name,
|
||||
ToolCallID: tc.ID,
|
||||
})
|
||||
continue
|
||||
case permissionPkg.CheckAsk:
|
||||
if a.approvalCallback != nil {
|
||||
allowed, always := permissionPkg.RequestApproval(tc.Name, tc.Arguments, a.approvalCallback)
|
||||
if always {
|
||||
a.permChecker.SetPolicy(tc.Name, permissionPkg.PolicyAllow)
|
||||
}
|
||||
if !allowed {
|
||||
errMsg := "tool call denied by user"
|
||||
out.ToolCallStart(tc.Name, tc.Arguments)
|
||||
out.ToolCallResult(tc.Name, errMsg, true, 0)
|
||||
a.AppendMessage(llm.Message{
|
||||
Role: "tool",
|
||||
Content: errMsg,
|
||||
ToolName: tc.Name,
|
||||
ToolCallID: tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pending = append(pending, pendingTool{tc: tc, isMCPTool: true})
|
||||
}
|
||||
if len(pending) > 0 {
|
||||
var wg sync.WaitGroup
|
||||
mu := sync.Mutex{}
|
||||
results := make([]llm.Message, len(pending))
|
||||
for i, p := range pending {
|
||||
wg.Add(1)
|
||||
go func(idx int, tool pendingTool) {
|
||||
defer wg.Done()
|
||||
tc := tool.tc
|
||||
out.ToolCallStart(tc.Name, tc.Arguments)
|
||||
startTime := time.Now()
|
||||
var result string
|
||||
var isErr bool
|
||||
if tool.isMemoryTool {
|
||||
result, isErr = a.handleMemoryTool(tc)
|
||||
} else if tool.isMCPTool {
|
||||
toolResult, err := a.registry.CallTool(ctx, tc.Name, tc.Arguments)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("ERROR: Tool '%s' failed: %v\nThis tool call failed but you can still complete the task with other available information.", tc.Name, err)
|
||||
isErr = true
|
||||
} else {
|
||||
result = toolResult.Content
|
||||
isErr = toolResult.IsError
|
||||
}
|
||||
}
|
||||
duration := time.Since(startTime)
|
||||
out.ToolCallResult(tc.Name, result, isErr, duration)
|
||||
mu.Lock()
|
||||
results[idx] = llm.Message{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
ToolName: tc.Name,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
mu.Unlock()
|
||||
}(i, p)
|
||||
}
|
||||
wg.Wait()
|
||||
for _, msg := range results {
|
||||
if msg.ToolName != "" {
|
||||
a.AppendMessage(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
if a.shouldCompact(lastPromptTokens) {
|
||||
if a.compact(ctx, out) {
|
||||
system = buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
|
||||
}
|
||||
}
|
||||
if i == maxIters-2 {
|
||||
out.Error(fmt.Sprintf("approaching iteration limit (%d/%d)", i+2, maxIters))
|
||||
}
|
||||
}
|
||||
out.Error(fmt.Sprintf("reached max iterations (%d)", maxIters))
|
||||
}
|
||||
|
||||
func isRetryableError(err error) bool {
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "parse JSON") || strings.Contains(msg, "unexpected end of JSON")
|
||||
}
|
||||
|
||||
func FormatToolArgs(args map[string]any) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
var parts []string
|
||||
for key, value := range args {
|
||||
var valStr string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if len(v) > 47 {
|
||||
valStr = `"` + v[:44] + `..."`
|
||||
} else {
|
||||
valStr = `"` + v + `"`
|
||||
}
|
||||
case int, float64, bool:
|
||||
valStr = fmt.Sprintf("%v", v)
|
||||
case []any:
|
||||
valStr = fmt.Sprintf("[%d items]", len(v))
|
||||
case map[string]any:
|
||||
valStr = fmt.Sprintf("{%d fields}", len(v))
|
||||
default:
|
||||
valStr = fmt.Sprintf("%v", v)
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", key, valStr))
|
||||
}
|
||||
sort.Strings(parts)
|
||||
result := strings.Join(parts, " ")
|
||||
if len(result) > 60 {
|
||||
return result[:57] + "..."
|
||||
}
|
||||
return result
|
||||
}
|
||||
73
internal/agent/loop_test.go
Normal file
73
internal/agent/loop_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormatToolArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args map[string]any
|
||||
want string
|
||||
contains []string
|
||||
maxLen int
|
||||
}{
|
||||
{
|
||||
name: "empty map",
|
||||
args: map[string]any{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "simple map",
|
||||
args: map[string]any{"key": "value"},
|
||||
contains: []string{"key=", `"value"`},
|
||||
},
|
||||
{
|
||||
name: "long args truncated at 60",
|
||||
args: map[string]any{"data": strings.Repeat("a", 300)},
|
||||
maxLen: 60,
|
||||
},
|
||||
{
|
||||
name: "multiple args",
|
||||
args: map[string]any{"path": "/tmp/test", "command": "ls"},
|
||||
contains: []string{"path=", "command="},
|
||||
},
|
||||
{
|
||||
name: "numeric args",
|
||||
args: map[string]any{"count": 42, "ratio": 3.14},
|
||||
contains: []string{"count=42", "ratio=3.14"},
|
||||
},
|
||||
{
|
||||
name: "array args",
|
||||
args: map[string]any{"items": []any{1, 2, 3}},
|
||||
contains: []string{"items=", "[3 items]"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FormatToolArgs(tt.args)
|
||||
|
||||
if tt.want != "" && got != tt.want {
|
||||
t.Errorf("FormatToolArgs() = %q, want %q", got, tt.want)
|
||||
}
|
||||
|
||||
for _, substr := range tt.contains {
|
||||
if !strings.Contains(got, substr) {
|
||||
t.Errorf("FormatToolArgs() = %q, missing %q", got, substr)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.maxLen > 0 {
|
||||
if len(got) > tt.maxLen {
|
||||
t.Errorf("FormatToolArgs() len = %d, want <= %d", len(got), tt.maxLen)
|
||||
}
|
||||
// Check for truncation indicator (either "..." in value or at end)
|
||||
if !strings.Contains(got, "...") {
|
||||
t.Errorf("FormatToolArgs() should contain '...' when truncated, got %q", got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
171
internal/agent/memory.go
Normal file
171
internal/agent/memory.go
Normal file
@ -0,0 +1,171 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
func (a *Agent) memoryBuiltinToolDefs() []llm.ToolDef {
|
||||
return memory.BuiltinToolDefs()
|
||||
}
|
||||
|
||||
func (a *Agent) isMemoryTool(name string) bool {
|
||||
return memory.IsBuiltinTool(name)
|
||||
}
|
||||
|
||||
func (a *Agent) handleMemoryTool(tc llm.ToolCall) (string, bool) {
|
||||
switch tc.Name {
|
||||
case "memory_save":
|
||||
return a.handleMemorySave(tc.Arguments)
|
||||
case "memory_recall":
|
||||
return a.handleMemoryRecall(tc.Arguments)
|
||||
case "memory_delete":
|
||||
return a.handleMemoryDelete(tc.Arguments)
|
||||
case "memory_update":
|
||||
return a.handleMemoryUpdate(tc.Arguments)
|
||||
case "memory_list":
|
||||
return a.handleMemoryList(tc.Arguments)
|
||||
default:
|
||||
return fmt.Sprintf("unknown memory tool: %s", tc.Name), true
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) handleMemorySave(args map[string]any) (string, bool) {
|
||||
content, _ := args["content"].(string)
|
||||
if content == "" {
|
||||
return "error: content is required", true
|
||||
}
|
||||
var tags []string
|
||||
if rawTags, ok := args["tags"]; ok {
|
||||
switch v := rawTags.(type) {
|
||||
case []any:
|
||||
for _, t := range v {
|
||||
if s, ok := t.(string); ok {
|
||||
tags = append(tags, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
tags = v
|
||||
}
|
||||
}
|
||||
id, err := a.memoryStore.Save(content, tags)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error saving memory: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Memory saved (id: %d)", id), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleMemoryRecall(args map[string]any) (string, bool) {
|
||||
query, _ := args["query"].(string)
|
||||
if query == "" {
|
||||
return "error: query is required", true
|
||||
}
|
||||
memories := a.memoryStore.Recall(query, 5)
|
||||
if len(memories) == 0 {
|
||||
return "No matching memories found.", false
|
||||
}
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "Found %d matching memories:\n", len(memories))
|
||||
for _, mem := range memories {
|
||||
fmt.Fprintf(&b, "- [%d] %s", mem.ID, mem.Content)
|
||||
if len(mem.Tags) > 0 {
|
||||
fmt.Fprintf(&b, " (tags: %s)", strings.Join(mem.Tags, ", "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return b.String(), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleMemoryDelete(args map[string]any) (string, bool) {
|
||||
idVal, ok := args["id"]
|
||||
if !ok {
|
||||
return "error: id is required", true
|
||||
}
|
||||
var id int
|
||||
switch v := idVal.(type) {
|
||||
case float64:
|
||||
id = int(v)
|
||||
case int:
|
||||
id = v
|
||||
default:
|
||||
return "error: id must be a number", true
|
||||
}
|
||||
deleted, err := a.memoryStore.Delete(id)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error deleting memory: %v", err), true
|
||||
}
|
||||
if !deleted {
|
||||
return fmt.Sprintf("memory with id %d not found", id), true
|
||||
}
|
||||
return fmt.Sprintf("Memory %d deleted", id), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleMemoryUpdate(args map[string]any) (string, bool) {
|
||||
idVal, ok := args["id"]
|
||||
if !ok {
|
||||
return "error: id is required", true
|
||||
}
|
||||
var id int
|
||||
switch v := idVal.(type) {
|
||||
case float64:
|
||||
id = int(v)
|
||||
case int:
|
||||
id = v
|
||||
default:
|
||||
return "error: id must be a number", true
|
||||
}
|
||||
content, _ := args["content"].(string)
|
||||
var tags []string
|
||||
if rawTags, ok := args["tags"]; ok {
|
||||
switch v := rawTags.(type) {
|
||||
case []any:
|
||||
for _, t := range v {
|
||||
if s, ok := t.(string); ok {
|
||||
tags = append(tags, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
tags = v
|
||||
}
|
||||
}
|
||||
if content == "" && len(tags) == 0 {
|
||||
return "error: at least one of content or tags is required", true
|
||||
}
|
||||
updated, err := a.memoryStore.Update(id, content, tags)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error updating memory: %v", err), true
|
||||
}
|
||||
if !updated {
|
||||
return fmt.Sprintf("memory with id %d not found", id), true
|
||||
}
|
||||
return fmt.Sprintf("Memory %d updated", id), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleMemoryList(args map[string]any) (string, bool) {
|
||||
limit := 20
|
||||
if rawLimit, ok := args["limit"]; ok {
|
||||
switch v := rawLimit.(type) {
|
||||
case float64:
|
||||
limit = int(v)
|
||||
case int:
|
||||
limit = v
|
||||
}
|
||||
}
|
||||
memories := a.memoryStore.Recent(limit)
|
||||
if len(memories) == 0 {
|
||||
return "No memories stored.", false
|
||||
}
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "Stored memories (%d total):\n", a.memoryStore.Count())
|
||||
for _, mem := range memories {
|
||||
fmt.Fprintf(&b, "- [%d] %s", mem.ID, mem.Content)
|
||||
if len(mem.Tags) > 0 {
|
||||
fmt.Fprintf(&b, " (tags: %s)", strings.Join(mem.Tags, ", "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return b.String(), false
|
||||
}
|
||||
159
internal/agent/memory_test.go
Normal file
159
internal/agent/memory_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/mcp"
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
func newTestAgentWithMemory(t *testing.T) *Agent {
|
||||
t.Helper()
|
||||
store := memory.NewStore(filepath.Join(t.TempDir(), "test-memories.json"))
|
||||
return &Agent{memoryStore: store, registry: mcp.NewRegistry()}
|
||||
}
|
||||
|
||||
func TestHandleMemoryTool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolCall llm.ToolCall
|
||||
wantSubstr string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "dispatch to save",
|
||||
toolCall: llm.ToolCall{
|
||||
Name: "memory_save",
|
||||
Arguments: map[string]any{"content": "test fact", "tags": []any{"tag1"}},
|
||||
},
|
||||
wantSubstr: "Memory saved (id:",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "dispatch to recall",
|
||||
toolCall: llm.ToolCall{
|
||||
Name: "memory_recall",
|
||||
Arguments: map[string]any{"query": "test"},
|
||||
},
|
||||
wantSubstr: "No matching memories found.",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unknown tool",
|
||||
toolCall: llm.ToolCall{
|
||||
Name: "unknown",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
wantSubstr: "unknown memory tool: unknown",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ag := newTestAgentWithMemory(t)
|
||||
result, isErr := ag.handleMemoryTool(tt.toolCall)
|
||||
if isErr != tt.wantErr {
|
||||
t.Errorf("handleMemoryTool() isErr = %v, want %v", isErr, tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(result, tt.wantSubstr) {
|
||||
t.Errorf("handleMemoryTool() = %q, want substring %q", result, tt.wantSubstr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMemorySave(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args map[string]any
|
||||
wantSubstr string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid with tags as []any",
|
||||
args: map[string]any{"content": "test fact", "tags": []any{"tag1", "tag2"}},
|
||||
wantSubstr: "Memory saved (id:",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid without tags",
|
||||
args: map[string]any{"content": "another fact"},
|
||||
wantSubstr: "Memory saved (id:",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing content",
|
||||
args: map[string]any{},
|
||||
wantSubstr: "error: content is required",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
args: map[string]any{"content": ""},
|
||||
wantSubstr: "error: content is required",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ag := newTestAgentWithMemory(t)
|
||||
result, isErr := ag.handleMemorySave(tt.args)
|
||||
if isErr != tt.wantErr {
|
||||
t.Errorf("handleMemorySave() isErr = %v, want %v", isErr, tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(result, tt.wantSubstr) {
|
||||
t.Errorf("handleMemorySave() = %q, want substring %q", result, tt.wantSubstr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleMemoryRecall(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(ag *Agent)
|
||||
args map[string]any
|
||||
wantSubstr string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid recall finds saved memory",
|
||||
setup: func(ag *Agent) {
|
||||
_, _ = ag.memoryStore.Save("user prefers Go", []string{"language"})
|
||||
},
|
||||
args: map[string]any{"query": "Go"},
|
||||
wantSubstr: "Found 1 matching memories:",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing query",
|
||||
setup: func(ag *Agent) {},
|
||||
args: map[string]any{},
|
||||
wantSubstr: "error: query is required",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no matches",
|
||||
setup: func(ag *Agent) {},
|
||||
args: map[string]any{"query": "nonexistent"},
|
||||
wantSubstr: "No matching memories found.",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ag := newTestAgentWithMemory(t)
|
||||
tt.setup(ag)
|
||||
result, isErr := ag.handleMemoryRecall(tt.args)
|
||||
if isErr != tt.wantErr {
|
||||
t.Errorf("handleMemoryRecall() isErr = %v, want %v", isErr, tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(result, tt.wantSubstr) {
|
||||
t.Errorf("handleMemoryRecall() = %q, want substring %q", result, tt.wantSubstr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
24
internal/agent/output.go
Normal file
24
internal/agent/output.go
Normal file
@ -0,0 +1,24 @@
|
||||
package agent
|
||||
|
||||
import "time"
|
||||
|
||||
// Output is the interface the agent uses to stream results to the UI.
|
||||
type Output interface {
|
||||
// StreamText sends incremental text content.
|
||||
StreamText(text string)
|
||||
|
||||
// StreamDone signals that the current response is complete.
|
||||
StreamDone(evalCount, promptTokens int)
|
||||
|
||||
// ToolCallStart signals the beginning of a tool invocation.
|
||||
ToolCallStart(name string, args map[string]any)
|
||||
|
||||
// ToolCallResult delivers the result of a tool invocation.
|
||||
ToolCallResult(name string, result string, isError bool, duration time.Duration)
|
||||
|
||||
// SystemMessage displays a system-level message to the user.
|
||||
SystemMessage(msg string)
|
||||
|
||||
// Error reports a non-fatal error to the user.
|
||||
Error(msg string)
|
||||
}
|
||||
272
internal/agent/system.go
Normal file
272
internal/agent/system.go
Normal file
@ -0,0 +1,272 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
const systemTemplate = `You are a helpful personal assistant running locally on the user's machine.
|
||||
You have access to tools via MCP servers. You MUST use tools to accomplish tasks — do not guess or make up answers when a tool can provide the real information.
|
||||
%s
|
||||
Current date: %s
|
||||
%s%s
|
||||
%s%s%s
|
||||
## Available Tools
|
||||
%s
|
||||
## Guidelines
|
||||
- **ALWAYS use your tools** when the user asks you to read, explore, search, or modify files. You have filesystem tools — use them.
|
||||
- When the user says "read this codebase" or similar, use list/read tools starting from the working directory shown above.
|
||||
- Be concise and direct in your responses.
|
||||
- When a tool call fails, explain what happened and suggest alternatives.
|
||||
- For multi-step tasks, explain your plan briefly before executing.
|
||||
- Format responses in markdown when it improves readability.
|
||||
- If you're unsure about something, say so rather than guessing.
|
||||
- Never fabricate tool results — always call the actual tool.
|
||||
- Do NOT claim you cannot access files or the filesystem. You have tools for that — use them.
|
||||
%s`
|
||||
|
||||
const smallModelTemplate = `You are a local AI assistant. Use tools to read/write files and run commands.
|
||||
%sDate: %s
|
||||
%s%s
|
||||
%s
|
||||
## Tools
|
||||
%s
|
||||
Guidelines:
|
||||
- Be concise and direct
|
||||
- Use tools when needed to complete tasks
|
||||
- If a tool fails, continue with available information
|
||||
- Don't guess - use tools to verify
|
||||
- You can complete tasks even if some tools fail
|
||||
%s`
|
||||
|
||||
func isSmallModel(modelName string) bool {
|
||||
lower := strings.ToLower(modelName)
|
||||
if strings.Contains(lower, "0.8b") || strings.Contains(lower, "1b") || strings.Contains(lower, "2b") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildSystemPrompt(modePrefix string, tools []llm.ToolDef, skillContent, loadedContext string, memStore *memory.Store, iceContext, workDir, ignoreContent string) string {
|
||||
return buildSystemPromptForModel(modePrefix, tools, skillContent, loadedContext, memStore, iceContext, workDir, ignoreContent, "")
|
||||
}
|
||||
|
||||
func buildSystemPromptForModel(modePrefix string, tools []llm.ToolDef, skillContent, loadedContext string, memStore *memory.Store, iceContext, workDir, ignoreContent string, modelName string) string {
|
||||
useSmallModel := isSmallModel(modelName)
|
||||
var toolList string
|
||||
if len(tools) == 0 {
|
||||
toolList = "No tools currently available.\n"
|
||||
} else if useSmallModel {
|
||||
toolList = simplifyToolsForSmallModel(tools)
|
||||
} else {
|
||||
var b strings.Builder
|
||||
for _, t := range tools {
|
||||
fmt.Fprintf(&b, "- **%s**: %s\n", t.Name, t.Description)
|
||||
}
|
||||
toolList = b.String()
|
||||
}
|
||||
envSection := buildEnvironmentSection(workDir)
|
||||
var skillSection string
|
||||
if skillContent != "" {
|
||||
skillSection = fmt.Sprintf("\n## Active Skills\n%s\n", skillContent)
|
||||
}
|
||||
var ctxSection string
|
||||
if loadedContext != "" {
|
||||
ctxSection = fmt.Sprintf("\n## Loaded Context\n%s\n", loadedContext)
|
||||
}
|
||||
var memorySection string
|
||||
if iceContext != "" {
|
||||
memorySection = iceContext
|
||||
} else if memStore != nil {
|
||||
memorySection = buildMemorySection(memStore)
|
||||
}
|
||||
var memoryGuidelines string
|
||||
if memStore != nil {
|
||||
memoryGuidelines = `
|
||||
## Memory Guidelines
|
||||
- You have access to persistent memory via memory_save and memory_recall tools.
|
||||
- Proactively save important user preferences, project facts, and key decisions.
|
||||
- When the user shares personal information (name, preferences, etc.), save it.
|
||||
- Use memory_recall to look up previously saved information when relevant.
|
||||
- Don't save trivial or session-specific information.
|
||||
`
|
||||
}
|
||||
var ignoreSection string
|
||||
if ignoreContent != "" {
|
||||
ignoreSection = fmt.Sprintf("\n## Ignored Paths\nThe following paths/patterns should be excluded from file operations:\n%s\n", ignoreContent)
|
||||
}
|
||||
var modePrefixSection string
|
||||
if modePrefix != "" {
|
||||
modePrefixSection = "\n" + modePrefix + "\n"
|
||||
}
|
||||
dateStr := time.Now().Format("Monday, January 2, 2006")
|
||||
if useSmallModel {
|
||||
return fmt.Sprintf(smallModelTemplate,
|
||||
modePrefixSection,
|
||||
dateStr,
|
||||
envSection,
|
||||
ignoreSection,
|
||||
skillSection,
|
||||
toolList,
|
||||
memoryGuidelines,
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(systemTemplate,
|
||||
modePrefixSection,
|
||||
dateStr,
|
||||
envSection,
|
||||
ignoreSection,
|
||||
skillSection,
|
||||
ctxSection,
|
||||
memorySection,
|
||||
toolList,
|
||||
memoryGuidelines,
|
||||
)
|
||||
}
|
||||
|
||||
func simplifyToolsForSmallModel(tools []llm.ToolDef) string {
|
||||
var b strings.Builder
|
||||
for _, t := range tools {
|
||||
desc := t.Description
|
||||
if len(desc) > 50 {
|
||||
desc = desc[:47] + "..."
|
||||
}
|
||||
fmt.Fprintf(&b, "- %s: %s\n", t.Name, desc)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func buildEnvironmentSection(workDir string) string {
|
||||
if workDir == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("\n## Environment\n")
|
||||
b.WriteString(fmt.Sprintf("Working directory: %s\n", workDir))
|
||||
if info := detectProjectInfo(workDir); info != "" {
|
||||
b.WriteString(info)
|
||||
}
|
||||
if gitInfo := detectGitInfo(workDir); gitInfo != "" {
|
||||
b.WriteString(gitInfo)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func detectProjectInfo(workDir string) string {
|
||||
markers := []struct {
|
||||
file string
|
||||
desc string
|
||||
}{
|
||||
{"go.mod", "Go module"},
|
||||
{"package.json", "Node.js/JavaScript"},
|
||||
{"Cargo.toml", "Rust"},
|
||||
{"pyproject.toml", "Python"},
|
||||
{"setup.py", "Python"},
|
||||
{"Makefile", ""},
|
||||
{"Taskfile.yml", ""},
|
||||
}
|
||||
var found []string
|
||||
for _, m := range markers {
|
||||
if _, err := os.Stat(filepath.Join(workDir, m.file)); err == nil {
|
||||
if m.desc != "" {
|
||||
found = append(found, fmt.Sprintf("%s (%s)", m.file, m.desc))
|
||||
} else {
|
||||
found = append(found, m.file)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(found) == 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("Project markers: %s\n", strings.Join(found, ", "))
|
||||
}
|
||||
|
||||
func detectGitInfo(workDir string) string {
|
||||
gitDir := filepath.Join(workDir, ".git")
|
||||
if _, err := os.Stat(gitDir); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
branch := runGitCommand(workDir, "rev-parse", "--abbrev-ref", "HEAD")
|
||||
if branch != "" {
|
||||
b.WriteString(fmt.Sprintf("Git branch: %s\n", branch))
|
||||
}
|
||||
status := runGitCommand(workDir, "status", "--porcelain")
|
||||
if status != "" {
|
||||
lines := strings.Split(strings.TrimSpace(status), "\n")
|
||||
var modified, added, deleted int
|
||||
for _, line := range lines {
|
||||
if len(line) >= 2 {
|
||||
switch line[0] {
|
||||
case 'M', 'm':
|
||||
modified++
|
||||
case 'A':
|
||||
added++
|
||||
case 'D':
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
}
|
||||
if modified > 0 || added > 0 || deleted > 0 {
|
||||
statusParts := []string{}
|
||||
if modified > 0 {
|
||||
statusParts = append(statusParts, fmt.Sprintf("%d modified", modified))
|
||||
}
|
||||
if added > 0 {
|
||||
statusParts = append(statusParts, fmt.Sprintf("%d added", added))
|
||||
}
|
||||
if deleted > 0 {
|
||||
statusParts = append(statusParts, fmt.Sprintf("%d deleted", deleted))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("Git status: %s\n", strings.Join(statusParts, ", ")))
|
||||
}
|
||||
}
|
||||
recentLog := runGitCommand(workDir, "log", "-3", "--oneline")
|
||||
if recentLog != "" {
|
||||
b.WriteString(fmt.Sprintf("Recent commits:\n"))
|
||||
for _, line := range strings.Split(strings.TrimSpace(recentLog), "\n") {
|
||||
b.WriteString(fmt.Sprintf(" - %s\n", line))
|
||||
}
|
||||
}
|
||||
if b.Len() == 0 {
|
||||
return ""
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func runGitCommand(dir string, args ...string) string {
|
||||
cmd := exec.Command("git", args...)
|
||||
cmd.Dir = dir
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
func buildMemorySection(store *memory.Store) string {
|
||||
if store.Count() == 0 {
|
||||
return ""
|
||||
}
|
||||
recent := store.Recent(10)
|
||||
if len(recent) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("\n## Remembered Facts\n")
|
||||
for _, mem := range recent {
|
||||
b.WriteString(fmt.Sprintf("- %s", mem.Content))
|
||||
if len(mem.Tags) > 0 {
|
||||
b.WriteString(fmt.Sprintf(" [tags: %s]", strings.Join(mem.Tags, ", ")))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
186
internal/agent/system_test.go
Normal file
186
internal/agent/system_test.go
Normal file
@ -0,0 +1,186 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
func TestBuildSystemPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []llm.ToolDef
|
||||
skillContent string
|
||||
loadedCtx string
|
||||
memStore *memory.Store
|
||||
iceContext string
|
||||
contains []string
|
||||
notContains []string
|
||||
}{
|
||||
{
|
||||
name: "no optional sections",
|
||||
contains: []string{"No tools currently available.", "Current date:"},
|
||||
notContains: []string{"Active Skills", "Loaded Context", "Remembered Facts"},
|
||||
},
|
||||
{
|
||||
name: "with tools",
|
||||
tools: []llm.ToolDef{
|
||||
{Name: "test_tool", Description: "does stuff"},
|
||||
},
|
||||
contains: []string{"test_tool", "does stuff"},
|
||||
notContains: []string{"No tools currently available."},
|
||||
},
|
||||
{
|
||||
name: "with skills",
|
||||
skillContent: "skill content here",
|
||||
contains: []string{"Active Skills", "skill content here"},
|
||||
},
|
||||
{
|
||||
name: "with loaded context",
|
||||
loadedCtx: "some loaded context",
|
||||
contains: []string{"Loaded Context", "some loaded context"},
|
||||
},
|
||||
{
|
||||
name: "ICE overrides memory",
|
||||
iceContext: "ice assembled context",
|
||||
contains: []string{"ice assembled context"},
|
||||
notContains: []string{"Remembered Facts"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildSystemPrompt("", tt.tools, tt.skillContent, tt.loadedCtx, tt.memStore, tt.iceContext, "", "")
|
||||
for _, want := range tt.contains {
|
||||
if !strings.Contains(result, want) {
|
||||
t.Errorf("buildSystemPrompt() missing %q", want)
|
||||
}
|
||||
}
|
||||
for _, notWant := range tt.notContains {
|
||||
if strings.Contains(result, notWant) {
|
||||
t.Errorf("buildSystemPrompt() should not contain %q", notWant)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("with memory store entries", func(t *testing.T) {
|
||||
store := memory.NewStore(filepath.Join(t.TempDir(), "test-memories.json"))
|
||||
_, _ = store.Save("user prefers dark mode", []string{"preference"})
|
||||
result := buildSystemPrompt("", nil, "", "", store, "", "", "")
|
||||
if !strings.Contains(result, "Remembered Facts") {
|
||||
t.Error("expected Remembered Facts section")
|
||||
}
|
||||
if !strings.Contains(result, "user prefers dark mode") {
|
||||
t.Error("expected memory content in prompt")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_WithWorkDir(t *testing.T) {
|
||||
result := buildSystemPrompt("", nil, "", "", nil, "", "/home/user/myproject", "")
|
||||
if !strings.Contains(result, "Working directory: /home/user/myproject") {
|
||||
t.Error("expected working directory in prompt")
|
||||
}
|
||||
if !strings.Contains(result, "Environment") {
|
||||
t.Error("expected Environment section header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_EmptyWorkDir(t *testing.T) {
|
||||
result := buildSystemPrompt("", nil, "", "", nil, "", "", "")
|
||||
if strings.Contains(result, "Working directory") {
|
||||
t.Error("should not include working directory when empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_WithIgnoreContent(t *testing.T) {
|
||||
ignoreContent := "- node_modules\n- *.log\n- build/"
|
||||
result := buildSystemPrompt("", nil, "", "", nil, "", "", ignoreContent)
|
||||
if !strings.Contains(result, "Ignored Paths") {
|
||||
t.Error("expected Ignored Paths section header")
|
||||
}
|
||||
if !strings.Contains(result, "node_modules") {
|
||||
t.Error("expected node_modules in ignore section")
|
||||
}
|
||||
if !strings.Contains(result, "*.log") {
|
||||
t.Error("expected *.log in ignore section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSystemPrompt_EmptyIgnoreContent(t *testing.T) {
|
||||
result := buildSystemPrompt("", nil, "", "", nil, "", "", "")
|
||||
if strings.Contains(result, "Ignored Paths") {
|
||||
t.Error("should not include Ignored Paths when content is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectProjectInfo_GoProject(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
_ = os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test"), 0o644)
|
||||
|
||||
info := detectProjectInfo(dir)
|
||||
if !strings.Contains(info, "go.mod") {
|
||||
t.Errorf("expected go.mod in project info, got %q", info)
|
||||
}
|
||||
if !strings.Contains(info, "Go module") {
|
||||
t.Errorf("expected 'Go module' in project info, got %q", info)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectProjectInfo_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
info := detectProjectInfo(dir)
|
||||
if info != "" {
|
||||
t.Errorf("expected empty for dir with no markers, got %q", info)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMemorySection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(s *memory.Store)
|
||||
contains []string
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "empty store",
|
||||
setup: func(s *memory.Store) {},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "store with tagged entry",
|
||||
setup: func(s *memory.Store) {
|
||||
_, _ = s.Save("likes Go", []string{"lang", "preference"})
|
||||
},
|
||||
contains: []string{"Remembered Facts", "likes Go", "[tags: lang, preference]"},
|
||||
},
|
||||
{
|
||||
name: "store with untagged entry",
|
||||
setup: func(s *memory.Store) {
|
||||
_, _ = s.Save("project uses modules", nil)
|
||||
},
|
||||
contains: []string{"Remembered Facts", "project uses modules"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := memory.NewStore(filepath.Join(t.TempDir(), "mem.json"))
|
||||
tt.setup(store)
|
||||
result := buildMemorySection(store)
|
||||
if tt.wantEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got %q", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, want := range tt.contains {
|
||||
if !strings.Contains(result, want) {
|
||||
t.Errorf("buildMemorySection() missing %q in:\n%s", want, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
688
internal/agent/tools.go
Normal file
688
internal/agent/tools.go
Normal file
@ -0,0 +1,688 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
maxTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
func (a *Agent) toolsBuiltinToolDefs() []llm.ToolDef {
|
||||
return tools.AllToolDefs()
|
||||
}
|
||||
|
||||
func (a *Agent) isToolsTool(name string) bool {
|
||||
return tools.IsBuiltinTool(name)
|
||||
}
|
||||
|
||||
func (a *Agent) handleToolsTool(tc llm.ToolCall) (string, bool) {
|
||||
switch tc.Name {
|
||||
case "grep":
|
||||
return a.handleGrep(tc.Arguments)
|
||||
case "read":
|
||||
return a.handleRead(tc.Arguments)
|
||||
case "write":
|
||||
return a.handleWrite(tc.Arguments)
|
||||
case "glob":
|
||||
return a.handleGlob(tc.Arguments)
|
||||
case "bash":
|
||||
return a.handleBash(tc.Arguments)
|
||||
case "ls":
|
||||
return a.handleLs(tc.Arguments)
|
||||
case "find":
|
||||
return a.handleFind(tc.Arguments)
|
||||
case "diff":
|
||||
return a.handleDiff(tc.Arguments)
|
||||
case "edit":
|
||||
return a.handleEdit(tc.Arguments)
|
||||
case "mkdir":
|
||||
return a.handleMkdir(tc.Arguments)
|
||||
case "remove":
|
||||
return a.handleRemove(tc.Arguments)
|
||||
case "copy":
|
||||
return a.handleCopy(tc.Arguments)
|
||||
case "move":
|
||||
return a.handleMove(tc.Arguments)
|
||||
case "exists":
|
||||
return a.handleExists(tc.Arguments)
|
||||
default:
|
||||
return fmt.Sprintf("unknown tool: %s", tc.Name), true
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) handleGrep(args map[string]any) (string, bool) {
|
||||
pattern, _ := args["pattern"].(string)
|
||||
if pattern == "" {
|
||||
return "error: pattern is required", true
|
||||
}
|
||||
path := a.getArgString(args, "path", a.workDir)
|
||||
include := a.getArgString(args, "include", "")
|
||||
context := a.getArgInt(args, "context", 3)
|
||||
maxResults := a.MaxGrepResults()
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return fmt.Sprintf("error: path does not exist: %s", path), true
|
||||
}
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error: invalid regex pattern: %v", err), true
|
||||
}
|
||||
var results []string
|
||||
err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
if shouldSkipDir(info.Name()) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if include != "" {
|
||||
matched, err := filepath.Match(include, info.Name())
|
||||
if err != nil || !matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(info.Name(), ".") {
|
||||
return nil
|
||||
}
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for i, line := range lines {
|
||||
if re.MatchString(line) {
|
||||
relPath, _ := filepath.Rel(path, filePath)
|
||||
ctxStart := i - context
|
||||
if ctxStart < 0 {
|
||||
ctxStart = 0
|
||||
}
|
||||
ctxEnd := i + context + 1
|
||||
if ctxEnd > len(lines) {
|
||||
ctxEnd = len(lines)
|
||||
}
|
||||
results = append(results, fmt.Sprintf("%s:%d: %s", relPath, i+1, line))
|
||||
if context > 0 && ctxStart < i {
|
||||
for j := ctxStart; j < i; j++ {
|
||||
if len(results) < maxResults {
|
||||
results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j]))
|
||||
}
|
||||
}
|
||||
}
|
||||
if context > 0 && i+1 < ctxEnd {
|
||||
for j := i + 1; j < ctxEnd; j++ {
|
||||
if len(results) < maxResults {
|
||||
results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j]))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) >= maxResults {
|
||||
results = append(results, fmt.Sprintf("\n... (truncated, max %d results)", maxResults))
|
||||
return filepath.SkipAll
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error walking directory: %v", err), true
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No matches found for pattern: %s", pattern), false
|
||||
}
|
||||
return strings.Join(results, "\n"), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleRead(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error reading file: %v", err), true
|
||||
}
|
||||
lines := strings.Split(string(data), "\n")
|
||||
offset := a.getArgInt(args, "offset", 1)
|
||||
limit := a.getArgInt(args, "limit", 0)
|
||||
if offset > len(lines) {
|
||||
return "error: offset beyond file length", true
|
||||
}
|
||||
if offset > 1 {
|
||||
lines = lines[offset-1:]
|
||||
}
|
||||
if limit > 0 && len(lines) > limit {
|
||||
lines = lines[:limit]
|
||||
content := strings.Join(lines, "\n")
|
||||
content += fmt.Sprintf("\n\n... (%d more lines)", len(lines)-limit)
|
||||
return content, false
|
||||
}
|
||||
return strings.Join(lines, "\n"), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleWrite(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
content, _ := args["content"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Sprintf("error creating directory: %v", err), true
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
return fmt.Sprintf("error writing file: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Written to %s (%d bytes)", path, len(content)), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleGlob(args map[string]any) (string, bool) {
|
||||
pattern, _ := args["pattern"].(string)
|
||||
if pattern == "" {
|
||||
return "error: pattern is required", true
|
||||
}
|
||||
path := a.getArgString(args, "path", a.workDir)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return fmt.Sprintf("error: path does not exist: %s", path), true
|
||||
}
|
||||
basePattern := filepath.Join(path, pattern)
|
||||
matches, err := filepath.Glob(basePattern)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error: invalid pattern: %v", err), true
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return fmt.Sprintf("No files match pattern: %s", pattern), false
|
||||
}
|
||||
relMatches := make([]string, 0, len(matches))
|
||||
for _, m := range matches {
|
||||
rel, err := filepath.Rel(path, m)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
relMatches = append(relMatches, rel)
|
||||
}
|
||||
return strings.Join(relMatches, "\n"), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleBash(args map[string]any) (string, bool) {
|
||||
command, _ := args["command"].(string)
|
||||
if command == "" {
|
||||
return "error: command is required", true
|
||||
}
|
||||
timeout := a.getArgInt(args, "timeout", int(a.ToolTimeout().Seconds()))
|
||||
maxTimeoutSecs := int(a.ToolTimeout().Seconds())
|
||||
if maxTimeoutSecs > 120 {
|
||||
maxTimeoutSecs = 120
|
||||
}
|
||||
if timeout > maxTimeoutSecs {
|
||||
timeout = maxTimeoutSecs
|
||||
}
|
||||
if timeout < 1 {
|
||||
timeout = 1
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", command)
|
||||
cmd.Dir = a.workDir
|
||||
cmd.Env = os.Environ()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
output := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
if output != "" {
|
||||
output += "\n"
|
||||
}
|
||||
output += "STDERR:\n" + stderr.String()
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Sprintf("error: command timed out after %d seconds", timeout), true
|
||||
}
|
||||
if err != nil {
|
||||
if output == "" {
|
||||
return fmt.Sprintf("error: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Command exited with error:\n%s", output), true
|
||||
}
|
||||
if output == "" {
|
||||
return "Command completed successfully (no output)", false
|
||||
}
|
||||
return output, false
|
||||
}
|
||||
|
||||
func (a *Agent) handleLs(args map[string]any) (string, bool) {
|
||||
path := a.getArgString(args, "path", a.workDir)
|
||||
path = a.resolvePath(path)
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error reading directory: %v", err), true
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return "Directory is empty", false
|
||||
}
|
||||
var dirs []string
|
||||
var files []string
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if e.IsDir() {
|
||||
dirs = append(dirs, name+"/")
|
||||
} else {
|
||||
files = append(files, name)
|
||||
}
|
||||
}
|
||||
var result strings.Builder
|
||||
for _, d := range dirs {
|
||||
result.WriteString(d + "\n")
|
||||
}
|
||||
for _, f := range files {
|
||||
result.WriteString(f + "\n")
|
||||
}
|
||||
return result.String(), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleFind(args map[string]any) (string, bool) {
|
||||
name, _ := args["name"].(string)
|
||||
if name == "" {
|
||||
return "error: name is required", true
|
||||
}
|
||||
path := a.getArgString(args, "path", a.workDir)
|
||||
fileType := a.getArgString(args, "type", "")
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return fmt.Sprintf("error: path does not exist: %s", path), true
|
||||
}
|
||||
re, err := regexp.Compile("^" + strings.ReplaceAll(name, "*", ".*") + "$")
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error: invalid name pattern: %v", err), true
|
||||
}
|
||||
var results []string
|
||||
err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if shouldSkipDir(info.Name()) && filePath != path {
|
||||
if info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
isDir := info.IsDir()
|
||||
if fileType == "f" && isDir {
|
||||
return nil
|
||||
}
|
||||
if fileType == "d" && !isDir {
|
||||
return nil
|
||||
}
|
||||
if re.MatchString(info.Name()) {
|
||||
relPath, _ := filepath.Rel(path, filePath)
|
||||
if relPath != "." {
|
||||
if isDir {
|
||||
results = append(results, relPath+"/")
|
||||
} else {
|
||||
results = append(results, relPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error walking directory: %v", err), true
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No files/directories found matching: %s", name), false
|
||||
}
|
||||
return strings.Join(results, "\n"), false
|
||||
}
|
||||
|
||||
func (a *Agent) getArgString(args map[string]any, key, defaultValue string) string {
|
||||
if v, ok := args[key].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func (a *Agent) getArgInt(args map[string]any, key string, defaultValue int) int {
|
||||
if v, ok := args[key]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
case string:
|
||||
if n == "" {
|
||||
return defaultValue
|
||||
}
|
||||
if i, err := strconv.Atoi(n); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func (a *Agent) resolvePath(path string) string {
|
||||
if filepath.IsAbs(path) {
|
||||
return path
|
||||
}
|
||||
return filepath.Join(a.workDir, path)
|
||||
}
|
||||
|
||||
func shouldSkipDir(name string) bool {
|
||||
switch name {
|
||||
case "node_modules", ".git", "__pycache__", ".venv", "venv",
|
||||
"dist", "build", "target", ".cache", ".npm",
|
||||
".svn", "CVS", ".hg", ".bzr":
|
||||
return true
|
||||
}
|
||||
return strings.HasPrefix(name, ".")
|
||||
}
|
||||
|
||||
func (a *Agent) handleDiff(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
newContent, _ := args["new_content"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
if newContent == "" {
|
||||
return "error: new_content is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
oldContent, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error reading file: %v", err), true
|
||||
}
|
||||
oldLines := strings.Split(string(oldContent), "\n")
|
||||
newLines := strings.Split(newContent, "\n")
|
||||
diff := computeDiff(oldLines, newLines)
|
||||
if diff == "" {
|
||||
return "No changes (files are identical)", false
|
||||
}
|
||||
return diff, false
|
||||
}
|
||||
|
||||
func computeDiff(oldLines, newLines []string) string {
|
||||
var result strings.Builder
|
||||
oldLen := len(oldLines)
|
||||
newLen := len(newLines)
|
||||
lcs := longestCommonSubsequence(oldLines, newLines)
|
||||
oldIdx := 0
|
||||
newIdx := 0
|
||||
lcsIdx := 0
|
||||
for oldIdx < oldLen || newIdx < newLen {
|
||||
if lcsIdx < len(lcs) {
|
||||
for oldIdx < oldLen && oldLines[oldIdx] != lcs[lcsIdx] {
|
||||
result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx]))
|
||||
oldIdx++
|
||||
}
|
||||
for newIdx < newLen && newLines[newIdx] != lcs[lcsIdx] {
|
||||
result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx]))
|
||||
newIdx++
|
||||
}
|
||||
if oldIdx < oldLen && newIdx < newLen {
|
||||
result.WriteString(fmt.Sprintf(" %s\n", lcs[lcsIdx]))
|
||||
oldIdx++
|
||||
newIdx++
|
||||
lcsIdx++
|
||||
}
|
||||
} else {
|
||||
for oldIdx < oldLen {
|
||||
result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx]))
|
||||
oldIdx++
|
||||
}
|
||||
for newIdx < newLen {
|
||||
result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx]))
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func longestCommonSubsequence(a, b []string) []string {
|
||||
m, n := len(a), len(b)
|
||||
dp := make([][]int, m+1)
|
||||
for i := range dp {
|
||||
dp[i] = make([]int, n+1)
|
||||
}
|
||||
for i := 1; i <= m; i++ {
|
||||
for j := 1; j <= n; j++ {
|
||||
if a[i-1] == b[j-1] {
|
||||
dp[i][j] = dp[i-1][j-1] + 1
|
||||
} else {
|
||||
if dp[i-1][j] > dp[i][j-1] {
|
||||
dp[i][j] = dp[i-1][j]
|
||||
} else {
|
||||
dp[i][j] = dp[i][j-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
var lcs []string
|
||||
i, j := m, n
|
||||
for i > 0 && j > 0 {
|
||||
if a[i-1] == b[j-1] {
|
||||
lcs = append([]string{a[i-1]}, lcs...)
|
||||
i--
|
||||
j--
|
||||
} else if dp[i-1][j] > dp[i][j-1] {
|
||||
i--
|
||||
} else {
|
||||
j--
|
||||
}
|
||||
}
|
||||
return lcs
|
||||
}
|
||||
|
||||
func (a *Agent) handleEdit(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
patch, _ := args["patch"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
if patch == "" {
|
||||
return "error: patch is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
oldContent, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error reading file: %v", err), true
|
||||
}
|
||||
newContent, err := applyPatch(string(oldContent), patch)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error applying patch: %v", err), true
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(newContent), 0644); err != nil {
|
||||
return fmt.Sprintf("error writing file: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Applied patch to %s (%d bytes)", path, len(newContent)), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleMkdir(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return fmt.Sprintf("error creating directory: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Created directory: %s", path), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleRemove(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
recursive := a.getArgBool(args, "recursive", false)
|
||||
force := a.getArgBool(args, "force", false)
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if force {
|
||||
return "Removed (ignored nonexistent)", false
|
||||
}
|
||||
return fmt.Sprintf("error: path does not exist: %s", path), true
|
||||
}
|
||||
return fmt.Sprintf("error: %v", err), true
|
||||
}
|
||||
if info.IsDir() {
|
||||
if recursive {
|
||||
err = os.RemoveAll(path)
|
||||
} else {
|
||||
err = os.Remove(path)
|
||||
}
|
||||
} else {
|
||||
err = os.Remove(path)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error removing: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Removed: %s", path), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleCopy(args map[string]any) (string, bool) {
|
||||
source, _ := args["source"].(string)
|
||||
destination, _ := args["destination"].(string)
|
||||
if source == "" || destination == "" {
|
||||
return "error: source and destination are required", true
|
||||
}
|
||||
source = a.resolvePath(source)
|
||||
destination = a.resolvePath(destination)
|
||||
info, err := os.Stat(source)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error: %v", err), true
|
||||
}
|
||||
if info.IsDir() {
|
||||
return "error: copying directories not supported (use bash with cp -r)", true
|
||||
}
|
||||
srcData, err := os.ReadFile(source)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error reading source: %v", err), true
|
||||
}
|
||||
dir := filepath.Dir(destination)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Sprintf("error creating destination directory: %v", err), true
|
||||
}
|
||||
err = os.WriteFile(destination, srcData, info.Mode())
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error writing destination: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Copied: %s -> %s", source, destination), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleMove(args map[string]any) (string, bool) {
|
||||
source, _ := args["source"].(string)
|
||||
destination, _ := args["destination"].(string)
|
||||
if source == "" || destination == "" {
|
||||
return "error: source and destination are required", true
|
||||
}
|
||||
source = a.resolvePath(source)
|
||||
destination = a.resolvePath(destination)
|
||||
dir := filepath.Dir(destination)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Sprintf("error creating destination directory: %v", err), true
|
||||
}
|
||||
err := os.Rename(source, destination)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error moving: %v", err), true
|
||||
}
|
||||
return fmt.Sprintf("Moved: %s -> %s", source, destination), false
|
||||
}
|
||||
|
||||
func (a *Agent) handleExists(args map[string]any) (string, bool) {
|
||||
path, _ := args["path"].(string)
|
||||
if path == "" {
|
||||
return "error: path is required", true
|
||||
}
|
||||
path = a.resolvePath(path)
|
||||
info, err := os.Stat(path)
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Sprintf("false: %s does not exist", path), false
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error: %v", err), true
|
||||
}
|
||||
if info.IsDir() {
|
||||
return fmt.Sprintf("true: %s (directory)", path), false
|
||||
}
|
||||
return fmt.Sprintf("true: %s (file, %d bytes)", path, info.Size()), false
|
||||
}
|
||||
|
||||
func (a *Agent) getArgBool(args map[string]any, key string, defaultValue bool) bool {
|
||||
if v, ok := args[key]; ok {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func applyPatch(content, patch string) (string, error) {
|
||||
lines := strings.Split(content, "\n")
|
||||
patchLines := strings.Split(patch, "\n")
|
||||
var result []string
|
||||
i := 0
|
||||
for i < len(patchLines) {
|
||||
line := patchLines[i]
|
||||
if strings.HasPrefix(line, "@@") {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 4 {
|
||||
return "", fmt.Errorf("invalid hunk header: %s", line)
|
||||
}
|
||||
oldSpec := strings.TrimPrefix(parts[1], "-")
|
||||
oldParts := strings.Split(oldSpec, ",")
|
||||
oldStart, _ := strconv.Atoi(oldParts[0])
|
||||
newSpec := strings.TrimPrefix(parts[2], "+")
|
||||
newParts := strings.Split(newSpec, ",")
|
||||
newStart, _ := strconv.Atoi(newParts[0])
|
||||
oldIdx := oldStart - 1
|
||||
newIdx := newStart - 1
|
||||
i++
|
||||
for i < len(patchLines) && !strings.HasPrefix(patchLines[i], "@@") {
|
||||
patchLine := patchLines[i]
|
||||
if strings.HasPrefix(patchLine, "-") {
|
||||
if oldIdx < len(lines) {
|
||||
_ = lines[oldIdx]
|
||||
oldIdx++
|
||||
}
|
||||
} else if strings.HasPrefix(patchLine, "+") {
|
||||
content := strings.TrimPrefix(patchLine, "+")
|
||||
result = append(result, content)
|
||||
newIdx++
|
||||
} else if strings.HasPrefix(patchLine, " ") || patchLine == "" {
|
||||
if oldIdx < len(lines) {
|
||||
result = append(result, lines[oldIdx])
|
||||
oldIdx++
|
||||
}
|
||||
} else {
|
||||
result = append(result, patchLine)
|
||||
}
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
i++
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
return strings.Join(result, "\n"), nil
|
||||
}
|
||||
387
internal/command/commands.go
Normal file
387
internal/command/commands.go
Normal file
@ -0,0 +1,387 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxContextFileSize = 32 * 1024 // 32KB
|
||||
|
||||
// RegisterBuiltins adds all built-in slash commands to the registry.
|
||||
func RegisterBuiltins(r *Registry) {
|
||||
r.Register(&Command{
|
||||
Name: "help",
|
||||
Aliases: []string{"h", "?"},
|
||||
Description: "Show help overlay with shortcuts and commands",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{Action: ActionShowHelp}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "clear",
|
||||
Description: "Clear conversation history",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{
|
||||
Text: "Conversation cleared.",
|
||||
Action: ActionClear,
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "new",
|
||||
Description: "Start a fresh conversation",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{
|
||||
Text: "New conversation started.",
|
||||
Action: ActionClear,
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "model",
|
||||
Aliases: []string{"m"},
|
||||
Description: "Show, switch, or list models",
|
||||
Usage: "/model [name|list|fast|smart]",
|
||||
Handler: func(ctx *Context, args []string) Result {
|
||||
if len(args) == 0 {
|
||||
return Result{Action: ActionShowModelPicker}
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "list", "ls":
|
||||
var b strings.Builder
|
||||
b.WriteString("Available models:\n")
|
||||
for _, m := range ctx.ModelList {
|
||||
marker := " "
|
||||
if m == ctx.Model {
|
||||
marker = "* "
|
||||
}
|
||||
fmt.Fprintf(&b, " %s%s\n", marker, m)
|
||||
}
|
||||
b.WriteString("\n* = current")
|
||||
return Result{Text: b.String()}
|
||||
|
||||
case "fast":
|
||||
if len(ctx.ModelList) > 0 {
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Switching to fastest model: %s", ctx.ModelList[0]),
|
||||
Action: ActionSwitchModel,
|
||||
Data: ctx.ModelList[0],
|
||||
}
|
||||
}
|
||||
return Result{Error: "No models available"}
|
||||
|
||||
case "smart":
|
||||
if len(ctx.ModelList) > 0 {
|
||||
smartModel := ctx.ModelList[len(ctx.ModelList)-1]
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Switching to smartest model: %s", smartModel),
|
||||
Action: ActionSwitchModel,
|
||||
Data: smartModel,
|
||||
}
|
||||
}
|
||||
return Result{Error: "No models available"}
|
||||
|
||||
default:
|
||||
for _, m := range ctx.ModelList {
|
||||
if m == args[0] {
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Switching to model: %s", m),
|
||||
Action: ActionSwitchModel,
|
||||
Data: m,
|
||||
}
|
||||
}
|
||||
}
|
||||
return Result{Error: fmt.Sprintf("Unknown model: %s (use /model list to see available)", args[0])}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "models",
|
||||
Aliases: []string{"ml"},
|
||||
Description: "Open model picker",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{Action: ActionShowModelPicker}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "agent",
|
||||
Aliases: []string{"a"},
|
||||
Description: "Show or switch agent profile",
|
||||
Usage: "/agent [name|list]",
|
||||
Handler: func(ctx *Context, args []string) Result {
|
||||
if len(args) == 0 || args[0] == "list" {
|
||||
var b strings.Builder
|
||||
if len(ctx.AgentList) == 0 {
|
||||
b.WriteString("No agent profiles found in ~/.agents/agents/")
|
||||
return Result{Text: b.String()}
|
||||
}
|
||||
b.WriteString("Available agent profiles:\n")
|
||||
for _, a := range ctx.AgentList {
|
||||
marker := " "
|
||||
if a == ctx.AgentProfile {
|
||||
marker = "* "
|
||||
}
|
||||
fmt.Fprintf(&b, " %s%s\n", marker, a)
|
||||
}
|
||||
b.WriteString("\n* = current")
|
||||
return Result{Text: b.String()}
|
||||
}
|
||||
|
||||
for _, a := range ctx.AgentList {
|
||||
if a == args[0] {
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Switching to agent: %s", a),
|
||||
Action: ActionSwitchAgent,
|
||||
Data: a,
|
||||
}
|
||||
}
|
||||
}
|
||||
return Result{Error: fmt.Sprintf("Unknown agent: %s (use /agent list to see available)", args[0])}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "load",
|
||||
Aliases: []string{"l"},
|
||||
Description: "Load a markdown file as context",
|
||||
Usage: "/load <path>",
|
||||
Handler: func(_ *Context, args []string) Result {
|
||||
if len(args) == 0 {
|
||||
return Result{Error: "Usage: /load <path>"}
|
||||
}
|
||||
path := strings.Join(args, " ")
|
||||
|
||||
// Expand ~ to home directory.
|
||||
if strings.HasPrefix(path, "~/") {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
path = home + path[1:]
|
||||
}
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return Result{Error: fmt.Sprintf("Cannot access %s: %v", path, err)}
|
||||
}
|
||||
if info.Size() > maxContextFileSize {
|
||||
return Result{Error: fmt.Sprintf("File too large (%d bytes, max %d)", info.Size(), maxContextFileSize)}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Result{Error: fmt.Sprintf("Cannot read %s: %v", path, err)}
|
||||
}
|
||||
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Loaded context: %s (%d bytes)", path, len(data)),
|
||||
Action: ActionLoadContext,
|
||||
Data: path + "\x00" + string(data), // path\0content
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "unload",
|
||||
Description: "Remove loaded context file",
|
||||
Handler: func(ctx *Context, _ []string) Result {
|
||||
if ctx.LoadedFile == "" {
|
||||
return Result{Text: "No context file loaded."}
|
||||
}
|
||||
return Result{
|
||||
Text: "Context unloaded.",
|
||||
Action: ActionUnloadContext,
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "skill",
|
||||
Aliases: []string{"sk"},
|
||||
Description: "Manage skills (list, activate, deactivate)",
|
||||
Usage: "/skill [list|activate|deactivate] [name]",
|
||||
Handler: func(ctx *Context, args []string) Result {
|
||||
if len(args) == 0 || args[0] == "list" {
|
||||
return skillList(ctx)
|
||||
}
|
||||
if len(args) < 2 {
|
||||
return Result{Error: "Usage: /skill [list|activate|deactivate] <name>"}
|
||||
}
|
||||
switch args[0] {
|
||||
case "activate", "on":
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Activated skill: %s", args[1]),
|
||||
Action: ActionActivateSkill,
|
||||
Data: args[1],
|
||||
}
|
||||
case "deactivate", "off":
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Deactivated skill: %s", args[1]),
|
||||
Action: ActionDeactivateSkill,
|
||||
Data: args[1],
|
||||
}
|
||||
default:
|
||||
return Result{Error: fmt.Sprintf("Unknown skill action: %s (use list, activate, or deactivate)", args[0])}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "servers",
|
||||
Description: "List connected MCP servers",
|
||||
Handler: func(ctx *Context, _ []string) Result {
|
||||
if len(ctx.ServerNames) == 0 {
|
||||
return Result{Text: "No MCP servers connected."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("Connected servers (%d):\n", len(ctx.ServerNames)))
|
||||
for _, name := range ctx.ServerNames {
|
||||
fmt.Fprintf(&b, " - %s\n", name)
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\nTotal tools: %d", ctx.ToolCount))
|
||||
return Result{Text: b.String()}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "ice",
|
||||
Description: "Show Infinite Context Engine status",
|
||||
Handler: func(ctx *Context, _ []string) Result {
|
||||
if !ctx.ICEEnabled {
|
||||
return Result{Text: "ICE is not enabled. Add `ice: {enabled: true}` to your config.yaml"}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("Infinite Context Engine (ICE)\n")
|
||||
fmt.Fprintf(&b, " Status: enabled\n")
|
||||
fmt.Fprintf(&b, " Conversations: %d stored\n", ctx.ICEConversations)
|
||||
fmt.Fprintf(&b, " Session ID: %s\n", ctx.ICESessionID)
|
||||
fmt.Fprintf(&b, " Embed model: nomic-embed-text\n")
|
||||
return Result{Text: b.String()}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "sessions",
|
||||
Aliases: []string{"ss"},
|
||||
Description: "Browse and restore saved sessions",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{Action: ActionShowSessions}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "changes",
|
||||
Description: "List files modified by the agent this session",
|
||||
Handler: func(ctx *Context, _ []string) Result {
|
||||
if len(ctx.FileChanges) == 0 {
|
||||
return Result{Text: "No files modified this session."}
|
||||
}
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "Files modified (%d):\n", len(ctx.FileChanges))
|
||||
for path, count := range ctx.FileChanges {
|
||||
if count > 1 {
|
||||
fmt.Fprintf(&b, " ✎ %s (%dx)\n", path, count)
|
||||
} else {
|
||||
fmt.Fprintf(&b, " ✎ %s\n", path)
|
||||
}
|
||||
}
|
||||
return Result{Text: b.String()}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "commit",
|
||||
Aliases: []string{"ci"},
|
||||
Description: "Generate commit message from staged changes and commit",
|
||||
Handler: func(_ *Context, args []string) Result {
|
||||
return Result{Action: ActionCommit, Data: strings.Join(args, " ")}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "stats",
|
||||
Description: "Show token usage statistics for this session",
|
||||
Handler: func(ctx *Context, _ []string) Result {
|
||||
if ctx.SessionTurnCount == 0 {
|
||||
return Result{Text: "No token usage recorded yet."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("Session Token Stats\n")
|
||||
fmt.Fprintf(&b, " Model: %s\n", ctx.CurrentModel)
|
||||
fmt.Fprintf(&b, " Turns: %d\n", ctx.SessionTurnCount)
|
||||
fmt.Fprintf(&b, " Output tokens: %d\n", ctx.SessionEvalTotal)
|
||||
fmt.Fprintf(&b, " Prompt tokens: %d (last turn)\n", ctx.SessionPromptTotal)
|
||||
if ctx.NumCtx > 0 {
|
||||
fmt.Fprintf(&b, " Context window: %d\n", ctx.NumCtx)
|
||||
pct := ctx.SessionPromptTotal * 100 / ctx.NumCtx
|
||||
fmt.Fprintf(&b, " Context used: %d%%\n", pct)
|
||||
}
|
||||
avgOut := ctx.SessionEvalTotal / ctx.SessionTurnCount
|
||||
fmt.Fprintf(&b, " Avg out/turn: %d\n", avgOut)
|
||||
return Result{Text: b.String()}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "export",
|
||||
Description: "Export conversation to a markdown file",
|
||||
Usage: "/export [path]",
|
||||
Handler: func(_ *Context, args []string) Result {
|
||||
if len(args) < 1 || args[0] == "" {
|
||||
return Result{Error: "usage: /export <filepath>"}
|
||||
}
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Exporting conversation to: %s", args[0]),
|
||||
Action: ActionExport,
|
||||
Data: args[0],
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "import",
|
||||
Description: "Import conversation from a markdown file",
|
||||
Usage: "/import [path]",
|
||||
Handler: func(_ *Context, args []string) Result {
|
||||
if len(args) < 1 || args[0] == "" {
|
||||
return Result{Error: "usage: /import <filepath>"}
|
||||
}
|
||||
return Result{
|
||||
Text: fmt.Sprintf("Importing conversation from: %s", args[0]),
|
||||
Action: ActionImport,
|
||||
Data: args[0],
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
r.Register(&Command{
|
||||
Name: "exit",
|
||||
Aliases: []string{"quit", "q"},
|
||||
Description: "Quit ai-agent",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{Action: ActionQuit}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func skillList(ctx *Context) Result {
|
||||
if len(ctx.Skills) == 0 {
|
||||
return Result{Text: "No skills found. Add .md files to ~/.config/ai-agent/skills/"}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("Skills (%d):\n", len(ctx.Skills)))
|
||||
for _, s := range ctx.Skills {
|
||||
status := " "
|
||||
if s.Active {
|
||||
status = "* "
|
||||
}
|
||||
fmt.Fprintf(&b, " %s%s — %s\n", status, s.Name, s.Description)
|
||||
}
|
||||
b.WriteString("\n* = active")
|
||||
return Result{Text: b.String()}
|
||||
}
|
||||
380
internal/command/commands_test.go
Normal file
380
internal/command/commands_test.go
Normal file
@ -0,0 +1,380 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestRegistry() *Registry {
|
||||
r := NewRegistry()
|
||||
RegisterBuiltins(r)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestBuiltin_Help(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
result := r.Execute(&Context{}, "help", nil)
|
||||
if result.Action != ActionShowHelp {
|
||||
t.Errorf("help action = %d, want %d (ActionShowHelp)", result.Action, ActionShowHelp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltin_Clear(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
result := r.Execute(&Context{}, "clear", nil)
|
||||
if result.Action != ActionClear {
|
||||
t.Errorf("clear action = %d, want %d (ActionClear)", result.Action, ActionClear)
|
||||
}
|
||||
if result.Text == "" {
|
||||
t.Error("clear should have text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltin_New(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
result := r.Execute(&Context{}, "new", nil)
|
||||
if result.Action != ActionClear {
|
||||
t.Errorf("new action = %d, want %d (ActionClear)", result.Action, ActionClear)
|
||||
}
|
||||
if result.Text == "" {
|
||||
t.Error("new should have text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltin_Model(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
ctx := &Context{
|
||||
Model: "qwen3.5:0.8b",
|
||||
ModelList: []string{"qwen3.5:0.8b", "qwen3.5:2b", "qwen3.5:4b", "qwen3.5:9b"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
wantAction Action
|
||||
wantData string
|
||||
wantErr bool
|
||||
checkText string
|
||||
}{
|
||||
{
|
||||
name: "no args opens model picker",
|
||||
args: nil,
|
||||
wantAction: ActionShowModelPicker,
|
||||
},
|
||||
{
|
||||
name: "list shows models",
|
||||
args: []string{"list"},
|
||||
checkText: "Available models",
|
||||
},
|
||||
{
|
||||
name: "fast switches to first",
|
||||
args: []string{"fast"},
|
||||
wantAction: ActionSwitchModel,
|
||||
wantData: "qwen3.5:0.8b",
|
||||
},
|
||||
{
|
||||
name: "smart switches to last",
|
||||
args: []string{"smart"},
|
||||
wantAction: ActionSwitchModel,
|
||||
wantData: "qwen3.5:9b",
|
||||
},
|
||||
{
|
||||
name: "valid name switches",
|
||||
args: []string{"qwen3.5:2b"},
|
||||
wantAction: ActionSwitchModel,
|
||||
wantData: "qwen3.5:2b",
|
||||
},
|
||||
{
|
||||
name: "invalid name errors",
|
||||
args: []string{"nonexistent"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := r.Execute(ctx, "model", tt.args)
|
||||
if tt.wantErr {
|
||||
if result.Error == "" {
|
||||
t.Error("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if result.Error != "" {
|
||||
t.Errorf("unexpected error: %s", result.Error)
|
||||
return
|
||||
}
|
||||
if tt.wantAction != ActionNone && result.Action != tt.wantAction {
|
||||
t.Errorf("action = %d, want %d", result.Action, tt.wantAction)
|
||||
}
|
||||
if tt.wantData != "" && result.Data != tt.wantData {
|
||||
t.Errorf("data = %q, want %q", result.Data, tt.wantData)
|
||||
}
|
||||
if tt.checkText != "" && !strings.Contains(result.Text, tt.checkText) {
|
||||
t.Errorf("text %q does not contain %q", result.Text, tt.checkText)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltin_Models(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
ctx := &Context{
|
||||
Model: "qwen3.5:0.8b",
|
||||
ModelList: []string{"qwen3.5:0.8b", "qwen3.5:2b"},
|
||||
}
|
||||
result := r.Execute(ctx, "models", nil)
|
||||
if result.Action != ActionShowModelPicker {
|
||||
t.Errorf("expected ActionShowModelPicker, got %d", result.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltin_Agent(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
|
||||
t.Run("no args lists agents", func(t *testing.T) {
|
||||
ctx := &Context{AgentList: []string{"coder", "reviewer"}, AgentProfile: "coder"}
|
||||
result := r.Execute(ctx, "agent", nil)
|
||||
if !strings.Contains(result.Text, "Available agent profiles") {
|
||||
t.Errorf("expected agent list, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list subcommand", func(t *testing.T) {
|
||||
ctx := &Context{AgentList: []string{"coder"}}
|
||||
result := r.Execute(ctx, "agent", []string{"list"})
|
||||
if !strings.Contains(result.Text, "Available agent profiles") {
|
||||
t.Errorf("expected agent list, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid switch", func(t *testing.T) {
|
||||
ctx := &Context{AgentList: []string{"coder", "reviewer"}}
|
||||
result := r.Execute(ctx, "agent", []string{"reviewer"})
|
||||
if result.Action != ActionSwitchAgent {
|
||||
t.Errorf("action = %d, want %d", result.Action, ActionSwitchAgent)
|
||||
}
|
||||
if result.Data != "reviewer" {
|
||||
t.Errorf("data = %q, want %q", result.Data, "reviewer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid errors", func(t *testing.T) {
|
||||
ctx := &Context{AgentList: []string{"coder"}}
|
||||
result := r.Execute(ctx, "agent", []string{"unknown"})
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for unknown agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no agents", func(t *testing.T) {
|
||||
ctx := &Context{AgentList: []string{}}
|
||||
result := r.Execute(ctx, "agent", nil)
|
||||
if !strings.Contains(result.Text, "No agent profiles") {
|
||||
t.Errorf("expected no agents message, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltin_Load(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
|
||||
t.Run("no args errors", func(t *testing.T) {
|
||||
result := r.Execute(&Context{}, "load", nil)
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for no args")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid file loads", func(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "test.md")
|
||||
if err := os.WriteFile(path, []byte("# Hello"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result := r.Execute(&Context{}, "load", []string{path})
|
||||
if result.Error != "" {
|
||||
t.Errorf("unexpected error: %s", result.Error)
|
||||
}
|
||||
if result.Action != ActionLoadContext {
|
||||
t.Errorf("action = %d, want %d", result.Action, ActionLoadContext)
|
||||
}
|
||||
// Data should be path\0content
|
||||
parts := strings.SplitN(result.Data, "\x00", 2)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected path\\0content, got %q", result.Data)
|
||||
}
|
||||
if parts[0] != path {
|
||||
t.Errorf("data path = %q, want %q", parts[0], path)
|
||||
}
|
||||
if parts[1] != "# Hello" {
|
||||
t.Errorf("data content = %q, want %q", parts[1], "# Hello")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("too large errors", func(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "big.md")
|
||||
data := make([]byte, 33*1024) // > 32KB
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result := r.Execute(&Context{}, "load", []string{path})
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for oversized file")
|
||||
}
|
||||
if !strings.Contains(result.Error, "too large") {
|
||||
t.Errorf("error = %q, want containing 'too large'", result.Error)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonexistent errors", func(t *testing.T) {
|
||||
result := r.Execute(&Context{}, "load", []string{"/nonexistent/file.md"})
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for nonexistent file")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltin_Unload(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
|
||||
t.Run("no loaded file", func(t *testing.T) {
|
||||
result := r.Execute(&Context{LoadedFile: ""}, "unload", nil)
|
||||
if !strings.Contains(result.Text, "No context") {
|
||||
t.Errorf("expected 'No context' message, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("loaded file unloads", func(t *testing.T) {
|
||||
result := r.Execute(&Context{LoadedFile: "something.md"}, "unload", nil)
|
||||
if result.Action != ActionUnloadContext {
|
||||
t.Errorf("action = %d, want %d", result.Action, ActionUnloadContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltin_Skill(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
ctx := &Context{
|
||||
Skills: []SkillInfo{
|
||||
{Name: "coder", Description: "Code generation", Active: true},
|
||||
{Name: "reviewer", Description: "Code review", Active: false},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("no args lists skills", func(t *testing.T) {
|
||||
result := r.Execute(ctx, "skill", nil)
|
||||
if !strings.Contains(result.Text, "Skills") {
|
||||
t.Errorf("expected skills list, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("list subcommand", func(t *testing.T) {
|
||||
result := r.Execute(ctx, "skill", []string{"list"})
|
||||
if !strings.Contains(result.Text, "Skills") {
|
||||
t.Errorf("expected skills list, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("activate", func(t *testing.T) {
|
||||
result := r.Execute(ctx, "skill", []string{"activate", "reviewer"})
|
||||
if result.Action != ActionActivateSkill {
|
||||
t.Errorf("action = %d, want %d", result.Action, ActionActivateSkill)
|
||||
}
|
||||
if result.Data != "reviewer" {
|
||||
t.Errorf("data = %q, want %q", result.Data, "reviewer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deactivate", func(t *testing.T) {
|
||||
result := r.Execute(ctx, "skill", []string{"deactivate", "coder"})
|
||||
if result.Action != ActionDeactivateSkill {
|
||||
t.Errorf("action = %d, want %d", result.Action, ActionDeactivateSkill)
|
||||
}
|
||||
if result.Data != "coder" {
|
||||
t.Errorf("data = %q, want %q", result.Data, "coder")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown action errors", func(t *testing.T) {
|
||||
result := r.Execute(ctx, "skill", []string{"unknown", "foo"})
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for unknown skill action")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing name errors", func(t *testing.T) {
|
||||
result := r.Execute(ctx, "skill", []string{"activate"})
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for missing skill name")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltin_Servers(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
|
||||
t.Run("no servers", func(t *testing.T) {
|
||||
result := r.Execute(&Context{ServerNames: nil}, "servers", nil)
|
||||
if !strings.Contains(result.Text, "No MCP servers") {
|
||||
t.Errorf("expected no servers message, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with servers", func(t *testing.T) {
|
||||
ctx := &Context{
|
||||
ServerNames: []string{"server-a", "server-b"},
|
||||
ToolCount: 10,
|
||||
}
|
||||
result := r.Execute(ctx, "servers", nil)
|
||||
if !strings.Contains(result.Text, "server-a") {
|
||||
t.Errorf("expected server-a in output, got %q", result.Text)
|
||||
}
|
||||
if !strings.Contains(result.Text, "server-b") {
|
||||
t.Errorf("expected server-b in output, got %q", result.Text)
|
||||
}
|
||||
if !strings.Contains(result.Text, "10") {
|
||||
t.Errorf("expected tool count in output, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltin_ICE(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
result := r.Execute(&Context{ICEEnabled: false}, "ice", nil)
|
||||
if !strings.Contains(result.Text, "not enabled") {
|
||||
t.Errorf("expected disabled message, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("enabled shows status", func(t *testing.T) {
|
||||
ctx := &Context{
|
||||
ICEEnabled: true,
|
||||
ICEConversations: 5,
|
||||
ICESessionID: "abc-123",
|
||||
}
|
||||
result := r.Execute(ctx, "ice", nil)
|
||||
if !strings.Contains(result.Text, "enabled") {
|
||||
t.Errorf("expected enabled status, got %q", result.Text)
|
||||
}
|
||||
if !strings.Contains(result.Text, "5") {
|
||||
t.Errorf("expected conversation count, got %q", result.Text)
|
||||
}
|
||||
if !strings.Contains(result.Text, "abc-123") {
|
||||
t.Errorf("expected session ID, got %q", result.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuiltin_Exit(t *testing.T) {
|
||||
r := newTestRegistry()
|
||||
result := r.Execute(&Context{}, "exit", nil)
|
||||
if result.Action != ActionQuit {
|
||||
t.Errorf("exit action = %d, want %d (ActionQuit)", result.Action, ActionQuit)
|
||||
}
|
||||
}
|
||||
116
internal/command/custom.go
Normal file
116
internal/command/custom.go
Normal file
@ -0,0 +1,116 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CustomCommand represents a user-defined command loaded from a markdown file.
|
||||
type CustomCommand struct {
|
||||
Name string
|
||||
Description string
|
||||
Template string // prompt template with {{input}} placeholder
|
||||
}
|
||||
|
||||
// LoadCustomCommands reads .md files from the commands directory and returns
|
||||
// parsed custom commands. Each file should have YAML-like frontmatter:
|
||||
//
|
||||
// ---
|
||||
// name: review
|
||||
// description: Code review prompt
|
||||
// ---
|
||||
// Review this code: {{input}}
|
||||
func LoadCustomCommands(dir string) []CustomCommand {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmds []CustomCommand
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cmd, ok := parseCustomCommand(string(data)); ok {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
}
|
||||
return cmds
|
||||
}
|
||||
|
||||
// parseCustomCommand parses a markdown file with YAML frontmatter.
|
||||
func parseCustomCommand(content string) (CustomCommand, bool) {
|
||||
content = strings.TrimSpace(content)
|
||||
if !strings.HasPrefix(content, "---") {
|
||||
return CustomCommand{}, false
|
||||
}
|
||||
|
||||
// Find end of frontmatter.
|
||||
rest := content[3:]
|
||||
idx := strings.Index(rest, "---")
|
||||
if idx < 0 {
|
||||
return CustomCommand{}, false
|
||||
}
|
||||
|
||||
frontmatter := rest[:idx]
|
||||
body := strings.TrimSpace(rest[idx+3:])
|
||||
|
||||
cmd := CustomCommand{Template: body}
|
||||
|
||||
// Parse simple key: value pairs from frontmatter.
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(parts[0])
|
||||
val := strings.TrimSpace(parts[1])
|
||||
switch key {
|
||||
case "name":
|
||||
cmd.Name = val
|
||||
case "description":
|
||||
cmd.Description = val
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Name == "" || cmd.Template == "" {
|
||||
return CustomCommand{}, false
|
||||
}
|
||||
|
||||
return cmd, true
|
||||
}
|
||||
|
||||
// RegisterCustomCommands loads and registers custom commands from the given directory.
|
||||
func RegisterCustomCommands(r *Registry, dir string) {
|
||||
cmds := LoadCustomCommands(dir)
|
||||
for _, cc := range cmds {
|
||||
// Capture for closure.
|
||||
tmpl := cc.Template
|
||||
desc := cc.Description
|
||||
if desc == "" {
|
||||
desc = "Custom command"
|
||||
}
|
||||
|
||||
r.Register(&Command{
|
||||
Name: cc.Name,
|
||||
Description: desc,
|
||||
Handler: func(_ *Context, args []string) Result {
|
||||
input := strings.Join(args, " ")
|
||||
prompt := strings.ReplaceAll(tmpl, "{{input}}", input)
|
||||
return Result{
|
||||
Action: ActionSendPrompt,
|
||||
Data: prompt,
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
148
internal/command/custom_test.go
Normal file
148
internal/command/custom_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseCustomCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantOK bool
|
||||
wantCmd CustomCommand
|
||||
}{
|
||||
{
|
||||
name: "valid command",
|
||||
content: `---
|
||||
name: review
|
||||
description: Code review prompt
|
||||
---
|
||||
Review this code: {{input}}`,
|
||||
wantOK: true,
|
||||
wantCmd: CustomCommand{
|
||||
Name: "review",
|
||||
Description: "Code review prompt",
|
||||
Template: "Review this code: {{input}}",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no description",
|
||||
content: `---
|
||||
name: explain
|
||||
---
|
||||
Explain this: {{input}}`,
|
||||
wantOK: true,
|
||||
wantCmd: CustomCommand{
|
||||
Name: "explain",
|
||||
Template: "Explain this: {{input}}",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no frontmatter",
|
||||
content: "just some text",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "no name",
|
||||
content: `---
|
||||
description: something
|
||||
---
|
||||
body`,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "no body",
|
||||
content: `---
|
||||
name: empty
|
||||
---`,
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd, ok := parseCustomCommand(tt.content)
|
||||
if ok != tt.wantOK {
|
||||
t.Fatalf("parseCustomCommand() ok = %v, want %v", ok, tt.wantOK)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if cmd.Name != tt.wantCmd.Name {
|
||||
t.Errorf("Name = %q, want %q", cmd.Name, tt.wantCmd.Name)
|
||||
}
|
||||
if cmd.Description != tt.wantCmd.Description {
|
||||
t.Errorf("Description = %q, want %q", cmd.Description, tt.wantCmd.Description)
|
||||
}
|
||||
if cmd.Template != tt.wantCmd.Template {
|
||||
t.Errorf("Template = %q, want %q", cmd.Template, tt.wantCmd.Template)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCustomCommands(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write a valid command file.
|
||||
err := os.WriteFile(filepath.Join(dir, "review.md"), []byte(`---
|
||||
name: review
|
||||
description: Review code
|
||||
---
|
||||
Review: {{input}}`), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Write an invalid file (no frontmatter).
|
||||
err = os.WriteFile(filepath.Join(dir, "invalid.md"), []byte("just text"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Write a non-md file (should be ignored).
|
||||
err = os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("not a command"), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmds := LoadCustomCommands(dir)
|
||||
if len(cmds) != 1 {
|
||||
t.Fatalf("LoadCustomCommands() returned %d commands, want 1", len(cmds))
|
||||
}
|
||||
if cmds[0].Name != "review" {
|
||||
t.Errorf("Name = %q, want %q", cmds[0].Name, "review")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCustomCommands_MissingDir(t *testing.T) {
|
||||
cmds := LoadCustomCommands("/nonexistent/path")
|
||||
if len(cmds) != 0 {
|
||||
t.Errorf("expected empty result for missing dir, got %d", len(cmds))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterCustomCommands(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
err := os.WriteFile(filepath.Join(dir, "test.md"), []byte(`---
|
||||
name: testcmd
|
||||
description: A test command
|
||||
---
|
||||
Do this: {{input}}`), 0o644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reg := NewRegistry()
|
||||
RegisterCustomCommands(reg, dir)
|
||||
|
||||
result := reg.Execute(&Context{}, "testcmd", []string{"hello", "world"})
|
||||
if result.Action != ActionSendPrompt {
|
||||
t.Errorf("Action = %v, want ActionSendPrompt", result.Action)
|
||||
}
|
||||
if result.Data != "Do this: hello world" {
|
||||
t.Errorf("Data = %q, want %q", result.Data, "Do this: hello world")
|
||||
}
|
||||
}
|
||||
129
internal/command/registry.go
Normal file
129
internal/command/registry.go
Normal file
@ -0,0 +1,129 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Command represents a slash command.
|
||||
type Command struct {
|
||||
Name string
|
||||
Aliases []string
|
||||
Description string
|
||||
Usage string
|
||||
Handler func(ctx *Context, args []string) Result
|
||||
}
|
||||
|
||||
// Context provides commands with read access to application state.
|
||||
type Context struct {
|
||||
Model string
|
||||
ModelList []string
|
||||
AgentProfile string
|
||||
AgentList []string
|
||||
ToolCount int
|
||||
ServerCount int
|
||||
ServerNames []string
|
||||
Skills []SkillInfo
|
||||
LoadedFile string
|
||||
ICEEnabled bool
|
||||
ICEConversations int
|
||||
ICESessionID string
|
||||
// Token stats
|
||||
SessionEvalTotal int
|
||||
SessionPromptTotal int
|
||||
SessionTurnCount int
|
||||
NumCtx int
|
||||
CurrentModel string
|
||||
// File changes
|
||||
FileChanges map[string]int // path → modification count
|
||||
}
|
||||
|
||||
// SkillInfo is a read-only view of a skill for command display.
|
||||
type SkillInfo struct {
|
||||
Name string
|
||||
Description string
|
||||
Active bool
|
||||
}
|
||||
|
||||
// Result is returned by command handlers to describe what to do.
|
||||
type Result struct {
|
||||
Text string // Display text (shown as system message)
|
||||
Action Action // Side effect for the TUI to execute
|
||||
Data string // Optional payload (e.g. file path, model name)
|
||||
Error string // Error text (takes priority over Text)
|
||||
}
|
||||
|
||||
// Action describes a side effect the TUI should perform.
|
||||
type Action int
|
||||
|
||||
const (
|
||||
ActionNone Action = iota
|
||||
ActionShowHelp // Show help overlay
|
||||
ActionClear // Clear conversation history
|
||||
ActionQuit // Exit the application
|
||||
ActionLoadContext // Load markdown context (Data = path)
|
||||
ActionUnloadContext // Remove loaded context
|
||||
ActionActivateSkill // Activate skill (Data = name)
|
||||
ActionDeactivateSkill // Deactivate skill (Data = name)
|
||||
ActionSwitchModel // Switch model (Data = model name)
|
||||
ActionSwitchAgent // Switch agent profile (Data = agent name)
|
||||
ActionShowSessions // Open sessions picker
|
||||
ActionShowModelPicker // Open model picker overlay
|
||||
ActionCommit // Generate commit message and commit
|
||||
ActionSendPrompt // Send Data as a message to the agent
|
||||
ActionExport // Export conversation (Data = path)
|
||||
ActionImport // Import conversation (Data = path)
|
||||
)
|
||||
|
||||
// Registry holds all registered slash commands.
|
||||
type Registry struct {
|
||||
commands map[string]*Command // name/alias → command
|
||||
all []*Command // ordered list
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty command registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
commands: make(map[string]*Command),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a command to the registry.
|
||||
func (r *Registry) Register(cmd *Command) {
|
||||
r.all = append(r.all, cmd)
|
||||
r.commands[cmd.Name] = cmd
|
||||
for _, alias := range cmd.Aliases {
|
||||
r.commands[alias] = cmd
|
||||
}
|
||||
}
|
||||
|
||||
// Execute dispatches a slash command by name and returns the result.
|
||||
func (r *Registry) Execute(ctx *Context, name string, args []string) Result {
|
||||
cmd, ok := r.commands[name]
|
||||
if !ok {
|
||||
return Result{Error: fmt.Sprintf("unknown command: /%s — type /help for available commands", name)}
|
||||
}
|
||||
return cmd.Handler(ctx, args)
|
||||
}
|
||||
|
||||
// All returns all registered commands in registration order.
|
||||
func (r *Registry) All() []*Command {
|
||||
return r.all
|
||||
}
|
||||
|
||||
// Match returns commands whose name starts with the given prefix.
|
||||
func (r *Registry) Match(prefix string) []*Command {
|
||||
var matches []*Command
|
||||
seen := make(map[string]bool)
|
||||
for _, cmd := range r.all {
|
||||
if strings.HasPrefix(cmd.Name, prefix) && !seen[cmd.Name] {
|
||||
matches = append(matches, cmd)
|
||||
seen[cmd.Name] = true
|
||||
}
|
||||
}
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
return matches[i].Name < matches[j].Name
|
||||
})
|
||||
return matches
|
||||
}
|
||||
164
internal/command/registry_test.go
Normal file
164
internal/command/registry_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
package command
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
cmd := &Command{
|
||||
Name: "test",
|
||||
Description: "A test command",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{Text: "ok"}
|
||||
},
|
||||
}
|
||||
r.Register(cmd)
|
||||
|
||||
all := r.All()
|
||||
if len(all) != 1 {
|
||||
t.Fatalf("expected 1 command, got %d", len(all))
|
||||
}
|
||||
if all[0].Name != "test" {
|
||||
t.Errorf("command name = %q, want %q", all[0].Name, "test")
|
||||
}
|
||||
|
||||
// Execute to verify it was registered correctly
|
||||
result := r.Execute(&Context{}, "test", nil)
|
||||
if result.Text != "ok" {
|
||||
t.Errorf("result text = %q, want %q", result.Text, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Execute(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
called := false
|
||||
r.Register(&Command{
|
||||
Name: "run",
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
called = true
|
||||
return Result{Text: "executed"}
|
||||
},
|
||||
})
|
||||
|
||||
t.Run("found command executes handler", func(t *testing.T) {
|
||||
result := r.Execute(&Context{}, "run", nil)
|
||||
if !called {
|
||||
t.Error("handler was not called")
|
||||
}
|
||||
if result.Text != "executed" {
|
||||
t.Errorf("result text = %q, want %q", result.Text, "executed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found returns error", func(t *testing.T) {
|
||||
result := r.Execute(&Context{}, "nonexistent", nil)
|
||||
if result.Error == "" {
|
||||
t.Error("expected error for unknown command")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry_ExecuteByAlias(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&Command{
|
||||
Name: "mycommand",
|
||||
Aliases: []string{"mc", "m"},
|
||||
Handler: func(_ *Context, _ []string) Result {
|
||||
return Result{Text: "alias works"}
|
||||
},
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmdName string
|
||||
wantOk bool
|
||||
}{
|
||||
{name: "by name", cmdName: "mycommand", wantOk: true},
|
||||
{name: "by alias mc", cmdName: "mc", wantOk: true},
|
||||
{name: "by alias m", cmdName: "m", wantOk: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := r.Execute(&Context{}, tt.cmdName, nil)
|
||||
if tt.wantOk && result.Error != "" {
|
||||
t.Errorf("unexpected error: %s", result.Error)
|
||||
}
|
||||
if tt.wantOk && result.Text != "alias works" {
|
||||
t.Errorf("result text = %q, want %q", result.Text, "alias works")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_All(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
names := []string{"alpha", "beta", "gamma"}
|
||||
for _, name := range names {
|
||||
n := name // capture
|
||||
r.Register(&Command{
|
||||
Name: n,
|
||||
Handler: func(_ *Context, _ []string) Result { return Result{} },
|
||||
})
|
||||
}
|
||||
|
||||
all := r.All()
|
||||
if len(all) != len(names) {
|
||||
t.Fatalf("expected %d commands, got %d", len(names), len(all))
|
||||
}
|
||||
for i, cmd := range all {
|
||||
if cmd.Name != names[i] {
|
||||
t.Errorf("All()[%d].Name = %q, want %q", i, cmd.Name, names[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Match(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&Command{
|
||||
Name: "model",
|
||||
Aliases: []string{"m"},
|
||||
Handler: func(_ *Context, _ []string) Result { return Result{} },
|
||||
})
|
||||
r.Register(&Command{
|
||||
Name: "models",
|
||||
Aliases: []string{"ml"},
|
||||
Handler: func(_ *Context, _ []string) Result { return Result{} },
|
||||
})
|
||||
r.Register(&Command{
|
||||
Name: "help",
|
||||
Handler: func(_ *Context, _ []string) Result { return Result{} },
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
want int
|
||||
}{
|
||||
{name: "prefix mo matches model and models", prefix: "mo", want: 2},
|
||||
{name: "prefix model matches model and models", prefix: "model", want: 2},
|
||||
{name: "prefix models matches only models", prefix: "models", want: 1},
|
||||
{name: "prefix h matches help", prefix: "h", want: 1},
|
||||
{name: "no match", prefix: "z", want: 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := r.Match(tt.prefix)
|
||||
if len(matches) != tt.want {
|
||||
t.Errorf("Match(%q) returned %d results, want %d", tt.prefix, len(matches), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify no duplicates from aliases
|
||||
t.Run("aliases dont create dupes", func(t *testing.T) {
|
||||
matches := r.Match("model")
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if seen[m.Name] {
|
||||
t.Errorf("duplicate match for %q", m.Name)
|
||||
}
|
||||
seen[m.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
366
internal/config/agents.go
Normal file
366
internal/config/agents.go
Normal file
@ -0,0 +1,366 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type AgentsDir struct {
|
||||
Path string
|
||||
Agents map[string]AgentProfile
|
||||
MCPServers []ServerConfig
|
||||
GlobalInstructions string
|
||||
Skills []SkillDef
|
||||
}
|
||||
|
||||
type AgentProfile struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
Skills []string `yaml:"skills" json:"skills"`
|
||||
MCPServers []string `yaml:"mcp_servers" json:"mcp_servers"`
|
||||
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
||||
UseCases []string `yaml:"use_cases" json:"use_cases"`
|
||||
}
|
||||
|
||||
type SkillDef struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Path string `yaml:"path" json:"path"`
|
||||
}
|
||||
|
||||
type MCPConfig struct {
|
||||
Servers []ServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
type ModelsConfig struct {
|
||||
Models []Model `yaml:"models,omitempty"`
|
||||
DefaultModel string `yaml:"default_model,omitempty"`
|
||||
FallbackChain []string `yaml:"fallback_chain,omitempty"`
|
||||
AutoSelect bool `yaml:"auto_select,omitempty"`
|
||||
EmbedModel string `yaml:"embed_model,omitempty"`
|
||||
}
|
||||
|
||||
func FindAgentsDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
filepath.Join(home, ".agents"),
|
||||
filepath.Join(home, ".config", "agents"),
|
||||
}
|
||||
|
||||
for _, dir := range candidates {
|
||||
if _, err := os.Stat(dir); err == nil {
|
||||
return dir
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func FindAgentsDirWithCreate() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get home dir: %w", err)
|
||||
}
|
||||
|
||||
dirs := []string{
|
||||
filepath.Join(home, ".agents"),
|
||||
filepath.Join(home, ".config", "agents"),
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
if _, err := os.Stat(dir); err == nil {
|
||||
return dir, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirs[0], 0755); err != nil {
|
||||
return "", fmt.Errorf("create agents dir: %w", err)
|
||||
}
|
||||
|
||||
return dirs[0], nil
|
||||
}
|
||||
|
||||
func LoadAgentsDir(path string) (*AgentsDir, error) {
|
||||
if path == "" {
|
||||
path = FindAgentsDir()
|
||||
if path == "" {
|
||||
return &AgentsDir{
|
||||
Path: "",
|
||||
Agents: make(map[string]AgentProfile),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
dir := &AgentsDir{
|
||||
Path: path,
|
||||
Agents: make(map[string]AgentProfile),
|
||||
}
|
||||
|
||||
if err := dir.loadAgents(path); err != nil {
|
||||
return nil, fmt.Errorf("load agents: %w", err)
|
||||
}
|
||||
|
||||
if err := dir.loadMCP(path); err != nil {
|
||||
return nil, fmt.Errorf("load MCP: %w", err)
|
||||
}
|
||||
|
||||
if err := dir.loadGlobalInstructions(path); err != nil {
|
||||
return nil, fmt.Errorf("load instructions: %w", err)
|
||||
}
|
||||
|
||||
if err := dir.loadSkills(path); err != nil {
|
||||
return nil, fmt.Errorf("load skills: %w", err)
|
||||
}
|
||||
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func (d *AgentsDir) loadAgents(path string) error {
|
||||
agentsDir := filepath.Join(path, "agents")
|
||||
entries, err := os.ReadDir(agentsDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
agentPath := filepath.Join(agentsDir, entry.Name(), "agent.yaml")
|
||||
if _, err := os.Stat(agentPath); err != nil {
|
||||
agentPath = filepath.Join(agentsDir, entry.Name(), "agent.md")
|
||||
}
|
||||
if _, err := os.Stat(agentPath); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(agentPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var profile AgentProfile
|
||||
if err := yaml.Unmarshal(data, &profile); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if profile.Name == "" {
|
||||
profile.Name = entry.Name()
|
||||
}
|
||||
|
||||
d.Agents[profile.Name] = profile
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *AgentsDir) loadMCP(path string) error {
|
||||
mcpPath := filepath.Join(path, "mcp.json")
|
||||
data, err := os.ReadFile(mcpPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var mcpCfg MCPConfig
|
||||
if err := json.Unmarshal(data, &mcpCfg); err != nil {
|
||||
return fmt.Errorf("parse mcp.json: %w", err)
|
||||
}
|
||||
|
||||
d.MCPServers = mcpCfg.Servers
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *AgentsDir) loadGlobalInstructions(path string) error {
|
||||
paths := []string{
|
||||
filepath.Join(path, "agents.md"),
|
||||
filepath.Join(path, "instructions.md"),
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
data, err := os.ReadFile(p)
|
||||
if err == nil {
|
||||
d.GlobalInstructions = string(data)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *AgentsDir) loadSkills(path string) error {
|
||||
skillsDir := filepath.Join(path, "skills")
|
||||
entries, err := os.ReadDir(skillsDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(skillsDir, entry.Name())
|
||||
|
||||
// Try both SKILL.md and skill.md (case insensitive check)
|
||||
skillPath := ""
|
||||
for _, name := range []string{"SKILL.md", "skill.md"} {
|
||||
path := filepath.Join(skillDir, name)
|
||||
if info, err := os.Stat(path); err == nil && !info.IsDir() {
|
||||
skillPath = path
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if skillPath == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(skillPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
d.Skills = append(d.Skills, SkillDef{
|
||||
Name: entry.Name(),
|
||||
Description: extractDescription(string(data)),
|
||||
Path: skillPath,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractDescription(content string) string {
|
||||
for _, line := range splitLines(content) {
|
||||
line = trimWhitespace(line)
|
||||
if line == "" || startsWith(line, "#") {
|
||||
continue
|
||||
}
|
||||
return line
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
start := 0
|
||||
for i, r := range s {
|
||||
if r == '\n' {
|
||||
lines = append(lines, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
lines = append(lines, s[start:])
|
||||
return lines
|
||||
}
|
||||
|
||||
func trimWhitespace(s string) string {
|
||||
start := 0
|
||||
end := len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
|
||||
func startsWith(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
|
||||
}
|
||||
|
||||
func (d *AgentsDir) GetAgent(name string) *AgentProfile {
|
||||
if agent, ok := d.Agents[name]; ok {
|
||||
return &agent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *AgentsDir) ListAgents() []AgentProfile {
|
||||
agents := make([]AgentProfile, 0, len(d.Agents))
|
||||
for _, agent := range d.Agents {
|
||||
agents = append(agents, agent)
|
||||
}
|
||||
return agents
|
||||
}
|
||||
|
||||
func (d *AgentsDir) GetSkills() []SkillDef {
|
||||
return d.Skills
|
||||
}
|
||||
|
||||
func (d *AgentsDir) HasMCP() bool {
|
||||
return len(d.MCPServers) > 0
|
||||
}
|
||||
|
||||
func (d *AgentsDir) GetMCPServers() []ServerConfig {
|
||||
return d.MCPServers
|
||||
}
|
||||
|
||||
func (d *AgentsDir) GetGlobalInstructions() string {
|
||||
return d.GlobalInstructions
|
||||
}
|
||||
|
||||
func CreateDefaultAgentsDir() error {
|
||||
dir, err := FindAgentsDirWithCreate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subdirs := []string{"agents", "skills", "tasks", "memories"}
|
||||
for _, sub := range subdirs {
|
||||
path := filepath.Join(dir, sub)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return fmt.Errorf("create %s: %w", sub, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mcpPath := filepath.Join(dir, "mcp.json")
|
||||
if _, err := os.Stat(mcpPath); err != nil {
|
||||
defaultMCP := MCPConfig{
|
||||
Servers: []ServerConfig{},
|
||||
}
|
||||
data, _ := json.MarshalIndent(defaultMCP, "", " ")
|
||||
if err := os.WriteFile(mcpPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("write mcp.json: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
agentsPath := filepath.Join(dir, "agents.md")
|
||||
if _, err := os.Stat(agentsPath); err != nil {
|
||||
defaultContent := `# Global Agent Instructions
|
||||
|
||||
You are a helpful local AI coding assistant.
|
||||
|
||||
## Guidelines
|
||||
- Be concise and direct
|
||||
- Explain your reasoning
|
||||
- Ask for clarification when needed
|
||||
- Never fabricate information
|
||||
`
|
||||
if err := os.WriteFile(agentsPath, []byte(defaultContent), 0644); err != nil {
|
||||
return fmt.Errorf("write agents.md: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
176
internal/config/agents_test.go
Normal file
176
internal/config/agents_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractDescription(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "first non-header non-empty line",
|
||||
content: "# Title\n\nThis is the description.\nMore text.",
|
||||
want: "This is the description.",
|
||||
},
|
||||
{
|
||||
name: "header only content",
|
||||
content: "# Title\n## Subtitle\n### Another",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace around description",
|
||||
content: "# Title\n\n Indented description \n",
|
||||
want: "Indented description",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractDescription(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractDescription() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want int // expected number of lines
|
||||
}{
|
||||
{name: "normal lines", s: "a\nb\nc", want: 3},
|
||||
{name: "empty string", s: "", want: 1},
|
||||
{name: "trailing newline", s: "a\nb\n", want: 3},
|
||||
{name: "single line", s: "hello", want: 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitLines(tt.s)
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("splitLines(%q) returned %d lines, want %d (lines: %v)", tt.s, len(got), tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimWhitespace(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{name: "tabs", s: "\thello\t", want: "hello"},
|
||||
{name: "spaces", s: " hello ", want: "hello"},
|
||||
{name: "mixed", s: "\t hello \t", want: "hello"},
|
||||
{name: "already trimmed", s: "hello", want: "hello"},
|
||||
{name: "empty", s: "", want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := trimWhitespace(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("trimWhitespace(%q) = %q, want %q", tt.s, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartsWith(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
prefix string
|
||||
want bool
|
||||
}{
|
||||
{name: "match", s: "hello world", prefix: "hello", want: true},
|
||||
{name: "no match", s: "hello world", prefix: "world", want: false},
|
||||
{name: "empty prefix", s: "hello", prefix: "", want: true},
|
||||
{name: "longer prefix", s: "hi", prefix: "hello", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := startsWith(tt.s, tt.prefix)
|
||||
if got != tt.want {
|
||||
t.Errorf("startsWith(%q, %q) = %v, want %v", tt.s, tt.prefix, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAgentsDir(t *testing.T) {
|
||||
t.Run("valid temp structure with agent", func(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
// Create agents/test-agent/agent.yaml
|
||||
agentDir := filepath.Join(tmp, "agents", "test-agent")
|
||||
if err := os.MkdirAll(agentDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
agentYAML := `name: test-agent
|
||||
description: A test agent
|
||||
model: qwen3.5:0.8b
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(agentDir, "agent.yaml"), []byte(agentYAML), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dir, err := LoadAgentsDir(tmp)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAgentsDir() error: %v", err)
|
||||
}
|
||||
if dir.Path != tmp {
|
||||
t.Errorf("Path = %q, want %q", dir.Path, tmp)
|
||||
}
|
||||
if len(dir.Agents) != 1 {
|
||||
t.Errorf("expected 1 agent, got %d", len(dir.Agents))
|
||||
}
|
||||
agent, ok := dir.Agents["test-agent"]
|
||||
if !ok {
|
||||
t.Fatal("expected agent 'test-agent' to exist")
|
||||
}
|
||||
if agent.Description != "A test agent" {
|
||||
t.Errorf("agent description = %q, want %q", agent.Description, "A test agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty path uses FindAgentsDir", func(t *testing.T) {
|
||||
dir, err := LoadAgentsDir("")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAgentsDir('') error: %v", err)
|
||||
}
|
||||
// Should return a valid AgentsDir (possibly with no agents)
|
||||
if dir == nil {
|
||||
t.Fatal("expected non-nil AgentsDir")
|
||||
}
|
||||
if dir.Agents == nil {
|
||||
t.Error("expected Agents map to be initialized")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonexistent subdirs dont error", func(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
// Empty temp dir — no agents/, skills/, mcp.json, etc.
|
||||
dir, err := LoadAgentsDir(tmp)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAgentsDir() error: %v", err)
|
||||
}
|
||||
if len(dir.Agents) != 0 {
|
||||
t.Errorf("expected 0 agents, got %d", len(dir.Agents))
|
||||
}
|
||||
})
|
||||
}
|
||||
186
internal/config/config.go
Normal file
186
internal/config/config.go
Normal file
@ -0,0 +1,186 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Ollama OllamaConfig `yaml:"ollama"`
|
||||
Model ModelConfig `yaml:"model,omitempty"`
|
||||
Agents AgentsConfig `yaml:"agents,omitempty"`
|
||||
Servers []ServerConfig `yaml:"servers,omitempty"`
|
||||
SkillsDir string `yaml:"skills_dir,omitempty"`
|
||||
ICE ICEConfig `yaml:"ice,omitempty"`
|
||||
AgentProfile string `yaml:"agent_profile,omitempty"`
|
||||
Tools ToolsConfig `yaml:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type AgentsConfig struct {
|
||||
Dir string `yaml:"dir,omitempty"`
|
||||
AutoLoad bool `yaml:"auto_load"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Timeout string `yaml:"timeout,omitempty"` // e.g., "30s", "2m"
|
||||
MaxGrepResults int `yaml:"max_grep_results,omitempty"`
|
||||
MaxIterations int `yaml:"max_iterations,omitempty"`
|
||||
}
|
||||
|
||||
type ICEConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
EmbedModel string `yaml:"embed_model,omitempty"`
|
||||
StorePath string `yaml:"store_path,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
NumCtx int `yaml:"num_ctx"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Command string `yaml:"command,omitempty"`
|
||||
Args []string `yaml:"args,omitempty"`
|
||||
Env []string `yaml:"env,omitempty"`
|
||||
Transport string `yaml:"transport,omitempty"`
|
||||
URL string `yaml:"url,omitempty"`
|
||||
}
|
||||
|
||||
func defaults() Config {
|
||||
modelCfg := DefaultModelConfig()
|
||||
return Config{
|
||||
Ollama: OllamaConfig{
|
||||
Model: "qwen3.5:2b",
|
||||
BaseURL: "http://localhost:11434",
|
||||
NumCtx: 262144,
|
||||
},
|
||||
Model: modelCfg,
|
||||
Agents: AgentsConfig{
|
||||
Dir: "",
|
||||
AutoLoad: true,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Timeout: "30s",
|
||||
MaxGrepResults: 500,
|
||||
MaxIterations: 10,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
cfg := defaults()
|
||||
|
||||
localPath := findConfigFile()
|
||||
if localPath != "" {
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read config %s: %w", localPath, err)
|
||||
}
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config %s: %w", localPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
agentsDir := cfg.Agents.Dir
|
||||
if agentsDir == "" {
|
||||
agentsDir = FindAgentsDir()
|
||||
}
|
||||
|
||||
var agentsData *AgentsDir
|
||||
if agentsDir != "" && cfg.Agents.AutoLoad {
|
||||
var err error
|
||||
agentsData, err = LoadAgentsDir(agentsDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to load .agents directory: %v\n", err)
|
||||
} else {
|
||||
if agentsData != nil {
|
||||
if cfg.Ollama.Model == "" {
|
||||
cfg.Ollama.Model = cfg.Model.DefaultModel
|
||||
}
|
||||
|
||||
if len(cfg.Servers) == 0 && agentsData.HasMCP() {
|
||||
cfg.Servers = agentsData.GetMCPServers()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applyEnvOverrides(&cfg)
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func LoadWithAgentsDir() (*Config, *AgentsDir, error) {
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
agentsDir := cfg.Agents.Dir
|
||||
if agentsDir == "" {
|
||||
agentsDir = FindAgentsDir()
|
||||
}
|
||||
var agents *AgentsDir
|
||||
if agentsDir != "" && cfg.Agents.AutoLoad {
|
||||
agents, _ = LoadAgentsDir(agentsDir)
|
||||
}
|
||||
|
||||
return cfg, agents, nil
|
||||
}
|
||||
|
||||
func findConfigFile() string {
|
||||
candidates := []string{
|
||||
"ai-agent.yaml",
|
||||
"ai-agent.yml",
|
||||
"config.yaml",
|
||||
"config.yml",
|
||||
}
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
candidates = append(candidates,
|
||||
filepath.Join(home, ".config", "ai-agent", "config.yaml"),
|
||||
filepath.Join(home, ".config", "ai-agent", "config.yml"),
|
||||
)
|
||||
}
|
||||
for _, path := range candidates {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func applyEnvOverrides(cfg *Config) {
|
||||
if v := os.Getenv("OLLAMA_HOST"); v != "" {
|
||||
cfg.Ollama.BaseURL = v
|
||||
}
|
||||
if v := os.Getenv("LOCAL_AGENT_MODEL"); v != "" {
|
||||
cfg.Ollama.Model = v
|
||||
}
|
||||
if v := os.Getenv("LOCAL_AGENT_AGENTS_DIR"); v != "" {
|
||||
cfg.Agents.Dir = v
|
||||
}
|
||||
if v := os.Getenv("LOCAL_AGENT_TOOLS_TIMEOUT"); v != "" {
|
||||
cfg.Tools.Timeout = v
|
||||
}
|
||||
if v := os.Getenv("LOCAL_AGENT_TOOLS_MAX_GREP"); v != "" {
|
||||
cfg.Tools.MaxGrepResults = parseEnvInt(v, cfg.Tools.MaxGrepResults)
|
||||
}
|
||||
if v := os.Getenv("LOCAL_AGENT_TOOLS_MAX_ITER"); v != "" {
|
||||
cfg.Tools.MaxIterations = parseEnvInt(v, cfg.Tools.MaxIterations)
|
||||
}
|
||||
if v := os.Getenv("LOCAL_AGENT_ICE_EMBED_MODEL"); v != "" {
|
||||
cfg.ICE.EmbedModel = v
|
||||
}
|
||||
}
|
||||
|
||||
func parseEnvInt(v string, defaultVal int) int {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return i
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
82
internal/config/config_test.go
Normal file
82
internal/config/config_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
cfg := defaults()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got string
|
||||
want string
|
||||
}{
|
||||
{name: "Ollama.Model", got: cfg.Ollama.Model, want: "qwen3.5:2b"},
|
||||
{name: "Ollama.BaseURL", got: cfg.Ollama.BaseURL, want: "http://localhost:11434"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("%s = %q, want %q", tt.name, tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if cfg.Ollama.NumCtx != 262144 {
|
||||
t.Errorf("Ollama.NumCtx = %d, want %d", cfg.Ollama.NumCtx, 262144)
|
||||
}
|
||||
|
||||
if !cfg.Model.AutoSelect {
|
||||
t.Error("Model.AutoSelect should be true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverrides(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envKey string
|
||||
envVal string
|
||||
checkFn func(cfg *Config) string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "OLLAMA_HOST overrides BaseURL",
|
||||
envKey: "OLLAMA_HOST",
|
||||
envVal: "http://custom:1234",
|
||||
checkFn: func(cfg *Config) string {
|
||||
return cfg.Ollama.BaseURL
|
||||
},
|
||||
want: "http://custom:1234",
|
||||
},
|
||||
{
|
||||
name: "LOCAL_AGENT_MODEL overrides Model",
|
||||
envKey: "LOCAL_AGENT_MODEL",
|
||||
envVal: "custom-model",
|
||||
checkFn: func(cfg *Config) string {
|
||||
return cfg.Ollama.Model
|
||||
},
|
||||
want: "custom-model",
|
||||
},
|
||||
{
|
||||
name: "LOCAL_AGENT_AGENTS_DIR overrides AgentsDir",
|
||||
envKey: "LOCAL_AGENT_AGENTS_DIR",
|
||||
envVal: "/custom/agents",
|
||||
checkFn: func(cfg *Config) string {
|
||||
return cfg.Agents.Dir
|
||||
},
|
||||
want: "/custom/agents",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv(tt.envKey, tt.envVal)
|
||||
cfg := defaults()
|
||||
applyEnvOverrides(&cfg)
|
||||
got := tt.checkFn(&cfg)
|
||||
if got != tt.want {
|
||||
t.Errorf("after setting %s=%q, got %q, want %q", tt.envKey, tt.envVal, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
103
internal/config/ignore.go
Normal file
103
internal/config/ignore.go
Normal file
@ -0,0 +1,103 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IgnorePatterns holds parsed .agentignore patterns.
|
||||
type IgnorePatterns struct {
|
||||
patterns []string
|
||||
raw string // original file content for injection into system prompt
|
||||
}
|
||||
|
||||
// LoadIgnoreFile reads and parses an .agentignore file from the given directory.
|
||||
// Returns nil if no .agentignore file exists (not an error).
|
||||
func LoadIgnoreFile(dir string) *IgnorePatterns {
|
||||
path := filepath.Join(dir, ".agentignore")
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var patterns []string
|
||||
var rawLines []string
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
rawLines = append(rawLines, line)
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Skip empty lines and comments.
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
patterns = append(patterns, trimmed)
|
||||
}
|
||||
|
||||
return &IgnorePatterns{
|
||||
patterns: patterns,
|
||||
raw: strings.Join(rawLines, "\n"),
|
||||
}
|
||||
}
|
||||
|
||||
// Match returns true if the given path should be ignored.
|
||||
// Returns false if the receiver is nil.
|
||||
func (ip *IgnorePatterns) Match(path string) bool {
|
||||
if ip == nil || len(ip.patterns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Normalise the path separators and remove trailing slashes for comparison.
|
||||
path = filepath.ToSlash(path)
|
||||
cleanPath := strings.TrimSuffix(path, "/")
|
||||
|
||||
for _, pattern := range ip.patterns {
|
||||
pat := strings.TrimSuffix(pattern, "/")
|
||||
|
||||
// Check each component of the path against the pattern.
|
||||
// e.g. "node_modules" should match "node_modules", "node_modules/foo",
|
||||
// and "src/node_modules/bar".
|
||||
parts := strings.Split(cleanPath, "/")
|
||||
for _, part := range parts {
|
||||
if matched, _ := filepath.Match(pat, part); matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Also try matching the full path with the pattern (for glob patterns
|
||||
// that include path separators like "build/output").
|
||||
if matched, _ := filepath.Match(pat, cleanPath); matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Prefix match: path starts with the pattern directory.
|
||||
if strings.HasPrefix(cleanPath, pat+"/") || cleanPath == pat {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Raw returns the raw file content for system prompt injection.
|
||||
// Returns an empty string if the receiver is nil.
|
||||
func (ip *IgnorePatterns) Raw() string {
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
return ip.raw
|
||||
}
|
||||
|
||||
// Patterns returns the list of patterns.
|
||||
// Returns nil if the receiver is nil.
|
||||
func (ip *IgnorePatterns) Patterns() []string {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
return ip.patterns
|
||||
}
|
||||
183
internal/config/ignore_test.go
Normal file
183
internal/config/ignore_test.go
Normal file
@ -0,0 +1,183 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadIgnoreFile_Valid(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
content := `# Build artifacts
|
||||
node_modules
|
||||
*.log
|
||||
.git
|
||||
build/
|
||||
dist/
|
||||
vendor/
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ip := LoadIgnoreFile(dir)
|
||||
if ip == nil {
|
||||
t.Fatal("expected non-nil IgnorePatterns")
|
||||
}
|
||||
|
||||
wantPatterns := []string{"node_modules", "*.log", ".git", "build/", "dist/", "vendor/"}
|
||||
if len(ip.Patterns()) != len(wantPatterns) {
|
||||
t.Fatalf("got %d patterns, want %d", len(ip.Patterns()), len(wantPatterns))
|
||||
}
|
||||
for i, p := range ip.Patterns() {
|
||||
if p != wantPatterns[i] {
|
||||
t.Errorf("pattern[%d] = %q, want %q", i, p, wantPatterns[i])
|
||||
}
|
||||
}
|
||||
|
||||
if ip.Raw() != content[:len(content)-1] { // raw joins lines without trailing newline from Join
|
||||
// Just check it contains the comment and patterns
|
||||
if ip.Raw() == "" {
|
||||
t.Error("Raw() should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIgnoreFile_Missing(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ip := LoadIgnoreFile(dir)
|
||||
if ip != nil {
|
||||
t.Error("expected nil for missing .agentignore")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIgnoreFile_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(""), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ip := LoadIgnoreFile(dir)
|
||||
if ip == nil {
|
||||
t.Fatal("expected non-nil IgnorePatterns for empty file")
|
||||
}
|
||||
if len(ip.Patterns()) != 0 {
|
||||
t.Errorf("expected 0 patterns, got %d", len(ip.Patterns()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadIgnoreFile_CommentsOnly(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
content := "# This is a comment\n# Another comment\n\n"
|
||||
if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ip := LoadIgnoreFile(dir)
|
||||
if ip == nil {
|
||||
t.Fatal("expected non-nil IgnorePatterns")
|
||||
}
|
||||
if len(ip.Patterns()) != 0 {
|
||||
t.Errorf("expected 0 patterns for comments-only file, got %d", len(ip.Patterns()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePatterns_Match_Exact(t *testing.T) {
|
||||
ip := &IgnorePatterns{
|
||||
patterns: []string{"node_modules", ".git", "vendor"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
{"node_modules", true},
|
||||
{"node_modules/package/index.js", true},
|
||||
{".git", true},
|
||||
{".git/config", true},
|
||||
{"vendor", true},
|
||||
{"vendor/lib/foo.go", true},
|
||||
{"src/main.go", false},
|
||||
{"README.md", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
if got := ip.Match(tt.path); got != tt.want {
|
||||
t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePatterns_Match_Glob(t *testing.T) {
|
||||
ip := &IgnorePatterns{
|
||||
patterns: []string{"*.log", "*.tmp"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
{"app.log", true},
|
||||
{"debug.log", true},
|
||||
{"temp.tmp", true},
|
||||
{"logs/app.log", true},
|
||||
{"main.go", false},
|
||||
{"log.txt", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
if got := ip.Match(tt.path); got != tt.want {
|
||||
t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePatterns_Match_DirectoryPattern(t *testing.T) {
|
||||
ip := &IgnorePatterns{
|
||||
patterns: []string{"build/", "dist/"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
{"build", true},
|
||||
{"build/output.js", true},
|
||||
{"dist", true},
|
||||
{"dist/bundle.js", true},
|
||||
{"src/build.go", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
if got := ip.Match(tt.path); got != tt.want {
|
||||
t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePatterns_Match_NilReceiver(t *testing.T) {
|
||||
var ip *IgnorePatterns
|
||||
if ip.Match("anything") {
|
||||
t.Error("nil IgnorePatterns should not match anything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePatterns_Raw_NilReceiver(t *testing.T) {
|
||||
var ip *IgnorePatterns
|
||||
if ip.Raw() != "" {
|
||||
t.Error("nil IgnorePatterns Raw() should return empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnorePatterns_Patterns_NilReceiver(t *testing.T) {
|
||||
var ip *IgnorePatterns
|
||||
if ip.Patterns() != nil {
|
||||
t.Error("nil IgnorePatterns Patterns() should return nil")
|
||||
}
|
||||
}
|
||||
167
internal/config/models.go
Normal file
167
internal/config/models.go
Normal file
@ -0,0 +1,167 @@
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ModelFamily string
|
||||
|
||||
const (
|
||||
FamilyQwen3 ModelFamily = "qwen3"
|
||||
FamilyQwen35 ModelFamily = "qwen3.5"
|
||||
FamilyLlama ModelFamily = "llama"
|
||||
FamilyMistral ModelFamily = "mistral"
|
||||
)
|
||||
|
||||
type ModelCapability int
|
||||
|
||||
const (
|
||||
CapabilitySimple ModelCapability = iota
|
||||
CapabilityMedium
|
||||
CapabilityComplex
|
||||
CapabilityAdvanced
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Name string `yaml:"name"`
|
||||
Family ModelFamily `yaml:"family"`
|
||||
DisplayName string `yaml:"display_name"`
|
||||
Size string `yaml:"size"`
|
||||
Parameters string `yaml:"parameters"`
|
||||
ContextSize int `yaml:"context_size"`
|
||||
Capability ModelCapability `yaml:"capability"`
|
||||
Speed float64 `yaml:"speed"` // 1.0 = baseline
|
||||
UseCases []string `yaml:"use_cases"`
|
||||
Description string `yaml:"description"`
|
||||
Default bool `yaml:"default,omitempty"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
Models []Model `yaml:"models"`
|
||||
DefaultModel string `yaml:"default_model"`
|
||||
FallbackChain []string `yaml:"fallback_chain"`
|
||||
AutoSelect bool `yaml:"auto_select"`
|
||||
EmbedModel string `yaml:"embed_model,omitempty"`
|
||||
}
|
||||
|
||||
func DefaultModels() []Model {
|
||||
return []Model{
|
||||
{
|
||||
Name: "qwen3.5:0.8b",
|
||||
Family: FamilyQwen35,
|
||||
DisplayName: "Qwen 3.5 0.8B",
|
||||
Size: "0.8B",
|
||||
Parameters: "0.8 billion",
|
||||
ContextSize: 262144,
|
||||
Capability: CapabilitySimple,
|
||||
Speed: 4.0,
|
||||
UseCases: []string{"quick_answers", "simple_tools", "single_file_edits"},
|
||||
Description: "Fast, lightweight model for simple tasks and quick answers",
|
||||
Default: false,
|
||||
},
|
||||
{
|
||||
Name: "qwen3.5:2b",
|
||||
Family: FamilyQwen35,
|
||||
DisplayName: "Qwen 3.5 2B",
|
||||
Size: "2B",
|
||||
Parameters: "2 billion",
|
||||
ContextSize: 262144,
|
||||
Capability: CapabilityMedium,
|
||||
Speed: 2.5,
|
||||
UseCases: []string{"code_completion", "simple_refactoring", "explanations"},
|
||||
Description: "Balanced model for medium complexity tasks",
|
||||
Default: true,
|
||||
},
|
||||
{
|
||||
Name: "qwen3.5:4b",
|
||||
Family: FamilyQwen35,
|
||||
DisplayName: "Qwen 3.5 4B",
|
||||
Size: "4B",
|
||||
Parameters: "4 billion",
|
||||
ContextSize: 262144,
|
||||
Capability: CapabilityComplex,
|
||||
Speed: 1.5,
|
||||
UseCases: []string{"multi_step_reasoning", "code_review", "debugging", "refactoring"},
|
||||
Description: "Capable model for complex reasoning and code analysis",
|
||||
Default: false,
|
||||
},
|
||||
{
|
||||
Name: "qwen3.5:9b",
|
||||
Family: FamilyQwen35,
|
||||
DisplayName: "Qwen 3.5 9B",
|
||||
Size: "9B",
|
||||
Parameters: "9 billion",
|
||||
ContextSize: 262144,
|
||||
Capability: CapabilityAdvanced,
|
||||
Speed: 1.0,
|
||||
UseCases: []string{"complex_reasoning", "architecture", "full_stack", "advanced_debugging"},
|
||||
Description: "Full capability model for advanced tasks",
|
||||
Default: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultModelConfig() ModelConfig {
|
||||
models := DefaultModels()
|
||||
return ModelConfig{
|
||||
Models: models,
|
||||
DefaultModel: "qwen3.5:2b",
|
||||
FallbackChain: []string{"qwen3.5:2b", "qwen3.5:0.8b", "qwen3.5:4b", "qwen3.5:9b"},
|
||||
AutoSelect: true,
|
||||
EmbedModel: "nomic-embed-text",
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) IsSimpleTask() bool {
|
||||
return m.Capability <= CapabilityMedium
|
||||
}
|
||||
|
||||
func (m *Model) IsComplexTask() bool {
|
||||
return m.Capability >= CapabilityComplex
|
||||
}
|
||||
|
||||
func (mc *ModelConfig) GetModel(name string) (*Model, error) {
|
||||
for _, m := range mc.Models {
|
||||
if m.Name == name {
|
||||
return &m, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("model not found: %s", name)
|
||||
}
|
||||
|
||||
func (mc *ModelConfig) GetDefaultModel() *Model {
|
||||
for _, m := range mc.Models {
|
||||
if m.Default {
|
||||
return &m
|
||||
}
|
||||
}
|
||||
if len(mc.Models) > 0 {
|
||||
return &mc.Models[len(mc.Models)-1]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *ModelConfig) SelectModelForTask(taskComplexity string) string {
|
||||
if !mc.AutoSelect {
|
||||
return mc.DefaultModel
|
||||
}
|
||||
|
||||
switch taskComplexity {
|
||||
case "simple":
|
||||
return mc.Models[0].Name
|
||||
case "medium":
|
||||
for _, m := range mc.Models {
|
||||
if m.Capability == CapabilityMedium {
|
||||
return m.Name
|
||||
}
|
||||
}
|
||||
case "complex":
|
||||
for _, m := range mc.Models {
|
||||
if m.Capability == CapabilityComplex {
|
||||
return m.Name
|
||||
}
|
||||
}
|
||||
case "advanced":
|
||||
return mc.DefaultModel
|
||||
}
|
||||
|
||||
return mc.DefaultModel
|
||||
}
|
||||
159
internal/config/models_test.go
Normal file
159
internal/config/models_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestModel_IsSimpleTask(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
capability ModelCapability
|
||||
want bool
|
||||
}{
|
||||
{name: "CapabilitySimple is simple", capability: CapabilitySimple, want: true},
|
||||
{name: "CapabilityMedium is simple", capability: CapabilityMedium, want: true},
|
||||
{name: "CapabilityComplex is not simple", capability: CapabilityComplex, want: false},
|
||||
{name: "CapabilityAdvanced is not simple", capability: CapabilityAdvanced, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &Model{Capability: tt.capability}
|
||||
if got := m.IsSimpleTask(); got != tt.want {
|
||||
t.Errorf("Model{Capability: %d}.IsSimpleTask() = %v, want %v", tt.capability, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModel_IsComplexTask(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
capability ModelCapability
|
||||
want bool
|
||||
}{
|
||||
{name: "CapabilitySimple is not complex", capability: CapabilitySimple, want: false},
|
||||
{name: "CapabilityMedium is not complex", capability: CapabilityMedium, want: false},
|
||||
{name: "CapabilityComplex is complex", capability: CapabilityComplex, want: true},
|
||||
{name: "CapabilityAdvanced is complex", capability: CapabilityAdvanced, want: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &Model{Capability: tt.capability}
|
||||
if got := m.IsComplexTask(); got != tt.want {
|
||||
t.Errorf("Model{Capability: %d}.IsComplexTask() = %v, want %v", tt.capability, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_GetModel(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "found model", model: "qwen3.5:0.8b", wantErr: false},
|
||||
{name: "not found", model: "nonexistent", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := cfg.GetModel(tt.model)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if got.Name != tt.model {
|
||||
t.Errorf("GetModel(%q).Name = %q, want %q", tt.model, got.Name, tt.model)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_GetDefaultModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg ModelConfig
|
||||
want string // empty means nil expected
|
||||
}{
|
||||
{
|
||||
name: "model with Default=true",
|
||||
cfg: ModelConfig{
|
||||
Models: []Model{
|
||||
{Name: "a", Default: false},
|
||||
{Name: "b", Default: true},
|
||||
{Name: "c", Default: false},
|
||||
},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "no default returns last",
|
||||
cfg: ModelConfig{
|
||||
Models: []Model{
|
||||
{Name: "a", Default: false},
|
||||
{Name: "b", Default: false},
|
||||
},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "empty slice returns nil",
|
||||
cfg: ModelConfig{Models: []Model{}},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.cfg.GetDefaultModel()
|
||||
if tt.want == "" {
|
||||
if got != nil {
|
||||
t.Errorf("expected nil, got %+v", got)
|
||||
}
|
||||
} else {
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil model, got nil")
|
||||
}
|
||||
if got.Name != tt.want {
|
||||
t.Errorf("GetDefaultModel().Name = %q, want %q", got.Name, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_SelectModelForTask(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
complexity string
|
||||
autoSelect bool
|
||||
want string
|
||||
}{
|
||||
{name: "auto simple", complexity: "simple", autoSelect: true, want: "qwen3.5:0.8b"},
|
||||
{name: "auto medium", complexity: "medium", autoSelect: true, want: "qwen3.5:2b"},
|
||||
{name: "auto complex", complexity: "complex", autoSelect: true, want: "qwen3.5:4b"},
|
||||
{name: "auto advanced", complexity: "advanced", autoSelect: true, want: cfg.DefaultModel},
|
||||
{name: "no autoselect simple", complexity: "simple", autoSelect: false, want: cfg.DefaultModel},
|
||||
{name: "no autoselect complex", complexity: "complex", autoSelect: false, want: cfg.DefaultModel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg.AutoSelect = tt.autoSelect
|
||||
got := cfg.SelectModelForTask(tt.complexity)
|
||||
if got != tt.want {
|
||||
t.Errorf("SelectModelForTask(%q) = %q, want %q", tt.complexity, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
364
internal/config/qwen_router.go
Normal file
364
internal/config/qwen_router.go
Normal file
@ -0,0 +1,364 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type QwenModelRouter struct {
|
||||
config *ModelConfig
|
||||
overrideLog []ModelOverride
|
||||
modeContext ModeContext
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type ModeContext int
|
||||
|
||||
const (
|
||||
ModeAskContext ModeContext = iota
|
||||
ModePlanContext
|
||||
ModeBuildContext
|
||||
)
|
||||
|
||||
type QwenComplexity string
|
||||
|
||||
const (
|
||||
QwenTrivial QwenComplexity = "trivial"
|
||||
QwenSimple QwenComplexity = "simple"
|
||||
QwenModerate QwenComplexity = "moderate"
|
||||
QwenAdvanced QwenComplexity = "advanced"
|
||||
)
|
||||
|
||||
var (
|
||||
qwenTrivialIndicators = []string{
|
||||
"what is", "who is", "when is", "where is",
|
||||
"define", "meaning of", "synonym", "antonym",
|
||||
"list files", "show me", "display",
|
||||
"yes", "no", "ok", "thanks",
|
||||
"hello", "hi", "hey",
|
||||
}
|
||||
qwenSimpleIndicators = []string{
|
||||
"how do i", "explain", "what does", "why does",
|
||||
"find", "search", "get", "read",
|
||||
"print", "echo", "cat", "ls", "grep",
|
||||
"simple", "quick", "fast", "brief",
|
||||
"check", "verify", "test",
|
||||
"create file", "write file", "save",
|
||||
}
|
||||
qwenModerateIndicators = []string{
|
||||
"create", "generate", "add", "modify", "update",
|
||||
"fix", "debug", "refactor", "optimize",
|
||||
"function", "class", "method", "interface",
|
||||
"test", "unit test", "integration test",
|
||||
"script", "command", "pipeline",
|
||||
"compare", "analyze", "review",
|
||||
"multiple", "several", "across",
|
||||
}
|
||||
qwenAdvancedIndicators = []string{
|
||||
"architecture", "design pattern", "system design",
|
||||
"infrastructure", "deployment", "scaling",
|
||||
"security audit", "performance optimization",
|
||||
"multi-step", "complex", "comprehensive",
|
||||
"build a", "implement", "develop", "engineer",
|
||||
"full stack", "end-to-end", "production",
|
||||
"migration", "refactor entire", "rewrite",
|
||||
}
|
||||
qwenCodePatterns = map[string]QwenComplexity{
|
||||
"variable": QwenSimple,
|
||||
"constant": QwenSimple,
|
||||
"function": QwenSimple,
|
||||
"loop": QwenSimple,
|
||||
"condition": QwenSimple,
|
||||
"array": QwenSimple,
|
||||
"slice": QwenSimple,
|
||||
"map": QwenSimple,
|
||||
"struct": QwenModerate,
|
||||
"interface": QwenModerate,
|
||||
"generics": QwenModerate,
|
||||
"concurrency": QwenModerate,
|
||||
"goroutine": QwenModerate,
|
||||
"channel": QwenModerate,
|
||||
"mutex": QwenModerate,
|
||||
"architecture": QwenAdvanced,
|
||||
"pattern": QwenAdvanced,
|
||||
"microservice": QwenAdvanced,
|
||||
"distributed": QwenAdvanced,
|
||||
"kubernetes": QwenAdvanced,
|
||||
}
|
||||
)
|
||||
|
||||
func NewQwenModelRouter(cfg *ModelConfig) *QwenModelRouter {
|
||||
return &QwenModelRouter{
|
||||
config: cfg,
|
||||
overrideLog: make([]ModelOverride, 0),
|
||||
modeContext: ModeAskContext,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) SetModeContext(mode ModeContext) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.modeContext = mode
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) ClassifyTaskComplexity(query string) QwenComplexity {
|
||||
return classifyQwenTask(query, r.modeContext)
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) SelectModel(query string) string {
|
||||
complexity := r.ClassifyTaskComplexity(query)
|
||||
return r.config.SelectModelForTask(string(complexity))
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) SelectModelForMode(query string, mode ModeContext) string {
|
||||
switch mode {
|
||||
case ModeAskContext:
|
||||
return r.selectAskModel(query)
|
||||
case ModePlanContext:
|
||||
return r.selectPlanModel(query)
|
||||
case ModeBuildContext:
|
||||
return r.selectBuildModel(query)
|
||||
}
|
||||
return r.SelectModel(query)
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) selectAskModel(query string) string {
|
||||
complexity := classifyQwenTask(query, ModeAskContext)
|
||||
switch complexity {
|
||||
case QwenTrivial, QwenSimple:
|
||||
if r.isModelAvailable("qwen3.5:0.8b") {
|
||||
return "qwen3.5:0.8b"
|
||||
}
|
||||
return "qwen3.5:2b"
|
||||
case QwenModerate:
|
||||
return "qwen3.5:2b"
|
||||
case QwenAdvanced:
|
||||
return "qwen3.5:4b"
|
||||
default:
|
||||
return "qwen3.5:2b"
|
||||
}
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) selectPlanModel(query string) string {
|
||||
complexity := classifyQwenTask(query, ModePlanContext)
|
||||
switch complexity {
|
||||
case QwenTrivial, QwenSimple:
|
||||
return "qwen3.5:2b"
|
||||
case QwenModerate:
|
||||
return "qwen3.5:4b"
|
||||
case QwenAdvanced:
|
||||
return "qwen3.5:9b"
|
||||
default:
|
||||
return "qwen3.5:4b"
|
||||
}
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) selectBuildModel(query string) string {
|
||||
complexity := classifyQwenTask(query, ModeBuildContext)
|
||||
switch complexity {
|
||||
case QwenTrivial, QwenSimple:
|
||||
return "qwen3.5:2b"
|
||||
case QwenModerate:
|
||||
return "qwen3.5:4b"
|
||||
case QwenAdvanced:
|
||||
return "qwen3.5:9b"
|
||||
default:
|
||||
return "qwen3.5:4b"
|
||||
}
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) isModelAvailable(name string) bool {
|
||||
for _, m := range r.config.Models {
|
||||
if m.Name == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func classifyQwenTask(query string, mode ModeContext) QwenComplexity {
|
||||
lowerQuery := strings.ToLower(query)
|
||||
words := strings.Fields(lowerQuery)
|
||||
wordCount := len(words)
|
||||
score := 0
|
||||
for _, indicator := range qwenTrivialIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score -= 4
|
||||
}
|
||||
}
|
||||
for _, indicator := range qwenSimpleIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score -= 1
|
||||
}
|
||||
}
|
||||
for _, indicator := range qwenModerateIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score += 2
|
||||
}
|
||||
}
|
||||
for _, indicator := range qwenAdvancedIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score += 4
|
||||
}
|
||||
}
|
||||
for pattern, complexity := range qwenCodePatterns {
|
||||
if strings.Contains(lowerQuery, pattern) {
|
||||
switch complexity {
|
||||
case QwenSimple:
|
||||
score -= 1
|
||||
case QwenModerate:
|
||||
score += 2
|
||||
case QwenAdvanced:
|
||||
score += 4
|
||||
}
|
||||
}
|
||||
}
|
||||
if wordCount > 50 {
|
||||
score += 3
|
||||
} else if wordCount > 30 {
|
||||
score += 1
|
||||
} else if wordCount < 5 && score <= 0 {
|
||||
score -= 2
|
||||
}
|
||||
if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") {
|
||||
score += 2
|
||||
}
|
||||
if strings.Contains(lowerQuery, "how") && wordCount > 10 {
|
||||
score += 1
|
||||
}
|
||||
if strings.Contains(lowerQuery, "?") && wordCount < 10 {
|
||||
score -= 1
|
||||
}
|
||||
switch mode {
|
||||
case ModeAskContext:
|
||||
score -= 1
|
||||
case ModeBuildContext:
|
||||
score += 1
|
||||
}
|
||||
switch {
|
||||
case score <= -3:
|
||||
return QwenTrivial
|
||||
case score <= 1:
|
||||
return QwenSimple
|
||||
case score <= 5:
|
||||
return QwenModerate
|
||||
default:
|
||||
return QwenAdvanced
|
||||
}
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) RecordOverride(query, userModel string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
routerModel := r.SelectModel(query)
|
||||
r.overrideLog = append(r.overrideLog, ModelOverride{
|
||||
Query: query,
|
||||
UserModel: userModel,
|
||||
RouterModel: routerModel,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
if len(r.overrideLog) > 100 {
|
||||
r.overrideLog = r.overrideLog[len(r.overrideLog)-100:]
|
||||
}
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) GetLearnedPatterns() map[string]QwenComplexity {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
if len(r.overrideLog) < 3 {
|
||||
return nil
|
||||
}
|
||||
wordCounts := make(map[string]map[QwenComplexity]int)
|
||||
for _, o := range r.overrideLog {
|
||||
if o.Query == "" || o.UserModel == "" {
|
||||
continue
|
||||
}
|
||||
var complexity QwenComplexity
|
||||
switch {
|
||||
case strings.Contains(o.UserModel, "0.8b"):
|
||||
complexity = QwenTrivial
|
||||
case strings.Contains(o.UserModel, "2b"):
|
||||
complexity = QwenSimple
|
||||
case strings.Contains(o.UserModel, "4b"):
|
||||
complexity = QwenModerate
|
||||
case strings.Contains(o.UserModel, "9b"):
|
||||
complexity = QwenAdvanced
|
||||
default:
|
||||
continue
|
||||
}
|
||||
words := strings.Fields(strings.ToLower(o.Query))
|
||||
for _, w := range words {
|
||||
if len(w) < 3 {
|
||||
continue
|
||||
}
|
||||
if _, ok := wordCounts[w]; !ok {
|
||||
wordCounts[w] = make(map[QwenComplexity]int)
|
||||
}
|
||||
wordCounts[w][complexity]++
|
||||
}
|
||||
}
|
||||
wordComplexity := make(map[string]QwenComplexity)
|
||||
for word, counts := range wordCounts {
|
||||
var maxCount int
|
||||
var dominant QwenComplexity
|
||||
for c, cnt := range counts {
|
||||
if cnt > maxCount {
|
||||
maxCount = cnt
|
||||
dominant = c
|
||||
}
|
||||
}
|
||||
if maxCount >= 2 {
|
||||
wordComplexity[word] = dominant
|
||||
}
|
||||
}
|
||||
return wordComplexity
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string {
|
||||
preferred := r.SelectModel(query)
|
||||
fallbackOrder := []string{
|
||||
preferred,
|
||||
"qwen3.5:2b",
|
||||
"qwen3.5:0.8b",
|
||||
"qwen3.5:4b",
|
||||
"qwen3.5:9b",
|
||||
}
|
||||
for _, model := range fallbackOrder {
|
||||
if err := pinger.PingModel(ctx, model); err == nil {
|
||||
return model
|
||||
}
|
||||
}
|
||||
return r.config.DefaultModel
|
||||
}
|
||||
|
||||
func (r *QwenModelRouter) GetRecommendedModel(query string) (model string, reason string, complexity QwenComplexity) {
|
||||
r.mu.RLock()
|
||||
mode := r.modeContext
|
||||
r.mu.RUnlock()
|
||||
complexity = classifyQwenTask(query, mode)
|
||||
switch complexity {
|
||||
case QwenTrivial:
|
||||
model = "qwen3.5:0.8b"
|
||||
reason = "trivial task - ultra-fast response"
|
||||
case QwenSimple:
|
||||
model = "qwen3.5:2b"
|
||||
reason = "simple task - balanced speed/capability"
|
||||
case QwenModerate:
|
||||
model = "qwen3.5:4b"
|
||||
reason = "moderate complexity - multi-step reasoning"
|
||||
case QwenAdvanced:
|
||||
model = "qwen3.5:9b"
|
||||
reason = "advanced task - complex reasoning required"
|
||||
}
|
||||
switch mode {
|
||||
case ModeAskContext:
|
||||
reason += " (ASK mode - prefer speed)"
|
||||
case ModePlanContext:
|
||||
reason += " (PLAN mode - prefer reasoning)"
|
||||
case ModeBuildContext:
|
||||
reason += " (BUILD mode - prefer capability)"
|
||||
}
|
||||
return model, reason, complexity
|
||||
}
|
||||
254
internal/config/qwen_router_test.go
Normal file
254
internal/config/qwen_router_test.go
Normal file
@ -0,0 +1,254 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQwenRouter_ClassifyTrivial(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
maxComplexity QwenComplexity
|
||||
}{
|
||||
{"simple what", "what is go", QwenTrivial},
|
||||
{"simple who", "who created go", QwenSimple},
|
||||
{"simple define", "define interface", QwenTrivial},
|
||||
{"simple greeting", "hello", QwenTrivial},
|
||||
{"simple thanks", "thanks", QwenTrivial},
|
||||
{"simple list", "list files", QwenTrivial},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyQwenTask(tt.query, ModeAskContext)
|
||||
if got > tt.maxComplexity {
|
||||
t.Errorf("classifyQwenTask(%q) = %v, want <= %v", tt.query, got, tt.maxComplexity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_ClassifySimple(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
}{
|
||||
{"simple how", "how do i create a file"},
|
||||
{"simple explain", "explain this code"},
|
||||
{"simple find", "find all go files"},
|
||||
{"simple check", "check if file exists"},
|
||||
{"simple read", "read config file"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyQwenTask(tt.query, ModeAskContext)
|
||||
t.Logf("%s: %v", tt.query, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_ClassifyModerate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
}{
|
||||
{"create function", "create a function to parse json"},
|
||||
{"debug issue", "debug this nil pointer error"},
|
||||
{"refactor code", "refactor this function"},
|
||||
{"add test", "add unit tests for handler"},
|
||||
{"optimize query", "optimize this database query"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyQwenTask(tt.query, ModeBuildContext)
|
||||
t.Logf("%s: %v", tt.query, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_ClassifyAdvanced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
}{
|
||||
{"architecture", "design microservice architecture"},
|
||||
{"system design", "system design for high traffic"},
|
||||
{"security audit", "security audit of api"},
|
||||
{"full stack", "build a full stack application"},
|
||||
{"migration", "migration from mysql to postgres"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyQwenTask(tt.query, ModeBuildContext)
|
||||
t.Logf("%s: %v", tt.query, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_ModeAffectsClassification(t *testing.T) {
|
||||
query := "how do i fix this bug"
|
||||
|
||||
ask := classifyQwenTask(query, ModeAskContext)
|
||||
build := classifyQwenTask(query, ModeBuildContext)
|
||||
|
||||
// BUILD mode should generally prefer equal or larger models than ASK
|
||||
// Note: This is a soft requirement - the mode adjustment is subtle
|
||||
t.Logf("ASK mode: %v, BUILD mode: %v", ask, build)
|
||||
}
|
||||
|
||||
func TestQwenRouter_WordCountAffectsClassification(t *testing.T) {
|
||||
short := "what is go"
|
||||
long := "what is the go programming language and how does it compare to rust and what are its main features and use cases in modern software development"
|
||||
|
||||
shortComplexity := classifyQwenTask(short, ModeAskContext)
|
||||
longComplexity := classifyQwenTask(long, ModeAskContext)
|
||||
|
||||
// Long query should ideally be more complex, but at minimum not less
|
||||
// Note: This test documents the behavior - word count does affect scoring
|
||||
t.Logf("short (%d chars): %v, long (%d chars): %v", len(short), shortComplexity, len(long), longComplexity)
|
||||
}
|
||||
|
||||
func TestQwenRouter_CodePatterns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
maxComplexity QwenComplexity
|
||||
}{
|
||||
{"simple variable", "declare a variable", QwenModerate},
|
||||
{"simple function", "write a function", QwenAdvanced},
|
||||
{"moderate struct", "define a struct", QwenAdvanced},
|
||||
{"moderate interface", "implement an interface", QwenAdvanced},
|
||||
{"moderate concurrency", "add concurrency with goroutines", QwenAdvanced},
|
||||
{"advanced architecture", "design the architecture", QwenAdvanced},
|
||||
{"advanced distributed", "distributed system design", QwenAdvanced},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyQwenTask(tt.query, ModeBuildContext)
|
||||
// All code patterns should classify as something (not panic)
|
||||
t.Logf("%s: %v", tt.query, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_SelectAskModel(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
router := NewQwenModelRouter(&cfg)
|
||||
router.SetModeContext(ModeAskContext)
|
||||
|
||||
// Simple question should get small model
|
||||
model := router.SelectModelForMode("what is go", ModeAskContext)
|
||||
if model != "qwen3.5:0.8b" && model != "qwen3.5:2b" {
|
||||
t.Errorf("ASK mode simple query should get small model, got %s", model)
|
||||
}
|
||||
|
||||
// Complex question should get capable model (2B or higher)
|
||||
model = router.SelectModelForMode("design a distributed system", ModeAskContext)
|
||||
if model == "qwen3.5:0.8b" {
|
||||
t.Errorf("ASK mode complex query should not get 0.8B model, got %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_SelectPlanModel(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
router := NewQwenModelRouter(&cfg)
|
||||
router.SetModeContext(ModePlanContext)
|
||||
|
||||
// Planning should prefer 4B for reasoning
|
||||
model := router.SelectModelForMode("plan the architecture", ModePlanContext)
|
||||
if model != "qwen3.5:4b" && model != "qwen3.5:9b" {
|
||||
t.Errorf("PLAN mode should prefer 4B or 9B, got %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_SelectBuildModel(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
router := NewQwenModelRouter(&cfg)
|
||||
router.SetModeContext(ModeBuildContext)
|
||||
|
||||
// Building should prefer capable models
|
||||
model := router.SelectModelForMode("implement the feature", ModeBuildContext)
|
||||
if model != "qwen3.5:4b" && model != "qwen3.5:9b" {
|
||||
t.Errorf("BUILD mode should prefer 4B or 9B, got %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_GetRecommendedModel(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
router := NewQwenModelRouter(&cfg)
|
||||
|
||||
model, reason, complexity := router.GetRecommendedModel("what is go")
|
||||
|
||||
if model == "" {
|
||||
t.Error("GetRecommendedModel should return a model")
|
||||
}
|
||||
if reason == "" {
|
||||
t.Error("GetRecommendedModel should return a reason")
|
||||
}
|
||||
if complexity == "" {
|
||||
t.Error("GetRecommendedModel should return a complexity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_QuestionMarkHandling(t *testing.T) {
|
||||
// Short questions with ? should be simpler
|
||||
short := "what is go?"
|
||||
long := "can you explain what the go programming language is and how it works?"
|
||||
|
||||
shortComplexity := classifyQwenTask(short, ModeAskContext)
|
||||
longComplexity := classifyQwenTask(long, ModeAskContext)
|
||||
|
||||
if shortComplexity >= longComplexity {
|
||||
t.Logf("Note: short question complexity (%v) vs long (%v)", shortComplexity, longComplexity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_WhyQuestions(t *testing.T) {
|
||||
// Why questions need reasoning
|
||||
why := "why does this code fail"
|
||||
what := "what does this code do"
|
||||
|
||||
whyComplexity := classifyQwenTask(why, ModeAskContext)
|
||||
whatComplexity := classifyQwenTask(what, ModeAskContext)
|
||||
|
||||
if whyComplexity < whatComplexity {
|
||||
t.Errorf("why questions should be more complex: why=%v, what=%v", whyComplexity, whatComplexity)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQwenRouter_ClassifyTask(b *testing.B) {
|
||||
queries := []string{
|
||||
"what is go",
|
||||
"how do i create a file",
|
||||
"debug this nil pointer error",
|
||||
"design microservice architecture",
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, q := range queries {
|
||||
_ = classifyQwenTask(q, ModeAskContext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQwenRouter_SelectModel(b *testing.B) {
|
||||
cfg := DefaultModelConfig()
|
||||
router := NewQwenModelRouter(&cfg)
|
||||
queries := []string{
|
||||
"what is go",
|
||||
"how do i create a file",
|
||||
"debug this nil pointer error",
|
||||
"design microservice architecture",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, q := range queries {
|
||||
_ = router.SelectModel(q)
|
||||
}
|
||||
}
|
||||
}
|
||||
318
internal/config/router.go
Normal file
318
internal/config/router.go
Normal file
@ -0,0 +1,318 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TaskComplexity string
|
||||
|
||||
const (
|
||||
ComplexitySimple TaskComplexity = "simple"
|
||||
ComplexityMedium TaskComplexity = "medium"
|
||||
ComplexityComplex TaskComplexity = "complex"
|
||||
ComplexityAdvanced TaskComplexity = "advanced"
|
||||
)
|
||||
|
||||
var simpleIndicators = []string{
|
||||
"what is", "how do i", "explain", "what does",
|
||||
"find", "search", "list", "show", "get",
|
||||
"print", "echo", "read", "cat", "ls",
|
||||
"simple", "quick", "fast",
|
||||
}
|
||||
|
||||
var mediumIndicators = []string{
|
||||
"create", "write", "generate", "add", "modify",
|
||||
"change", "update", "fix", "refactor",
|
||||
"function", "class", "variable", "test",
|
||||
"script", "command", "file", "directory",
|
||||
}
|
||||
|
||||
var complexIndicators = []string{
|
||||
"debug", "error", "bug", "issue", "problem",
|
||||
"refactor", "architecture", "design", "review",
|
||||
"multiple", "several", "across", "migrate",
|
||||
"optimize", "performance", "security",
|
||||
"explain why", "analyze", "compare",
|
||||
}
|
||||
|
||||
var advancedIndicators = []string{
|
||||
"build a", "create a", "implement", "develop",
|
||||
"full stack", "system", "infrastructure",
|
||||
"multi-step", "complex", "comprehensive",
|
||||
"security audit", "architecture design",
|
||||
}
|
||||
|
||||
// ModelPinger is an interface for checking if a model is available.
|
||||
type ModelPinger interface {
|
||||
PingModel(ctx context.Context, model string) error
|
||||
}
|
||||
|
||||
// ModelOverride records when a user explicitly selects a model.
|
||||
type ModelOverride struct {
|
||||
Query string
|
||||
UserModel string
|
||||
RouterModel string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
config *ModelConfig
|
||||
overrideLog []ModelOverride
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRouter(cfg *ModelConfig) *Router {
|
||||
return &Router{
|
||||
config: cfg,
|
||||
overrideLog: make([]ModelOverride, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) ClassifyTaskComplexity(query string) TaskComplexity {
|
||||
return ClassifyTask(query)
|
||||
}
|
||||
|
||||
func (r *Router) SelectModel(query string) string {
|
||||
complexity := r.ClassifyTaskComplexity(query)
|
||||
|
||||
// Check learned patterns if we have enough data
|
||||
wordComplexity := r.getLearnedPatterns()
|
||||
if len(wordComplexity) > 0 {
|
||||
words := strings.Fields(strings.ToLower(query))
|
||||
|
||||
// Count votes from learned patterns
|
||||
complexityVotes := make(map[TaskComplexity]int)
|
||||
for _, w := range words {
|
||||
if len(w) >= 3 { // Skip short words
|
||||
if c, ok := wordComplexity[w]; ok {
|
||||
complexityVotes[c]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If strong learned signal (>30% words match a pattern), use it
|
||||
if len(words) > 0 {
|
||||
matchRatio := float64(complexityVotes[ComplexitySimple]+complexityVotes[ComplexityAdvanced]) / float64(len(words))
|
||||
if matchRatio > 0.3 {
|
||||
if complexityVotes[ComplexityAdvanced] > complexityVotes[ComplexitySimple] {
|
||||
complexity = ComplexityAdvanced
|
||||
} else if complexityVotes[ComplexitySimple] > complexityVotes[ComplexityAdvanced] {
|
||||
complexity = ComplexitySimple
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r.config.SelectModelForTask(string(complexity))
|
||||
}
|
||||
|
||||
// RecordOverride logs when a user explicitly selects a model.
|
||||
// This helps the router learn from user preferences.
|
||||
func (r *Router) RecordOverride(query, userModel string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
routerModel := r.SelectModel(query)
|
||||
|
||||
r.overrideLog = append(r.overrideLog, ModelOverride{
|
||||
Query: query,
|
||||
UserModel: userModel,
|
||||
RouterModel: routerModel,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
|
||||
// Keep last 100 overrides
|
||||
if len(r.overrideLog) > 100 {
|
||||
r.overrideLog = r.overrideLog[len(r.overrideLog)-100:]
|
||||
}
|
||||
}
|
||||
|
||||
// getLearnedPatterns analyzes override history to find word->complexity mappings.
|
||||
func (r *Router) getLearnedPatterns() map[string]TaskComplexity {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if len(r.overrideLog) < 3 {
|
||||
return nil // Not enough data
|
||||
}
|
||||
|
||||
wordCounts := make(map[string]map[TaskComplexity]int)
|
||||
|
||||
for _, o := range r.overrideLog {
|
||||
if o.Query == "" || o.UserModel == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine complexity from user-selected model
|
||||
var complexity TaskComplexity
|
||||
switch {
|
||||
case strings.Contains(o.UserModel, "0.8") || strings.Contains(o.UserModel, "2b"):
|
||||
complexity = ComplexitySimple
|
||||
case strings.Contains(o.UserModel, "4b"):
|
||||
complexity = ComplexityMedium
|
||||
case strings.Contains(o.UserModel, "9b"):
|
||||
complexity = ComplexityAdvanced
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
words := strings.Fields(strings.ToLower(o.Query))
|
||||
for _, w := range words {
|
||||
if len(w) < 3 {
|
||||
continue // Skip short words
|
||||
}
|
||||
if _, ok := wordCounts[w]; !ok {
|
||||
wordCounts[w] = make(map[TaskComplexity]int)
|
||||
}
|
||||
wordCounts[w][complexity]++
|
||||
}
|
||||
}
|
||||
|
||||
// For each word, find dominant complexity
|
||||
wordComplexity := make(map[string]TaskComplexity)
|
||||
for word, counts := range wordCounts {
|
||||
var maxCount int
|
||||
var dominant TaskComplexity
|
||||
for c, cnt := range counts {
|
||||
if cnt > maxCount {
|
||||
maxCount = cnt
|
||||
dominant = c
|
||||
}
|
||||
}
|
||||
// Only use if we have enough samples (at least 2 overrides)
|
||||
if maxCount >= 2 {
|
||||
wordComplexity[word] = dominant
|
||||
}
|
||||
}
|
||||
|
||||
return wordComplexity
|
||||
}
|
||||
|
||||
func (r *Router) GetFallbackChain(currentModel string) []string {
|
||||
chain := r.config.FallbackChain
|
||||
|
||||
for i, model := range chain {
|
||||
if model == currentModel {
|
||||
return chain[i:]
|
||||
}
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (r *Router) GetModelForCapability(capability ModelCapability) string {
|
||||
for _, m := range r.config.Models {
|
||||
if m.Capability == capability {
|
||||
return m.Name
|
||||
}
|
||||
}
|
||||
return r.config.DefaultModel
|
||||
}
|
||||
|
||||
// SelectAvailableModel returns the first available model from the fallback chain.
|
||||
// It checks each model in order and returns the first one that responds to a ping.
|
||||
// If no models are available, returns the default model.
|
||||
func (r *Router) SelectAvailableModel(ctx context.Context, pinger ModelPinger) string {
|
||||
chain := r.config.FallbackChain
|
||||
|
||||
for _, model := range chain {
|
||||
if err := pinger.PingModel(ctx, model); err == nil {
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to default if none available
|
||||
return r.config.DefaultModel
|
||||
}
|
||||
|
||||
// SelectAvailableModelForTask returns the first available model for the given task complexity.
|
||||
// It prioritizes models appropriate for the task, then falls back to larger models if unavailable.
|
||||
func (r *Router) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string {
|
||||
// First, get the preferred model for this task
|
||||
preferred := r.SelectModel(query)
|
||||
|
||||
// Check if preferred model is available
|
||||
if err := pinger.PingModel(ctx, preferred); err == nil {
|
||||
return preferred
|
||||
}
|
||||
|
||||
// Try fallback chain
|
||||
chain := r.GetFallbackChain(preferred)
|
||||
for _, model := range chain {
|
||||
if err := pinger.PingModel(ctx, model); err == nil {
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: default model
|
||||
return r.config.DefaultModel
|
||||
}
|
||||
|
||||
func (r *Router) ForceModel(name string) (*Model, error) {
|
||||
return r.config.GetModel(name)
|
||||
}
|
||||
|
||||
func (r *Router) ListModels() []Model {
|
||||
return r.config.Models
|
||||
}
|
||||
|
||||
func (r *Router) GetDefaultModel() string {
|
||||
return r.config.DefaultModel
|
||||
}
|
||||
|
||||
func ClassifyTask(query string) TaskComplexity {
|
||||
lowerQuery := strings.ToLower(query)
|
||||
wordCount := len(strings.Fields(query))
|
||||
|
||||
score := 0
|
||||
|
||||
for _, indicator := range simpleIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score -= 2
|
||||
}
|
||||
}
|
||||
|
||||
for _, indicator := range mediumIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score += 1
|
||||
}
|
||||
}
|
||||
|
||||
for _, indicator := range complexIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score += 2
|
||||
}
|
||||
}
|
||||
|
||||
for _, indicator := range advancedIndicators {
|
||||
if strings.Contains(lowerQuery, indicator) {
|
||||
score += 3
|
||||
}
|
||||
}
|
||||
|
||||
if wordCount > 50 {
|
||||
score += 2
|
||||
}
|
||||
|
||||
if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") {
|
||||
score += 1
|
||||
}
|
||||
|
||||
if strings.Contains(lowerQuery, "how") && wordCount > 10 {
|
||||
score += 1
|
||||
}
|
||||
|
||||
switch {
|
||||
case score <= -2:
|
||||
return ComplexitySimple
|
||||
case score <= 1:
|
||||
return ComplexityMedium
|
||||
case score <= 4:
|
||||
return ComplexityComplex
|
||||
default:
|
||||
return ComplexityAdvanced
|
||||
}
|
||||
}
|
||||
166
internal/config/router_test.go
Normal file
166
internal/config/router_test.go
Normal file
@ -0,0 +1,166 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassifyTask(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
want TaskComplexity
|
||||
}{
|
||||
{name: "empty query", query: "", want: ComplexityMedium},
|
||||
{name: "simple what is", query: "what is Go", want: ComplexitySimple},
|
||||
|
||||
// "create a function": medium "create" +1, "function" +1, advanced "create a" +3 = 5 → advanced
|
||||
{name: "create a function is advanced due to overlaps", query: "create a function", want: ComplexityAdvanced},
|
||||
|
||||
// "debug this error across multiple files": complex "debug" +2, "error" +2, "bug" +2 (substring of debug),
|
||||
// "multiple" +2, "across" +2 = 10, medium "file" +1 = 11 → advanced
|
||||
{name: "debug across files is advanced", query: "debug this error across multiple files", want: ComplexityAdvanced},
|
||||
|
||||
// "implement a full stack system with infrastructure": advanced "implement" +3, "full stack" +3, "system" +3,
|
||||
// "infrastructure" +3 = 12 → advanced
|
||||
{name: "advanced full stack system", query: "implement a full stack system with infrastructure", want: ComplexityAdvanced},
|
||||
|
||||
// Boundary: "explain" → simple -2 → score -2 → simple
|
||||
{name: "boundary simple score -2", query: "explain", want: ComplexitySimple},
|
||||
|
||||
// No indicators → score 0 → medium
|
||||
{name: "boundary medium score 0", query: "hello world", want: ComplexityMedium},
|
||||
|
||||
// "create" → medium +1, but also matches advanced "create a"? No, "create" doesn't contain "create a".
|
||||
// So just +1 → medium
|
||||
{name: "boundary medium score 1", query: "create", want: ComplexityMedium},
|
||||
|
||||
// "debug" alone: complex "debug" +2, "bug" +2 (substring) = 4 → complex
|
||||
{name: "debug alone is complex", query: "debug", want: ComplexityComplex},
|
||||
|
||||
// "debug error": "debug" +2, "error" +2, "bug" +2 (substring of debug) = 6 → advanced
|
||||
{name: "debug error is advanced", query: "debug error", want: ComplexityAdvanced},
|
||||
|
||||
// Word count >50 bonus (+2) with "debug": "debug" +2, "bug" +2 = 4, +2 word bonus = 6 → advanced
|
||||
{
|
||||
name: "word count bonus over 50 with debug",
|
||||
query: strings.Repeat("word ", 51) + "debug",
|
||||
want: ComplexityAdvanced,
|
||||
},
|
||||
|
||||
// "why does this happen": "why" +1 = 1 → medium
|
||||
{name: "why bonus", query: "why does this happen", want: ComplexityMedium},
|
||||
|
||||
// "reason for the crash": "reason" +1 = 1 → medium
|
||||
{name: "reason bonus", query: "reason for the crash", want: ComplexityMedium},
|
||||
|
||||
// "how about we think...": no indicators, "how" + >10 words +1 = 1 → medium
|
||||
{name: "how with many words", query: "how about we think about the things that are happening right now in the code base", want: ComplexityMedium},
|
||||
|
||||
// Case insensitivity
|
||||
{name: "case insensitive WHAT IS", query: "WHAT IS Go", want: ComplexitySimple},
|
||||
{name: "case insensitive EXPLAIN", query: "EXPLAIN this code", want: ComplexitySimple},
|
||||
|
||||
// Pure simple: multiple simple indicators
|
||||
{name: "multiple simple indicators", query: "what is this simple quick search", want: ComplexitySimple},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ClassifyTask(tt.query)
|
||||
if got != tt.want {
|
||||
t.Errorf("ClassifyTask(%q) = %q, want %q", tt.query, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GetFallbackChain(t *testing.T) {
|
||||
cfg := &ModelConfig{
|
||||
FallbackChain: []string{"a", "b", "c", "d"},
|
||||
}
|
||||
r := NewRouter(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantLen int
|
||||
wantAll bool // true means expect full chain
|
||||
}{
|
||||
{name: "found at start", model: "a", wantLen: 4},
|
||||
{name: "found in middle", model: "c", wantLen: 2},
|
||||
{name: "found at end", model: "d", wantLen: 1},
|
||||
{name: "not found returns full chain", model: "unknown", wantLen: 4, wantAll: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := r.GetFallbackChain(tt.model)
|
||||
if len(got) != tt.wantLen {
|
||||
t.Errorf("GetFallbackChain(%q) returned %d items, want %d", tt.model, len(got), tt.wantLen)
|
||||
}
|
||||
if tt.wantAll && got[0] != "a" {
|
||||
t.Errorf("GetFallbackChain(%q) first element = %q, want %q", tt.model, got[0], "a")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_GetModelForCapability(t *testing.T) {
|
||||
cfg := &ModelConfig{
|
||||
Models: []Model{
|
||||
{Name: "fast", Capability: CapabilitySimple},
|
||||
{Name: "mid", Capability: CapabilityMedium},
|
||||
{Name: "big", Capability: CapabilityComplex},
|
||||
},
|
||||
DefaultModel: "fallback",
|
||||
}
|
||||
r := NewRouter(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
capability ModelCapability
|
||||
want string
|
||||
}{
|
||||
{name: "match simple", capability: CapabilitySimple, want: "fast"},
|
||||
{name: "match medium", capability: CapabilityMedium, want: "mid"},
|
||||
{name: "match complex", capability: CapabilityComplex, want: "big"},
|
||||
{name: "no match returns default", capability: CapabilityAdvanced, want: "fallback"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := r.GetModelForCapability(tt.capability)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetModelForCapability(%d) = %q, want %q", tt.capability, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel(t *testing.T) {
|
||||
cfg := DefaultModelConfig()
|
||||
r := NewRouter(&cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
want string
|
||||
}{
|
||||
// "what is Go" → simple → first model
|
||||
{name: "simple query selects first model", query: "what is Go", want: cfg.Models[0].Name},
|
||||
// "debug" → complex → complex-capable model
|
||||
{name: "complex query selects complex model", query: "debug", want: "qwen3.5:4b"},
|
||||
// "implement a system" → advanced → DefaultModel
|
||||
{name: "advanced query selects default model", query: "implement a full stack system", want: cfg.DefaultModel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := r.SelectModel(tt.query)
|
||||
if got != tt.want {
|
||||
t.Errorf("SelectModel(%q) = %q, want %q", tt.query, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
31
internal/db/db.go
Normal file
31
internal/db/db.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
58
internal/db/migrations/001_init.sql
Normal file
58
internal/db/migrations/001_init.sql
Normal file
@ -0,0 +1,58 @@
|
||||
-- Sessions table: replaces noted CLI dependency for session persistence.
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
model TEXT NOT NULL DEFAULT '',
|
||||
mode TEXT NOT NULL DEFAULT 'BUILD',
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
|
||||
-- Session messages: individual chat entries within a session.
|
||||
CREATE TABLE IF NOT EXISTS session_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL, -- 'user', 'assistant', 'tool', 'system', 'error'
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
tool_name TEXT NOT NULL DEFAULT '',
|
||||
tool_args TEXT NOT NULL DEFAULT '',
|
||||
is_error INTEGER NOT NULL DEFAULT 0,
|
||||
thinking TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_session_messages_session_id ON session_messages(session_id);
|
||||
|
||||
-- Tool permissions: per-tool allow/deny/always-allow.
|
||||
CREATE TABLE IF NOT EXISTS tool_permissions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tool_name TEXT NOT NULL UNIQUE,
|
||||
policy TEXT NOT NULL DEFAULT 'ask', -- 'allow', 'deny', 'ask'
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
|
||||
-- Token usage stats: per-turn tracking.
|
||||
CREATE TABLE IF NOT EXISTS token_stats (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
turn INTEGER NOT NULL DEFAULT 0,
|
||||
eval_count INTEGER NOT NULL DEFAULT 0,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
model TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_token_stats_session_id ON token_stats(session_id);
|
||||
|
||||
-- File changes: files modified by agent during a session.
|
||||
CREATE TABLE IF NOT EXISTS file_changes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
file_path TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL DEFAULT '',
|
||||
added INTEGER NOT NULL DEFAULT 0,
|
||||
removed INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_file_changes_session_id ON file_changes(session_id);
|
||||
53
internal/db/models.go
Normal file
53
internal/db/models.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package db
|
||||
|
||||
type FileChange struct {
|
||||
ID int64 `json:"id"`
|
||||
SessionID int64 `json:"session_id"`
|
||||
FilePath string `json:"file_path"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Added int64 `json:"added"`
|
||||
Removed int64 `json:"removed"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID int64 `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Model string `json:"model"`
|
||||
Mode string `json:"mode"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type SessionMessage struct {
|
||||
ID int64 `json:"id"`
|
||||
SessionID int64 `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolArgs string `json:"tool_args"`
|
||||
IsError int64 `json:"is_error"`
|
||||
Thinking string `json:"thinking"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type TokenStat struct {
|
||||
ID int64 `json:"id"`
|
||||
SessionID int64 `json:"session_id"`
|
||||
Turn int64 `json:"turn"`
|
||||
EvalCount int64 `json:"eval_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type ToolPermission struct {
|
||||
ID int64 `json:"id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Policy string `json:"policy"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
100
internal/db/permissions.sql.go
Normal file
100
internal/db/permissions.sql.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: permissions.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const deleteToolPermission = `-- name: DeleteToolPermission :exec
|
||||
DELETE FROM tool_permissions WHERE tool_name = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteToolPermission(ctx context.Context, toolName string) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteToolPermission, toolName)
|
||||
return err
|
||||
}
|
||||
|
||||
const getToolPermission = `-- name: GetToolPermission :one
|
||||
SELECT id, tool_name, policy, updated_at FROM tool_permissions WHERE tool_name = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetToolPermission(ctx context.Context, toolName string) (ToolPermission, error) {
|
||||
row := q.db.QueryRowContext(ctx, getToolPermission, toolName)
|
||||
var i ToolPermission
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ToolName,
|
||||
&i.Policy,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listToolPermissions = `-- name: ListToolPermissions :many
|
||||
SELECT id, tool_name, policy, updated_at FROM tool_permissions ORDER BY tool_name ASC
|
||||
`
|
||||
|
||||
func (q *Queries) ListToolPermissions(ctx context.Context) ([]ToolPermission, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listToolPermissions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []ToolPermission{}
|
||||
for rows.Next() {
|
||||
var i ToolPermission
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ToolName,
|
||||
&i.Policy,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const resetToolPermissions = `-- name: ResetToolPermissions :exec
|
||||
DELETE FROM tool_permissions
|
||||
`
|
||||
|
||||
func (q *Queries) ResetToolPermissions(ctx context.Context) error {
|
||||
_, err := q.db.ExecContext(ctx, resetToolPermissions)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertToolPermission = `-- name: UpsertToolPermission :one
|
||||
INSERT INTO tool_permissions (tool_name, policy)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(tool_name) DO UPDATE SET policy = excluded.policy, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
RETURNING id, tool_name, policy, updated_at
|
||||
`
|
||||
|
||||
type UpsertToolPermissionParams struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Policy string `json:"policy"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpsertToolPermission(ctx context.Context, arg UpsertToolPermissionParams) (ToolPermission, error) {
|
||||
row := q.db.QueryRowContext(ctx, upsertToolPermission, arg.ToolName, arg.Policy)
|
||||
var i ToolPermission
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ToolName,
|
||||
&i.Policy,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
17
internal/db/queries/permissions.sql
Normal file
17
internal/db/queries/permissions.sql
Normal file
@ -0,0 +1,17 @@
|
||||
-- name: GetToolPermission :one
|
||||
SELECT * FROM tool_permissions WHERE tool_name = ?;
|
||||
|
||||
-- name: UpsertToolPermission :one
|
||||
INSERT INTO tool_permissions (tool_name, policy)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(tool_name) DO UPDATE SET policy = excluded.policy, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
RETURNING *;
|
||||
|
||||
-- name: ListToolPermissions :many
|
||||
SELECT * FROM tool_permissions ORDER BY tool_name ASC;
|
||||
|
||||
-- name: DeleteToolPermission :exec
|
||||
DELETE FROM tool_permissions WHERE tool_name = ?;
|
||||
|
||||
-- name: ResetToolPermissions :exec
|
||||
DELETE FROM tool_permissions;
|
||||
27
internal/db/queries/sessions.sql
Normal file
27
internal/db/queries/sessions.sql
Normal file
@ -0,0 +1,27 @@
|
||||
-- name: CreateSession :one
|
||||
INSERT INTO sessions (title, model, mode) VALUES (?, ?, ?) RETURNING *;
|
||||
|
||||
-- name: GetSession :one
|
||||
SELECT * FROM sessions WHERE id = ?;
|
||||
|
||||
-- name: ListSessions :many
|
||||
SELECT * FROM sessions ORDER BY updated_at DESC LIMIT ?;
|
||||
|
||||
-- name: UpdateSessionTitle :exec
|
||||
UPDATE sessions SET title = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?;
|
||||
|
||||
-- name: UpdateSessionTimestamp :exec
|
||||
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?;
|
||||
|
||||
-- name: DeleteSession :exec
|
||||
DELETE FROM sessions WHERE id = ?;
|
||||
|
||||
-- name: CreateSessionMessage :one
|
||||
INSERT INTO session_messages (session_id, role, content, tool_name, tool_args, is_error, thinking)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *;
|
||||
|
||||
-- name: GetSessionMessages :many
|
||||
SELECT * FROM session_messages WHERE session_id = ? ORDER BY id ASC;
|
||||
|
||||
-- name: CountSessions :one
|
||||
SELECT COUNT(*) FROM sessions;
|
||||
31
internal/db/queries/stats.sql
Normal file
31
internal/db/queries/stats.sql
Normal file
@ -0,0 +1,31 @@
|
||||
-- name: RecordTokenUsage :one
|
||||
INSERT INTO token_stats (session_id, turn, eval_count, prompt_tokens, model)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING *;
|
||||
|
||||
-- name: GetSessionTokenStats :many
|
||||
SELECT * FROM token_stats WHERE session_id = ? ORDER BY turn ASC;
|
||||
|
||||
-- name: GetSessionTotalTokens :one
|
||||
SELECT
|
||||
CAST(COALESCE(SUM(eval_count), 0) AS INTEGER) AS total_eval,
|
||||
CAST(COALESCE(SUM(prompt_tokens), 0) AS INTEGER) AS total_prompt,
|
||||
CAST(COUNT(*) AS INTEGER) AS turn_count
|
||||
FROM token_stats WHERE session_id = ?;
|
||||
|
||||
-- name: RecordFileChange :one
|
||||
INSERT INTO file_changes (session_id, file_path, tool_name, added, removed)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING *;
|
||||
|
||||
-- name: GetSessionFileChanges :many
|
||||
SELECT * FROM file_changes WHERE session_id = ? ORDER BY created_at ASC;
|
||||
|
||||
-- name: GetSessionFileChangeSummary :many
|
||||
SELECT
|
||||
file_path,
|
||||
CAST(COALESCE(SUM(added), 0) AS INTEGER) AS total_added,
|
||||
CAST(COALESCE(SUM(removed), 0) AS INTEGER) AS total_removed,
|
||||
CAST(COUNT(*) AS INTEGER) AS change_count
|
||||
FROM file_changes
|
||||
WHERE session_id = ?
|
||||
GROUP BY file_path
|
||||
ORDER BY file_path ASC;
|
||||
206
internal/db/sessions.sql.go
Normal file
206
internal/db/sessions.sql.go
Normal file
@ -0,0 +1,206 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: sessions.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const countSessions = `-- name: CountSessions :one
|
||||
SELECT COUNT(*) FROM sessions
|
||||
`
|
||||
|
||||
func (q *Queries) CountSessions(ctx context.Context) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, countSessions)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const createSession = `-- name: CreateSession :one
|
||||
INSERT INTO sessions (title, model, mode) VALUES (?, ?, ?) RETURNING id, title, model, mode, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
Title string `json:"title"`
|
||||
Model string `json:"model"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||
row := q.db.QueryRowContext(ctx, createSession, arg.Title, arg.Model, arg.Mode)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Title,
|
||||
&i.Model,
|
||||
&i.Mode,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const createSessionMessage = `-- name: CreateSessionMessage :one
|
||||
INSERT INTO session_messages (session_id, role, content, tool_name, tool_args, is_error, thinking)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id, session_id, role, content, tool_name, tool_args, is_error, thinking, created_at
|
||||
`
|
||||
|
||||
type CreateSessionMessageParams struct {
|
||||
SessionID int64 `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolArgs string `json:"tool_args"`
|
||||
IsError int64 `json:"is_error"`
|
||||
Thinking string `json:"thinking"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSessionMessage(ctx context.Context, arg CreateSessionMessageParams) (SessionMessage, error) {
|
||||
row := q.db.QueryRowContext(ctx, createSessionMessage,
|
||||
arg.SessionID,
|
||||
arg.Role,
|
||||
arg.Content,
|
||||
arg.ToolName,
|
||||
arg.ToolArgs,
|
||||
arg.IsError,
|
||||
arg.Thinking,
|
||||
)
|
||||
var i SessionMessage
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Content,
|
||||
&i.ToolName,
|
||||
&i.ToolArgs,
|
||||
&i.IsError,
|
||||
&i.Thinking,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteSession = `-- name: DeleteSession :exec
|
||||
DELETE FROM sessions WHERE id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSession(ctx context.Context, id int64) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteSession, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSession = `-- name: GetSession :one
|
||||
SELECT id, title, model, mode, created_at, updated_at FROM sessions WHERE id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetSession(ctx context.Context, id int64) (Session, error) {
|
||||
row := q.db.QueryRowContext(ctx, getSession, id)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Title,
|
||||
&i.Model,
|
||||
&i.Mode,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSessionMessages = `-- name: GetSessionMessages :many
|
||||
SELECT id, session_id, role, content, tool_name, tool_args, is_error, thinking, created_at FROM session_messages WHERE session_id = ? ORDER BY id ASC
|
||||
`
|
||||
|
||||
func (q *Queries) GetSessionMessages(ctx context.Context, sessionID int64) ([]SessionMessage, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getSessionMessages, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []SessionMessage{}
|
||||
for rows.Next() {
|
||||
var i SessionMessage
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Role,
|
||||
&i.Content,
|
||||
&i.ToolName,
|
||||
&i.ToolArgs,
|
||||
&i.IsError,
|
||||
&i.Thinking,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSessions = `-- name: ListSessions :many
|
||||
SELECT id, title, model, mode, created_at, updated_at FROM sessions ORDER BY updated_at DESC LIMIT ?
|
||||
`
|
||||
|
||||
func (q *Queries) ListSessions(ctx context.Context, limit int64) ([]Session, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listSessions, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []Session{}
|
||||
for rows.Next() {
|
||||
var i Session
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Title,
|
||||
&i.Model,
|
||||
&i.Mode,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateSessionTimestamp = `-- name: UpdateSessionTimestamp :exec
|
||||
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateSessionTimestamp(ctx context.Context, id int64) error {
|
||||
_, err := q.db.ExecContext(ctx, updateSessionTimestamp, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateSessionTitle = `-- name: UpdateSessionTitle :exec
|
||||
UPDATE sessions SET title = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?
|
||||
`
|
||||
|
||||
type UpdateSessionTitleParams struct {
|
||||
Title string `json:"title"`
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSessionTitle(ctx context.Context, arg UpdateSessionTitleParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateSessionTitle, arg.Title, arg.ID)
|
||||
return err
|
||||
}
|
||||
11
internal/db/sqlc.yaml
Normal file
11
internal/db/sqlc.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
version: "2"
|
||||
sql:
|
||||
- engine: "sqlite"
|
||||
queries: "queries"
|
||||
schema: "migrations"
|
||||
gen:
|
||||
go:
|
||||
package: "db"
|
||||
out: "."
|
||||
emit_json_tags: true
|
||||
emit_empty_slices: true
|
||||
216
internal/db/stats.sql.go
Normal file
216
internal/db/stats.sql.go
Normal file
@ -0,0 +1,216 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: stats.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const getSessionFileChangeSummary = `-- name: GetSessionFileChangeSummary :many
|
||||
SELECT
|
||||
file_path,
|
||||
CAST(COALESCE(SUM(added), 0) AS INTEGER) AS total_added,
|
||||
CAST(COALESCE(SUM(removed), 0) AS INTEGER) AS total_removed,
|
||||
CAST(COUNT(*) AS INTEGER) AS change_count
|
||||
FROM file_changes
|
||||
WHERE session_id = ?
|
||||
GROUP BY file_path
|
||||
ORDER BY file_path ASC
|
||||
`
|
||||
|
||||
type GetSessionFileChangeSummaryRow struct {
|
||||
FilePath string `json:"file_path"`
|
||||
TotalAdded int64 `json:"total_added"`
|
||||
TotalRemoved int64 `json:"total_removed"`
|
||||
ChangeCount int64 `json:"change_count"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSessionFileChangeSummary(ctx context.Context, sessionID int64) ([]GetSessionFileChangeSummaryRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getSessionFileChangeSummary, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []GetSessionFileChangeSummaryRow{}
|
||||
for rows.Next() {
|
||||
var i GetSessionFileChangeSummaryRow
|
||||
if err := rows.Scan(
|
||||
&i.FilePath,
|
||||
&i.TotalAdded,
|
||||
&i.TotalRemoved,
|
||||
&i.ChangeCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getSessionFileChanges = `-- name: GetSessionFileChanges :many
|
||||
SELECT id, session_id, file_path, tool_name, added, removed, created_at FROM file_changes WHERE session_id = ? ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
func (q *Queries) GetSessionFileChanges(ctx context.Context, sessionID int64) ([]FileChange, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getSessionFileChanges, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []FileChange{}
|
||||
for rows.Next() {
|
||||
var i FileChange
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.FilePath,
|
||||
&i.ToolName,
|
||||
&i.Added,
|
||||
&i.Removed,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getSessionTokenStats = `-- name: GetSessionTokenStats :many
|
||||
SELECT id, session_id, turn, eval_count, prompt_tokens, model, created_at FROM token_stats WHERE session_id = ? ORDER BY turn ASC
|
||||
`
|
||||
|
||||
func (q *Queries) GetSessionTokenStats(ctx context.Context, sessionID int64) ([]TokenStat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getSessionTokenStats, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := []TokenStat{}
|
||||
for rows.Next() {
|
||||
var i TokenStat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Turn,
|
||||
&i.EvalCount,
|
||||
&i.PromptTokens,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getSessionTotalTokens = `-- name: GetSessionTotalTokens :one
|
||||
SELECT
|
||||
CAST(COALESCE(SUM(eval_count), 0) AS INTEGER) AS total_eval,
|
||||
CAST(COALESCE(SUM(prompt_tokens), 0) AS INTEGER) AS total_prompt,
|
||||
CAST(COUNT(*) AS INTEGER) AS turn_count
|
||||
FROM token_stats WHERE session_id = ?
|
||||
`
|
||||
|
||||
type GetSessionTotalTokensRow struct {
|
||||
TotalEval int64 `json:"total_eval"`
|
||||
TotalPrompt int64 `json:"total_prompt"`
|
||||
TurnCount int64 `json:"turn_count"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSessionTotalTokens(ctx context.Context, sessionID int64) (GetSessionTotalTokensRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, getSessionTotalTokens, sessionID)
|
||||
var i GetSessionTotalTokensRow
|
||||
err := row.Scan(&i.TotalEval, &i.TotalPrompt, &i.TurnCount)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const recordFileChange = `-- name: RecordFileChange :one
|
||||
INSERT INTO file_changes (session_id, file_path, tool_name, added, removed)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING id, session_id, file_path, tool_name, added, removed, created_at
|
||||
`
|
||||
|
||||
type RecordFileChangeParams struct {
|
||||
SessionID int64 `json:"session_id"`
|
||||
FilePath string `json:"file_path"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Added int64 `json:"added"`
|
||||
Removed int64 `json:"removed"`
|
||||
}
|
||||
|
||||
func (q *Queries) RecordFileChange(ctx context.Context, arg RecordFileChangeParams) (FileChange, error) {
|
||||
row := q.db.QueryRowContext(ctx, recordFileChange,
|
||||
arg.SessionID,
|
||||
arg.FilePath,
|
||||
arg.ToolName,
|
||||
arg.Added,
|
||||
arg.Removed,
|
||||
)
|
||||
var i FileChange
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.FilePath,
|
||||
&i.ToolName,
|
||||
&i.Added,
|
||||
&i.Removed,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const recordTokenUsage = `-- name: RecordTokenUsage :one
|
||||
INSERT INTO token_stats (session_id, turn, eval_count, prompt_tokens, model)
|
||||
VALUES (?, ?, ?, ?, ?) RETURNING id, session_id, turn, eval_count, prompt_tokens, model, created_at
|
||||
`
|
||||
|
||||
type RecordTokenUsageParams struct {
|
||||
SessionID int64 `json:"session_id"`
|
||||
Turn int64 `json:"turn"`
|
||||
EvalCount int64 `json:"eval_count"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func (q *Queries) RecordTokenUsage(ctx context.Context, arg RecordTokenUsageParams) (TokenStat, error) {
|
||||
row := q.db.QueryRowContext(ctx, recordTokenUsage,
|
||||
arg.SessionID,
|
||||
arg.Turn,
|
||||
arg.EvalCount,
|
||||
arg.PromptTokens,
|
||||
arg.Model,
|
||||
)
|
||||
var i TokenStat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.SessionID,
|
||||
&i.Turn,
|
||||
&i.EvalCount,
|
||||
&i.PromptTokens,
|
||||
&i.Model,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
71
internal/db/store.go
Normal file
71
internal/db/store.go
Normal file
@ -0,0 +1,71 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrations embed.FS
|
||||
|
||||
type Store struct {
|
||||
*Queries
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func Open() (*Store, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("home dir: %w", err)
|
||||
}
|
||||
dir := filepath.Join(home, ".config", "ai-agent")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create config dir: %w", err)
|
||||
}
|
||||
return OpenPath(filepath.Join(dir, "ai-agent.db"))
|
||||
}
|
||||
|
||||
func OpenPath(path string) (*Store, error) {
|
||||
conn, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=ON")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
if err := runMigrations(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("migrations: %w", err)
|
||||
}
|
||||
return &Store{Queries: New(conn), db: conn}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) DB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
func runMigrations(conn *sql.DB) error {
|
||||
entries, err := migrations.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
data, err := migrations.ReadFile("migrations/" + entry.Name())
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
|
||||
}
|
||||
if _, err := conn.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("exec migration %s: %w", entry.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
271
internal/db/store_test.go
Normal file
271
internal/db/store_test.go
Normal file
@ -0,0 +1,271 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testStore(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
s, err := OpenPath(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func TestOpenAndMigrate(t *testing.T) {
|
||||
s := testStore(t)
|
||||
// Verify the store is functional by counting sessions.
|
||||
ctx := context.Background()
|
||||
count, err := s.CountSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("count sessions: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("expected 0 sessions, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCRUD(t *testing.T) {
|
||||
s := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create.
|
||||
sess, err := s.CreateSession(ctx, CreateSessionParams{
|
||||
Title: "Test Session",
|
||||
Model: "qwen3.5:4b",
|
||||
Mode: "BUILD",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
if sess.Title != "Test Session" {
|
||||
t.Fatalf("expected title 'Test Session', got %q", sess.Title)
|
||||
}
|
||||
|
||||
// Read.
|
||||
got, err := s.GetSession(ctx, sess.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
if got.Model != "qwen3.5:4b" {
|
||||
t.Fatalf("expected model 'qwen3.5:4b', got %q", got.Model)
|
||||
}
|
||||
|
||||
// Create message.
|
||||
msg, err := s.CreateSessionMessage(ctx, CreateSessionMessageParams{
|
||||
SessionID: sess.ID,
|
||||
Role: "user",
|
||||
Content: "Hello world",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create message: %v", err)
|
||||
}
|
||||
if msg.Content != "Hello world" {
|
||||
t.Fatalf("expected content 'Hello world', got %q", msg.Content)
|
||||
}
|
||||
|
||||
// List messages.
|
||||
msgs, err := s.GetSessionMessages(ctx, sess.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get messages: %v", err)
|
||||
}
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(msgs))
|
||||
}
|
||||
|
||||
// Delete.
|
||||
if err := s.DeleteSession(ctx, sess.ID); err != nil {
|
||||
t.Fatalf("delete session: %v", err)
|
||||
}
|
||||
count, err := s.CountSessions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("expected 0 after delete, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPermissions(t *testing.T) {
|
||||
s := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Upsert.
|
||||
perm, err := s.UpsertToolPermission(ctx, UpsertToolPermissionParams{
|
||||
ToolName: "bash",
|
||||
Policy: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("upsert permission: %v", err)
|
||||
}
|
||||
if perm.Policy != "allow" {
|
||||
t.Fatalf("expected policy 'allow', got %q", perm.Policy)
|
||||
}
|
||||
|
||||
// Update via upsert.
|
||||
perm2, err := s.UpsertToolPermission(ctx, UpsertToolPermissionParams{
|
||||
ToolName: "bash",
|
||||
Policy: "deny",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("upsert update: %v", err)
|
||||
}
|
||||
if perm2.Policy != "deny" {
|
||||
t.Fatalf("expected policy 'deny', got %q", perm2.Policy)
|
||||
}
|
||||
|
||||
// List.
|
||||
perms, err := s.ListToolPermissions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list permissions: %v", err)
|
||||
}
|
||||
if len(perms) != 1 {
|
||||
t.Fatalf("expected 1 permission, got %d", len(perms))
|
||||
}
|
||||
|
||||
// Reset.
|
||||
if err := s.ResetToolPermissions(ctx); err != nil {
|
||||
t.Fatalf("reset: %v", err)
|
||||
}
|
||||
perms, err = s.ListToolPermissions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list after reset: %v", err)
|
||||
}
|
||||
if len(perms) != 0 {
|
||||
t.Fatalf("expected 0 after reset, got %d", len(perms))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenStats(t *testing.T) {
|
||||
s := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sess, err := s.CreateSession(ctx, CreateSessionParams{
|
||||
Title: "Stats Test",
|
||||
Model: "qwen3.5:4b",
|
||||
Mode: "ASK",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
|
||||
// Record usage.
|
||||
_, err = s.RecordTokenUsage(ctx, RecordTokenUsageParams{
|
||||
SessionID: sess.ID,
|
||||
Turn: 1,
|
||||
EvalCount: 100,
|
||||
PromptTokens: 500,
|
||||
Model: "qwen3.5:4b",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("record usage: %v", err)
|
||||
}
|
||||
|
||||
_, err = s.RecordTokenUsage(ctx, RecordTokenUsageParams{
|
||||
SessionID: sess.ID,
|
||||
Turn: 2,
|
||||
EvalCount: 200,
|
||||
PromptTokens: 600,
|
||||
Model: "qwen3.5:4b",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("record usage 2: %v", err)
|
||||
}
|
||||
|
||||
// Get totals.
|
||||
totals, err := s.GetSessionTotalTokens(ctx, sess.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get totals: %v", err)
|
||||
}
|
||||
if totals.TotalEval != 300 {
|
||||
t.Fatalf("expected total_eval 300, got %v", totals.TotalEval)
|
||||
}
|
||||
if totals.TotalPrompt != 1100 {
|
||||
t.Fatalf("expected total_prompt 1100, got %v", totals.TotalPrompt)
|
||||
}
|
||||
if totals.TurnCount != 2 {
|
||||
t.Fatalf("expected turn_count 2, got %v", totals.TurnCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileChanges(t *testing.T) {
|
||||
s := testStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sess, err := s.CreateSession(ctx, CreateSessionParams{
|
||||
Title: "Changes Test",
|
||||
Model: "qwen3.5:4b",
|
||||
Mode: "BUILD",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
|
||||
_, err = s.RecordFileChange(ctx, RecordFileChangeParams{
|
||||
SessionID: sess.ID,
|
||||
FilePath: "main.go",
|
||||
ToolName: "write_file",
|
||||
Added: 10,
|
||||
Removed: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("record change: %v", err)
|
||||
}
|
||||
|
||||
_, err = s.RecordFileChange(ctx, RecordFileChangeParams{
|
||||
SessionID: sess.ID,
|
||||
FilePath: "main.go",
|
||||
ToolName: "write_file",
|
||||
Added: 5,
|
||||
Removed: 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("record change 2: %v", err)
|
||||
}
|
||||
|
||||
summary, err := s.GetSessionFileChangeSummary(ctx, sess.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get summary: %v", err)
|
||||
}
|
||||
if len(summary) != 1 {
|
||||
t.Fatalf("expected 1 file, got %d", len(summary))
|
||||
}
|
||||
if summary[0].TotalAdded != 15 {
|
||||
t.Fatalf("expected 15 added, got %d", summary[0].TotalAdded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoubleOpen(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.db")
|
||||
|
||||
s1, err := OpenPath(path)
|
||||
if err != nil {
|
||||
t.Fatalf("first open: %v", err)
|
||||
}
|
||||
defer s1.Close()
|
||||
|
||||
// Running migrations again should be idempotent (IF NOT EXISTS).
|
||||
s2, err := OpenPath(path)
|
||||
if err != nil {
|
||||
t.Fatalf("second open: %v", err)
|
||||
}
|
||||
defer s2.Close()
|
||||
}
|
||||
|
||||
func TestOpenDefault(t *testing.T) {
|
||||
if os.Getenv("CI") != "" {
|
||||
t.Skip("skip in CI to avoid side effects")
|
||||
}
|
||||
s, err := Open()
|
||||
if err != nil {
|
||||
t.Fatalf("open default: %v", err)
|
||||
}
|
||||
s.Close()
|
||||
}
|
||||
126
internal/ice/assembler.go
Normal file
126
internal/ice/assembler.go
Normal file
@ -0,0 +1,126 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
type Assembler struct {
|
||||
embedder *Embedder
|
||||
convStore *Store
|
||||
memStore *memory.Store
|
||||
budgetCfg BudgetConfig
|
||||
sessionID string
|
||||
}
|
||||
|
||||
func (a *Assembler) Assemble(ctx context.Context, query string) (string, error) {
|
||||
budget := a.budgetCfg.Calculate(0)
|
||||
type convResult struct {
|
||||
chunks []ContextChunk
|
||||
err error
|
||||
}
|
||||
type memResult struct {
|
||||
chunks []ContextChunk
|
||||
}
|
||||
convCh := make(chan convResult, 1)
|
||||
memCh := make(chan memResult, 1)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
chunks, err := a.retrieveConversations(ctx, query, budget.Conversation)
|
||||
convCh <- convResult{chunks: chunks, err: err}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
chunks := a.retrieveMemories(query, budget.Memory)
|
||||
memCh <- memResult{chunks: chunks}
|
||||
}()
|
||||
wg.Wait()
|
||||
close(convCh)
|
||||
close(memCh)
|
||||
cr := <-convCh
|
||||
mr := <-memCh
|
||||
if cr.err != nil {
|
||||
return "", fmt.Errorf("conversation retrieval: %w", cr.err)
|
||||
}
|
||||
return formatContext(cr.chunks, mr.chunks), nil
|
||||
}
|
||||
|
||||
func (a *Assembler) retrieveConversations(ctx context.Context, query string, tokenBudget int) ([]ContextChunk, error) {
|
||||
if tokenBudget <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
queryEmb, err := a.embedder.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := a.convStore.Search(queryEmb, a.sessionID, 20)
|
||||
var chunks []ContextChunk
|
||||
usedTokens := 0
|
||||
for _, r := range results {
|
||||
tokens := estimateTokens(r.Entry.Content)
|
||||
if usedTokens+tokens > tokenBudget {
|
||||
continue
|
||||
}
|
||||
chunks = append(chunks, ContextChunk{
|
||||
Source: SourceConversation,
|
||||
Content: r.Entry.Content,
|
||||
Score: r.Score,
|
||||
Tokens: tokens,
|
||||
})
|
||||
usedTokens += tokens
|
||||
}
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
func (a *Assembler) retrieveMemories(query string, tokenBudget int) []ContextChunk {
|
||||
if a.memStore == nil || tokenBudget <= 0 {
|
||||
return nil
|
||||
}
|
||||
memories := a.memStore.Recall(query, 10)
|
||||
var chunks []ContextChunk
|
||||
usedTokens := 0
|
||||
for _, m := range memories {
|
||||
tokens := estimateTokens(m.Content)
|
||||
if usedTokens+tokens > tokenBudget {
|
||||
continue
|
||||
}
|
||||
content := m.Content
|
||||
if len(m.Tags) > 0 {
|
||||
content += " [" + strings.Join(m.Tags, ", ") + "]"
|
||||
}
|
||||
chunks = append(chunks, ContextChunk{
|
||||
Source: SourceMemory,
|
||||
Content: content,
|
||||
Tokens: tokens,
|
||||
})
|
||||
usedTokens += tokens
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
func formatContext(convChunks, memChunks []ContextChunk) string {
|
||||
var sb strings.Builder
|
||||
if len(convChunks) > 0 {
|
||||
sb.WriteString("\n## Relevant Past Conversations\n\n")
|
||||
for _, c := range convChunks {
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(c.Content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
if len(memChunks) > 0 {
|
||||
sb.WriteString("\n## Remembered Facts\n\n")
|
||||
for _, c := range memChunks {
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(c.Content)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
88
internal/ice/assembler_test.go
Normal file
88
internal/ice/assembler_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormatContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
convChunks []ContextChunk
|
||||
memChunks []ContextChunk
|
||||
wantConv bool // should contain "Relevant Past Conversations"
|
||||
wantMem bool // should contain "Remembered Facts"
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "both conversation and memory chunks",
|
||||
convChunks: []ContextChunk{
|
||||
{Source: SourceConversation, Content: "past chat about Go"},
|
||||
},
|
||||
memChunks: []ContextChunk{
|
||||
{Source: SourceMemory, Content: "user prefers dark mode"},
|
||||
},
|
||||
wantConv: true,
|
||||
wantMem: true,
|
||||
},
|
||||
{
|
||||
name: "conversations only",
|
||||
convChunks: []ContextChunk{
|
||||
{Source: SourceConversation, Content: "previous discussion"},
|
||||
},
|
||||
memChunks: nil,
|
||||
wantConv: true,
|
||||
wantMem: false,
|
||||
},
|
||||
{
|
||||
name: "memories only",
|
||||
convChunks: nil,
|
||||
memChunks: []ContextChunk{
|
||||
{Source: SourceMemory, Content: "user name is Alice"},
|
||||
},
|
||||
wantConv: false,
|
||||
wantMem: true,
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
convChunks: nil,
|
||||
memChunks: nil,
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := formatContext(tt.convChunks, tt.memChunks)
|
||||
|
||||
if tt.wantEmpty {
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string, got %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
hasConv := strings.Contains(got, "Relevant Past Conversations")
|
||||
hasMem := strings.Contains(got, "Remembered Facts")
|
||||
|
||||
if hasConv != tt.wantConv {
|
||||
t.Errorf("has conversations section = %v, want %v", hasConv, tt.wantConv)
|
||||
}
|
||||
if hasMem != tt.wantMem {
|
||||
t.Errorf("has memories section = %v, want %v", hasMem, tt.wantMem)
|
||||
}
|
||||
|
||||
// Verify content is present in output.
|
||||
for _, c := range tt.convChunks {
|
||||
if !strings.Contains(got, c.Content) {
|
||||
t.Errorf("output missing conversation content %q", c.Content)
|
||||
}
|
||||
}
|
||||
for _, c := range tt.memChunks {
|
||||
if !strings.Contains(got, c.Content) {
|
||||
t.Errorf("output missing memory content %q", c.Content)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
76
internal/ice/automemory.go
Normal file
76
internal/ice/automemory.go
Normal file
@ -0,0 +1,76 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
var autoMemorySystemPrompt = "Extract any important facts, user preferences, decisions, or action items from this exchange.\n" +
|
||||
"Output one item per line in the format: TYPE: content\n" +
|
||||
"Where TYPE is one of: FACT, DECISION, PREFERENCE, TODO\n" +
|
||||
"If there is nothing worth remembering, output exactly: NONE"
|
||||
|
||||
var autoMemoryUserTemplate = "User: %s\nAssistant: %s"
|
||||
|
||||
type AutoMemory struct {
|
||||
client llm.Client
|
||||
memStore *memory.Store
|
||||
}
|
||||
|
||||
func (am *AutoMemory) Detect(ctx context.Context, userMsg, assistantMsg string) error {
|
||||
if am.memStore == nil {
|
||||
return nil
|
||||
}
|
||||
if len(userMsg) < 20 && len(assistantMsg) < 50 {
|
||||
return nil
|
||||
}
|
||||
prompt := fmt.Sprintf(autoMemoryUserTemplate, userMsg, assistantMsg)
|
||||
var response strings.Builder
|
||||
err := am.client.ChatStream(ctx, llm.ChatOptions{
|
||||
System: autoMemorySystemPrompt,
|
||||
Messages: []llm.Message{
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
}, func(chunk llm.StreamChunk) error {
|
||||
response.WriteString(chunk.Text)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("auto-memory LLM call: %w", err)
|
||||
}
|
||||
return am.parseAndSave(response.String())
|
||||
}
|
||||
|
||||
func (am *AutoMemory) parseAndSave(response string) error {
|
||||
lines := strings.Split(strings.TrimSpace(response), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.EqualFold(line, "NONE") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, ": ", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
typeName := strings.TrimSpace(parts[0])
|
||||
content := strings.TrimSpace(parts[1])
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
tag := strings.ToLower(typeName)
|
||||
switch tag {
|
||||
case "fact", "decision", "preference", "todo":
|
||||
// Valid type.
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if _, err := am.memStore.Save(content, []string{tag, "auto"}); err != nil {
|
||||
return fmt.Errorf("save auto-memory: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
104
internal/ice/automemory_test.go
Normal file
104
internal/ice/automemory_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
func TestParseAndSave(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantCount int
|
||||
wantTags [][]string
|
||||
}{
|
||||
{
|
||||
name: "valid FACT and DECISION lines",
|
||||
input: "FACT: user likes Go\nDECISION: use postgres",
|
||||
wantCount: 2,
|
||||
wantTags: [][]string{{"fact", "auto"}, {"decision", "auto"}},
|
||||
},
|
||||
{
|
||||
name: "NONE saves nothing",
|
||||
input: "NONE",
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty lines are skipped",
|
||||
input: "\n\n\n",
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid type is skipped",
|
||||
input: "UNKNOWN: something",
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "missing colon format is skipped",
|
||||
input: "this has no colon",
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "PREFERENCE type",
|
||||
input: "PREFERENCE: dark mode",
|
||||
wantCount: 1,
|
||||
wantTags: [][]string{{"preference", "auto"}},
|
||||
},
|
||||
{
|
||||
name: "TODO type",
|
||||
input: "TODO: fix the bug",
|
||||
wantCount: 1,
|
||||
wantTags: [][]string{{"todo", "auto"}},
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
input: "FACT: real fact\nBAD: not valid\nTODO: real todo",
|
||||
wantCount: 2,
|
||||
wantTags: [][]string{{"fact", "auto"}, {"todo", "auto"}},
|
||||
},
|
||||
{
|
||||
name: "empty content after type is skipped",
|
||||
input: "FACT: ",
|
||||
wantCount: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
memPath := filepath.Join(dir, "memories.json")
|
||||
ms := memory.NewStore(memPath)
|
||||
am := &AutoMemory{memStore: ms}
|
||||
err := am.parseAndSave(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("parseAndSave returned error: %v", err)
|
||||
}
|
||||
if ms.Count() != tt.wantCount {
|
||||
t.Errorf("memory count = %d, want %d", ms.Count(), tt.wantCount)
|
||||
}
|
||||
if tt.wantTags != nil {
|
||||
recent := ms.Recent(tt.wantCount)
|
||||
for i, j := 0, len(recent)-1; i < j; i, j = i+1, j-1 {
|
||||
recent[i], recent[j] = recent[j], recent[i]
|
||||
}
|
||||
for i, wantTags := range tt.wantTags {
|
||||
if i >= len(recent) {
|
||||
t.Errorf("missing memory at index %d", i)
|
||||
continue
|
||||
}
|
||||
got := recent[i].Tags
|
||||
if len(got) != len(wantTags) {
|
||||
t.Errorf("memory[%d] tags = %v, want %v", i, got, wantTags)
|
||||
continue
|
||||
}
|
||||
for j := range wantTags {
|
||||
if got[j] != wantTags[j] {
|
||||
t.Errorf("memory[%d] tag[%d] = %q, want %q", i, j, got[j], wantTags[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
54
internal/ice/budget.go
Normal file
54
internal/ice/budget.go
Normal file
@ -0,0 +1,54 @@
|
||||
package ice
|
||||
|
||||
// BudgetConfig controls how the context window is divided among sources.
|
||||
type BudgetConfig struct {
|
||||
NumCtx int
|
||||
SystemReserve int // tokens reserved for system prompt
|
||||
RecentReserve int // tokens reserved for recent conversation
|
||||
ConversationPct float64 // fraction of remaining budget for past conversations
|
||||
MemoryPct float64 // fraction of remaining budget for memories
|
||||
CodePct float64 // fraction of remaining budget for code context
|
||||
}
|
||||
|
||||
// DefaultBudgetConfig returns sensible defaults for a given context window.
|
||||
func DefaultBudgetConfig(numCtx int) BudgetConfig {
|
||||
return BudgetConfig{
|
||||
NumCtx: numCtx,
|
||||
SystemReserve: 1500,
|
||||
RecentReserve: 2000,
|
||||
ConversationPct: 0.40,
|
||||
MemoryPct: 0.20,
|
||||
CodePct: 0.40,
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate allocates token budgets given how many tokens the current prompt uses.
|
||||
func (bc BudgetConfig) Calculate(promptTokens int) Budget {
|
||||
// Use 75% of numCtx as total available.
|
||||
available := int(float64(bc.NumCtx) * 0.75)
|
||||
available -= bc.SystemReserve
|
||||
available -= bc.RecentReserve
|
||||
available -= promptTokens
|
||||
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return Budget{
|
||||
Total: available,
|
||||
System: bc.SystemReserve,
|
||||
Recent: bc.RecentReserve,
|
||||
Conversation: int(float64(available) * bc.ConversationPct),
|
||||
Memory: int(float64(available) * bc.MemoryPct),
|
||||
Code: int(float64(available) * bc.CodePct),
|
||||
}
|
||||
}
|
||||
|
||||
// estimateTokens returns a rough token count for a string (chars / 4).
|
||||
func estimateTokens(s string) int {
|
||||
n := len(s) / 4
|
||||
if n == 0 && len(s) > 0 {
|
||||
n = 1
|
||||
}
|
||||
return n
|
||||
}
|
||||
133
internal/ice/budget_test.go
Normal file
133
internal/ice/budget_test.go
Normal file
@ -0,0 +1,133 @@
|
||||
package ice
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBudgetConfig_Calculate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg BudgetConfig
|
||||
promptTokens int
|
||||
wantTotal int
|
||||
wantConv int
|
||||
wantMemory int
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "normal allocation",
|
||||
cfg: BudgetConfig{
|
||||
NumCtx: 8192,
|
||||
SystemReserve: 1500,
|
||||
RecentReserve: 2000,
|
||||
ConversationPct: 0.40,
|
||||
MemoryPct: 0.20,
|
||||
CodePct: 0.40,
|
||||
},
|
||||
promptTokens: 500,
|
||||
// available = int(8192*0.75) - 1500 - 2000 - 500 = 6144 - 4000 = 2144
|
||||
wantTotal: 2144,
|
||||
wantConv: 857, // int(2144 * 0.40) = 857
|
||||
wantMemory: 428, // int(2144 * 0.20) = 428
|
||||
wantCode: 857, // int(2144 * 0.40) = 857
|
||||
},
|
||||
{
|
||||
name: "large prompt clamps to zero",
|
||||
cfg: BudgetConfig{
|
||||
NumCtx: 8192,
|
||||
SystemReserve: 1500,
|
||||
RecentReserve: 2000,
|
||||
ConversationPct: 0.40,
|
||||
MemoryPct: 0.20,
|
||||
CodePct: 0.40,
|
||||
},
|
||||
promptTokens: 99999,
|
||||
wantTotal: 0,
|
||||
wantConv: 0,
|
||||
wantMemory: 0,
|
||||
wantCode: 0,
|
||||
},
|
||||
{
|
||||
name: "exact boundary available is zero",
|
||||
cfg: BudgetConfig{
|
||||
NumCtx: 8192,
|
||||
SystemReserve: 1500,
|
||||
RecentReserve: 2000,
|
||||
ConversationPct: 0.40,
|
||||
MemoryPct: 0.20,
|
||||
CodePct: 0.40,
|
||||
},
|
||||
// int(8192*0.75) - 1500 - 2000 = 2644
|
||||
promptTokens: 2644,
|
||||
wantTotal: 0,
|
||||
wantConv: 0,
|
||||
wantMemory: 0,
|
||||
wantCode: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := tt.cfg.Calculate(tt.promptTokens)
|
||||
if b.Total != tt.wantTotal {
|
||||
t.Errorf("Total = %d, want %d", b.Total, tt.wantTotal)
|
||||
}
|
||||
if b.Conversation != tt.wantConv {
|
||||
t.Errorf("Conversation = %d, want %d", b.Conversation, tt.wantConv)
|
||||
}
|
||||
if b.Memory != tt.wantMemory {
|
||||
t.Errorf("Memory = %d, want %d", b.Memory, tt.wantMemory)
|
||||
}
|
||||
if b.Code != tt.wantCode {
|
||||
t.Errorf("Code = %d, want %d", b.Code, tt.wantCode)
|
||||
}
|
||||
if b.System != tt.cfg.SystemReserve {
|
||||
t.Errorf("System = %d, want %d", b.System, tt.cfg.SystemReserve)
|
||||
}
|
||||
if b.Recent != tt.cfg.RecentReserve {
|
||||
t.Errorf("Recent = %d, want %d", b.Recent, tt.cfg.RecentReserve)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "len/4 heuristic",
|
||||
input: "hello world",
|
||||
want: 2, // 11/4 = 2
|
||||
},
|
||||
{
|
||||
name: "single char clamps to 1",
|
||||
input: "a",
|
||||
want: 1, // 1/4 = 0, clamp to 1
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "exactly 4 chars",
|
||||
input: "abcd",
|
||||
want: 1, // 4/4 = 1
|
||||
},
|
||||
{
|
||||
name: "three chars clamps to 1",
|
||||
input: "abc",
|
||||
want: 1, // 3/4 = 0, clamp to 1
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateTokens(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("estimateTokens(%q) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
56
internal/ice/embed.go
Normal file
56
internal/ice/embed.go
Normal file
@ -0,0 +1,56 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEmbedModel = "nomic-embed-text"
|
||||
maxBatchSize = 32
|
||||
)
|
||||
|
||||
type Embedder struct {
|
||||
client llm.Client
|
||||
model string
|
||||
}
|
||||
|
||||
func NewEmbedder(client llm.Client, model string) *Embedder {
|
||||
if model == "" {
|
||||
model = defaultEmbedModel
|
||||
}
|
||||
return &Embedder{client: client, model: model}
|
||||
}
|
||||
|
||||
func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
vecs, err := e.EmbedBatch(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vecs) == 0 {
|
||||
return nil, fmt.Errorf("empty embedding response")
|
||||
}
|
||||
return vecs[0], nil
|
||||
}
|
||||
|
||||
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var all [][]float32
|
||||
for i := 0; i < len(texts); i += maxBatchSize {
|
||||
end := i + maxBatchSize
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
batch := texts[i:end]
|
||||
vecs, err := e.client.Embed(ctx, e.model, batch)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed batch [%d:%d]: %w", i, end, err)
|
||||
}
|
||||
all = append(all, vecs...)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
113
internal/ice/engine.go
Normal file
113
internal/ice/engine.go
Normal file
@ -0,0 +1,113 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/memory"
|
||||
)
|
||||
|
||||
type EngineConfig struct {
|
||||
EmbedModel string
|
||||
StorePath string
|
||||
NumCtx int
|
||||
}
|
||||
|
||||
type Engine struct {
|
||||
embedder *Embedder
|
||||
store *Store
|
||||
memStore *memory.Store
|
||||
budgetCfg BudgetConfig
|
||||
sessionID string
|
||||
turnIndex int
|
||||
autoMemory *AutoMemory
|
||||
}
|
||||
|
||||
func NewEngine(client llm.Client, memStore *memory.Store, cfg EngineConfig) (*Engine, error) {
|
||||
storePath := cfg.StorePath
|
||||
if storePath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("determine home dir: %w", err)
|
||||
}
|
||||
storePath = filepath.Join(home, ".config", "ai-agent", "conversations.json")
|
||||
}
|
||||
embedModel := cfg.EmbedModel
|
||||
if embedModel == "" {
|
||||
embedModel = defaultEmbedModel
|
||||
}
|
||||
sessionID := fmt.Sprintf("s_%d", time.Now().UnixNano())
|
||||
return &Engine{
|
||||
embedder: NewEmbedder(client, embedModel),
|
||||
store: NewStore(storePath),
|
||||
memStore: memStore,
|
||||
budgetCfg: DefaultBudgetConfig(cfg.NumCtx),
|
||||
sessionID: sessionID,
|
||||
autoMemory: &AutoMemory{client: client, memStore: memStore},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) AssembleContext(ctx context.Context, query string) (string, error) {
|
||||
a := &Assembler{
|
||||
embedder: e.embedder,
|
||||
convStore: e.store,
|
||||
memStore: e.memStore,
|
||||
budgetCfg: e.budgetCfg,
|
||||
sessionID: e.sessionID,
|
||||
}
|
||||
return a.Assemble(ctx, query)
|
||||
}
|
||||
|
||||
func (e *Engine) IndexMessage(ctx context.Context, role, content string) error {
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
text := content
|
||||
if len(text) > 2000 {
|
||||
text = text[:2000]
|
||||
}
|
||||
emb, err := e.embedder.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return fmt.Errorf("embed message: %w", err)
|
||||
}
|
||||
e.turnIndex++
|
||||
_, err = e.store.Add(e.sessionID, role, text, emb, e.turnIndex)
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *Engine) IndexSummary(ctx context.Context, summary string) error {
|
||||
if summary == "" {
|
||||
return nil
|
||||
}
|
||||
emb, err := e.embedder.Embed(ctx, summary)
|
||||
if err != nil {
|
||||
return fmt.Errorf("embed summary: %w", err)
|
||||
}
|
||||
_, err = e.store.Add(e.sessionID, "summary", summary, emb, e.turnIndex)
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *Engine) DetectAutoMemory(ctx context.Context, userMsg, assistantMsg string) {
|
||||
if e.autoMemory == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
_ = e.autoMemory.Detect(ctx, userMsg, assistantMsg)
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) Flush() error {
|
||||
return e.store.Flush()
|
||||
}
|
||||
|
||||
func (e *Engine) Store() *Store {
|
||||
return e.store
|
||||
}
|
||||
|
||||
func (e *Engine) SessionID() string {
|
||||
return e.sessionID
|
||||
}
|
||||
73
internal/ice/engine_test.go
Normal file
73
internal/ice/engine_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEngineConfigDefaults(t *testing.T) {
|
||||
// Test embed model default
|
||||
embedModel := ""
|
||||
if embedModel == "" {
|
||||
embedModel = defaultEmbedModel
|
||||
}
|
||||
if embedModel != defaultEmbedModel {
|
||||
t.Errorf("embedModel = %q, want %q", embedModel, defaultEmbedModel)
|
||||
}
|
||||
|
||||
// Test custom embed model
|
||||
cfg := EngineConfig{
|
||||
EmbedModel: "custom-model",
|
||||
}
|
||||
if cfg.EmbedModel != "custom-model" {
|
||||
t.Errorf("EmbedModel = %q, want %q", cfg.EmbedModel, "custom-model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBudgetConfigCalculate(t *testing.T) {
|
||||
cfg := DefaultBudgetConfig(16384)
|
||||
|
||||
budget := cfg.Calculate(100)
|
||||
// 16384 * 0.75 = 12288
|
||||
// 12288 - 1500 - 2000 - 100 = 8688
|
||||
if budget.Total != 8688 {
|
||||
t.Errorf("Total = %d, want %d", budget.Total, 8688)
|
||||
}
|
||||
if budget.System != 1500 {
|
||||
t.Errorf("System = %d, want %d", budget.System, 1500)
|
||||
}
|
||||
if budget.Recent != 2000 {
|
||||
t.Errorf("Recent = %d, want %d", budget.Recent, 2000)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBudgetConfigCalculateNegative(t *testing.T) {
|
||||
// With small context, should not panic and return zeros
|
||||
cfg := DefaultBudgetConfig(1000)
|
||||
budget := cfg.Calculate(500)
|
||||
|
||||
// 1000 * 0.75 = 750
|
||||
// 750 - 1500 - 2000 - 500 = -3250 -> clamped to 0
|
||||
if budget.Total != 0 {
|
||||
t.Errorf("Total should be 0 when budget is negative, got %d", budget.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBudgetConfigPercentages(t *testing.T) {
|
||||
cfg := DefaultBudgetConfig(16384)
|
||||
budget := cfg.Calculate(100)
|
||||
|
||||
// Check percentages: ConversationPct=0.40, MemoryPct=0.20, CodePct=0.40
|
||||
// available = 12288 - 1500 - 2000 - 100 = 8688
|
||||
// Conversation = 8688 * 0.40 = 3475
|
||||
// Memory = 8688 * 0.20 = 1737
|
||||
// Code = 8688 * 0.40 = 3475
|
||||
if budget.Conversation != 3475 {
|
||||
t.Errorf("Conversation = %d, want %d", budget.Conversation, 3475)
|
||||
}
|
||||
if budget.Memory != 1737 {
|
||||
t.Errorf("Memory = %d, want %d", budget.Memory, 1737)
|
||||
}
|
||||
if budget.Code != 3475 {
|
||||
t.Errorf("Code = %d, want %d", budget.Code, 3475)
|
||||
}
|
||||
}
|
||||
167
internal/ice/store.go
Normal file
167
internal/ice/store.go
Normal file
@ -0,0 +1,167 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// timeNow is a variable for testing.
|
||||
var timeNow = time.Now
|
||||
|
||||
const minSimilarityThreshold = 0.3
|
||||
|
||||
// Store is a flat-file vector store for conversation history.
|
||||
// It holds all entries in memory and persists to a JSON file.
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
entries []ConversationEntry
|
||||
nextID int
|
||||
dirty bool
|
||||
}
|
||||
|
||||
// NewStore loads an existing store from path or creates an empty one.
|
||||
func NewStore(path string) *Store {
|
||||
s := &Store{path: path}
|
||||
s.load()
|
||||
return s
|
||||
}
|
||||
|
||||
// Add appends a new conversation entry and returns its ID.
|
||||
func (s *Store) Add(sessionID, role, content string, embedding []float32, turnIndex int) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.nextID++
|
||||
entry := ConversationEntry{
|
||||
ID: s.nextID,
|
||||
SessionID: sessionID,
|
||||
Role: role,
|
||||
Content: content,
|
||||
Embedding: embedding,
|
||||
TurnIndex: turnIndex,
|
||||
}
|
||||
// Use a zero-value check to set CreatedAt (avoids importing time in every call site).
|
||||
entry.CreatedAt = timeNow()
|
||||
s.entries = append(s.entries, entry)
|
||||
s.dirty = true
|
||||
return s.nextID, nil
|
||||
}
|
||||
|
||||
// Search returns the top-K entries most similar to queryEmbedding.
|
||||
// Entries from excludeSession are skipped. Results are sorted by score descending.
|
||||
func (s *Store) Search(queryEmbedding []float32, excludeSession string, topK int) []ScoredEntry {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(queryEmbedding) == 0 || len(s.entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var scored []ScoredEntry
|
||||
for _, e := range s.entries {
|
||||
if e.SessionID == excludeSession {
|
||||
continue
|
||||
}
|
||||
if len(e.Embedding) == 0 {
|
||||
continue
|
||||
}
|
||||
sim := cosineSimilarity(queryEmbedding, e.Embedding)
|
||||
if sim >= minSimilarityThreshold {
|
||||
scored = append(scored, ScoredEntry{Entry: e, Score: sim})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(scored, func(i, j int) bool {
|
||||
return scored[i].Score > scored[j].Score
|
||||
})
|
||||
|
||||
if len(scored) > topK {
|
||||
scored = scored[:topK]
|
||||
}
|
||||
return scored
|
||||
}
|
||||
|
||||
// Flush persists any pending changes to disk.
|
||||
func (s *Store) Flush() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.dirty {
|
||||
return nil
|
||||
}
|
||||
return s.persist()
|
||||
}
|
||||
|
||||
// Count returns the total number of stored entries.
|
||||
func (s *Store) Count() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return len(s.entries)
|
||||
}
|
||||
|
||||
// load reads entries from the JSON file.
|
||||
func (s *Store) load() {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
return // File doesn't exist yet.
|
||||
}
|
||||
|
||||
var entries []ConversationEntry
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
return // Corrupt file, start empty.
|
||||
}
|
||||
|
||||
s.entries = entries
|
||||
for _, e := range s.entries {
|
||||
if e.ID > s.nextID {
|
||||
s.nextID = e.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// persist writes all entries to the JSON file.
|
||||
func (s *Store) persist() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("create ice store dir: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(s.entries)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal ice store: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(s.path, data, 0o644); err != nil {
|
||||
return fmt.Errorf("write ice store: %w", err)
|
||||
}
|
||||
|
||||
s.dirty = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// cosineSimilarity computes the cosine similarity between two vectors.
|
||||
func cosineSimilarity(a, b []float32) float32 {
|
||||
if len(a) != len(b) || len(a) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var dot, normA, normB float64
|
||||
for i := range a {
|
||||
dot += float64(a[i]) * float64(b[i])
|
||||
normA += float64(a[i]) * float64(a[i])
|
||||
normB += float64(b[i]) * float64(b[i])
|
||||
}
|
||||
|
||||
denom := math.Sqrt(normA) * math.Sqrt(normB)
|
||||
if denom == 0 {
|
||||
return 0
|
||||
}
|
||||
return float32(dot / denom)
|
||||
}
|
||||
232
internal/ice/store_test.go
Normal file
232
internal/ice/store_test.go
Normal file
@ -0,0 +1,232 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCosineSimilarity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b []float32
|
||||
want float32
|
||||
tol float32
|
||||
}{
|
||||
{
|
||||
name: "identical vectors",
|
||||
a: []float32{1, 2, 3},
|
||||
b: []float32{1, 2, 3},
|
||||
want: 1.0,
|
||||
tol: 1e-6,
|
||||
},
|
||||
{
|
||||
name: "orthogonal vectors",
|
||||
a: []float32{1, 0},
|
||||
b: []float32{0, 1},
|
||||
want: 0.0,
|
||||
tol: 1e-6,
|
||||
},
|
||||
{
|
||||
name: "opposite vectors",
|
||||
a: []float32{1, 0},
|
||||
b: []float32{-1, 0},
|
||||
want: -1.0,
|
||||
tol: 1e-6,
|
||||
},
|
||||
{
|
||||
name: "different lengths returns 0",
|
||||
a: []float32{1, 0},
|
||||
b: []float32{1, 0, 0},
|
||||
want: 0,
|
||||
tol: 0,
|
||||
},
|
||||
{
|
||||
name: "zero vector returns 0",
|
||||
a: []float32{0, 0},
|
||||
b: []float32{1, 1},
|
||||
want: 0,
|
||||
tol: 0,
|
||||
},
|
||||
{
|
||||
name: "known value with tolerance",
|
||||
a: []float32{1, 1},
|
||||
b: []float32{1, 0},
|
||||
// 1/(sqrt(2)*1) ≈ 0.7071
|
||||
want: float32(1.0 / math.Sqrt(2)),
|
||||
tol: 1e-4,
|
||||
},
|
||||
{
|
||||
name: "empty vectors",
|
||||
a: []float32{},
|
||||
b: []float32{},
|
||||
want: 0,
|
||||
tol: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := cosineSimilarity(tt.a, tt.b)
|
||||
diff := got - tt.want
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff > tt.tol {
|
||||
t.Errorf("cosineSimilarity(%v, %v) = %f, want %f (±%f)",
|
||||
tt.a, tt.b, got, tt.want, tt.tol)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Add_And_Count(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
|
||||
s := NewStore(path)
|
||||
|
||||
if s.Count() != 0 {
|
||||
t.Fatalf("new store Count = %d, want 0", s.Count())
|
||||
}
|
||||
|
||||
id1, err := s.Add("sess1", "user", "hello", []float32{1, 0, 0}, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Add returned error: %v", err)
|
||||
}
|
||||
if id1 != 1 {
|
||||
t.Errorf("first Add returned id=%d, want 1", id1)
|
||||
}
|
||||
if s.Count() != 1 {
|
||||
t.Errorf("Count after first Add = %d, want 1", s.Count())
|
||||
}
|
||||
|
||||
id2, err := s.Add("sess1", "assistant", "world", []float32{0, 1, 0}, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Add returned error: %v", err)
|
||||
}
|
||||
if id2 != 2 {
|
||||
t.Errorf("second Add returned id=%d, want 2", id2)
|
||||
}
|
||||
if s.Count() != 2 {
|
||||
t.Errorf("Count after second Add = %d, want 2", s.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
s := NewStore(path)
|
||||
|
||||
// Add entries with known embeddings.
|
||||
s.Add("sess1", "user", "entry A", []float32{1, 0, 0}, 0)
|
||||
s.Add("sess1", "user", "entry B", []float32{0, 1, 0}, 1)
|
||||
s.Add("sess2", "user", "entry C", []float32{0.9, 0.1, 0}, 0)
|
||||
s.Add("sess2", "user", "entry D", []float32{0, 0, 1}, 1) // orthogonal to query
|
||||
|
||||
t.Run("similarity filtering and sorting", func(t *testing.T) {
|
||||
// Query similar to entries A and C, exclude no session.
|
||||
results := s.Search([]float32{1, 0, 0}, "", 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected results, got 0")
|
||||
}
|
||||
// Entry A should be highest (identical to query).
|
||||
if results[0].Entry.Content != "entry A" {
|
||||
t.Errorf("top result = %q, want 'entry A'", results[0].Entry.Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session exclusion", func(t *testing.T) {
|
||||
results := s.Search([]float32{1, 0, 0}, "sess1", 10)
|
||||
for _, r := range results {
|
||||
if r.Entry.SessionID == "sess1" {
|
||||
t.Errorf("excluded session sess1 should not appear in results")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("min threshold 0.3", func(t *testing.T) {
|
||||
// Entry D: [0,0,1] is orthogonal to [1,0,0] → similarity 0.
|
||||
results := s.Search([]float32{1, 0, 0}, "", 10)
|
||||
for _, r := range results {
|
||||
if r.Score < minSimilarityThreshold {
|
||||
t.Errorf("result %q has score %f below threshold %f",
|
||||
r.Entry.Content, r.Score, minSimilarityThreshold)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("topK limit", func(t *testing.T) {
|
||||
results := s.Search([]float32{1, 0, 0}, "", 1)
|
||||
if len(results) > 1 {
|
||||
t.Errorf("topK=1 but got %d results", len(results))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty store returns nil", func(t *testing.T) {
|
||||
emptyPath := filepath.Join(dir, "empty.json")
|
||||
empty := NewStore(emptyPath)
|
||||
results := empty.Search([]float32{1, 0}, "", 5)
|
||||
if results != nil {
|
||||
t.Errorf("empty store search should return nil, got %v", results)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty query embedding returns nil", func(t *testing.T) {
|
||||
results := s.Search([]float32{}, "", 5)
|
||||
if results != nil {
|
||||
t.Errorf("empty query should return nil, got %v", results)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Flush_Persistence(t *testing.T) {
|
||||
t.Run("round trip", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
|
||||
s1 := NewStore(path)
|
||||
s1.Add("sess1", "user", "hello", []float32{1, 0}, 0)
|
||||
s1.Add("sess1", "assistant", "world", []float32{0, 1}, 1)
|
||||
|
||||
if err := s1.Flush(); err != nil {
|
||||
t.Fatalf("Flush: %v", err)
|
||||
}
|
||||
|
||||
// Reload from same path.
|
||||
s2 := NewStore(path)
|
||||
if s2.Count() != 2 {
|
||||
t.Errorf("reloaded store Count = %d, want 2", s2.Count())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("corrupt JSON recovery", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
|
||||
// Write corrupt JSON.
|
||||
os.WriteFile(path, []byte("not valid json{{{"), 0o644)
|
||||
|
||||
s := NewStore(path)
|
||||
if s.Count() != 0 {
|
||||
t.Errorf("corrupt store Count = %d, want 0", s.Count())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nextID restoration", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
|
||||
s1 := NewStore(path)
|
||||
s1.Add("sess1", "user", "first", []float32{1}, 0)
|
||||
s1.Add("sess1", "user", "second", []float32{1}, 1)
|
||||
s1.Flush()
|
||||
|
||||
s2 := NewStore(path)
|
||||
id, _ := s2.Add("sess1", "user", "third", []float32{1}, 2)
|
||||
if id != 3 {
|
||||
t.Errorf("continued id = %d, want 3", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
46
internal/ice/types.go
Normal file
46
internal/ice/types.go
Normal file
@ -0,0 +1,46 @@
|
||||
package ice
|
||||
|
||||
import "time"
|
||||
|
||||
// SourceKind identifies where a context chunk came from.
|
||||
type SourceKind int
|
||||
|
||||
const (
|
||||
SourceConversation SourceKind = iota
|
||||
SourceMemory
|
||||
)
|
||||
|
||||
// ConversationEntry is a single stored message with its embedding.
|
||||
type ConversationEntry struct {
|
||||
ID int `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Role string `json:"role"` // "user", "assistant", "summary"
|
||||
Content string `json:"content"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
TurnIndex int `json:"turn_index"`
|
||||
}
|
||||
|
||||
// ScoredEntry pairs a conversation entry with its similarity score.
|
||||
type ScoredEntry struct {
|
||||
Entry ConversationEntry
|
||||
Score float32
|
||||
}
|
||||
|
||||
// ContextChunk is a piece of assembled context ready for the prompt.
|
||||
type ContextChunk struct {
|
||||
Source SourceKind
|
||||
Content string
|
||||
Score float32
|
||||
Tokens int
|
||||
}
|
||||
|
||||
// Budget holds the token allocation for each context source.
|
||||
type Budget struct {
|
||||
Total int
|
||||
System int
|
||||
Conversation int
|
||||
Memory int
|
||||
Code int
|
||||
Recent int
|
||||
}
|
||||
184
internal/initcmd/initcmd.go
Normal file
184
internal/initcmd/initcmd.go
Normal file
@ -0,0 +1,184 @@
|
||||
package initcmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// projectMarker maps a marker file name to its detected project type.
|
||||
var projectMarkers = map[string]string{
|
||||
"go.mod": "Go",
|
||||
"go.sum": "Go",
|
||||
"package.json": "Node.js",
|
||||
"Cargo.toml": "Rust",
|
||||
"pyproject.toml": "Python",
|
||||
"requirements.txt": "Python",
|
||||
"setup.py": "Python",
|
||||
"Pipfile": "Python",
|
||||
"Gemfile": "Ruby",
|
||||
"pom.xml": "Java (Maven)",
|
||||
"build.gradle": "Java (Gradle)",
|
||||
"build.gradle.kts": "Kotlin (Gradle)",
|
||||
"CMakeLists.txt": "C/C++ (CMake)",
|
||||
"Makefile": "Make",
|
||||
"Taskfile.yml": "Taskfile",
|
||||
"Taskfile.yaml": "Taskfile",
|
||||
"docker-compose.yml": "Docker Compose",
|
||||
"docker-compose.yaml": "Docker Compose",
|
||||
"Dockerfile": "Docker",
|
||||
".gitignore": "Git",
|
||||
}
|
||||
|
||||
// Options configures the behaviour of Run.
|
||||
type Options struct {
|
||||
// Force overwrites an existing AGENT.md.
|
||||
Force bool
|
||||
}
|
||||
|
||||
// Run scans dir for project markers and generates an AGENT.md file.
|
||||
// It returns an error if AGENT.md already exists unless opts.Force is true.
|
||||
func Run(dir string, opts Options) error {
|
||||
agentPath := filepath.Join(dir, "AGENT.md")
|
||||
|
||||
if !opts.Force {
|
||||
if _, err := os.Stat(agentPath); err == nil {
|
||||
return fmt.Errorf("AGENT.md already exists in %s (use --force to overwrite)", dir)
|
||||
}
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading directory: %w", err)
|
||||
}
|
||||
|
||||
// Detect project types from marker files.
|
||||
detectedTypes := detectProjectTypes(entries)
|
||||
|
||||
// Build directory listing.
|
||||
listing := buildDirectoryListing(dir, entries)
|
||||
|
||||
// Generate AGENT.md content.
|
||||
content := generateAgentMD(detectedTypes, listing)
|
||||
|
||||
if err := os.WriteFile(agentPath, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("writing AGENT.md: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectProjectTypes returns a deduplicated, sorted list of project types
|
||||
// found based on marker files in the directory entries.
|
||||
func detectProjectTypes(entries []os.DirEntry) []string {
|
||||
seen := make(map[string]bool)
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if pt, ok := projectMarkers[e.Name()]; ok {
|
||||
seen[pt] = true
|
||||
}
|
||||
}
|
||||
|
||||
types := make([]string, 0, len(seen))
|
||||
for t := range seen {
|
||||
types = append(types, t)
|
||||
}
|
||||
sort.Strings(types)
|
||||
return types
|
||||
}
|
||||
|
||||
// buildDirectoryListing returns a formatted string of top-level files and
|
||||
// first-level subdirectory contents.
|
||||
func buildDirectoryListing(dir string, entries []os.DirEntry) string {
|
||||
var b strings.Builder
|
||||
|
||||
var files []string
|
||||
var dirs []string
|
||||
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
// Skip hidden files/dirs except well-known ones.
|
||||
if strings.HasPrefix(name, ".") && name != ".gitignore" {
|
||||
continue
|
||||
}
|
||||
if e.IsDir() {
|
||||
dirs = append(dirs, name)
|
||||
} else {
|
||||
files = append(files, name)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(files)
|
||||
sort.Strings(dirs)
|
||||
|
||||
for _, f := range files {
|
||||
b.WriteString(f)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
for _, d := range dirs {
|
||||
b.WriteString(d + "/\n")
|
||||
subEntries, err := os.ReadDir(filepath.Join(dir, d))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var subNames []string
|
||||
for _, se := range subEntries {
|
||||
n := se.Name()
|
||||
if strings.HasPrefix(n, ".") {
|
||||
continue
|
||||
}
|
||||
if se.IsDir() {
|
||||
subNames = append(subNames, n+"/")
|
||||
} else {
|
||||
subNames = append(subNames, n)
|
||||
}
|
||||
}
|
||||
sort.Strings(subNames)
|
||||
for _, sn := range subNames {
|
||||
b.WriteString(" " + sn + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// generateAgentMD produces the Markdown content for the AGENT.md file.
|
||||
func generateAgentMD(projectTypes []string, listing string) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("# AGENT.md\n\n")
|
||||
|
||||
// Project type section.
|
||||
b.WriteString("## Project Type\n\n")
|
||||
if len(projectTypes) == 0 {
|
||||
b.WriteString("Unknown\n")
|
||||
} else {
|
||||
b.WriteString(strings.Join(projectTypes, ", ") + "\n")
|
||||
}
|
||||
|
||||
// Directory structure.
|
||||
b.WriteString("\n## Directory Structure\n\n")
|
||||
b.WriteString("```\n")
|
||||
b.WriteString(listing)
|
||||
b.WriteString("```\n")
|
||||
|
||||
// Placeholder sections.
|
||||
b.WriteString("\n## Build Commands\n\n")
|
||||
b.WriteString("<!-- Add build, test, and run commands here -->\n")
|
||||
|
||||
b.WriteString("\n## Architecture\n\n")
|
||||
b.WriteString("<!-- Describe the high-level architecture here -->\n")
|
||||
|
||||
b.WriteString("\n## Key Files\n\n")
|
||||
b.WriteString("<!-- List important files and their purposes here -->\n")
|
||||
|
||||
b.WriteString("\n## Notes\n\n")
|
||||
b.WriteString("<!-- Any additional notes for the agent -->\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
141
internal/initcmd/initcmd_test.go
Normal file
141
internal/initcmd/initcmd_test.go
Normal file
@ -0,0 +1,141 @@
|
||||
package initcmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRun_GoProject(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a go.mod marker file.
|
||||
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com/test\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Create a source directory with a file.
|
||||
if err := os.Mkdir(filepath.Join(dir, "cmd"), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "cmd", "main.go"), []byte("package main\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Run(dir, Options{}); err != nil {
|
||||
t.Fatalf("Run() returned error: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, "AGENT.md"))
|
||||
if err != nil {
|
||||
t.Fatalf("reading AGENT.md: %v", err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
|
||||
// Check project type detection.
|
||||
if !strings.Contains(content, "Go") {
|
||||
t.Error("expected AGENT.md to contain 'Go' project type")
|
||||
}
|
||||
|
||||
// Check directory listing includes go.mod.
|
||||
if !strings.Contains(content, "go.mod") {
|
||||
t.Error("expected AGENT.md to list go.mod")
|
||||
}
|
||||
|
||||
// Check directory listing includes cmd/ subdirectory.
|
||||
if !strings.Contains(content, "cmd/") {
|
||||
t.Error("expected AGENT.md to list cmd/ directory")
|
||||
}
|
||||
|
||||
// Check that main.go appears under cmd/.
|
||||
if !strings.Contains(content, "main.go") {
|
||||
t.Error("expected AGENT.md to list main.go inside cmd/")
|
||||
}
|
||||
|
||||
// Check placeholder sections exist.
|
||||
for _, section := range []string{"## Build Commands", "## Architecture", "## Key Files", "## Notes"} {
|
||||
if !strings.Contains(content, section) {
|
||||
t.Errorf("expected AGENT.md to contain section %q", section)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_EmptyDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
if err := Run(dir, Options{}); err != nil {
|
||||
t.Fatalf("Run() returned error: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, "AGENT.md"))
|
||||
if err != nil {
|
||||
t.Fatalf("reading AGENT.md: %v", err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
|
||||
// With no marker files, project type should be "Unknown".
|
||||
if !strings.Contains(content, "Unknown") {
|
||||
t.Error("expected AGENT.md to contain 'Unknown' project type for empty dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ExistingAgentMD_NoOverwrite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
agentPath := filepath.Join(dir, "AGENT.md")
|
||||
original := "# Original content\n"
|
||||
if err := os.WriteFile(agentPath, []byte(original), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := Run(dir, Options{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when AGENT.md already exists")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "already exists") {
|
||||
t.Errorf("expected 'already exists' in error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify file was not modified.
|
||||
data, err := os.ReadFile(agentPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != original {
|
||||
t.Error("AGENT.md was unexpectedly modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_Force(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
agentPath := filepath.Join(dir, "AGENT.md")
|
||||
if err := os.WriteFile(agentPath, []byte("# Old content\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a go.mod so the new content is distinguishable.
|
||||
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test\n"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Run(dir, Options{Force: true}); err != nil {
|
||||
t.Fatalf("Run() with Force=true returned error: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(agentPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if strings.Contains(content, "Old content") {
|
||||
t.Error("AGENT.md should have been overwritten with Force=true")
|
||||
}
|
||||
if !strings.Contains(content, "Go") {
|
||||
t.Error("expected new AGENT.md to detect Go project type")
|
||||
}
|
||||
}
|
||||
230
internal/integration/integration_test.go
Normal file
230
internal/integration/integration_test.go
Normal file
@ -0,0 +1,230 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/agent"
|
||||
"ai-agent/internal/command"
|
||||
"ai-agent/internal/config"
|
||||
"ai-agent/internal/llm"
|
||||
"ai-agent/internal/mcp"
|
||||
"ai-agent/internal/tui"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
)
|
||||
|
||||
func skipIfNoOllama(t *testing.T) {
|
||||
if os.Getenv("OLLAMA_HOST") == "" {
|
||||
os.Setenv("OLLAMA_HOST", "http://localhost:11434")
|
||||
}
|
||||
client := llm.NewClient(llm.Config{
|
||||
BaseURL: os.Getenv("OLLAMA_HOST"),
|
||||
Model: "qwen3.5:2b",
|
||||
NumCtx: 262144,
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Ping(ctx); err != nil {
|
||||
t.Skip("Ollama not available: skipping integration test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTUI_Initialization(t *testing.T) {
|
||||
skipIfNoOllama(t)
|
||||
reg := command.NewRegistry()
|
||||
command.RegisterBuiltins(reg)
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewRouter(&cfg)
|
||||
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
|
||||
modelManager.SetCurrentModel("qwen3.5:2b")
|
||||
ag := agent.New(modelManager, mcp.NewRegistry(), cfg.Ollama.NumCtx)
|
||||
ag.SetRouter(router)
|
||||
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
|
||||
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
m = updated.(*tui.Model)
|
||||
if !m.Ready() {
|
||||
t.Error("TUI should be ready after WindowSizeMsg")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTUI_ScrollAnchorDuringStreaming(t *testing.T) {
|
||||
skipIfNoOllama(t)
|
||||
reg := command.NewRegistry()
|
||||
command.RegisterBuiltins(reg)
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewRouter(&cfg)
|
||||
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
|
||||
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
|
||||
ag.SetRouter(router)
|
||||
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
|
||||
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
m = updated.(*tui.Model)
|
||||
if !m.AnchorActive() {
|
||||
t.Error("anchorActive should be true after initialization")
|
||||
}
|
||||
updated, _ = m.Update(tui.StreamTextMsg{Text: "Hello"})
|
||||
m = updated.(*tui.Model)
|
||||
if !m.AnchorActive() {
|
||||
t.Error("anchorActive should remain true during streaming")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTUI_OverlayRendering(t *testing.T) {
|
||||
reg := command.NewRegistry()
|
||||
command.RegisterBuiltins(reg)
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewRouter(&cfg)
|
||||
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
|
||||
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
|
||||
ag.SetRouter(router)
|
||||
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
|
||||
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
m = updated.(*tui.Model)
|
||||
updated, _ = m.Update(tui.KeyPressMsg{Code: '?'})
|
||||
m = updated.(*tui.Model)
|
||||
view := m.View()
|
||||
if view == nil {
|
||||
t.Error("View should not be nil")
|
||||
}
|
||||
updated, _ = m.Update(tui.KeyPressMsg{Code: tea.KeyEscape})
|
||||
m = updated.(*tui.Model)
|
||||
}
|
||||
|
||||
func TestTUI_ToolCardRendering(t *testing.T) {
|
||||
reg := command.NewRegistry()
|
||||
command.RegisterBuiltins(reg)
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewRouter(&cfg)
|
||||
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
|
||||
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
|
||||
ag.SetRouter(router)
|
||||
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
|
||||
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
m = updated.(*tui.Model)
|
||||
startTime := time.Now()
|
||||
updated, _ = m.Update(tui.ToolCallStartMsg{
|
||||
Name: "read_file",
|
||||
Args: map[string]any{"path": "test.go"},
|
||||
StartTime: startTime,
|
||||
})
|
||||
m = updated.(*tui.Model)
|
||||
updated, _ = m.Update(tui.ToolCallResultMsg{
|
||||
Name: "read_file",
|
||||
Result: "file content",
|
||||
IsError: false,
|
||||
Duration: 100 * time.Millisecond,
|
||||
})
|
||||
m = updated.(*tui.Model)
|
||||
view := m.View()
|
||||
if view == nil {
|
||||
t.Error("View should not be nil after tool execution")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenRouter_Integration(t *testing.T) {
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewQwenModelRouter(&cfg)
|
||||
tests := []struct {
|
||||
query string
|
||||
mode config.ModeContext
|
||||
expectSmaller string
|
||||
expectLarger string
|
||||
}{
|
||||
{"what is go?", config.ModeAskContext, "qwen3.5:2b", ""},
|
||||
{"design architecture", config.ModeBuildContext, "", "qwen3.5:4b"},
|
||||
{"plan the system", config.ModePlanContext, "", "qwen3.5:4b"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.query, func(t *testing.T) {
|
||||
model := router.SelectModelForMode(tt.query, tt.mode)
|
||||
if tt.expectSmaller != "" {
|
||||
if modelRank(model) > modelRank(tt.expectSmaller) {
|
||||
t.Errorf("model %s is larger than expected %s", model, tt.expectSmaller)
|
||||
}
|
||||
}
|
||||
if tt.expectLarger != "" {
|
||||
if modelRank(model) < modelRank(tt.expectLarger) {
|
||||
t.Errorf("model %s is smaller than expected %s", model, tt.expectLarger)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func modelRank(model string) int {
|
||||
switch {
|
||||
case strings.Contains(model, "0.8b"):
|
||||
return 1
|
||||
case strings.Contains(model, "2b"):
|
||||
return 2
|
||||
case strings.Contains(model, "4b"):
|
||||
return 3
|
||||
case strings.Contains(model, "9b"):
|
||||
return 4
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileOperations_Integration(t *testing.T) {
|
||||
skipIfNoOllama(t)
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
if err := os.WriteFile(testFile, []byte("hello world"), 0644); err != nil {
|
||||
t.Fatalf("failed to create test file: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read test file: %v", err)
|
||||
}
|
||||
if string(content) != "hello world" {
|
||||
t.Errorf("unexpected file content: %q", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTUI_Render(b *testing.B) {
|
||||
reg := command.NewRegistry()
|
||||
command.RegisterBuiltins(reg)
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewRouter(&cfg)
|
||||
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
|
||||
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
|
||||
ag.SetRouter(router)
|
||||
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
|
||||
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
|
||||
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
m = updated.(*tui.Model)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = m.View()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQwenRouter_Classification(b *testing.B) {
|
||||
cfg := config.DefaultModelConfig()
|
||||
router := config.NewQwenModelRouter(&cfg)
|
||||
queries := []string{
|
||||
"what is go",
|
||||
"how do i create a file",
|
||||
"debug this nil pointer error",
|
||||
"design microservice architecture",
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, q := range queries {
|
||||
_ = router.SelectModelForMode(q, config.ModeAskContext)
|
||||
}
|
||||
}
|
||||
}
|
||||
58
internal/llm/client.go
Normal file
58
internal/llm/client.go
Normal file
@ -0,0 +1,58 @@
|
||||
package llm
|
||||
|
||||
import "context"
|
||||
|
||||
// Client is the interface for LLM providers.
|
||||
type Client interface {
|
||||
// ChatStream sends messages to the LLM and streams the response.
|
||||
// The callback is called for each chunk. Return a non-nil error to abort.
|
||||
ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error
|
||||
|
||||
// Ping checks if the LLM is reachable and the model is available.
|
||||
Ping() error
|
||||
|
||||
// Model returns the current model name.
|
||||
Model() string
|
||||
|
||||
// Embed generates embeddings for the given texts using the specified model.
|
||||
Embed(ctx context.Context, model string, texts []string) ([][]float32, error)
|
||||
}
|
||||
|
||||
// ChatOptions holds parameters for a chat request.
|
||||
type ChatOptions struct {
|
||||
Messages []Message
|
||||
Tools []ToolDef
|
||||
System string
|
||||
}
|
||||
|
||||
// Message represents a conversation message.
|
||||
type Message struct {
|
||||
Role string `json:"role"` // system, user, assistant, tool
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChunk is a piece of a streaming response.
|
||||
type StreamChunk struct {
|
||||
Text string // incremental text content
|
||||
ToolCalls []ToolCall // tool calls (usually in final chunk)
|
||||
Done bool // true on the last chunk
|
||||
EvalCount int // tokens generated (only on Done)
|
||||
PromptEvalCount int // prompt tokens evaluated (only on Done)
|
||||
}
|
||||
|
||||
// ToolCall represents a tool invocation requested by the LLM.
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolDef defines a tool the LLM can call.
|
||||
type ToolDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]any `json:"parameters"` // JSON Schema
|
||||
}
|
||||
163
internal/llm/manager.go
Normal file
163
internal/llm/manager.go
Normal file
@ -0,0 +1,163 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ModelManager struct {
|
||||
baseURL string
|
||||
numCtx int
|
||||
clients map[string]*OllamaClient
|
||||
currentModel string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var _ Client = (*ModelManager)(nil)
|
||||
|
||||
func NewModelManager(baseURL string, numCtx int) *ModelManager {
|
||||
return &ModelManager{
|
||||
baseURL: baseURL,
|
||||
numCtx: numCtx,
|
||||
clients: make(map[string]*OllamaClient),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelManager) GetClient(modelName string) (*OllamaClient, error) {
|
||||
m.mu.RLock()
|
||||
client, exists := m.clients[modelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if client, exists := m.clients[modelName]; exists {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
client, err := NewOllamaClient(m.baseURL, modelName, m.numCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create client for %s: %w", modelName, err)
|
||||
}
|
||||
|
||||
m.clients[modelName] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (m *ModelManager) SetCurrentModel(model string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
client, err := NewOllamaClient(m.baseURL, model, m.numCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client for %s: %w", model, err)
|
||||
}
|
||||
|
||||
m.clients[model] = client
|
||||
m.currentModel = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ModelManager) CurrentModel() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.currentModel
|
||||
}
|
||||
|
||||
func (m *ModelManager) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error {
|
||||
m.mu.RLock()
|
||||
model := m.currentModel
|
||||
m.mu.RUnlock()
|
||||
|
||||
if model == "" {
|
||||
return fmt.Errorf("no model selected")
|
||||
}
|
||||
|
||||
client, err := m.GetClient(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.ChatStream(ctx, opts, fn)
|
||||
}
|
||||
|
||||
func (m *ModelManager) ChatStreamForModel(ctx context.Context, model string, opts ChatOptions, fn func(StreamChunk) error) error {
|
||||
client, err := m.GetClient(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.ChatStream(ctx, opts, fn)
|
||||
}
|
||||
|
||||
func (m *ModelManager) Ping() error {
|
||||
m.mu.RLock()
|
||||
model := m.currentModel
|
||||
m.mu.RUnlock()
|
||||
|
||||
if model == "" {
|
||||
return fmt.Errorf("no model selected")
|
||||
}
|
||||
|
||||
client, err := m.GetClient(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Ping()
|
||||
}
|
||||
|
||||
func (m *ModelManager) PingModel(model string) error {
|
||||
client, err := m.GetClient(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Ping()
|
||||
}
|
||||
|
||||
func (m *ModelManager) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) {
|
||||
client, err := m.GetClient(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client.Embed(ctx, model, texts)
|
||||
}
|
||||
|
||||
func (m *ModelManager) EmbedWithCurrentModel(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
m.mu.RLock()
|
||||
model := m.currentModel
|
||||
m.mu.RUnlock()
|
||||
|
||||
if model == "" {
|
||||
return nil, fmt.Errorf("no model selected")
|
||||
}
|
||||
return m.Embed(ctx, model, texts)
|
||||
}
|
||||
|
||||
func (m *ModelManager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for range m.clients {
|
||||
}
|
||||
m.clients = make(map[string]*OllamaClient)
|
||||
}
|
||||
|
||||
func (m *ModelManager) BaseURL() string {
|
||||
return m.baseURL
|
||||
}
|
||||
|
||||
func (m *ModelManager) NumCtx() int {
|
||||
return m.numCtx
|
||||
}
|
||||
|
||||
func (m *ModelManager) Model() string {
|
||||
return m.CurrentModel()
|
||||
}
|
||||
|
||||
// ListModels returns model names available in Ollama at the manager's base URL.
|
||||
func (m *ModelManager) ListModels(ctx context.Context) ([]string, error) {
|
||||
return ListModels(ctx, m.baseURL)
|
||||
}
|
||||
92
internal/llm/manager_test.go
Normal file
92
internal/llm/manager_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewModelManager(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 4096)
|
||||
|
||||
if m.baseURL != "http://localhost:11434" {
|
||||
t.Errorf("baseURL = %q, want %q", m.baseURL, "http://localhost:11434")
|
||||
}
|
||||
if m.numCtx != 4096 {
|
||||
t.Errorf("numCtx = %d, want %d", m.numCtx, 4096)
|
||||
}
|
||||
if m.clients == nil {
|
||||
t.Error("clients map should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerBaseURL(t *testing.T) {
|
||||
m := NewModelManager("http://custom:9999", 2048)
|
||||
if m.BaseURL() != "http://custom:9999" {
|
||||
t.Errorf("BaseURL() = %q, want %q", m.BaseURL(), "http://custom:9999")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerNumCtx(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 8192)
|
||||
if m.NumCtx() != 8192 {
|
||||
t.Errorf("NumCtx() = %d, want %d", m.NumCtx(), 8192)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerCurrentModel(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 4096)
|
||||
|
||||
// Should return empty when no model set
|
||||
if m.CurrentModel() != "" {
|
||||
t.Errorf("CurrentModel() = %q, want %q", m.CurrentModel(), "")
|
||||
}
|
||||
|
||||
// Set a model
|
||||
m.SetCurrentModel("llama3")
|
||||
|
||||
if m.CurrentModel() != "llama3" {
|
||||
t.Errorf("CurrentModel() = %q, want %q", m.CurrentModel(), "llama3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerChatStreamNoModel(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 4096)
|
||||
|
||||
err := m.ChatStream(nil, ChatOptions{}, func(chunk StreamChunk) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("ChatStream should fail when no model is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerPingNoModel(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 4096)
|
||||
|
||||
err := m.Ping()
|
||||
|
||||
if err == nil {
|
||||
t.Error("Ping should fail when no model is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerEmbedWithCurrentModelNoModel(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 4096)
|
||||
|
||||
_, err := m.EmbedWithCurrentModel(nil, []string{"test"})
|
||||
|
||||
if err == nil {
|
||||
t.Error("EmbedWithCurrentModel should fail when no model is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelManagerClose(t *testing.T) {
|
||||
m := NewModelManager("http://localhost:11434", 4096)
|
||||
|
||||
// Should not panic
|
||||
m.Close()
|
||||
|
||||
if len(m.clients) != 0 {
|
||||
t.Errorf("after Close, clients map should be empty, got %d", len(m.clients))
|
||||
}
|
||||
}
|
||||
222
internal/llm/ollama.go
Normal file
222
internal/llm/ollama.go
Normal file
@ -0,0 +1,222 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
ollamaapi "github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// OllamaClient implements Client using the official Ollama Go library.
|
||||
type OllamaClient struct {
|
||||
client *ollamaapi.Client
|
||||
model string
|
||||
numCtx int
|
||||
}
|
||||
|
||||
// NewOllamaClient creates a new Ollama client.
|
||||
func NewOllamaClient(baseURL, model string, numCtx int) (*OllamaClient, error) {
|
||||
// The official client reads OLLAMA_HOST, but we want to support our config too.
|
||||
if baseURL != "" {
|
||||
os.Setenv("OLLAMA_HOST", baseURL)
|
||||
}
|
||||
|
||||
client, err := ollamaapi.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ollama client: %w", err)
|
||||
}
|
||||
|
||||
return &OllamaClient{
|
||||
client: client,
|
||||
model: model,
|
||||
numCtx: numCtx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OllamaClient) Model() string { return o.model }
|
||||
|
||||
// Ping checks Ollama is running and the model exists.
|
||||
func (o *OllamaClient) Ping() error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Check the model is available by requesting a show.
|
||||
req := &ollamaapi.ShowRequest{Model: o.model}
|
||||
_, err := o.client.Show(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model %q not available: %w", o.model, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatStream sends a chat request and streams the response via callback.
|
||||
func (o *OllamaClient) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error {
|
||||
|
||||
messages := make([]ollamaapi.Message, 0, len(opts.Messages)+1)
|
||||
if opts.System != "" {
|
||||
messages = append(messages, ollamaapi.Message{
|
||||
Role: "system",
|
||||
Content: opts.System,
|
||||
})
|
||||
}
|
||||
for _, m := range opts.Messages {
|
||||
msg := ollamaapi.Message{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ToolName: m.ToolName,
|
||||
ToolCallID: m.ToolCallID,
|
||||
}
|
||||
// Convert tool calls for assistant messages.
|
||||
for _, tc := range m.ToolCalls {
|
||||
args := ollamaapi.NewToolCallFunctionArguments()
|
||||
for k, v := range tc.Arguments {
|
||||
args.Set(k, v)
|
||||
}
|
||||
msg.ToolCalls = append(msg.ToolCalls, ollamaapi.ToolCall{
|
||||
ID: tc.ID,
|
||||
Function: ollamaapi.ToolCallFunction{
|
||||
Name: tc.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
tools := convertTools(opts.Tools)
|
||||
req := &ollamaapi.ChatRequest{
|
||||
Model: o.model,
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"num_ctx": o.numCtx,
|
||||
},
|
||||
}
|
||||
|
||||
return o.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error {
|
||||
chunk := StreamChunk{
|
||||
Text: resp.Message.Content,
|
||||
Done: resp.Done,
|
||||
}
|
||||
if resp.Done {
|
||||
chunk.EvalCount = resp.EvalCount
|
||||
chunk.PromptEvalCount = resp.PromptEvalCount
|
||||
}
|
||||
// Collect tool calls from the response.
|
||||
for _, tc := range resp.Message.ToolCalls {
|
||||
chunk.ToolCalls = append(chunk.ToolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
})
|
||||
}
|
||||
return fn(chunk)
|
||||
})
|
||||
}
|
||||
|
||||
// Embed generates embeddings for the given texts using the specified model.
|
||||
func (o *OllamaClient) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) {
|
||||
resp, err := o.client.Embed(ctx, &ollamaapi.EmbedRequest{
|
||||
Model: model,
|
||||
Input: texts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedding failed: %w", err)
|
||||
}
|
||||
return resp.Embeddings, nil
|
||||
}
|
||||
|
||||
// convertTools transforms our ToolDef slice into Ollama's Tools format.
|
||||
func convertTools(defs []ToolDef) ollamaapi.Tools {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tools := make(ollamaapi.Tools, 0, len(defs))
|
||||
for _, d := range defs {
|
||||
props := ollamaapi.NewToolPropertiesMap()
|
||||
var required []string
|
||||
|
||||
// Extract properties from JSON Schema.
|
||||
if propsRaw, ok := d.Parameters["properties"].(map[string]any); ok {
|
||||
for name, schema := range propsRaw {
|
||||
schemaMap, _ := schema.(map[string]any)
|
||||
prop := ollamaapi.ToolProperty{
|
||||
Description: strFromMap(schemaMap, "description"),
|
||||
}
|
||||
if t, ok := schemaMap["type"].(string); ok {
|
||||
prop.Type = ollamaapi.PropertyType{t}
|
||||
}
|
||||
if enumRaw, ok := schemaMap["enum"].([]any); ok {
|
||||
prop.Enum = enumRaw
|
||||
}
|
||||
props.Set(name, prop)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract required fields.
|
||||
if reqRaw, ok := d.Parameters["required"].([]any); ok {
|
||||
for _, r := range reqRaw {
|
||||
if s, ok := r.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tools = append(tools, ollamaapi.Tool{
|
||||
Type: "function",
|
||||
Function: ollamaapi.ToolFunction{
|
||||
Name: d.Name,
|
||||
Description: d.Description,
|
||||
Parameters: ollamaapi.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: required,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func strFromMap(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := m[key].(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// BaseURL returns the configured Ollama base URL for display.
|
||||
func (o *OllamaClient) BaseURL() string {
|
||||
if v := os.Getenv("OLLAMA_HOST"); v != "" {
|
||||
return v
|
||||
}
|
||||
return "http://localhost:11434"
|
||||
}
|
||||
|
||||
// ParseBaseURL validates the Ollama URL.
|
||||
func ParseBaseURL(rawURL string) (*url.URL, error) {
|
||||
return url.Parse(rawURL)
|
||||
}
|
||||
|
||||
// ListModels returns model names available in Ollama at baseURL.
|
||||
func ListModels(ctx context.Context, baseURL string) ([]string, error) {
|
||||
if baseURL != "" {
|
||||
os.Setenv("OLLAMA_HOST", baseURL)
|
||||
}
|
||||
client, err := ollamaapi.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama client: %w", err)
|
||||
}
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ollama list: %w", err)
|
||||
}
|
||||
names := make([]string, 0, len(resp.Models))
|
||||
for _, m := range resp.Models {
|
||||
names = append(names, m.Name)
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
121
internal/llm/ollama_test.go
Normal file
121
internal/llm/ollama_test.go
Normal file
@ -0,0 +1,121 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []ToolDef
|
||||
wantNil bool
|
||||
wantCount int
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "single tool with properties and required",
|
||||
input: []ToolDef{
|
||||
{
|
||||
Name: "read_file",
|
||||
Description: "Read a file",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "file path",
|
||||
},
|
||||
},
|
||||
"required": []any{"path"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "tool without properties in parameters",
|
||||
input: []ToolDef{
|
||||
{
|
||||
Name: "noop",
|
||||
Description: "Does nothing",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
wantCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertTools(tt.input)
|
||||
if tt.wantNil {
|
||||
if result != nil {
|
||||
t.Errorf("convertTools() = %v, want nil", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(result) != tt.wantCount {
|
||||
t.Errorf("convertTools() returned %d tools, want %d", len(result), tt.wantCount)
|
||||
}
|
||||
if tt.wantCount > 0 {
|
||||
tool := result[0]
|
||||
if tool.Function.Name != tt.input[0].Name {
|
||||
t.Errorf("tool name = %q, want %q", tool.Function.Name, tt.input[0].Name)
|
||||
}
|
||||
if tool.Function.Description != tt.input[0].Description {
|
||||
t.Errorf("tool description = %q, want %q", tool.Function.Description, tt.input[0].Description)
|
||||
}
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("tool type = %q, want %q", tool.Type, "function")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrFromMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
m map[string]any
|
||||
key string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "key present",
|
||||
m: map[string]any{"description": "a desc"},
|
||||
key: "description",
|
||||
want: "a desc",
|
||||
},
|
||||
{
|
||||
name: "key missing",
|
||||
m: map[string]any{"other": "value"},
|
||||
key: "description",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil map",
|
||||
m: nil,
|
||||
key: "description",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "non-string value",
|
||||
m: map[string]any{"count": 42},
|
||||
key: "count",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := strFromMap(tt.m, tt.key)
|
||||
if got != tt.want {
|
||||
t.Errorf("strFromMap() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
33
internal/logging/logger.go
Normal file
33
internal/logging/logger.go
Normal file
@ -0,0 +1,33 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
func NewSessionLogger() (*log.Logger, *os.File, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("home dir: %w", err)
|
||||
}
|
||||
logDir := filepath.Join(home, ".config", "ai-agent", "logs")
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return nil, nil, fmt.Errorf("create log dir: %w", err)
|
||||
}
|
||||
filename := time.Now().Format("2006-01-02_15-04-05") + ".log"
|
||||
f, err := os.Create(filepath.Join(logDir, filename))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create log file: %w", err)
|
||||
}
|
||||
logger := log.NewWithOptions(f, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.RFC3339,
|
||||
Prefix: "ai-agent",
|
||||
Level: log.DebugLevel,
|
||||
})
|
||||
return logger, f, nil
|
||||
}
|
||||
42
internal/logging/logger_test.go
Normal file
42
internal/logging/logger_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewSessionLogger(t *testing.T) {
|
||||
logger, f, err := NewSessionLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("NewSessionLogger() error: %v", err)
|
||||
}
|
||||
if f != nil {
|
||||
defer f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
}
|
||||
if logger == nil {
|
||||
t.Fatal("logger should not be nil")
|
||||
}
|
||||
if f == nil {
|
||||
t.Fatal("file should not be nil")
|
||||
}
|
||||
dir := filepath.Dir(f.Name())
|
||||
if !strings.Contains(dir, filepath.Join(".config", "ai-agent", "logs")) {
|
||||
t.Errorf("log file should be in ~/.config/ai-agent/logs/, got %q", dir)
|
||||
}
|
||||
logger.Info("test message", "key", "value")
|
||||
}
|
||||
|
||||
func TestNilLoggerNoPanic(t *testing.T) {
|
||||
var called bool
|
||||
logger, f, err := NewSessionLogger()
|
||||
if err == nil && f != nil {
|
||||
defer f.Close()
|
||||
defer os.Remove(f.Name())
|
||||
logger.Info("test")
|
||||
called = true
|
||||
}
|
||||
_ = called
|
||||
}
|
||||
92
internal/logging/reader.go
Normal file
92
internal/logging/reader.go
Normal file
@ -0,0 +1,92 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LogEntry struct {
|
||||
Path string
|
||||
ModTime time.Time
|
||||
Size int64
|
||||
}
|
||||
|
||||
func LogDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".config", "ai-agent", "logs")
|
||||
}
|
||||
|
||||
func ListLogs(n int) ([]LogEntry, error) {
|
||||
return listLogsIn(LogDir(), n)
|
||||
}
|
||||
|
||||
func listLogsIn(dir string, n int) ([]LogEntry, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read log dir: %w", err)
|
||||
}
|
||||
var logs []LogEntry
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
info, err := e.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
logs = append(logs, LogEntry{
|
||||
Path: filepath.Join(dir, e.Name()),
|
||||
ModTime: info.ModTime(),
|
||||
Size: info.Size(),
|
||||
})
|
||||
}
|
||||
sort.Slice(logs, func(i, j int) bool {
|
||||
return logs[i].ModTime.After(logs[j].ModTime)
|
||||
})
|
||||
if n > 0 && n < len(logs) {
|
||||
logs = logs[:n]
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func LatestLogPath() (string, error) {
|
||||
return latestLogPathIn(LogDir())
|
||||
}
|
||||
|
||||
func latestLogPathIn(dir string) (string, error) {
|
||||
logs, err := listLogsIn(dir, 1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(logs) == 0 {
|
||||
return "", fmt.Errorf("no log files found in %s", dir)
|
||||
}
|
||||
return logs[0].Path, nil
|
||||
}
|
||||
|
||||
func TailLog(path string, n int) ([]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open log: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
var lines []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("read log: %w", err)
|
||||
}
|
||||
if n > 0 && n < len(lines) {
|
||||
lines = lines[len(lines)-n:]
|
||||
}
|
||||
return lines, nil
|
||||
}
|
||||
157
internal/logging/reader_test.go
Normal file
157
internal/logging/reader_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// helper creates n temp log files in dir with distinct mod times.
|
||||
func createFakeLogs(t *testing.T, dir string, n int) []string {
|
||||
t.Helper()
|
||||
var paths []string
|
||||
for i := range n {
|
||||
name := filepath.Join(dir, "2025-01-01_00-00-0"+string(rune('0'+i))+".log")
|
||||
if err := os.WriteFile(name, []byte("line "+string(rune('0'+i))+"\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Stagger mod times so ordering is deterministic.
|
||||
ts := time.Now().Add(time.Duration(i) * time.Second)
|
||||
if err := os.Chtimes(name, ts, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
paths = append(paths, name)
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func TestListLogs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createFakeLogs(t, dir, 5)
|
||||
|
||||
logs, err := listLogsIn(dir, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("listLogsIn error: %v", err)
|
||||
}
|
||||
if len(logs) != 3 {
|
||||
t.Fatalf("expected 3 entries, got %d", len(logs))
|
||||
}
|
||||
|
||||
// Verify newest-first ordering.
|
||||
for i := 1; i < len(logs); i++ {
|
||||
if logs[i].ModTime.After(logs[i-1].ModTime) {
|
||||
t.Errorf("entry %d (%v) is newer than entry %d (%v)", i, logs[i].ModTime, i-1, logs[i-1].ModTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListLogs_All(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
createFakeLogs(t, dir, 4)
|
||||
|
||||
logs, err := listLogsIn(dir, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("listLogsIn error: %v", err)
|
||||
}
|
||||
if len(logs) != 4 {
|
||||
t.Fatalf("expected 4 entries, got %d", len(logs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListLogs_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logs, err := listLogsIn(dir, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("listLogsIn error: %v", err)
|
||||
}
|
||||
if len(logs) != 0 {
|
||||
t.Fatalf("expected 0 entries, got %d", len(logs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListLogs_MissingDir(t *testing.T) {
|
||||
_, err := listLogsIn("/tmp/nonexistent-log-dir-test-xyz", 5)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailLog(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.log")
|
||||
content := "line1\nline2\nline3\nline4\nline5\n"
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lines, err := TailLog(path, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("TailLog error: %v", err)
|
||||
}
|
||||
if len(lines) != 3 {
|
||||
t.Fatalf("expected 3 lines, got %d", len(lines))
|
||||
}
|
||||
if lines[0] != "line3" {
|
||||
t.Errorf("expected 'line3', got %q", lines[0])
|
||||
}
|
||||
if lines[2] != "line5" {
|
||||
t.Errorf("expected 'line5', got %q", lines[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailLog_FewerLines(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "short.log")
|
||||
if err := os.WriteFile(path, []byte("only\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lines, err := TailLog(path, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("TailLog error: %v", err)
|
||||
}
|
||||
if len(lines) != 1 {
|
||||
t.Fatalf("expected 1 line, got %d", len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTailLog_MissingFile(t *testing.T) {
|
||||
_, err := TailLog("/tmp/nonexistent-file-test-xyz.log", 10)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLatestLogPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
paths := createFakeLogs(t, dir, 3)
|
||||
|
||||
latest, err := latestLogPathIn(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("latestLogPathIn error: %v", err)
|
||||
}
|
||||
// The last created file has the newest mod time.
|
||||
expected := paths[len(paths)-1]
|
||||
if latest != expected {
|
||||
t.Errorf("expected %q, got %q", expected, latest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLatestLogPath_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
_, err := latestLogPathIn(dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogDir(t *testing.T) {
|
||||
dir := LogDir()
|
||||
if dir == "" {
|
||||
t.Fatal("LogDir should not be empty")
|
||||
}
|
||||
if filepath.Base(dir) != "logs" {
|
||||
t.Errorf("expected dir to end in 'logs', got %q", dir)
|
||||
}
|
||||
}
|
||||
110
internal/mcp/client.go
Normal file
110
internal/mcp/client.go
Normal file
@ -0,0 +1,110 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
type MCPClient struct {
|
||||
name string
|
||||
client *sdkmcp.Client
|
||||
session *sdkmcp.ClientSession
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func Connect(ctx context.Context, name, command string, args []string, env []string, transport, url string) (*MCPClient, error) {
|
||||
client := sdkmcp.NewClient(
|
||||
&sdkmcp.Implementation{Name: "ai-agent", Version: "0.2.0"},
|
||||
nil,
|
||||
)
|
||||
var t sdkmcp.Transport
|
||||
switch transport {
|
||||
case "sse":
|
||||
if url == "" {
|
||||
return nil, fmt.Errorf("sse transport requires url for %s", name)
|
||||
}
|
||||
t = &sdkmcp.SSEClientTransport{Endpoint: url}
|
||||
case "streamable-http":
|
||||
if url == "" {
|
||||
return nil, fmt.Errorf("streamable-http transport requires url for %s", name)
|
||||
}
|
||||
t = &sdkmcp.StreamableClientTransport{Endpoint: url}
|
||||
default:
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("stdio transport requires command for %s", name)
|
||||
}
|
||||
cmd := exec.Command(command, args...)
|
||||
if len(env) > 0 {
|
||||
cmd.Env = append(cmd.Environ(), env...)
|
||||
}
|
||||
t = &sdkmcp.CommandTransport{Command: cmd}
|
||||
}
|
||||
session, err := client.Connect(ctx, t, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to %s: %w", name, err)
|
||||
}
|
||||
return &MCPClient{
|
||||
name: name,
|
||||
client: client,
|
||||
session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MCPClient) Name() string { return c.name }
|
||||
|
||||
func (c *MCPClient) ListTools(ctx context.Context) ([]*sdkmcp.Tool, error) {
|
||||
caps := c.session.InitializeResult()
|
||||
if caps == nil || caps.Capabilities.Tools == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var tools []*sdkmcp.Tool
|
||||
for tool, err := range c.session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
return tools, fmt.Errorf("list tools from %s: %w", c.name, err)
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (c *MCPClient) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
|
||||
result, err := c.session.CallTool(ctx, &sdkmcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("call tool %s on %s: %w", name, c.name, err)
|
||||
}
|
||||
var text string
|
||||
for _, ct := range result.Content {
|
||||
if tc, ok := ct.(*sdkmcp.TextContent); ok {
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
}
|
||||
text += tc.Text
|
||||
}
|
||||
}
|
||||
return &ToolResult{Content: text, IsError: result.IsError}, nil
|
||||
}
|
||||
|
||||
func (c *MCPClient) Close() error {
|
||||
if c.session != nil {
|
||||
return c.session.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MCPClient) IsConnected() bool {
|
||||
return c.session != nil
|
||||
}
|
||||
|
||||
func (c *MCPClient) Ping(ctx context.Context) error {
|
||||
if c.session == nil {
|
||||
return fmt.Errorf("no session")
|
||||
}
|
||||
_, err := c.ListTools(ctx)
|
||||
return err
|
||||
}
|
||||
241
internal/mcp/registry.go
Normal file
241
internal/mcp/registry.go
Normal file
@ -0,0 +1,241 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ai-agent/internal/config"
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
type FailedServer struct {
|
||||
Name string
|
||||
Reason string
|
||||
}
|
||||
|
||||
type ServerStatus struct {
|
||||
Name string
|
||||
Connected bool
|
||||
LastError string
|
||||
LastPing time.Time
|
||||
}
|
||||
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
clients []*MCPClient
|
||||
toolMap map[string]*MCPClient
|
||||
toolDefs []llm.ToolDef
|
||||
failedServers []FailedServer
|
||||
serverConfigs map[string]config.ServerConfig
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{toolMap: make(map[string]*MCPClient), serverConfigs: make(map[string]config.ServerConfig)}
|
||||
}
|
||||
|
||||
const connectTimeout = 5 * time.Second
|
||||
|
||||
func (r *Registry) ConnectServer(ctx context.Context, srv config.ServerConfig) (int, error) {
|
||||
connCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
client, err := Connect(connCtx, srv.Name, srv.Command, srv.Args, srv.Env, srv.Transport, srv.URL)
|
||||
if err != nil {
|
||||
r.mu.Lock()
|
||||
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
|
||||
r.mu.Unlock()
|
||||
return 0, fmt.Errorf("connect to %s: %w", srv.Name, err)
|
||||
}
|
||||
tools, err := client.ListTools(connCtx)
|
||||
if err != nil {
|
||||
client.Close()
|
||||
r.mu.Lock()
|
||||
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
|
||||
r.mu.Unlock()
|
||||
return 0, fmt.Errorf("%s tools: %w", srv.Name, err)
|
||||
}
|
||||
r.mu.Lock()
|
||||
r.clients = append(r.clients, client)
|
||||
for _, tool := range tools {
|
||||
r.toolMap[tool.Name] = client
|
||||
r.toolDefs = append(r.toolDefs, ToLLMToolDef(tool.Name, tool.Description, tool.InputSchema))
|
||||
}
|
||||
r.serverConfigs[srv.Name] = srv
|
||||
r.mu.Unlock()
|
||||
return len(tools), nil
|
||||
}
|
||||
|
||||
func (r *Registry) ConnectAll(ctx context.Context, servers []config.ServerConfig, logFn func(string)) {
|
||||
for _, srv := range servers {
|
||||
toolCount, err := r.ConnectServer(ctx, srv)
|
||||
if err != nil {
|
||||
logFn(fmt.Sprintf("skip %s: %v", srv.Name, err))
|
||||
continue
|
||||
}
|
||||
logFn(fmt.Sprintf("connected %s (%d tools)", srv.Name, toolCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Registry) Tools() []llm.ToolDef {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.toolDefs
|
||||
}
|
||||
|
||||
func (r *Registry) ToolCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.toolDefs)
|
||||
}
|
||||
|
||||
func (r *Registry) ServerCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.clients)
|
||||
}
|
||||
|
||||
func (r *Registry) ServerNames() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
names := make([]string, len(r.clients))
|
||||
for i, c := range r.clients {
|
||||
names[i] = c.Name()
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (r *Registry) FailedServers() []FailedServer {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.failedServers
|
||||
}
|
||||
|
||||
func (r *Registry) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
|
||||
r.mu.RLock()
|
||||
client, ok := r.toolMap[name]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return &ToolResult{
|
||||
Content: fmt.Sprintf("unknown tool: %s", name),
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
return client.CallTool(ctx, name, args)
|
||||
}
|
||||
|
||||
func (r *Registry) Close() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, c := range r.clients {
|
||||
c.Close()
|
||||
}
|
||||
r.clients = nil
|
||||
r.toolMap = make(map[string]*MCPClient)
|
||||
r.toolDefs = nil
|
||||
}
|
||||
|
||||
func (r *Registry) HealthCheck(ctx context.Context) []ServerStatus {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
var results []ServerStatus
|
||||
for _, client := range r.clients {
|
||||
status := ServerStatus{Name: client.Name()}
|
||||
if client.IsConnected() {
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
err := client.Ping(pingCtx)
|
||||
cancel()
|
||||
status.Connected = err == nil
|
||||
if err != nil {
|
||||
status.LastError = err.Error()
|
||||
}
|
||||
status.LastPing = time.Now()
|
||||
}
|
||||
results = append(results, status)
|
||||
}
|
||||
for _, failed := range r.failedServers {
|
||||
results = append(results, ServerStatus{
|
||||
Name: failed.Name,
|
||||
Connected: false,
|
||||
LastError: failed.Reason,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (r *Registry) ReconnectServer(ctx context.Context, name string) (int, error) {
|
||||
r.mu.RLock()
|
||||
srv, ok := r.serverConfigs[name]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("no config found for server: %s", name)
|
||||
}
|
||||
r.mu.Lock()
|
||||
var remainingFailed []FailedServer
|
||||
for _, f := range r.failedServers {
|
||||
if f.Name != name {
|
||||
remainingFailed = append(remainingFailed, f)
|
||||
}
|
||||
}
|
||||
r.failedServers = remainingFailed
|
||||
r.mu.Unlock()
|
||||
return r.ConnectServer(ctx, srv)
|
||||
}
|
||||
|
||||
type MonitorConfig struct {
|
||||
Interval time.Duration
|
||||
MaxRetries int
|
||||
BackoffBase time.Duration
|
||||
}
|
||||
|
||||
var defaultMonitorConfig = MonitorConfig{
|
||||
Interval: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
BackoffBase: 5 * time.Second,
|
||||
}
|
||||
|
||||
func (r *Registry) StartHealthMonitor(ctx context.Context, cfg MonitorConfig, logFn func(string)) context.CancelFunc {
|
||||
if cfg.Interval == 0 {
|
||||
cfg = defaultMonitorConfig
|
||||
}
|
||||
monitorCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
ticker := time.NewTicker(cfg.Interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-monitorCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.healthCheckRound(monitorCtx, cfg, logFn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (r *Registry) healthCheckRound(ctx context.Context, cfg MonitorConfig, logFn func(string)) {
|
||||
statuses := r.HealthCheck(ctx)
|
||||
for _, status := range statuses {
|
||||
if status.Connected {
|
||||
continue
|
||||
}
|
||||
logFn(fmt.Sprintf("server %s unhealthy, attempting reconnect...", status.Name))
|
||||
for attempt := 1; attempt <= cfg.MaxRetries; attempt++ {
|
||||
backoff := cfg.BackoffBase * time.Duration(attempt)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
_, err := r.ReconnectServer(ctx, status.Name)
|
||||
if err == nil {
|
||||
logFn(fmt.Sprintf("server %s reconnected", status.Name))
|
||||
break
|
||||
}
|
||||
if attempt == cfg.MaxRetries {
|
||||
logFn(fmt.Sprintf("server %s reconnection failed after %d attempts: %v", status.Name, cfg.MaxRetries, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
73
internal/mcp/registry_test.go
Normal file
73
internal/mcp/registry_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
if r.ToolCount() != 0 {
|
||||
t.Errorf("ToolCount() = %d, want 0", r.ToolCount())
|
||||
}
|
||||
if r.ServerCount() != 0 {
|
||||
t.Errorf("ServerCount() = %d, want 0", r.ServerCount())
|
||||
}
|
||||
if tools := r.Tools(); len(tools) != 0 {
|
||||
t.Errorf("Tools() = %v, want empty", tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_CallTool_Unknown(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
result, err := r.CallTool(context.Background(), "nonexistent_tool", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool() unexpected error: %v", err)
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("CallTool() IsError = false, want true for unknown tool")
|
||||
}
|
||||
if !strings.Contains(result.Content, "unknown tool") {
|
||||
t.Errorf("CallTool() Content = %q, want to contain 'unknown tool'", result.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_HealthCheck_Empty(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
statuses := r.HealthCheck(context.Background())
|
||||
if len(statuses) != 0 {
|
||||
t.Errorf("HealthCheck() returned %d statuses, want 0", len(statuses))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_HealthCheck_TracksFailedServers(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
// Simulate a failed server by directly adding to failedServers
|
||||
r.mu.Lock()
|
||||
r.failedServers = append(r.failedServers, FailedServer{
|
||||
Name: "failed-server",
|
||||
Reason: "connection refused",
|
||||
})
|
||||
r.mu.Unlock()
|
||||
|
||||
statuses := r.HealthCheck(context.Background())
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("HealthCheck() returned %d statuses, want 1", len(statuses))
|
||||
}
|
||||
|
||||
status := statuses[0]
|
||||
if status.Name != "failed-server" {
|
||||
t.Errorf("status.Name = %q, want 'failed-server'", status.Name)
|
||||
}
|
||||
if status.Connected {
|
||||
t.Error("status.Connected = true, want false")
|
||||
}
|
||||
if status.LastError != "connection refused" {
|
||||
t.Errorf("status.LastError = %q, want 'connection refused'", status.LastError)
|
||||
}
|
||||
}
|
||||
27
internal/mcp/types.go
Normal file
27
internal/mcp/types.go
Normal file
@ -0,0 +1,27 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
type ServerInfo struct {
|
||||
Name string
|
||||
ToolCount int
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
Content string
|
||||
IsError bool
|
||||
}
|
||||
|
||||
func ToLLMToolDef(name, description string, inputSchema any) llm.ToolDef {
|
||||
params, _ := inputSchema.(map[string]any)
|
||||
if params == nil {
|
||||
params = map[string]any{"type": "object", "properties": map[string]any{}}
|
||||
}
|
||||
return llm.ToolDef{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
71
internal/mcp/types_test.go
Normal file
71
internal/mcp/types_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestToLLMToolDef(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
description string
|
||||
inputSchema any
|
||||
wantName string
|
||||
wantDesc string
|
||||
wantParams bool // true = should have non-nil params
|
||||
}{
|
||||
{
|
||||
name: "normal with valid schema",
|
||||
toolName: "read_file",
|
||||
description: "Read a file",
|
||||
inputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
wantName: "read_file",
|
||||
wantDesc: "Read a file",
|
||||
wantParams: true,
|
||||
},
|
||||
{
|
||||
name: "nil schema uses default",
|
||||
toolName: "noop",
|
||||
description: "No-op tool",
|
||||
inputSchema: nil,
|
||||
wantName: "noop",
|
||||
wantDesc: "No-op tool",
|
||||
wantParams: true,
|
||||
},
|
||||
{
|
||||
name: "non-map schema uses default",
|
||||
toolName: "bad_schema",
|
||||
description: "Bad schema tool",
|
||||
inputSchema: "not a map",
|
||||
wantName: "bad_schema",
|
||||
wantDesc: "Bad schema tool",
|
||||
wantParams: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ToLLMToolDef(tt.toolName, tt.description, tt.inputSchema)
|
||||
if result.Name != tt.wantName {
|
||||
t.Errorf("Name = %q, want %q", result.Name, tt.wantName)
|
||||
}
|
||||
if result.Description != tt.wantDesc {
|
||||
t.Errorf("Description = %q, want %q", result.Description, tt.wantDesc)
|
||||
}
|
||||
if tt.wantParams && result.Parameters == nil {
|
||||
t.Error("Parameters should not be nil")
|
||||
}
|
||||
// Nil and non-map schemas should get the default object schema.
|
||||
if tt.inputSchema == nil || func() bool { _, ok := tt.inputSchema.(map[string]any); return !ok }() {
|
||||
if result.Parameters["type"] != "object" {
|
||||
t.Errorf("default schema type = %v, want 'object'", result.Parameters["type"])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
259
internal/memory/store.go
Normal file
259
internal/memory/store.go
Normal file
@ -0,0 +1,259 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Memory struct {
|
||||
ID int `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
memories []Memory
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewStore(path string) *Store {
|
||||
if path == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
path = filepath.Join(home, ".config", "ai-agent", "memories.json")
|
||||
}
|
||||
s := &Store{path: path}
|
||||
s.load()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Store) Save(content string, tags []string) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.nextID++
|
||||
mem := Memory{
|
||||
ID: s.nextID,
|
||||
Content: content,
|
||||
Tags: tags,
|
||||
CreatedAt: time.Now(),
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
s.memories = append(s.memories, mem)
|
||||
if err := s.persist(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return mem.ID, nil
|
||||
}
|
||||
|
||||
func (s *Store) Recall(query string, maxResults int) []Memory {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if maxResults <= 0 {
|
||||
maxResults = 5
|
||||
}
|
||||
queryLower := strings.ToLower(query)
|
||||
words := strings.Fields(queryLower)
|
||||
type scored struct {
|
||||
mem Memory
|
||||
score int
|
||||
}
|
||||
var results []scored
|
||||
for i := range s.memories {
|
||||
mem := s.memories[i]
|
||||
score := 0
|
||||
contentLower := strings.ToLower(mem.Content)
|
||||
for _, w := range words {
|
||||
if strings.Contains(contentLower, w) {
|
||||
score += 2
|
||||
}
|
||||
}
|
||||
for _, tag := range mem.Tags {
|
||||
tagLower := strings.ToLower(tag)
|
||||
for _, w := range words {
|
||||
if strings.Contains(tagLower, w) {
|
||||
score += 3
|
||||
}
|
||||
}
|
||||
}
|
||||
if score > 0 {
|
||||
results = append(results, scored{mem: mem, score: score})
|
||||
}
|
||||
}
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
if results[i].score != results[j].score {
|
||||
return results[i].score > results[j].score
|
||||
}
|
||||
return results[i].mem.LastUsed.After(results[j].mem.LastUsed)
|
||||
})
|
||||
if len(results) > maxResults {
|
||||
results = results[:maxResults]
|
||||
}
|
||||
now := time.Now()
|
||||
out := make([]Memory, len(results))
|
||||
for i, r := range results {
|
||||
out[i] = r.mem
|
||||
for j := range s.memories {
|
||||
if s.memories[j].ID == r.mem.ID {
|
||||
s.memories[j].LastUsed = now
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = s.persist()
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Store) Recent(n int) []Memory {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if len(s.memories) == 0 {
|
||||
return nil
|
||||
}
|
||||
sorted := make([]Memory, len(s.memories))
|
||||
copy(sorted, s.memories)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].LastUsed.After(sorted[j].LastUsed)
|
||||
})
|
||||
if n > len(sorted) {
|
||||
n = len(sorted)
|
||||
}
|
||||
return sorted[:n]
|
||||
}
|
||||
|
||||
func (s *Store) Count() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return len(s.memories)
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id int) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for i, mem := range s.memories {
|
||||
if mem.ID == id {
|
||||
s.memories = append(s.memories[:i], s.memories[i+1:]...)
|
||||
return true, s.persist()
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteByTag(tag string) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
tagLower := strings.ToLower(tag)
|
||||
var remaining []Memory
|
||||
deleted := 0
|
||||
for _, mem := range s.memories {
|
||||
found := false
|
||||
for _, t := range mem.Tags {
|
||||
if strings.ToLower(t) == tagLower {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
deleted++
|
||||
} else {
|
||||
remaining = append(remaining, mem)
|
||||
}
|
||||
}
|
||||
s.memories = remaining
|
||||
if deleted > 0 {
|
||||
return deleted, s.persist()
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) Update(id int, content string, tags []string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for i, mem := range s.memories {
|
||||
if mem.ID == id {
|
||||
if content != "" {
|
||||
s.memories[i].Content = content
|
||||
}
|
||||
if tags != nil {
|
||||
s.memories[i].Tags = tags
|
||||
}
|
||||
s.memories[i].LastUsed = time.Now()
|
||||
return true, s.persist()
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *Store) Prune(olderThan time.Duration) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
cutoff := time.Now().Add(-olderThan)
|
||||
var remaining []Memory
|
||||
deleted := 0
|
||||
for _, mem := range s.memories {
|
||||
if mem.CreatedAt.Before(cutoff) {
|
||||
deleted++
|
||||
} else {
|
||||
remaining = append(remaining, mem)
|
||||
}
|
||||
}
|
||||
s.memories = remaining
|
||||
if deleted > 0 {
|
||||
return deleted, s.persist()
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) Get(id int) (Memory, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, mem := range s.memories {
|
||||
if mem.ID == id {
|
||||
return mem, true
|
||||
}
|
||||
}
|
||||
return Memory{}, false
|
||||
}
|
||||
|
||||
func (s *Store) load() {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var memories []Memory
|
||||
if err := json.Unmarshal(data, &memories); err != nil {
|
||||
return
|
||||
}
|
||||
s.memories = memories
|
||||
for _, m := range s.memories {
|
||||
if m.ID > s.nextID {
|
||||
s.nextID = m.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) persist() error {
|
||||
dir := filepath.Dir(s.path)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("create memory dir: %w", err)
|
||||
}
|
||||
data, err := json.MarshalIndent(s.memories, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal memories: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(s.path, data, 0o644); err != nil {
|
||||
return fmt.Errorf("write memories: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
381
internal/memory/store_test.go
Normal file
381
internal/memory/store_test.go
Normal file
@ -0,0 +1,381 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStore_Save_And_Count(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
|
||||
s := NewStore(path)
|
||||
if s.Count() != 0 {
|
||||
t.Fatalf("new store Count = %d, want 0", s.Count())
|
||||
}
|
||||
|
||||
id1, err := s.Save("first memory", []string{"tag1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Save returned error: %v", err)
|
||||
}
|
||||
if id1 != 1 {
|
||||
t.Errorf("first Save id = %d, want 1", id1)
|
||||
}
|
||||
if s.Count() != 1 {
|
||||
t.Errorf("Count after first Save = %d, want 1", s.Count())
|
||||
}
|
||||
|
||||
id2, err := s.Save("second memory", []string{"tag2"})
|
||||
if err != nil {
|
||||
t.Fatalf("Save returned error: %v", err)
|
||||
}
|
||||
if id2 != 2 {
|
||||
t.Errorf("second Save id = %d, want 2", id2)
|
||||
}
|
||||
if s.Count() != 2 {
|
||||
t.Errorf("Count after second Save = %d, want 2", s.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Recall(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
s.Save("the user prefers Go language", []string{"preference", "golang"})
|
||||
s.Save("project uses PostgreSQL database", []string{"tech", "database"})
|
||||
s.Save("user name is Alice", []string{"name"})
|
||||
|
||||
t.Run("content match", func(t *testing.T) {
|
||||
results := s.Recall("Go", 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected results for 'Go' query")
|
||||
}
|
||||
found := false
|
||||
for _, r := range results {
|
||||
if r.Content == "the user prefers Go language" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected to find 'the user prefers Go language'")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tag match", func(t *testing.T) {
|
||||
results := s.Recall("golang", 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected results for 'golang' tag query")
|
||||
}
|
||||
if results[0].Content != "the user prefers Go language" {
|
||||
t.Errorf("top result = %q, want 'the user prefers Go language'", results[0].Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("combined scoring", func(t *testing.T) {
|
||||
// "database" matches both content and tag for PostgreSQL entry.
|
||||
results := s.Recall("database", 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected results for 'database' query")
|
||||
}
|
||||
if results[0].Content != "project uses PostgreSQL database" {
|
||||
t.Errorf("top result = %q, want 'project uses PostgreSQL database'",
|
||||
results[0].Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("maxResults limit", func(t *testing.T) {
|
||||
results := s.Recall("user", 1)
|
||||
if len(results) > 1 {
|
||||
t.Errorf("maxResults=1 but got %d results", len(results))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default maxResults 5 when 0", func(t *testing.T) {
|
||||
// With 3 memories, should return all 3 (default limit is 5).
|
||||
results := s.Recall("user", 0)
|
||||
if len(results) > 5 {
|
||||
t.Errorf("default maxResults should be 5, got %d results", len(results))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("case insensitive", func(t *testing.T) {
|
||||
results := s.Recall("ALICE", 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected case-insensitive match for 'ALICE'")
|
||||
}
|
||||
if results[0].Content != "user name is Alice" {
|
||||
t.Errorf("result = %q, want 'user name is Alice'", results[0].Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no matches", func(t *testing.T) {
|
||||
results := s.Recall("xyzzyzxyz", 10)
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected no results for nonsense query, got %d", len(results))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Recall_TieBreakByRecency(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
// Save two memories with the same scoring potential.
|
||||
s.Save("alpha topic info", []string{"info"})
|
||||
// Small delay so LastUsed differs.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
s.Save("beta topic info", []string{"info"})
|
||||
|
||||
results := s.Recall("info", 10)
|
||||
if len(results) < 2 {
|
||||
t.Fatalf("expected at least 2 results, got %d", len(results))
|
||||
}
|
||||
// Both match tag "info" equally (+3), so more recent (beta) should come first.
|
||||
if results[0].Content != "beta topic info" {
|
||||
t.Errorf("expected more recent 'beta topic info' first, got %q", results[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Recent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
s.Save("old memory", nil)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
s.Save("new memory", nil)
|
||||
|
||||
t.Run("ordering by LastUsed", func(t *testing.T) {
|
||||
recent := s.Recent(2)
|
||||
if len(recent) != 2 {
|
||||
t.Fatalf("Recent(2) returned %d, want 2", len(recent))
|
||||
}
|
||||
if recent[0].Content != "new memory" {
|
||||
t.Errorf("first recent = %q, want 'new memory'", recent[0].Content)
|
||||
}
|
||||
if recent[1].Content != "old memory" {
|
||||
t.Errorf("second recent = %q, want 'old memory'", recent[1].Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("limit exceeds count returns all", func(t *testing.T) {
|
||||
recent := s.Recent(100)
|
||||
if len(recent) != 2 {
|
||||
t.Errorf("Recent(100) returned %d, want 2", len(recent))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty store", func(t *testing.T) {
|
||||
emptyPath := filepath.Join(dir, "empty.json")
|
||||
empty := NewStore(emptyPath)
|
||||
recent := empty.Recent(5)
|
||||
if recent != nil {
|
||||
t.Errorf("empty Recent should return nil, got %v", recent)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStore_Persistence_RoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
|
||||
s1 := NewStore(path)
|
||||
s1.Save("persistent memory", []string{"test"})
|
||||
s1.Save("another memory", []string{"test2"})
|
||||
|
||||
// Create new store from same path.
|
||||
s2 := NewStore(path)
|
||||
if s2.Count() != 2 {
|
||||
t.Errorf("reloaded Count = %d, want 2", s2.Count())
|
||||
}
|
||||
|
||||
// Verify data is intact.
|
||||
recent := s2.Recent(2)
|
||||
contents := map[string]bool{}
|
||||
for _, m := range recent {
|
||||
contents[m.Content] = true
|
||||
}
|
||||
if !contents["persistent memory"] {
|
||||
t.Error("missing 'persistent memory' after reload")
|
||||
}
|
||||
if !contents["another memory"] {
|
||||
t.Error("missing 'another memory' after reload")
|
||||
}
|
||||
|
||||
// Verify IDs continue.
|
||||
id, err := s2.Save("third", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Save after reload: %v", err)
|
||||
}
|
||||
if id != 3 {
|
||||
t.Errorf("continued id = %d, want 3", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Delete(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
id, _ := s.Save("to be deleted", []string{"temp"})
|
||||
if s.Count() != 1 {
|
||||
t.Fatalf("expected 1 memory, got %d", s.Count())
|
||||
}
|
||||
|
||||
deleted, err := s.Delete(id)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete returned error: %v", err)
|
||||
}
|
||||
if !deleted {
|
||||
t.Error("Delete returned false for existing memory")
|
||||
}
|
||||
if s.Count() != 0 {
|
||||
t.Errorf("Count after delete = %d, want 0", s.Count())
|
||||
}
|
||||
|
||||
// Try deleting non-existent.
|
||||
deleted, err = s.Delete(999)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete returned error: %v", err)
|
||||
}
|
||||
if deleted {
|
||||
t.Error("Delete should return false for non-existent memory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Update(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
id, _ := s.Save("original content", []string{"original"})
|
||||
|
||||
updated, err := s.Update(id, "updated content", []string{"updated"})
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
if !updated {
|
||||
t.Error("Update returned false for existing memory")
|
||||
}
|
||||
|
||||
// Verify update.
|
||||
mem, found := s.Get(id)
|
||||
if !found {
|
||||
t.Fatal("memory not found after update")
|
||||
}
|
||||
if mem.Content != "updated content" {
|
||||
t.Errorf("Content = %q, want 'updated content'", mem.Content)
|
||||
}
|
||||
if len(mem.Tags) != 1 || mem.Tags[0] != "updated" {
|
||||
t.Errorf("Tags = %v, want ['updated']", mem.Tags)
|
||||
}
|
||||
|
||||
// Try updating non-existent.
|
||||
updated, err = s.Update(999, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
if updated {
|
||||
t.Error("Update should return false for non-existent memory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteByTag(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
s.Save("keep this 1", []string{"keep"})
|
||||
s.Save("delete this", []string{"temp"})
|
||||
s.Save("keep this 2", []string{"keep"})
|
||||
s.Save("delete this too", []string{"temp"})
|
||||
s.Save("also keep", []string{"permanent"})
|
||||
|
||||
deleted, err := s.DeleteByTag("temp")
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteByTag returned error: %v", err)
|
||||
}
|
||||
if deleted != 2 {
|
||||
t.Errorf("DeleteByTag deleted = %d, want 2", deleted)
|
||||
}
|
||||
if s.Count() != 3 {
|
||||
t.Errorf("Count after delete = %d, want 3", s.Count())
|
||||
}
|
||||
|
||||
// Verify only temp memories are gone.
|
||||
results := s.Recall("keep", 10)
|
||||
if len(results) != 3 {
|
||||
t.Errorf("Recall returned %d, want 3", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Get(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
id, _ := s.Save("test memory", []string{"tag"})
|
||||
|
||||
mem, found := s.Get(id)
|
||||
if !found {
|
||||
t.Fatal("Get returned false for existing memory")
|
||||
}
|
||||
if mem.Content != "test memory" {
|
||||
t.Errorf("Content = %q, want 'test memory'", mem.Content)
|
||||
}
|
||||
if len(mem.Tags) != 1 || mem.Tags[0] != "tag" {
|
||||
t.Errorf("Tags = %v, want ['tag']", mem.Tags)
|
||||
}
|
||||
|
||||
// Try getting non-existent.
|
||||
_, found = s.Get(999)
|
||||
if found {
|
||||
t.Error("Get should return false for non-existent memory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_UpdatePartial(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "memories.json")
|
||||
s := NewStore(path)
|
||||
|
||||
id, _ := s.Save("original content", []string{"original", "tags"})
|
||||
|
||||
// Update only content, keep tags.
|
||||
updated, err := s.Update(id, "new content", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
if !updated {
|
||||
t.Error("Update returned false")
|
||||
}
|
||||
|
||||
mem, _ := s.Get(id)
|
||||
if mem.Content != "new content" {
|
||||
t.Errorf("Content = %q, want 'new content'", mem.Content)
|
||||
}
|
||||
// Tags should remain unchanged when nil is passed.
|
||||
if len(mem.Tags) != 2 {
|
||||
t.Errorf("Tags = %v, want 2 tags", mem.Tags)
|
||||
}
|
||||
|
||||
// Update only tags, keep content.
|
||||
updated, err = s.Update(id, "", []string{"only", "tags"})
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned error: %v", err)
|
||||
}
|
||||
if !updated {
|
||||
t.Error("Update returned false")
|
||||
}
|
||||
|
||||
mem, _ = s.Get(id)
|
||||
if mem.Content != "new content" {
|
||||
t.Errorf("Content changed unexpectedly to %q", mem.Content)
|
||||
}
|
||||
if len(mem.Tags) != 2 || mem.Tags[0] != "only" || mem.Tags[1] != "tags" {
|
||||
t.Errorf("Tags = %v, want ['only', 'tags']", mem.Tags)
|
||||
}
|
||||
}
|
||||
102
internal/memory/tools.go
Normal file
102
internal/memory/tools.go
Normal file
@ -0,0 +1,102 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
func BuiltinToolDefs() []llm.ToolDef {
|
||||
return []llm.ToolDef{
|
||||
{
|
||||
Name: "memory_save",
|
||||
Description: "Save an important fact, user preference, or piece of context to persistent memory. Use this proactively when the user shares information worth remembering across sessions.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The fact or information to remember.",
|
||||
},
|
||||
"tags": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "string"},
|
||||
"description": "Optional tags for categorization (e.g., 'preference', 'project', 'name').",
|
||||
},
|
||||
},
|
||||
"required": []string{"content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "memory_recall",
|
||||
Description: "Search persistent memory for previously saved facts. Use this when you need to recall user preferences, project details, or other saved context.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query to find relevant memories.",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "memory_delete",
|
||||
Description: "Delete a memory by its ID. Use memory_recall or memory_list first to find the ID of the memory you want to delete.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"id": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The ID of the memory to delete (use memory_list or memory_recall to find IDs).",
|
||||
},
|
||||
},
|
||||
"required": []string{"id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "memory_update",
|
||||
Description: "Update an existing memory's content or tags. Use memory_recall or memory_list first to find the ID.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"id": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "The ID of the memory to update (use memory_list or memory_recall to find IDs).",
|
||||
},
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New content for the memory.",
|
||||
},
|
||||
"tags": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "string"},
|
||||
"description": "New tags for the memory.",
|
||||
},
|
||||
},
|
||||
"required": []string{"id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "memory_list",
|
||||
Description: "List all stored memories with their IDs, content, and tags. Use this to see what has been saved.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of memories to return (default: 20).",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func IsBuiltinTool(name string) bool {
|
||||
switch name {
|
||||
case "memory_save", "memory_recall", "memory_delete", "memory_update", "memory_list":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
48
internal/memory/tools_test.go
Normal file
48
internal/memory/tools_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package memory
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuiltinToolDefs(t *testing.T) {
|
||||
defs := BuiltinToolDefs()
|
||||
if len(defs) != 5 {
|
||||
t.Fatalf("BuiltinToolDefs() returned %d defs, want 5", len(defs))
|
||||
}
|
||||
|
||||
names := map[string]bool{}
|
||||
for _, d := range defs {
|
||||
names[d.Name] = true
|
||||
}
|
||||
|
||||
expected := []string{"memory_save", "memory_recall", "memory_delete", "memory_update", "memory_list"}
|
||||
for _, name := range expected {
|
||||
if !names[name] {
|
||||
t.Errorf("missing %s tool definition", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBuiltinTool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tool string
|
||||
want bool
|
||||
}{
|
||||
{name: "memory_save", tool: "memory_save", want: true},
|
||||
{name: "memory_recall", tool: "memory_recall", want: true},
|
||||
{name: "memory_delete", tool: "memory_delete", want: true},
|
||||
{name: "memory_update", tool: "memory_update", want: true},
|
||||
{name: "memory_list", tool: "memory_list", want: true},
|
||||
{name: "unknown tool", tool: "unknown", want: false},
|
||||
{name: "empty string", tool: "", want: false},
|
||||
{name: "partial match", tool: "memory_", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsBuiltinTool(tt.tool)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsBuiltinTool(%q) = %v, want %v", tt.tool, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
158
internal/permission/checker.go
Normal file
158
internal/permission/checker.go
Normal file
@ -0,0 +1,158 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"ai-agent/internal/db"
|
||||
)
|
||||
|
||||
type Policy string
|
||||
|
||||
const (
|
||||
PolicyAllow Policy = "allow"
|
||||
PolicyDeny Policy = "deny"
|
||||
PolicyAsk Policy = "ask"
|
||||
)
|
||||
|
||||
type Checker struct {
|
||||
store *db.Store
|
||||
cache map[string]Policy
|
||||
mu sync.RWMutex
|
||||
yolo bool
|
||||
}
|
||||
|
||||
func NewChecker(store *db.Store, yolo bool) *Checker {
|
||||
c := &Checker{
|
||||
store: store,
|
||||
cache: make(map[string]Policy),
|
||||
yolo: yolo,
|
||||
}
|
||||
if store != nil {
|
||||
c.loadFromDB()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Checker) Check(toolName string) Policy {
|
||||
if c.yolo {
|
||||
return PolicyAllow
|
||||
}
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if p, ok := c.cache[toolName]; ok {
|
||||
return p
|
||||
}
|
||||
return PolicyAsk
|
||||
}
|
||||
|
||||
func (c *Checker) SetPolicy(toolName string, policy Policy) {
|
||||
c.mu.Lock()
|
||||
c.cache[toolName] = policy
|
||||
c.mu.Unlock()
|
||||
if c.store != nil {
|
||||
c.store.UpsertToolPermission(context.Background(), db.UpsertToolPermissionParams{
|
||||
ToolName: toolName,
|
||||
Policy: string(policy),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checker) IsYolo() bool {
|
||||
return c.yolo
|
||||
}
|
||||
|
||||
func (c *Checker) AllPolicies() map[string]Policy {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
result := make(map[string]Policy, len(c.cache))
|
||||
for k, v := range c.cache {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Checker) Reset() {
|
||||
c.mu.Lock()
|
||||
c.cache = make(map[string]Policy)
|
||||
c.mu.Unlock()
|
||||
if c.store != nil {
|
||||
c.store.ResetToolPermissions(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checker) loadFromDB() {
|
||||
perms, err := c.store.ListToolPermissions(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for _, p := range perms {
|
||||
switch Policy(p.Policy) {
|
||||
case PolicyAllow, PolicyDeny, PolicyAsk:
|
||||
c.cache[p.ToolName] = Policy(p.Policy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ApprovalRequest struct {
|
||||
ToolName string
|
||||
Args map[string]any
|
||||
Response chan ApprovalResponse
|
||||
}
|
||||
|
||||
type ApprovalResponse struct {
|
||||
Allowed bool
|
||||
Always bool
|
||||
}
|
||||
|
||||
func RequestApproval(toolName string, args map[string]any, callback func(ApprovalRequest)) (bool, bool) {
|
||||
if callback == nil {
|
||||
return true, false
|
||||
}
|
||||
ch := make(chan ApprovalResponse, 1)
|
||||
callback(ApprovalRequest{
|
||||
ToolName: toolName,
|
||||
Args: args,
|
||||
Response: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
return resp.Allowed, resp.Always
|
||||
}
|
||||
|
||||
type CheckResult int
|
||||
|
||||
const (
|
||||
CheckAllow CheckResult = iota
|
||||
CheckDeny
|
||||
CheckAsk
|
||||
)
|
||||
|
||||
func (c *Checker) ToCheckResult(toolName string) CheckResult {
|
||||
if c == nil || c.yolo {
|
||||
return CheckAllow
|
||||
}
|
||||
switch c.Check(toolName) {
|
||||
case PolicyAllow:
|
||||
return CheckAllow
|
||||
case PolicyDeny:
|
||||
return CheckDeny
|
||||
default:
|
||||
return CheckAsk
|
||||
}
|
||||
}
|
||||
|
||||
func NilSafe(store *db.Store, yolo bool) *Checker {
|
||||
return NewChecker(store, yolo)
|
||||
}
|
||||
|
||||
var AlwaysAllow = func(_ ApprovalRequest) {}
|
||||
|
||||
type ErrDenied struct {
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func (e *ErrDenied) Error() string {
|
||||
return "tool call denied by permission policy: " + e.ToolName
|
||||
}
|
||||
98
internal/permission/checker_test.go
Normal file
98
internal/permission/checker_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"ai-agent/internal/db"
|
||||
)
|
||||
|
||||
func TestChecker_DefaultPolicy(t *testing.T) {
|
||||
c := NewChecker(nil, false)
|
||||
if got := c.Check("some_tool"); got != PolicyAsk {
|
||||
t.Errorf("Check() = %q, want %q", got, PolicyAsk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecker_Yolo(t *testing.T) {
|
||||
c := NewChecker(nil, true)
|
||||
if got := c.Check("any_tool"); got != PolicyAllow {
|
||||
t.Errorf("Check() = %q, want %q", got, PolicyAllow)
|
||||
}
|
||||
if !c.IsYolo() {
|
||||
t.Error("expected IsYolo() = true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecker_SetPolicy(t *testing.T) {
|
||||
c := NewChecker(nil, false)
|
||||
c.SetPolicy("bash", PolicyAllow)
|
||||
if got := c.Check("bash"); got != PolicyAllow {
|
||||
t.Errorf("Check() = %q, want %q", got, PolicyAllow)
|
||||
}
|
||||
c.SetPolicy("bash", PolicyDeny)
|
||||
if got := c.Check("bash"); got != PolicyDeny {
|
||||
t.Errorf("Check() = %q, want %q", got, PolicyDeny)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecker_WithDB(t *testing.T) {
|
||||
store, err := db.OpenPath(filepath.Join(t.TempDir(), "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
c := NewChecker(store, false)
|
||||
c.SetPolicy("file_write", PolicyAllow)
|
||||
c2 := NewChecker(store, false)
|
||||
if got := c2.Check("file_write"); got != PolicyAllow {
|
||||
t.Errorf("persisted Check() = %q, want %q", got, PolicyAllow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecker_Reset(t *testing.T) {
|
||||
c := NewChecker(nil, false)
|
||||
c.SetPolicy("tool1", PolicyAllow)
|
||||
c.SetPolicy("tool2", PolicyDeny)
|
||||
c.Reset()
|
||||
if got := c.Check("tool1"); got != PolicyAsk {
|
||||
t.Errorf("after reset Check() = %q, want %q", got, PolicyAsk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecker_AllPolicies(t *testing.T) {
|
||||
c := NewChecker(nil, false)
|
||||
c.SetPolicy("a", PolicyAllow)
|
||||
c.SetPolicy("b", PolicyDeny)
|
||||
|
||||
policies := c.AllPolicies()
|
||||
if len(policies) != 2 {
|
||||
t.Errorf("AllPolicies() len = %d, want 2", len(policies))
|
||||
}
|
||||
if policies["a"] != PolicyAllow {
|
||||
t.Errorf("policies[a] = %q, want %q", policies["a"], PolicyAllow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToCheckResult(t *testing.T) {
|
||||
c := NewChecker(nil, false)
|
||||
c.SetPolicy("allowed", PolicyAllow)
|
||||
c.SetPolicy("denied", PolicyDeny)
|
||||
|
||||
if c.ToCheckResult("allowed") != CheckAllow {
|
||||
t.Error("expected CheckAllow for allowed tool")
|
||||
}
|
||||
if c.ToCheckResult("denied") != CheckDeny {
|
||||
t.Error("expected CheckDeny for denied tool")
|
||||
}
|
||||
if c.ToCheckResult("unknown") != CheckAsk {
|
||||
t.Error("expected CheckAsk for unknown tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToCheckResult_Nil(t *testing.T) {
|
||||
var c *Checker
|
||||
if c.ToCheckResult("anything") != CheckAllow {
|
||||
t.Error("nil checker should return CheckAllow")
|
||||
}
|
||||
}
|
||||
118
internal/skill/manager.go
Normal file
118
internal/skill/manager.go
Normal file
@ -0,0 +1,118 @@
|
||||
package skill
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
skills []*Skill
|
||||
dirs []string
|
||||
}
|
||||
|
||||
func NewManager(dir string) *Manager {
|
||||
dirs := []string{}
|
||||
if dir != "" {
|
||||
dirs = append(dirs, dir)
|
||||
} else {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
dirs = append(dirs, filepath.Join(home, ".config", "ai-agent", "skills"))
|
||||
}
|
||||
}
|
||||
return &Manager{dirs: dirs}
|
||||
}
|
||||
|
||||
func (m *Manager) AddSearchPath(dir string) {
|
||||
for _, d := range m.dirs {
|
||||
if d == dir {
|
||||
return
|
||||
}
|
||||
}
|
||||
m.dirs = append(m.dirs, dir)
|
||||
}
|
||||
|
||||
func (m *Manager) Names() []string {
|
||||
var names []string
|
||||
for _, s := range m.skills {
|
||||
names = append(names, s.Name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (m *Manager) LoadAll() error {
|
||||
for _, dir := range m.dirs {
|
||||
if err := m.loadFromDir(dir); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) loadFromDir(dir string) error {
|
||||
if dir == "" {
|
||||
return nil
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read skills dir: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
skill, err := parseFrontmatter(string(data))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
skill.Path = path
|
||||
if skill.Name == "" {
|
||||
skill.Name = strings.TrimSuffix(entry.Name(), ".md")
|
||||
}
|
||||
m.skills = append(m.skills, skill)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) All() []*Skill {
|
||||
return m.skills
|
||||
}
|
||||
|
||||
func (m *Manager) Activate(name string) error {
|
||||
for _, s := range m.skills {
|
||||
if s.Name == name {
|
||||
s.Active = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("skill not found: %s", name)
|
||||
}
|
||||
|
||||
func (m *Manager) Deactivate(name string) error {
|
||||
for _, s := range m.skills {
|
||||
if s.Name == name {
|
||||
s.Active = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("skill not found: %s", name)
|
||||
}
|
||||
|
||||
func (m *Manager) ActiveContent() string {
|
||||
var parts []string
|
||||
for _, s := range m.skills {
|
||||
if s.Active && s.Content != "" {
|
||||
parts = append(parts, fmt.Sprintf("### %s\n%s", s.Name, s.Content))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
169
internal/skill/manager_test.go
Normal file
169
internal/skill/manager_test.go
Normal file
@ -0,0 +1,169 @@
|
||||
package skill
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManager_LoadAll(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create valid skill files.
|
||||
os.WriteFile(filepath.Join(dir, "greeting.md"), []byte("---\nname: greeting\ndescription: Say hello\n---\nHello!"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "farewell.md"), []byte("---\nname: farewell\ndescription: Say bye\n---\nGoodbye!"), 0o644)
|
||||
|
||||
// Create a non-.md file (should be skipped).
|
||||
os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("not a skill"), 0o644)
|
||||
|
||||
// Create a subdirectory (should be skipped).
|
||||
os.MkdirAll(filepath.Join(dir, "subdir"), 0o755)
|
||||
|
||||
m := NewManager(dir)
|
||||
if err := m.LoadAll(); err != nil {
|
||||
t.Fatalf("LoadAll: %v", err)
|
||||
}
|
||||
|
||||
skills := m.All()
|
||||
if len(skills) != 2 {
|
||||
t.Fatalf("loaded %d skills, want 2", len(skills))
|
||||
}
|
||||
|
||||
names := map[string]bool{}
|
||||
for _, s := range skills {
|
||||
names[s.Name] = true
|
||||
}
|
||||
if !names["greeting"] {
|
||||
t.Error("missing 'greeting' skill")
|
||||
}
|
||||
if !names["farewell"] {
|
||||
t.Error("missing 'farewell' skill")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_LoadAll_NoFrontmatter(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// File without frontmatter uses filename as name.
|
||||
os.WriteFile(filepath.Join(dir, "plain.md"), []byte("Just content, no frontmatter"), 0o644)
|
||||
|
||||
m := NewManager(dir)
|
||||
if err := m.LoadAll(); err != nil {
|
||||
t.Fatalf("LoadAll: %v", err)
|
||||
}
|
||||
|
||||
skills := m.All()
|
||||
if len(skills) != 1 {
|
||||
t.Fatalf("loaded %d skills, want 1", len(skills))
|
||||
}
|
||||
if skills[0].Name != "plain" {
|
||||
t.Errorf("Name = %q, want 'plain'", skills[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_LoadAll_NonexistentDir(t *testing.T) {
|
||||
m := NewManager("/nonexistent/path/that/does/not/exist")
|
||||
if err := m.LoadAll(); err != nil {
|
||||
t.Fatalf("LoadAll on nonexistent dir should not error, got: %v", err)
|
||||
}
|
||||
if len(m.All()) != 0 {
|
||||
t.Errorf("expected 0 skills from nonexistent dir, got %d", len(m.All()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Activate_Deactivate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "test.md"), []byte("---\nname: test\n---\nTest content"), 0o644)
|
||||
|
||||
m := NewManager(dir)
|
||||
m.LoadAll()
|
||||
|
||||
t.Run("activate found", func(t *testing.T) {
|
||||
err := m.Activate("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Activate: %v", err)
|
||||
}
|
||||
skill := m.All()[0]
|
||||
if !skill.Active {
|
||||
t.Error("skill should be active after Activate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("activate not found", func(t *testing.T) {
|
||||
err := m.Activate("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent skill")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deactivate found", func(t *testing.T) {
|
||||
err := m.Deactivate("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Deactivate: %v", err)
|
||||
}
|
||||
skill := m.All()[0]
|
||||
if skill.Active {
|
||||
t.Error("skill should be inactive after Deactivate")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deactivate not found", func(t *testing.T) {
|
||||
err := m.Deactivate("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent skill")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_ActiveContent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(dir, "alpha.md"), []byte("---\nname: alpha\n---\nAlpha content"), 0o644)
|
||||
os.WriteFile(filepath.Join(dir, "beta.md"), []byte("---\nname: beta\n---\nBeta content"), 0o644)
|
||||
|
||||
m := NewManager(dir)
|
||||
m.LoadAll()
|
||||
|
||||
t.Run("none active returns empty", func(t *testing.T) {
|
||||
content := m.ActiveContent()
|
||||
if content != "" {
|
||||
t.Errorf("expected empty content, got %q", content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("one active returns its content", func(t *testing.T) {
|
||||
m.Activate("alpha")
|
||||
content := m.ActiveContent()
|
||||
if content == "" {
|
||||
t.Fatal("expected non-empty content")
|
||||
}
|
||||
if !contains(content, "Alpha content") {
|
||||
t.Errorf("content missing 'Alpha content': %q", content)
|
||||
}
|
||||
if contains(content, "Beta content") {
|
||||
t.Errorf("content should not contain inactive 'Beta content': %q", content)
|
||||
}
|
||||
m.Deactivate("alpha")
|
||||
})
|
||||
|
||||
t.Run("multiple active returns combined", func(t *testing.T) {
|
||||
m.Activate("alpha")
|
||||
m.Activate("beta")
|
||||
content := m.ActiveContent()
|
||||
if !contains(content, "Alpha content") || !contains(content, "Beta content") {
|
||||
t.Errorf("combined content missing expected parts: %q", content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
65
internal/skill/types.go
Normal file
65
internal/skill/types.go
Normal file
@ -0,0 +1,65 @@
|
||||
package skill
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Skill represents a loadable skill definition.
|
||||
type Skill struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Active bool `yaml:"-"`
|
||||
Content string `yaml:"-"` // markdown body after frontmatter
|
||||
Path string `yaml:"-"` // file path
|
||||
}
|
||||
|
||||
// parseFrontmatter extracts YAML frontmatter and markdown body from a skill file.
|
||||
// Frontmatter is delimited by "---" on the first and closing lines.
|
||||
func parseFrontmatter(data string) (*Skill, error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(data))
|
||||
|
||||
// Check for opening "---".
|
||||
if !scanner.Scan() || strings.TrimSpace(scanner.Text()) != "---" {
|
||||
// No frontmatter — treat entire content as body.
|
||||
return &Skill{Content: data}, nil
|
||||
}
|
||||
|
||||
// Read YAML lines until closing "---".
|
||||
var yamlBuf strings.Builder
|
||||
foundEnd := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "---" {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
yamlBuf.WriteString(line)
|
||||
yamlBuf.WriteString("\n")
|
||||
}
|
||||
|
||||
if !foundEnd {
|
||||
// No closing delimiter — treat as body only.
|
||||
return &Skill{Content: data}, nil
|
||||
}
|
||||
|
||||
// Parse YAML frontmatter.
|
||||
s := &Skill{}
|
||||
if err := yaml.Unmarshal([]byte(yamlBuf.String()), s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remaining content is the markdown body.
|
||||
var bodyBuf strings.Builder
|
||||
for scanner.Scan() {
|
||||
if bodyBuf.Len() > 0 {
|
||||
bodyBuf.WriteString("\n")
|
||||
}
|
||||
bodyBuf.WriteString(scanner.Text())
|
||||
}
|
||||
s.Content = strings.TrimSpace(bodyBuf.String())
|
||||
|
||||
return s, nil
|
||||
}
|
||||
78
internal/skill/types_test.go
Normal file
78
internal/skill/types_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
package skill
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseFrontmatter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantDesc string
|
||||
wantContent string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid frontmatter",
|
||||
input: "---\nname: test\ndescription: desc\n---\nBody content",
|
||||
wantName: "test",
|
||||
wantDesc: "desc",
|
||||
wantContent: "Body content",
|
||||
},
|
||||
{
|
||||
name: "no frontmatter",
|
||||
input: "Just body",
|
||||
wantContent: "Just body",
|
||||
},
|
||||
{
|
||||
name: "missing closing delimiter",
|
||||
input: "---\nname: test\nBody",
|
||||
wantContent: "---\nname: test\nBody",
|
||||
},
|
||||
{
|
||||
name: "invalid YAML",
|
||||
input: "---\n: :\n---\nbody",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
input: "---\nname: test\n---\n",
|
||||
wantName: "test",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "multiline body",
|
||||
input: "---\nname: multi\n---\nline 1\nline 2\nline 3",
|
||||
wantName: "multi",
|
||||
wantContent: "line 1\nline 2\nline 3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
skill, err := parseFrontmatter(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if skill.Name != tt.wantName {
|
||||
t.Errorf("Name = %q, want %q", skill.Name, tt.wantName)
|
||||
}
|
||||
if skill.Description != tt.wantDesc {
|
||||
t.Errorf("Description = %q, want %q", skill.Description, tt.wantDesc)
|
||||
}
|
||||
if skill.Content != tt.wantContent {
|
||||
t.Errorf("Content = %q, want %q", skill.Content, tt.wantContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
45
internal/tools/definitions.go
Normal file
45
internal/tools/definitions.go
Normal file
@ -0,0 +1,45 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
var builtinToolNames = map[string]bool{
|
||||
"grep": true,
|
||||
"read": true,
|
||||
"write": true,
|
||||
"glob": true,
|
||||
"bash": true,
|
||||
"ls": true,
|
||||
"find": true,
|
||||
"diff": true,
|
||||
"edit": true,
|
||||
"mkdir": true,
|
||||
"remove": true,
|
||||
"copy": true,
|
||||
"move": true,
|
||||
"exists": true,
|
||||
}
|
||||
|
||||
func AllToolDefs() []llm.ToolDef {
|
||||
return []llm.ToolDef{
|
||||
GrepToolDef(),
|
||||
ReadToolDef(),
|
||||
WriteToolDef(),
|
||||
GlobToolDef(),
|
||||
BashToolDef(),
|
||||
LsToolDef(),
|
||||
FindToolDef(),
|
||||
DiffToolDef(),
|
||||
EditToolDef(),
|
||||
MkdirToolDef(),
|
||||
RemoveToolDef(),
|
||||
CopyToolDef(),
|
||||
MoveToolDef(),
|
||||
ExistsToolDef(),
|
||||
}
|
||||
}
|
||||
|
||||
func IsBuiltinTool(name string) bool {
|
||||
return builtinToolNames[name]
|
||||
}
|
||||
306
internal/tools/tools.go
Normal file
306
internal/tools/tools.go
Normal file
@ -0,0 +1,306 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"ai-agent/internal/llm"
|
||||
)
|
||||
|
||||
func GrepToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "grep",
|
||||
Description: "Search for a pattern in files. Use this to find code, text, or values across multiple files.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The regex pattern to search for.",
|
||||
},
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Directory path to search in (defaults to current directory).",
|
||||
},
|
||||
"include": map[string]any{
|
||||
"type": "string",
|
||||
"description": "File pattern to include (e.g., '*.go', '*.ts').",
|
||||
},
|
||||
"context": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Number of lines of context to show around matches (default: 3).",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ReadToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "read",
|
||||
Description: "Read the contents of a file. Use this to view source code, configuration files, or any text file.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to read.",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (optional).",
|
||||
},
|
||||
"offset": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (optional, 1-indexed).",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func WriteToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "write",
|
||||
Description: "Write content to a file. Use this to create new files or overwrite existing ones. Creates parent directories if needed.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to write.",
|
||||
},
|
||||
"content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Content to write to the file.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path", "content"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GlobToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "glob",
|
||||
Description: "Find files matching a pattern. Use this to discover files by name patterns like '*.go', '**/*.ts', etc.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match (e.g., '**/*.go', 'src/**/*.ts').",
|
||||
},
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Directory to search in (defaults to current directory).",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BashToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "bash",
|
||||
Description: "Execute a shell command. Use this to run git, npm, go, or other command-line tools. Output is returned after completion.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The shell command to execute.",
|
||||
},
|
||||
"timeout": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (default: 30, max: 120).",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func LsToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "ls",
|
||||
Description: "List files and directories. Use this to see what's in a directory.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Directory path to list (defaults to current directory).",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func FindToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "find",
|
||||
Description: "Find files or directories by name. Use this to locate specific files when you know all or part of the filename.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Name or pattern to search for (supports * and ? wildcards).",
|
||||
},
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Directory to search in (defaults to current directory).",
|
||||
},
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Type to find: 'f' for files, 'd' for directories (default: both).",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func DiffToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "diff",
|
||||
Description: "Show the differences between the current file content and new content. Use this to preview changes before writing.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to diff.",
|
||||
},
|
||||
"new_content": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The new content to compare against the current file.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path", "new_content"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func EditToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "edit",
|
||||
Description: "Apply a patch to a file. Use this to make targeted edits to specific lines without overwriting the entire file. The patch format is: @@ -start,count +new_start,new_count @@\nfollowed by lines starting with - (remove), + (add), or (context).",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit.",
|
||||
},
|
||||
"patch": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Unified diff patch to apply. Format: @@ -start,count +new_start,new_count @@ followed by -line (remove), +line (add), or context line.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path", "patch"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func MkdirToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "mkdir",
|
||||
Description: "Create one or more directories. Creates parent directories as needed.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the directory to create.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func RemoveToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "remove",
|
||||
Description: "Remove files or directories. Use with caution - this permanently deletes files.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to remove (file or directory).",
|
||||
},
|
||||
"recursive": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Remove directories recursively (default: false).",
|
||||
},
|
||||
"force": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Ignore nonexistent files (default: false).",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func CopyToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "copy",
|
||||
Description: "Copy a file from source to destination.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"source": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Source path to copy from.",
|
||||
},
|
||||
"destination": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Destination path to copy to.",
|
||||
},
|
||||
},
|
||||
"required": []string{"source", "destination"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func MoveToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "move",
|
||||
Description: "Move or rename a file or directory.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"source": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Source path to move from.",
|
||||
},
|
||||
"destination": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Destination path to move to.",
|
||||
},
|
||||
},
|
||||
"required": []string{"source", "destination"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ExistsToolDef() llm.ToolDef {
|
||||
return llm.ToolDef{
|
||||
Name: "exists",
|
||||
Description: "Check if a file or directory exists and get information about it.",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to check.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
156
internal/tools/tools_test.go
Normal file
156
internal/tools/tools_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGrepToolDef(t *testing.T) {
|
||||
tool := GrepToolDef()
|
||||
|
||||
if tool.Name != "grep" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "grep")
|
||||
}
|
||||
if tool.Description == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if tool.Parameters == nil {
|
||||
t.Error("Parameters should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadToolDef(t *testing.T) {
|
||||
tool := ReadToolDef()
|
||||
|
||||
if tool.Name != "read" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "read")
|
||||
}
|
||||
props := tool.Parameters["properties"].(map[string]any)
|
||||
if _, ok := props["path"]; !ok {
|
||||
t.Error("should have path property")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteToolDef(t *testing.T) {
|
||||
tool := WriteToolDef()
|
||||
|
||||
if tool.Name != "write" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "write")
|
||||
}
|
||||
props := tool.Parameters["properties"].(map[string]any)
|
||||
if _, ok := props["path"]; !ok {
|
||||
t.Error("should have path property")
|
||||
}
|
||||
if _, ok := props["content"]; !ok {
|
||||
t.Error("should have content property")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobToolDef(t *testing.T) {
|
||||
tool := GlobToolDef()
|
||||
|
||||
if tool.Name != "glob" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "glob")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashToolDef(t *testing.T) {
|
||||
tool := BashToolDef()
|
||||
|
||||
if tool.Name != "bash" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "bash")
|
||||
}
|
||||
props := tool.Parameters["properties"].(map[string]any)
|
||||
if _, ok := props["command"]; !ok {
|
||||
t.Error("should have command property")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLsToolDef(t *testing.T) {
|
||||
tool := LsToolDef()
|
||||
|
||||
if tool.Name != "ls" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "ls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolDef(t *testing.T) {
|
||||
tool := FindToolDef()
|
||||
|
||||
if tool.Name != "find" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "find")
|
||||
}
|
||||
props := tool.Parameters["properties"].(map[string]any)
|
||||
if _, ok := props["name"]; !ok {
|
||||
t.Error("should have name property")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffToolDef(t *testing.T) {
|
||||
tool := DiffToolDef()
|
||||
|
||||
if tool.Name != "diff" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "diff")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditToolDef(t *testing.T) {
|
||||
tool := EditToolDef()
|
||||
|
||||
if tool.Name != "edit" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "edit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMkdirToolDef(t *testing.T) {
|
||||
tool := MkdirToolDef()
|
||||
|
||||
if tool.Name != "mkdir" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "mkdir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveToolDef(t *testing.T) {
|
||||
tool := RemoveToolDef()
|
||||
|
||||
if tool.Name != "remove" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "remove")
|
||||
}
|
||||
props := tool.Parameters["properties"].(map[string]any)
|
||||
if _, ok := props["recursive"]; !ok {
|
||||
t.Error("should have recursive property")
|
||||
}
|
||||
if _, ok := props["force"]; !ok {
|
||||
t.Error("should have force property")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyToolDef(t *testing.T) {
|
||||
tool := CopyToolDef()
|
||||
|
||||
if tool.Name != "copy" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "copy")
|
||||
}
|
||||
props := tool.Parameters["properties"].(map[string]any)
|
||||
if _, ok := props["source"]; !ok {
|
||||
t.Error("should have source property")
|
||||
}
|
||||
if _, ok := props["destination"]; !ok {
|
||||
t.Error("should have destination property")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveToolDef(t *testing.T) {
|
||||
tool := MoveToolDef()
|
||||
|
||||
if tool.Name != "move" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "move")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsToolDef(t *testing.T) {
|
||||
tool := ExistsToolDef()
|
||||
|
||||
if tool.Name != "exists" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "exists")
|
||||
}
|
||||
}
|
||||
162
internal/tui/RESPONSIVE_WIDTH.md
Normal file
162
internal/tui/RESPONSIVE_WIDTH.md
Normal file
@ -0,0 +1,162 @@
|
||||
# Responsive Width Implementation
|
||||
|
||||
## Overview
|
||||
This document describes the responsive width calculations implemented to prevent horizontal scrolling in the TUI chat interface.
|
||||
|
||||
## Width Calculation Hierarchy
|
||||
|
||||
### 1. Viewport Width (Primary Constraint)
|
||||
The viewport is the main container for chat content. All other widths derive from this.
|
||||
|
||||
**Formula** (from `model.go:373-380`):
|
||||
```go
|
||||
viewportWidth := screenWidth - 1
|
||||
if sidePanel.IsVisible() {
|
||||
viewportWidth = screenWidth - panelWidth - 2
|
||||
}
|
||||
if viewportWidth < 20 {
|
||||
viewportWidth = 20 // minimum width
|
||||
}
|
||||
```
|
||||
|
||||
**Breakdown**:
|
||||
- `screenWidth - 1`: Full width minus right edge padding (when panel hidden)
|
||||
- `screenWidth - panelWidth - 2`: Width minus panel and separator line (when panel visible)
|
||||
- Minimum 20 characters to ensure readability
|
||||
|
||||
### 2. Content Width (Text Wrapping)
|
||||
Used for wrapping text in `renderEntries()`, `renderUserMsg()`, `renderAssistantMsg()`, etc.
|
||||
|
||||
**Formula** (from `view.go:422-429`):
|
||||
```go
|
||||
contentW := screenWidth - 4
|
||||
if sidePanel.IsVisible() {
|
||||
contentW = screenWidth - panelWidth - 5
|
||||
}
|
||||
if contentW < 20 {
|
||||
contentW = 20
|
||||
}
|
||||
```
|
||||
|
||||
**Breakdown**:
|
||||
- `screenWidth - 4`: Full width with 2-char padding on each side
|
||||
- `screenWidth - panelWidth - 5`: Accounts for panel, separator, and padding
|
||||
- Minimum 20 characters
|
||||
|
||||
### 3. Markdown Width (Glamour Rendering)
|
||||
Used for rendering markdown content via Glamour.
|
||||
|
||||
**Formula** (from `model.go:382-386`):
|
||||
```go
|
||||
markdownWidth := viewportWidth - 3
|
||||
if markdownWidth < 20 {
|
||||
markdownWidth = 20
|
||||
}
|
||||
```
|
||||
|
||||
**Breakdown**:
|
||||
- Derived from viewport width minus 3 chars for padding/indentation
|
||||
- Minimum 20 characters
|
||||
|
||||
### 4. Input Width
|
||||
Matches viewport width exactly for unified appearance.
|
||||
|
||||
**Formula** (from `model.go:431`):
|
||||
```go
|
||||
input.SetWidth(viewportWidth)
|
||||
```
|
||||
|
||||
## Panel Width Calculation
|
||||
|
||||
Panel width is dynamic based on screen size (from `model.go:365-371`):
|
||||
|
||||
```go
|
||||
panelWidth := 30 // default
|
||||
if screenWidth < 100 {
|
||||
panelWidth = 25
|
||||
} else if screenWidth > 160 {
|
||||
panelWidth = 40
|
||||
}
|
||||
```
|
||||
|
||||
## Layout Constraints
|
||||
|
||||
### With Panel Visible
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ Panel (25-40) ││ Chat Viewport │
|
||||
│ ││ (screen - panel - 2) │
|
||||
│ ││ │
|
||||
│ ││ Content wrapped to: │
|
||||
│ ││ (screen - panel - 5) │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Without Panel
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ Chat Viewport (screen - 1) │
|
||||
│ │
|
||||
│ Content wrapped to: (screen - 4) │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Critical Invariants
|
||||
|
||||
The following invariants are enforced to prevent horizontal scrolling:
|
||||
|
||||
1. **viewportWidth ≤ screenWidth - 1** (or `screenWidth - panelWidth - 1` when panel visible)
|
||||
2. **contentWidth ≤ viewportWidth**
|
||||
3. **markdownWidth ≤ viewportWidth**
|
||||
4. **All widths ≥ 20** (minimum readability)
|
||||
|
||||
## Test Coverage
|
||||
|
||||
Comprehensive tests in `width_test.go` verify:
|
||||
|
||||
- `TestViewportWidthCalculation`: Validates width calculations for various screen sizes
|
||||
- `TestResponsiveWidthToggle`: Ensures widths adjust correctly when panel is toggled
|
||||
- `TestMinimumWidthConstraints`: Verifies minimum width enforcement on small screens
|
||||
- `TestRenderedTextWidth`: Tests actual text wrapping behavior
|
||||
- `TestLayoutConsistency`: Exhaustive testing across screen sizes 40-200 chars
|
||||
|
||||
## Example Calculations
|
||||
|
||||
### 120-char screen with panel (30 chars)
|
||||
```
|
||||
Viewport: 120 - 30 - 2 = 88 chars
|
||||
Content: 120 - 30 - 5 = 85 chars
|
||||
Markdown: 88 - 3 = 85 chars
|
||||
Input: 88 chars
|
||||
Total: 30 (panel) + 1 (separator) + 88 (viewport) = 119 ✓
|
||||
```
|
||||
|
||||
### 80-char screen without panel
|
||||
```
|
||||
Viewport: 80 - 1 = 79 chars
|
||||
Content: 80 - 4 = 76 chars
|
||||
Markdown: 79 - 3 = 76 chars
|
||||
Input: 79 chars
|
||||
Total: 79 chars ✓
|
||||
```
|
||||
|
||||
### 40-char screen with panel (25 chars) - Edge Case
|
||||
```
|
||||
Viewport: 40 - 25 - 2 = 13 → 20 (minimum enforced)
|
||||
Content: 40 - 25 - 5 = 10 → 20 (minimum enforced)
|
||||
Markdown: 20 - 3 = 17 → 20 (minimum enforced)
|
||||
Input: 20 chars
|
||||
Total: 25 + 1 + 20 = 46 (exceeds screen, but minimum width takes priority)
|
||||
```
|
||||
|
||||
**Note**: On very small screens (< 46 chars with panel), the minimum width constraints take precedence. Users should be advised to use larger terminal windows for optimal experience.
|
||||
|
||||
## Responsive Behavior
|
||||
|
||||
When the side panel is toggled:
|
||||
1. Viewport width recalculates immediately
|
||||
2. Content is re-wrapped to new width via `invalidateRenderedCache()`
|
||||
3. Markdown renderer is recreated with new width
|
||||
4. Input field resizes to match viewport
|
||||
|
||||
This ensures seamless responsive behavior without horizontal scrolling.
|
||||
240
internal/tui/accessibility.go
Normal file
240
internal/tui/accessibility.go
Normal file
@ -0,0 +1,240 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
// AccessibilityHelper provides accessibility features like screen reader support.
|
||||
type AccessibilityHelper struct {
|
||||
isDark bool
|
||||
styles AccessibilityStyles
|
||||
speakFunc func(string) // Function to speak text (for screen readers)
|
||||
announceFunc func(string) // Function to announce changes
|
||||
}
|
||||
|
||||
// AccessibilityStyles holds styling.
|
||||
type AccessibilityStyles struct {
|
||||
Announce lipgloss.Style
|
||||
}
|
||||
|
||||
// DefaultAccessibilityStyles returns default styles.
|
||||
func DefaultAccessibilityStyles(isDark bool) AccessibilityStyles {
|
||||
return AccessibilityStyles{
|
||||
Announce: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")),
|
||||
}
|
||||
}
|
||||
|
||||
// NewAccessibilityHelper creates a new accessibility helper.
|
||||
func NewAccessibilityHelper(isDark bool) *AccessibilityHelper {
|
||||
return &AccessibilityHelper{
|
||||
isDark: isDark,
|
||||
styles: DefaultAccessibilityStyles(isDark),
|
||||
}
|
||||
}
|
||||
|
||||
// SetDark updates theme.
|
||||
func (ah *AccessibilityHelper) SetDark(isDark bool) {
|
||||
ah.isDark = isDark
|
||||
ah.styles = DefaultAccessibilityStyles(isDark)
|
||||
}
|
||||
|
||||
// SetSpeakFunc sets the function to speak text.
|
||||
func (ah *AccessibilityHelper) SetSpeakFunc(f func(string)) {
|
||||
ah.speakFunc = f
|
||||
}
|
||||
|
||||
// SetAnnounceFunc sets the function to announce changes.
|
||||
func (ah *AccessibilityHelper) SetAnnounceFunc(f func(string)) {
|
||||
ah.announceFunc = f
|
||||
}
|
||||
|
||||
// Announce announces a message to the user.
|
||||
func (ah *AccessibilityHelper) Announce(format string, args ...string) {
|
||||
if ah.announceFunc != nil {
|
||||
msg := format
|
||||
if len(args) > 0 {
|
||||
msg = fmt.Sprintf(format, args)
|
||||
}
|
||||
ah.announceFunc(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Speak speaks text directly.
|
||||
func (ah *AccessibilityHelper) Speak(text string) {
|
||||
if ah.speakFunc != nil {
|
||||
ah.speakFunc(text)
|
||||
}
|
||||
}
|
||||
|
||||
// DescribeEntry creates an accessibility description for a chat entry.
|
||||
func (ah *AccessibilityHelper) DescribeEntry(entry ChatEntry, index int, toolCount int) string {
|
||||
var desc strings.Builder
|
||||
|
||||
switch entry.Kind {
|
||||
case "user":
|
||||
desc.WriteString("User message")
|
||||
case "assistant":
|
||||
desc.WriteString("Assistant response")
|
||||
if entry.ThinkingContent != "" {
|
||||
desc.WriteString(", has thinking")
|
||||
}
|
||||
case "tool_group":
|
||||
desc.WriteString("Tool execution")
|
||||
if index >= 0 && index < toolCount {
|
||||
desc.WriteString(", tool result")
|
||||
}
|
||||
case "system":
|
||||
desc.WriteString("System message")
|
||||
case "error":
|
||||
desc.WriteString("Error")
|
||||
}
|
||||
|
||||
// Add content preview
|
||||
if entry.Content != "" {
|
||||
preview := truncateStr(entry.Content, 50)
|
||||
desc.WriteString(": ")
|
||||
desc.WriteString(preview)
|
||||
}
|
||||
|
||||
return desc.String()
|
||||
}
|
||||
|
||||
// DescribeState creates an accessibility description of the current state.
|
||||
func (ah *AccessibilityHelper) DescribeState(state State, model, mode string) string {
|
||||
var desc string
|
||||
|
||||
switch state {
|
||||
case StateIdle:
|
||||
desc = "Ready"
|
||||
case StateWaiting:
|
||||
desc = "Waiting for response"
|
||||
case StateStreaming:
|
||||
desc = "Receiving response"
|
||||
}
|
||||
|
||||
if model != "" {
|
||||
desc += ", model: " + model
|
||||
}
|
||||
if mode != "" {
|
||||
desc += ", mode: " + mode
|
||||
}
|
||||
|
||||
return desc
|
||||
}
|
||||
|
||||
// DescribeOverlay creates an accessibility description of the current overlay.
|
||||
func (ah *AccessibilityHelper) DescribeOverlay(overlay OverlayKind) string {
|
||||
switch overlay {
|
||||
case OverlayNone:
|
||||
return ""
|
||||
case OverlayHelp:
|
||||
return "Help overlay open"
|
||||
case OverlayCompletion:
|
||||
return "Completion menu open"
|
||||
case OverlayModelPicker:
|
||||
return "Model picker open"
|
||||
case OverlayPlanForm:
|
||||
return "Plan form open"
|
||||
case OverlaySessionsPicker:
|
||||
return "Sessions picker open"
|
||||
default:
|
||||
return "Overlay open"
|
||||
}
|
||||
}
|
||||
|
||||
// DescribeTools creates an accessibility description of tool status.
|
||||
func (ah *AccessibilityHelper) DescribeTools(pending, total int) string {
|
||||
if pending == 0 && total == 0 {
|
||||
return "No tools running"
|
||||
}
|
||||
if pending > 0 {
|
||||
return fmt.Sprintf("%d tool running", pending)
|
||||
}
|
||||
return fmt.Sprintf("%d tools completed", total)
|
||||
}
|
||||
|
||||
// truncate truncates a string to maxLength.
|
||||
func truncateStr(s string, maxLength int) string {
|
||||
if len(s) <= maxLength {
|
||||
return s
|
||||
}
|
||||
return s[:maxLength-3] + "..."
|
||||
}
|
||||
|
||||
// AccessibilityLabel returns an accessibility label for a view element.
|
||||
func AccessibilityLabel(role, name string, props ...string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(role)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(name)
|
||||
|
||||
for _, p := range props {
|
||||
b.WriteString(", ")
|
||||
b.WriteString(p)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FocusOrder represents the focus order for keyboard navigation.
|
||||
type FocusOrder struct {
|
||||
Current int
|
||||
Items []Focusable
|
||||
}
|
||||
|
||||
// Focusable is an interface for focusable elements.
|
||||
type Focusable interface {
|
||||
Focus() error
|
||||
Blur() error
|
||||
IsFocused() bool
|
||||
}
|
||||
|
||||
// NewFocusOrder creates a new focus order.
|
||||
func NewFocusOrder(items []Focusable) *FocusOrder {
|
||||
return &FocusOrder{
|
||||
Current: 0,
|
||||
Items: items,
|
||||
}
|
||||
}
|
||||
|
||||
// Next moves focus to the next item.
|
||||
func (fo *FocusOrder) Next() {
|
||||
if len(fo.Items) == 0 {
|
||||
return
|
||||
}
|
||||
fo.Current = (fo.Current + 1) % len(fo.Items)
|
||||
fo.focusCurrent()
|
||||
}
|
||||
|
||||
// Prev moves focus to the previous item.
|
||||
func (fo *FocusOrder) Prev() {
|
||||
if len(fo.Items) == 0 {
|
||||
return
|
||||
}
|
||||
fo.Current--
|
||||
if fo.Current < 0 {
|
||||
fo.Current = len(fo.Items) - 1
|
||||
}
|
||||
fo.focusCurrent()
|
||||
}
|
||||
|
||||
// Current returns the currently focused item.
|
||||
func (fo *FocusOrder) CurrentItem() Focusable {
|
||||
if fo.Current >= 0 && fo.Current < len(fo.Items) {
|
||||
return fo.Items[fo.Current]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fo *FocusOrder) focusCurrent() {
|
||||
for i, item := range fo.Items {
|
||||
if i == fo.Current {
|
||||
item.Focus()
|
||||
} else {
|
||||
item.Blur()
|
||||
}
|
||||
}
|
||||
}
|
||||
50
internal/tui/adapter.go
Normal file
50
internal/tui/adapter.go
Normal file
@ -0,0 +1,50 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
)
|
||||
|
||||
// Adapter bridges the agent.Output interface to BubbleTea messages.
|
||||
type Adapter struct {
|
||||
program *tea.Program
|
||||
}
|
||||
|
||||
// NewAdapter creates an Adapter that sends messages to the given program.
|
||||
func NewAdapter(p *tea.Program) *Adapter {
|
||||
return &Adapter{program: p}
|
||||
}
|
||||
|
||||
func (a *Adapter) StreamText(text string) {
|
||||
sendMsg(a.program, StreamTextMsg{Text: text})
|
||||
}
|
||||
|
||||
func (a *Adapter) StreamDone(evalCount, promptTokens int) {
|
||||
sendMsg(a.program, StreamDoneMsg{EvalCount: evalCount, PromptTokens: promptTokens})
|
||||
}
|
||||
|
||||
func (a *Adapter) ToolCallStart(name string, args map[string]any) {
|
||||
sendMsg(a.program, ToolCallStartMsg{Name: name, Args: args, StartTime: time.Now()})
|
||||
}
|
||||
|
||||
func (a *Adapter) ToolCallResult(name string, result string, isError bool, duration time.Duration) {
|
||||
sendMsg(a.program, ToolCallResultMsg{Name: name, Result: result, IsError: isError, Duration: duration})
|
||||
}
|
||||
|
||||
func (a *Adapter) SystemMessage(msg string) {
|
||||
sendMsg(a.program, SystemMessageMsg{Msg: msg})
|
||||
}
|
||||
|
||||
func (a *Adapter) Error(msg string) {
|
||||
// Log error for debugging
|
||||
if len(msg) > 100 {
|
||||
msg = msg[:97] + "..."
|
||||
}
|
||||
sendMsg(a.program, ErrorMsg{Msg: msg})
|
||||
}
|
||||
|
||||
// Done sends the final completion message.
|
||||
func (a *Adapter) Done() {
|
||||
sendMsg(a.program, AgentDoneMsg{})
|
||||
}
|
||||
127
internal/tui/clipboard_test.go
Normal file
127
internal/tui/clipboard_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLastAssistantContent(t *testing.T) {
|
||||
t.Run("found", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "user", Content: "hello"},
|
||||
{Kind: "assistant", Content: "world"},
|
||||
}
|
||||
got := m.lastAssistantContent()
|
||||
if got != "world" {
|
||||
t.Errorf("expected 'world', got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not_found", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "user", Content: "hello"},
|
||||
{Kind: "system", Content: "info"},
|
||||
}
|
||||
got := m.lastAssistantContent()
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns_last", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "assistant", Content: "first"},
|
||||
{Kind: "user", Content: "question"},
|
||||
{Kind: "assistant", Content: "second"},
|
||||
}
|
||||
got := m.lastAssistantContent()
|
||||
if got != "second" {
|
||||
t.Errorf("expected 'second', got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty_entries", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.entries = nil
|
||||
got := m.lastAssistantContent()
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCopyLast_OnlyWhenIdleAndEmpty(t *testing.T) {
|
||||
t.Run("idle_empty_with_assistant", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.state = StateIdle
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "assistant", Content: "response text"},
|
||||
}
|
||||
m.input.SetValue("")
|
||||
|
||||
_, cmd := m.Update(ctrlKey('y'))
|
||||
if cmd == nil {
|
||||
t.Error("expected a command to be returned for copy")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non_empty_input_no_trigger", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.state = StateIdle
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "assistant", Content: "response text"},
|
||||
}
|
||||
m.input.SetValue("some text")
|
||||
|
||||
_, cmd := m.Update(ctrlKey('y'))
|
||||
// When input is non-empty, ctrl+y should not trigger copy.
|
||||
// The cmd may be non-nil (textarea update), but no copy should occur.
|
||||
// Verify no system message about clipboard appears.
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
if sysMsg, ok := msg.(SystemMessageMsg); ok {
|
||||
if sysMsg.Msg == "Copied to clipboard." {
|
||||
t.Error("should not trigger copy when input is non-empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non_idle_no_trigger", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.state = StateStreaming
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "assistant", Content: "response text"},
|
||||
}
|
||||
m.input.SetValue("")
|
||||
|
||||
initialEntryCount := len(m.entries)
|
||||
m.Update(ctrlKey('y'))
|
||||
// Should not add any system message about clipboard
|
||||
if len(m.entries) > initialEntryCount {
|
||||
t.Error("should not trigger copy when not idle")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no_assistant_entries", func(t *testing.T) {
|
||||
m := newTestModel(t)
|
||||
m.state = StateIdle
|
||||
m.entries = []ChatEntry{
|
||||
{Kind: "user", Content: "hello"},
|
||||
}
|
||||
m.input.SetValue("")
|
||||
|
||||
_, cmd := m.Update(ctrlKey('y'))
|
||||
// Should not return a copy command when there's no assistant content
|
||||
if cmd != nil {
|
||||
msg := cmd()
|
||||
if sysMsg, ok := msg.(SystemMessageMsg); ok {
|
||||
if sysMsg.Msg == "Copied to clipboard." {
|
||||
t.Error("should not trigger copy when no assistant content")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
79
internal/tui/commit.go
Normal file
79
internal/tui/commit.go
Normal file
@ -0,0 +1,79 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"ai-agent/internal/llm"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
)
|
||||
|
||||
func runCommit(client llm.Client, model string, extraMsg string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
diff, err := gitDiff()
|
||||
if err != nil {
|
||||
return CommitResultMsg{Err: fmt.Errorf("git diff: %w", err)}
|
||||
}
|
||||
if strings.TrimSpace(diff) == "" {
|
||||
return CommitResultMsg{Err: fmt.Errorf("no staged changes (use `git add` first)")}
|
||||
}
|
||||
if len(diff) > 8000 {
|
||||
diff = diff[:8000] + "\n... (truncated)"
|
||||
}
|
||||
prompt := "Write a concise git commit message for the following staged diff. " +
|
||||
"Return ONLY the commit message, no explanation or markdown. " +
|
||||
"Use conventional commit style (e.g. feat:, fix:, refactor:). " +
|
||||
"Keep the first line under 72 characters."
|
||||
if extraMsg != "" {
|
||||
prompt += "\n\nAdditional context: " + extraMsg
|
||||
}
|
||||
prompt += "\n\nDiff:\n" + diff
|
||||
var msgBuf strings.Builder
|
||||
err = client.ChatStream(context.Background(), llm.ChatOptions{
|
||||
Messages: []llm.Message{{Role: "user", Content: prompt}},
|
||||
System: "You are a helpful assistant that writes git commit messages.",
|
||||
}, func(chunk llm.StreamChunk) error {
|
||||
if chunk.Text != "" {
|
||||
msgBuf.WriteString(chunk.Text)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return CommitResultMsg{Err: fmt.Errorf("LLM error: %w", err)}
|
||||
}
|
||||
commitMsg := strings.TrimSpace(msgBuf.String())
|
||||
if commitMsg == "" {
|
||||
return CommitResultMsg{Err: fmt.Errorf("LLM returned empty commit message")}
|
||||
}
|
||||
commitMsg += fmt.Sprintf("\n\nAssisted-by: ai-agent (%s)", model)
|
||||
if err := gitCommit(commitMsg); err != nil {
|
||||
return CommitResultMsg{Err: fmt.Errorf("git commit: %w", err)}
|
||||
}
|
||||
return CommitResultMsg{Message: commitMsg}
|
||||
}
|
||||
}
|
||||
|
||||
func gitDiff() (string, error) {
|
||||
cmd := exec.Command("git", "diff", "--cached", "--stat")
|
||||
stat, _ := cmd.Output()
|
||||
cmd = exec.Command("git", "diff", "--cached")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(stat) + "\n" + string(out), nil
|
||||
}
|
||||
|
||||
func gitCommit(msg string) error {
|
||||
cmd := exec.Command("git", "commit", "-m", msg)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("%s: %s", err, stderr.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user