first commit
Some checks failed
CI / test (push) Has been cancelled
Release / release (push) Failing after 4m36s

This commit is contained in:
admin 2026-03-08 15:40:34 +07:00
commit 8dc496b626
159 changed files with 27932 additions and 0 deletions

31
.gitea/workflows/ci.yml Normal file
View File

@ -0,0 +1,31 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.25"
cache: true
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: latest
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test ./... -count=1 -race

View File

@ -0,0 +1,37 @@
name: Release
on:
push:
tags:
- "v*"
permissions:
contents: write
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.25"
cache: true
- name: Run tests
run: go test ./... -count=1
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser
version: latest
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GIT_TOKEN }}
HOMEBREW_TAP_TOKEN: ${{ secrets.HOMEBREW_TAP_TOKEN }}

32
.gitignore vendored Normal file
View File

@ -0,0 +1,32 @@
# Binaries
bin/
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
*.out
ai-agent
# Vector embeddings
.vecgrep/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Logs
*.log
# Temporary files
tmp/
temp/
*.tmp

71
.goreleaser.yaml Normal file
View File

@ -0,0 +1,71 @@
version: 2
project_name: ai-agent
before:
hooks:
- go mod tidy
builds:
- id: ai-agent
main: ./cmd/ai-agent
binary: ai-agent
goos:
- linux
- darwin
- windows
goarch:
- amd64
- arm64
env:
- CGO_ENABLED=0
ldflags:
- -s -w -X main.version={{.Version}}
archives:
- id: default
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
format_overrides:
- goos: windows
formats: [zip]
checksum:
name_template: "checksums.txt"
snapshot:
version_template: "{{ incpatch .Version }}-next"
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
- "^ci:"
- "^chore:"
- Merge pull request
- Merge branch
release:
github:
owner: abdul-hamid-achik
name: ai-agent
draft: false
prerelease: auto
name_template: "v{{.Version}}"
brews:
- name: ai-agent
homepage: https://github.com/abdul-hamid-achik/ai-agent
description: "Local AI agent with TUI, powered by Ollama and MCP servers"
license: MIT
directory: Formula
repository:
owner: abdul-hamid-achik
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_TOKEN }}"
install: |
bin.install "ai-agent"
test: |
system "#{bin}/ai-agent", "--version"
skip_upload: "{{ if .Env.HOMEBREW_TAP_TOKEN }}false{{ else }}true{{ end }}"

375
README.md Normal file
View File

@ -0,0 +1,375 @@
# ai-agent
A fully local AI coding agent for the terminal -- powered by Ollama and small models, with intelligent routing, cross-session memory, and MCP tool integration.
```
╭──────────────────────────────────────────╮
│ ai-agent │
│ 100% local. Your data never leaves. │
│ │
│ ASK -- PLAN -- BUILD │
│ 0.8B 4B 9B │
╰──────────────────────────────────────────╯
```
---
## What is ai-agent?
- **100% local** -- runs entirely on your machine via Ollama. No API keys, no cloud, no data leaving your device.
- **Small model optimized** -- intelligent routing across Qwen 3.5 variants (0.8B / 2B / 4B / 9B) based on task complexity.
- **Three operational modes** -- ASK for quick answers, PLAN for design and reasoning, BUILD for full execution with tools.
- **MCP native** -- first-class Model Context Protocol support (STDIO, SSE, Streamable HTTP) for extensible tool integration.
- **Beautiful TUI** -- built with Charm's BubbleTea v2, Lip Gloss v2, and Glamour for rich markdown rendering in the terminal.
- **Infinite Context Engine (ICE)** -- cross-session vector retrieval that surfaces relevant past conversations automatically.
- **Auto-Memory Detection** -- the LLM extracts facts, decisions, preferences, and TODOs from conversations and persists them.
- **Thinking/CoT extraction** -- chain-of-thought reasoning is captured and displayed in collapsible blocks.
- **Skills system** -- load `.md` skill files with YAML frontmatter to inject domain-specific instructions into the system prompt.
- **Agent profiles** -- configure per-project agents with custom system prompts, skills, and MCP servers.
---
## Quick Start
### Prerequisites
- [Go 1.25+](https://go.dev/dl/)
- [Ollama](https://ollama.ai/) running locally
- [Task](https://taskfile.dev/) (optional, for build commands)
### Install
Pull the required model, then install:
```bash
ollama pull qwen3.5:2b
go install github.com/abdul-hamid-achik/ai-agent/cmd/ai-agent@latest
```
For the full model routing suite (optional):
```bash
ollama pull qwen3.5:0.8b
ollama pull qwen3.5:4b
ollama pull qwen3.5:9b
ollama pull nomic-embed-text # for ICE vector embeddings
```
### Configure
Create a config file (optional -- defaults work out of the box):
```bash
mkdir -p ~/.config/ai-agent
cp config.example.yaml ~/.config/ai-agent/config.yaml
```
### Run
```bash
ai-agent
```
Or from source:
```bash
task dev
```
---
## Features
### Model Routing
ai-agent automatically selects the right model size for the task at hand. Simple questions go to the fast 2B model; complex multi-step reasoning escalates to the 9B model. The router analyzes query complexity using keyword heuristics and word count.
| Complexity | Model | Speed | Use Cases |
|------------|---------------|--------|----------------------------------------------|
| Simple | qwen3.5:2b | 2.5x | Quick answers, simple tool use, single edits |
| Medium | qwen3.5:4b | 1.5x | Code completion, refactoring, explanations |
| Complex | qwen3.5:9b | 1.0x | Multi-step reasoning, debugging, code review |
The fallback chain ensures graceful degradation if a model is not available: `2b -> 4b -> 9b`.
### Three Modes: ASK / PLAN / BUILD
Cycle between modes with `shift+tab`. Each mode configures a different system prompt and preferred model tier.
- **ASK** -- Direct, concise answers. Routes to the fastest available model. Tools available for file reads and searches.
- **PLAN** -- Design and planning. Breaks tasks into steps. Reads and explores with tools but does not modify files.
- **BUILD** -- Full execution mode. Uses the most capable model. All tools enabled including writes and modifications.
### MCP Tool Integration
Connect any MCP-compatible tool server. Supports all three transport protocols:
- **STDIO** -- Launch tools as subprocesses (default).
- **SSE** -- Connect to Server-Sent Events endpoints.
- **Streamable HTTP** -- Connect to HTTP-based MCP servers.
Tool calls execute in parallel when possible. The registry handles graceful failure if a server becomes unavailable.
### Infinite Context Engine (ICE)
ICE embeds each conversation turn using `nomic-embed-text` and stores them persistently. On every new message, it retrieves the most relevant past conversations via cosine similarity and injects them into the system prompt -- giving the agent memory that spans across sessions.
### Auto-Memory Detection
After each conversation turn, a background process analyzes the exchange and extracts structured memories:
- **FACT** -- objective information the user shared
- **DECISION** -- choices made during the conversation
- **PREFERENCE** -- user preferences and working styles
- **TODO** -- action items and follow-ups
Memories are stored in `~/.config/ai-agent/memories.json` with tag-weighted search scoring (tags weighted 3x over content).
### Thinking/CoT Display
When the model produces chain-of-thought reasoning, ai-agent captures it and renders it in collapsible blocks. Toggle the display with `ctrl+t`.
### Skills System
Drop `.md` files with YAML frontmatter into the skills directory to inject domain-specific instructions:
```
~/.config/ai-agent/skills/
```
Manage active skills with `/skill list`, `/skill activate <name>`, and `/skill deactivate <name>`.
### Agent Profiles
Create per-project or per-domain agent profiles:
```
~/.agents/<name>/
AGENT.md # System prompt additions
SKILL.md # Agent-specific skills
mcp.yaml # Agent-specific MCP servers
```
Switch profiles with `/agent <name>` or `/agent list`.
---
## Configuration
### File Locations
Config is searched in order (first match wins):
1. `./ai-agent.yaml` (project-local)
2. `~/.config/ai-agent/config.yaml` (user-global)
### Annotated Example
```yaml
ollama:
model: "qwen3.5:2b" # Default model
base_url: "http://localhost:11434" # Ollama API endpoint
num_ctx: 262144 # Context window size
# Skills directory (default: ~/.config/ai-agent/skills/)
# skills_dir: "/path/to/custom/skills"
# MCP tool servers
servers:
# STDIO transport (default)
- name: noted
command: noted
args: [mcp]
# SSE transport
# - name: remote-server
# transport: sse
# url: "http://localhost:8811"
# Streamable HTTP transport
# - name: streamable-server
# transport: streamable-http
# url: "http://localhost:8812/mcp"
# ICE configuration
# ice:
# enabled: true
# embed_model: "nomic-embed-text"
# store_path: "~/.config/ai-agent/conversations.json"
```
### Environment Variables
| Variable | Description | Overrides |
|--------------------------|------------------------------|----------------------|
| `OLLAMA_HOST` | Ollama API base URL | `ollama.base_url` |
| `LOCAL_AGENT_MODEL` | Default model name | `ollama.model` |
| `LOCAL_AGENT_AGENTS_DIR` | Path to agents directory | `agents.dir` |
---
## Keyboard Shortcuts
### Input
| Key | Action |
|-----------------|-------------------------------|
| `enter` | Send message |
| `shift+enter` | Insert new line |
| `up` / `down` | Browse input history |
| `shift+tab` | Cycle mode (ASK/PLAN/BUILD) |
| `ctrl+m` | Quick model switch |
### Navigation
| Key | Action |
|------------------|------------------------------|
| `pgup` / `pgdown`| Scroll conversation |
| `ctrl+u` | Half-page scroll up |
| `ctrl+d` | Half-page scroll down |
### Display
| Key | Action |
|-----------------|-------------------------------|
| `?` | Toggle help overlay |
| `t` | Expand/collapse tool calls |
| `space` | Toggle last tool details |
| `ctrl+t` | Toggle thinking/CoT display |
| `ctrl+y` | Copy last response |
### Control
| Key | Action |
|-----------------|-------------------------------|
| `esc` | Cancel streaming / close overlay |
| `ctrl+c` | Quit |
| `ctrl+l` | Clear screen |
| `ctrl+n` | New conversation |
---
## Slash Commands
| Command | Description |
|--------------------------------------|-----------------------------------|
| `/help` | Show help overlay |
| `/clear` | Clear conversation history |
| `/new` | Start a fresh conversation |
| `/model [name\|list\|fast\|smart]` | Show or switch models |
| `/models` | Open model picker |
| `/agent [name\|list]` | Show or switch agent profile |
| `/load <path>` | Load markdown file as context |
| `/unload` | Remove loaded context |
| `/skill [list\|activate\|deactivate] [name]` | Manage skills |
| `/servers` | List connected MCP servers |
| `/ice` | Show ICE engine status |
| `/sessions` | Browse saved sessions |
| `/exit` | Quit |
---
## Architecture
```
cmd/ai-agent/ Entry point
internal/
agent/ ReAct loop orchestration
llm/ LLM abstraction (OllamaClient, ModelManager)
mcp/ MCP server registry
config/ YAML config, env overrides, Router
ice/ Infinite Context Engine
memory/ Persistent key-value store
skill/ Skill file loader
command/ Slash command registry
tui/ BubbleTea v2 terminal UI
logging/ Structured logging
```
### Request Flow
```
User Input
|
v
agent.AddUserMessage()
|
v
ICE embeds message, retrieves relevant past context
|
v
System prompt assembled (tools + skills + context + ICE + memory)
|
v
Router selects model based on task complexity
|
v
LLM streams response via ChatStream()
|
v
Tool calls routed through MCP registry (parallel execution)
|
v
ReAct loop continues (up to 10 iterations) until final text
|
v
Conversation compacted if token budget exceeded
Auto-memory detection runs in background
```
### Key Interfaces
- `llm.Client` -- pluggable LLM provider (`ChatStream`, `Ping`, `Embed`)
- `agent.Output` -- streaming callbacks for TUI rendering
- `command.Registry` -- extensible slash command dispatch
### Concurrency
`sync.RWMutex` protects shared state in `ModelManager`, `mcp.Registry`, and `memory.Store`. Auto-memory detection and MCP connections run as background goroutines. Tool calls execute in parallel when independent.
---
## Comparison
| Feature | ai-agent | opencode | crush |
|----------------------------------|:-----------:|:--------:|:-----:|
| 100% local (no API keys) | Yes | No | Yes |
| Model routing by task complexity | Yes | No | No |
| Operational modes (ASK/PLAN/BUILD)| Yes | No | No |
| Cross-session memory (ICE) | Yes | No | No |
| Auto-memory detection | Yes | No | No |
| Thinking/CoT extraction | Yes | Yes | No |
| MCP tool support | Yes | Yes | Yes |
| Skills system | Yes | No | No |
| Plan form overlay | Yes | No | No |
| Small model optimized | Yes | No | No |
| TUI chat interface | Yes | Yes | Yes |
| Language | Go | TypeScript| Go |
---
## Building
This project uses [Task](https://taskfile.dev/) as its build tool.
```bash
task build # Compile to bin/ai-agent
task run # Build and run
task dev # Quick run via go run ./cmd/ai-agent
task test # Run all tests: go test ./...
task lint # Run golangci-lint run ./...
task clean # Remove bin/ directory
```
Run a single test:
```bash
go test ./internal/agent/ -run TestFunctionName
```
---
## License
MIT

33
Taskfile.yml Normal file
View File

@ -0,0 +1,33 @@
version: '3'
tasks:
build:
desc: Build the binary
cmds:
- go build -o bin/ai-agent ./cmd/ai-agent
run:
desc: Build and run
deps: [build]
cmds:
- ./bin/ai-agent
dev:
desc: Run with go run
cmds:
- go run ./cmd/ai-agent
lint:
desc: Run linter
cmds:
- golangci-lint run ./...
test:
desc: Run tests
cmds:
- go test ./...
clean:
desc: Clean build artifacts
cmds:
- rm -rf bin/

70
config.example.yaml Normal file
View File

@ -0,0 +1,70 @@
# ai-agent configuration
# Place at ~/.config/ai-agent/config.yaml or ./ai-agent.yaml
ollama:
model: "qwen3.5:2b"
base_url: "http://localhost:11434"
# num_ctx: контекст в токенах. Большие значения (например 262144) требуют много RAM/VRAM;
# при ошибке "requires more system memory" уменьшите до 32768 или 8192.
num_ctx: 262144
# Model routing suite (optional - for automatic model tier selection)
# models:
# - name: "qwen3.5:0.8b"
# size: "0.8B"
# capability: "simple"
# - name: "qwen3.5:2b"
# size: "2B"
# capability: "medium"
# - name: "qwen3.5:4b"
# size: "4B"
# capability: "complex"
# - name: "qwen3.5:9b"
# size: "9B"
# capability: "advanced"
# Embedding model for ICE (Infinite Context Engine)
# embed_model: "nomic-embed-text"
# UI Configuration (optional)
# ui:
# # Theme: "dark", "light", or "auto" (default: auto - detects system theme)
# # The Nord theme will be applied automatically:
# # - Dark mode: Nord Polar Night (dark blues) + Frost (light text) + Aurora (colorful accents)
# # - Light mode: Nord Aurora Light (white background + same colorful accents)
# theme: "auto"
# # Syntax highlighting theme for code blocks (via Chroma)
# # Options: monokai, github, dracula, one-dark, solarized-dark, etc.
# # Default: monokai (dark), github (light)
# code_theme: "monokai"
# # Show line numbers in code blocks (default: false)
# code_line_numbers: false
# # Compact mode threshold (terminal width < this value triggers compact mode)
# compact_threshold: 80
# MCP tool servers
servers:
# STDIO transport (default)
- name: noted
command: noted
args: [mcp]
# tinyvault — requires `tvault init` before first use
# - name: tinyvault
# command: tvault
# args: [mcp-server]
# Docker MCP Gateway (requires Docker Desktop with MCP enabled)
# - name: docker-gateway
# command: docker
# args: ["mcp", "gateway", "run"]
# SSE transport
# - name: remote-server
# transport: sse
# url: "http://localhost:8811"
# Streamable HTTP transport
# - name: streamable-server
# transport: streamable-http
# url: "http://localhost:8812/mcp"

65
config.yaml Normal file
View File

@ -0,0 +1,65 @@
ollama:
model: "qwen3.5:27b"
base_url: "http://10.2.18.188:11434"
# 262144 требует ~22.6 GiB; при 20.4 GiB доступно — уменьшаем контекст (32768 укладывается в память)
num_ctx: 32768
# Model routing suite (optional - for automatic model tier selection)
# models:
# - name: "qwen3.5:0.8b"
# size: "0.8B"
# capability: "simple"
# - name: "qwen3.5:2b"
# size: "2B"
# capability: "medium"
# - name: "qwen3.5:4b"
# size: "4B"
# capability: "complex"
# - name: "qwen3.5:9b"
# size: "9B"
# capability: "advanced"
# Embedding model for ICE (Infinite Context Engine)
# embed_model: "nomic-embed-text"
# UI Configuration (optional)
# ui:
# # Theme: "dark", "light", or "auto" (default: auto - detects system theme)
# # The Nord theme will be applied automatically:
# # - Dark mode: Nord Polar Night (dark blues) + Frost (light text) + Aurora (colorful accents)
# # - Light mode: Nord Aurora Light (white background + same colorful accents)
# theme: "auto"
# # Syntax highlighting theme for code blocks (via Chroma)
# # Options: monokai, github, dracula, one-dark, solarized-dark, etc.
# # Default: monokai (dark), github (light)
# code_theme: "monokai"
# # Show line numbers in code blocks (default: false)
# code_line_numbers: false
# # Compact mode threshold (terminal width < this value triggers compact mode)
# compact_threshold: 80
# MCP tool servers
servers:
- name: noted
command: noted
args: [mcp]
# tinyvault — requires `tvault init` before first use
# - name: tinyvault
# command: tvault
# args: [mcp-server]
# Docker MCP Gateway (requires Docker Desktop with MCP enabled)
# - name: docker-gateway
# command: docker
# args: ["mcp", "gateway", "run"]
# SSE transport
# - name: remote-server
# transport: sse
# url: "http://localhost:8811"
# Streamable HTTP transport
# - name: streamable-server
# transport: streamable-http
# url: "http://localhost:8812/mcp"

72
go.mod Normal file
View File

@ -0,0 +1,72 @@
module ai-agent
go 1.25.5
require (
charm.land/bubbles/v2 v2.0.0
charm.land/bubbletea/v2 v2.0.1
charm.land/lipgloss/v2 v2.0.0
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/glamour v0.10.0
github.com/charmbracelet/log v0.4.2
github.com/lucasb-eyer/go-colorful v1.3.0
github.com/modelcontextprotocol/go-sdk v1.3.1
github.com/ollama/ollama v0.17.4
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.2 // indirect
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.20 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/segmentio/asm v1.1.3 // indirect
github.com/segmentio/encoding v0.5.3 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/yuin/goldmark-emoji v1.0.5 // indirect
golang.org/x/crypto v0.43.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/term v0.36.0 // indirect
golang.org/x/text v0.30.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.46.1 // indirect
)

164
go.sum Normal file
View File

@ -0,0 +1,164 @@
charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
charm.land/bubbletea/v2 v2.0.0 h1:p0d6CtWyJXJ9GfzMpUUqbP/XUUhhlk06+vCKWmox1wQ=
charm.land/bubbletea/v2 v2.0.0/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
charm.land/bubbletea/v2 v2.0.1 h1:B8e9zzK7x9JJ+XvHGF4xnYu9Xa0E0y0MyggY6dbaCfQ=
charm.land/bubbletea/v2 v2.0.1/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
charm.land/lipgloss/v2 v2.0.0 h1:sd8N/B3x892oiOjFfBQdXBQp3cAkvjGaU5TvVZC3ivo=
charm.land/lipgloss/v2 v2.0.0/go.mod h1:w6SnmsBFBmEFBodiEDurGS/sdUY/u1+v72DqUzc6J14=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.4.0 h1:TKnLPh7IbnizJIBKFWa9mKayRUBQ9Kh1BPCk6w2PnYM=
github.com/aymanbagabas/go-udiff v0.4.0/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/charmbracelet/colorprofile v0.4.2 h1:BdSNuMjRbotnxHSfxy+PCSa4xAmz7szw70ktAtWRYrY=
github.com/charmbracelet/colorprofile v0.4.2/go.mod h1:0rTi81QpwDElInthtrQ6Ni7cG0sDtwAd4C4le060fT8=
github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY=
github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk=
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 h1:eyFRbAmexyt43hVfeyBofiGSEmJ7krjLOYt/9CF5NKA=
github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8/go.mod h1:SQpCTRNBtzJkwku5ye4S3HEuthAlGy2n9VXZnWkEW98=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI=
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY=
github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo=
github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM=
github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k=
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjcQQaQ=
github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI=
github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ollama/ollama v0.17.4 h1:X3KNm9x4BlHqk/AXMGtC7pBAFd46nmJmy8yDaBZLo9s=
github.com/ollama/ollama v0.17.4/go.mod h1:tCX4IMV8DHjl3zY0THxuEkpWDZSOchJpzTuLACpMwFw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w=
github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic=
github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk=
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=

194
internal/agent/agent.go Normal file
View File

@ -0,0 +1,194 @@
package agent
import (
"sync"
"time"
"ai-agent/internal/config"
"ai-agent/internal/ice"
"ai-agent/internal/llm"
"ai-agent/internal/mcp"
"ai-agent/internal/memory"
"ai-agent/internal/permission"
)
type Agent struct {
mu sync.RWMutex
llmClient llm.Client
registry *mcp.Registry
messages []llm.Message
skillContent string
loadedCtx string
numCtx int
memoryStore *memory.Store
iceEngine *ice.Engine
router *config.Router
modePrefix string
toolsEnabled bool
workDir string
ignoreContent string
permChecker *permission.Checker
approvalCallback func(permission.ApprovalRequest)
toolsConfig config.ToolsConfig
}
func New(llmClient llm.Client, registry *mcp.Registry, numCtx int) *Agent {
return &Agent{
llmClient: llmClient,
registry: registry,
numCtx: numCtx,
toolsEnabled: true,
}
}
func (a *Agent) SetRouter(router *config.Router) {
a.router = router
}
func (a *Agent) SetModeContext(prefix string, allowTools bool) {
a.modePrefix = prefix
a.toolsEnabled = allowTools
}
func (a *Agent) AppendLoadedContext(content string) {
if a.loadedCtx == "" {
a.loadedCtx = content
} else {
a.loadedCtx += content
}
}
func (a *Agent) Router() *config.Router {
return a.router
}
func (a *Agent) NumCtx() int {
return a.numCtx
}
func (a *Agent) SetMemoryStore(store *memory.Store) {
a.memoryStore = store
}
func (a *Agent) AddUserMessage(content string) {
a.mu.Lock()
defer a.mu.Unlock()
a.messages = append(a.messages, llm.Message{
Role: "user",
Content: content,
})
}
func (a *Agent) Messages() []llm.Message {
a.mu.RLock()
defer a.mu.RUnlock()
return a.messages
}
func (a *Agent) ClearHistory() {
a.mu.Lock()
defer a.mu.Unlock()
a.messages = nil
}
func (a *Agent) AppendMessage(msg llm.Message) {
a.mu.Lock()
defer a.mu.Unlock()
a.messages = append(a.messages, msg)
}
func (a *Agent) ReplaceMessages(msgs []llm.Message) {
a.mu.Lock()
defer a.mu.Unlock()
a.messages = msgs
}
func (a *Agent) SetSkillContent(content string) {
a.skillContent = content
}
func (a *Agent) SetLoadedContext(content string) {
a.loadedCtx = content
}
func (a *Agent) Model() string {
return a.llmClient.Model()
}
func (a *Agent) LLMClient() llm.Client {
return a.llmClient
}
func (a *Agent) ToolCount() int {
count := a.registry.ToolCount()
if a.memoryStore != nil {
count += 2
}
return count
}
func (a *Agent) ServerCount() int {
return a.registry.ServerCount()
}
func (a *Agent) ServerNames() []string {
return a.registry.ServerNames()
}
func (a *Agent) SetWorkDir(dir string) {
a.workDir = dir
}
func (a *Agent) SetIgnoreContent(content string) {
a.ignoreContent = content
}
func (a *Agent) SetPermissionChecker(checker *permission.Checker) {
a.permChecker = checker
}
func (a *Agent) SetApprovalCallback(cb func(permission.ApprovalRequest)) {
a.approvalCallback = cb
}
func (a *Agent) SetICEEngine(engine *ice.Engine) {
a.iceEngine = engine
}
func (a *Agent) ICEEngine() *ice.Engine {
return a.iceEngine
}
func (a *Agent) SetToolsConfig(cfg config.ToolsConfig) {
a.toolsConfig = cfg
}
func (a *Agent) MaxIterations() int {
if a.toolsConfig.MaxIterations > 0 {
return a.toolsConfig.MaxIterations
}
return 10
}
func (a *Agent) ToolTimeout() time.Duration {
if a.toolsConfig.Timeout != "" {
if d, err := time.ParseDuration(a.toolsConfig.Timeout); err == nil {
return d
}
}
return 30 * time.Second
}
func (a *Agent) MaxGrepResults() int {
if a.toolsConfig.MaxGrepResults > 0 {
return a.toolsConfig.MaxGrepResults
}
return 500
}
func (a *Agent) Close() {
if a.iceEngine != nil {
_ = a.iceEngine.Flush()
}
a.registry.Close()
}

95
internal/agent/compact.go Normal file
View File

@ -0,0 +1,95 @@
package agent
import (
"context"
"fmt"
"strings"
"ai-agent/internal/llm"
)
const compactThreshold = 0.75
const keepMessages = 4
func (a *Agent) shouldCompact(promptTokens int) bool {
if a.numCtx <= 0 || promptTokens <= 0 {
return false
}
return float64(promptTokens) > float64(a.numCtx)*compactThreshold
}
func (a *Agent) compact(ctx context.Context, out Output) bool {
a.mu.RLock()
msgCount := len(a.messages)
a.mu.RUnlock()
if msgCount <= keepMessages+1 {
return false
}
a.mu.RLock()
splitAt := msgCount - keepMessages
older := make([]llm.Message, splitAt)
copy(older, a.messages[:splitAt])
recent := make([]llm.Message, keepMessages)
copy(recent, a.messages[splitAt:])
a.mu.RUnlock()
summary := summarizeMessages(older)
var summaryBuf strings.Builder
err := a.llmClient.ChatStream(ctx, llm.ChatOptions{
Messages: []llm.Message{
{Role: "user", Content: summary},
},
System: "You are a conversation summarizer. Produce a concise summary of the conversation so far, capturing all key facts, decisions, tool results, and user requests. Keep it under 500 words. Output only the summary, no preamble.",
}, func(chunk llm.StreamChunk) error {
if chunk.Text != "" {
summaryBuf.WriteString(chunk.Text)
}
return nil
})
if err != nil {
out.Error(fmt.Sprintf("compaction failed: %v", err))
return false
}
summaryText := summaryBuf.String()
if summaryText == "" {
return false
}
if a.iceEngine != nil {
if err := a.iceEngine.IndexSummary(ctx, summaryText); err != nil {
out.Error(fmt.Sprintf("ICE summary indexing failed: %v", err))
}
}
compacted := make([]llm.Message, 0, 1+len(recent))
compacted = append(compacted, llm.Message{
Role: "user",
Content: fmt.Sprintf("[Conversation summary: %s]", summaryText),
})
compacted = append(compacted, recent...)
a.ReplaceMessages(compacted)
out.SystemMessage(fmt.Sprintf("Context compacted: %d messages summarized, %d kept", len(older), len(recent)))
return true
}
func summarizeMessages(msgs []llm.Message) string {
var b strings.Builder
b.WriteString("Summarize this conversation:\n\n")
for _, msg := range msgs {
switch msg.Role {
case "user":
fmt.Fprintf(&b, "User: %s\n", msg.Content)
case "assistant":
if msg.Content != "" {
fmt.Fprintf(&b, "Assistant: %s\n", msg.Content)
}
for _, tc := range msg.ToolCalls {
fmt.Fprintf(&b, "Assistant called tool %s(%s)\n", tc.Name, FormatToolArgs(tc.Arguments))
}
case "tool":
content := msg.Content
if len(content) > 300 {
content = content[:297] + "..."
}
fmt.Fprintf(&b, "Tool %s result: %s\n", msg.ToolName, content)
}
}
return b.String()
}

View File

@ -0,0 +1,143 @@
package agent
import (
"strings"
"testing"
"time"
"ai-agent/internal/llm"
"ai-agent/internal/mcp"
)
type mockOutput struct {
texts []string
errors []string
sysMsgs []string
}
func (m *mockOutput) StreamText(text string) {
m.texts = append(m.texts, text)
}
func (m *mockOutput) StreamDone(_, _ int) {}
func (m *mockOutput) ToolCallStart(_ string, _ map[string]any) {}
func (m *mockOutput) ToolCallResult(_ string, _ string, _ bool, _ time.Duration) {}
func (m *mockOutput) SystemMessage(msg string) {
m.sysMsgs = append(m.sysMsgs, msg)
}
func (m *mockOutput) Error(msg string) {
m.errors = append(m.errors, msg)
}
func TestShouldCompact(t *testing.T) {
tests := []struct {
name string
numCtx int
promptTokens int
want bool
}{
{"below 75%", 1000, 749, false},
{"above 75%", 1000, 751, true},
{"exactly 75% (strict >)", 1000, 750, false},
{"numCtx zero", 0, 500, false},
{"promptTokens zero", 1000, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ag := &Agent{
numCtx: tt.numCtx,
registry: mcp.NewRegistry(),
}
got := ag.shouldCompact(tt.promptTokens)
if got != tt.want {
t.Errorf("shouldCompact(%d) with numCtx=%d = %v, want %v",
tt.promptTokens, tt.numCtx, got, tt.want)
}
})
}
}
func TestSummarizeMessages(t *testing.T) {
tests := []struct {
name string
msgs []llm.Message
contains []string
}{
{
name: "user message",
msgs: []llm.Message{
{Role: "user", Content: "hello"},
},
contains: []string{"User: hello"},
},
{
name: "assistant message",
msgs: []llm.Message{
{Role: "assistant", Content: "hi there"},
},
contains: []string{"Assistant: hi there"},
},
{
name: "tool message",
msgs: []llm.Message{
{Role: "tool", Content: "result data", ToolName: "read_file"},
},
contains: []string{"Tool read_file result: result data"},
},
{
name: "tool content truncation at 300 chars",
msgs: []llm.Message{
{Role: "tool", Content: strings.Repeat("x", 400), ToolName: "big_tool"},
},
contains: []string{"Tool big_tool result: " + strings.Repeat("x", 297) + "..."},
},
{
name: "empty slice",
msgs: []llm.Message{},
contains: []string{"Summarize this conversation:"},
},
{
name: "assistant with tool calls",
msgs: []llm.Message{
{
Role: "assistant",
ToolCalls: []llm.ToolCall{
{Name: "search", Arguments: map[string]any{"q": "test"}},
},
},
},
contains: []string{"Assistant called tool search("},
},
{
name: "mixed messages",
msgs: []llm.Message{
{Role: "user", Content: "find files"},
{Role: "assistant", Content: "", ToolCalls: []llm.ToolCall{
{Name: "glob", Arguments: map[string]any{"pattern": "*.go"}},
}},
{Role: "tool", Content: "file1.go\nfile2.go", ToolName: "glob"},
{Role: "assistant", Content: "Found 2 files"},
},
contains: []string{
"User: find files",
"Assistant called tool glob(",
"Tool glob result:",
"Assistant: Found 2 files",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := summarizeMessages(tt.msgs)
for _, want := range tt.contains {
if !strings.Contains(result, want) {
t.Errorf("summarizeMessages() missing %q in:\n%s", want, result)
}
}
})
}
}

View File

@ -0,0 +1,71 @@
package agent
import (
"fmt"
"io"
"os"
"time"
)
// HeadlessOutput implements the Output interface for non-interactive / pipe mode.
// Text is written to stdout; tool calls, system messages, and errors go to stderr.
type HeadlessOutput struct {
stdout io.Writer
stderr io.Writer
}
// NewHeadlessOutput creates a HeadlessOutput that writes text to os.Stdout
// and diagnostics to os.Stderr.
func NewHeadlessOutput() *HeadlessOutput {
return &HeadlessOutput{
stdout: os.Stdout,
stderr: os.Stderr,
}
}
// newHeadlessOutput creates a HeadlessOutput with custom writers (for testing).
func newHeadlessOutput(stdout, stderr io.Writer) *HeadlessOutput {
return &HeadlessOutput{
stdout: stdout,
stderr: stderr,
}
}
// StreamText writes incremental text content to stdout.
func (h *HeadlessOutput) StreamText(text string) {
fmt.Fprint(h.stdout, text)
}
// StreamDone writes a trailing newline to ensure output is terminated.
func (h *HeadlessOutput) StreamDone(evalCount, promptTokens int) {
fmt.Fprintln(h.stdout)
}
// ToolCallStart writes a brief tool invocation notice to stderr.
func (h *HeadlessOutput) ToolCallStart(name string, args map[string]any) {
fmt.Fprintf(h.stderr, "→ %s %s\n", name, FormatToolArgs(args))
}
// ToolCallResult writes the tool result summary to stderr.
func (h *HeadlessOutput) ToolCallResult(name string, result string, isError bool, duration time.Duration) {
status := "ok"
if isError {
status = "ERROR"
}
// Truncate long results for stderr display.
display := result
if len(display) > 200 {
display = display[:197] + "..."
}
fmt.Fprintf(h.stderr, "← %s [%s %s] %s\n", name, status, duration.Round(time.Millisecond), display)
}
// SystemMessage writes a system message to stderr.
func (h *HeadlessOutput) SystemMessage(msg string) {
fmt.Fprintf(h.stderr, "[system] %s\n", msg)
}
// Error writes an error message to stderr.
func (h *HeadlessOutput) Error(msg string) {
fmt.Fprintf(h.stderr, "[error] %s\n", msg)
}

View File

@ -0,0 +1,154 @@
package agent
import (
"bytes"
"strings"
"testing"
"time"
)
// Verify HeadlessOutput satisfies the Output interface at compile time.
var _ Output = (*HeadlessOutput)(nil)
func TestHeadlessOutput_StreamText(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.StreamText("hello ")
out.StreamText("world")
if got := stdout.String(); got != "hello world" {
t.Errorf("StreamText: stdout = %q, want %q", got, "hello world")
}
if stderr.Len() != 0 {
t.Errorf("StreamText: unexpected stderr output: %q", stderr.String())
}
}
func TestHeadlessOutput_StreamDone(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.StreamText("response")
out.StreamDone(100, 50)
if got := stdout.String(); got != "response\n" {
t.Errorf("StreamDone: stdout = %q, want %q", got, "response\n")
}
if stderr.Len() != 0 {
t.Errorf("StreamDone: unexpected stderr output: %q", stderr.String())
}
}
func TestHeadlessOutput_ToolCallStart(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.ToolCallStart("read_file", map[string]any{"path": "/tmp/test.go"})
if stdout.Len() != 0 {
t.Errorf("ToolCallStart: unexpected stdout output: %q", stdout.String())
}
got := stderr.String()
if !strings.Contains(got, "read_file") {
t.Errorf("ToolCallStart: stderr = %q, missing tool name", got)
}
if !strings.HasPrefix(got, "→ ") {
t.Errorf("ToolCallStart: stderr = %q, missing arrow prefix", got)
}
}
func TestHeadlessOutput_ToolCallResult(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.ToolCallResult("read_file", "file contents here", false, 150*time.Millisecond)
if stdout.Len() != 0 {
t.Errorf("ToolCallResult: unexpected stdout output: %q", stdout.String())
}
got := stderr.String()
if !strings.Contains(got, "read_file") {
t.Errorf("ToolCallResult: stderr = %q, missing tool name", got)
}
if !strings.Contains(got, "ok") {
t.Errorf("ToolCallResult: stderr = %q, missing ok status", got)
}
if !strings.Contains(got, "file contents here") {
t.Errorf("ToolCallResult: stderr = %q, missing result content", got)
}
}
func TestHeadlessOutput_ToolCallResult_Error(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.ToolCallResult("write_file", "permission denied", true, 50*time.Millisecond)
got := stderr.String()
if !strings.Contains(got, "ERROR") {
t.Errorf("ToolCallResult error: stderr = %q, missing ERROR status", got)
}
}
func TestHeadlessOutput_ToolCallResult_LongResult(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
longResult := strings.Repeat("x", 300)
out.ToolCallResult("search", longResult, false, 100*time.Millisecond)
got := stderr.String()
if strings.Contains(got, strings.Repeat("x", 300)) {
t.Error("ToolCallResult: long result should be truncated")
}
if !strings.Contains(got, "...") {
t.Error("ToolCallResult: truncated result should end with ...")
}
}
func TestHeadlessOutput_SystemMessage(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.SystemMessage("compacting conversation")
if stdout.Len() != 0 {
t.Errorf("SystemMessage: unexpected stdout output: %q", stdout.String())
}
got := stderr.String()
if !strings.Contains(got, "[system]") {
t.Errorf("SystemMessage: stderr = %q, missing [system] prefix", got)
}
if !strings.Contains(got, "compacting conversation") {
t.Errorf("SystemMessage: stderr = %q, missing message", got)
}
}
func TestHeadlessOutput_Error(t *testing.T) {
var stdout, stderr bytes.Buffer
out := newHeadlessOutput(&stdout, &stderr)
out.Error("something went wrong")
if stdout.Len() != 0 {
t.Errorf("Error: unexpected stdout output: %q", stdout.String())
}
got := stderr.String()
if !strings.Contains(got, "[error]") {
t.Errorf("Error: stderr = %q, missing [error] prefix", got)
}
if !strings.Contains(got, "something went wrong") {
t.Errorf("Error: stderr = %q, missing message", got)
}
}
func TestNewHeadlessOutput(t *testing.T) {
out := NewHeadlessOutput()
if out == nil {
t.Fatal("NewHeadlessOutput returned nil")
}
if out.stdout == nil || out.stderr == nil {
t.Error("NewHeadlessOutput: writers should not be nil")
}
}

277
internal/agent/loop.go Normal file
View File

@ -0,0 +1,277 @@
package agent
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"ai-agent/internal/llm"
permissionPkg "ai-agent/internal/permission"
)
func (a *Agent) Run(ctx context.Context, out Output) {
var tools []llm.ToolDef
if a.toolsEnabled {
tools = a.registry.Tools()
if a.memoryStore != nil {
tools = append(tools, a.memoryBuiltinToolDefs()...)
}
tools = append(tools, a.toolsBuiltinToolDefs()...)
}
var iceContext string
a.mu.RLock()
hasMessages := len(a.messages) > 0
var lastMsg llm.Message
if hasMessages {
lastMsg = a.messages[len(a.messages)-1]
}
a.mu.RUnlock()
if a.iceEngine != nil && hasMessages {
if lastMsg.Role == "user" {
if err := a.iceEngine.IndexMessage(ctx, "user", lastMsg.Content); err != nil {
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
}
if assembled, err := a.iceEngine.AssembleContext(ctx, lastMsg.Content); err == nil {
iceContext = assembled
}
}
}
system := buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
const maxRetries = 2
var lastPromptTokens int
var retryCount int
maxIters := a.MaxIterations()
for i := 0; i < maxIters; i++ {
select {
case <-ctx.Done():
return
default:
}
var textBuf strings.Builder
var toolCalls []llm.ToolCall
err := a.llmClient.ChatStream(ctx, llm.ChatOptions{
Messages: a.messages,
Tools: tools,
System: system,
}, func(chunk llm.StreamChunk) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if chunk.Text != "" {
textBuf.WriteString(chunk.Text)
out.StreamText(chunk.Text)
}
if len(chunk.ToolCalls) > 0 {
toolCalls = append(toolCalls, chunk.ToolCalls...)
}
if chunk.Done {
lastPromptTokens = chunk.PromptEvalCount
out.StreamDone(chunk.EvalCount, chunk.PromptEvalCount)
}
return nil
})
if err != nil {
if ctx.Err() != nil {
return
}
if retryCount < maxRetries && isRetryableError(err) {
retryCount++
out.Error(fmt.Sprintf("LLM produced malformed output, retrying (%d/%d)...", retryCount, maxRetries))
textBuf.Reset()
toolCalls = nil
continue
}
out.Error(fmt.Sprintf("LLM error: %v", err))
out.SystemMessage(fmt.Sprintf("⚠️ Model response failed: %v\n\nYou can try:\n- Checking if Ollama is running (`ollama ps`)\n- Switching to a different model (ctrl+m)\n- Reducing context size\n\nTool results are still available above.", err))
return
}
retryCount = 0
assistantMsg := llm.Message{
Role: "assistant",
Content: textBuf.String(),
ToolCalls: toolCalls,
}
a.AppendMessage(assistantMsg)
if a.iceEngine != nil && assistantMsg.Content != "" {
if err := a.iceEngine.IndexMessage(ctx, "assistant", assistantMsg.Content); err != nil {
out.Error(fmt.Sprintf("ICE indexing failed: %v", err))
}
}
if len(toolCalls) == 0 {
a.mu.RLock()
hasEnoughMessages := len(a.messages) >= 2
var userContent string
if hasEnoughMessages {
for idx := len(a.messages) - 2; idx >= 0; idx-- {
if a.messages[idx].Role == "user" {
userContent = a.messages[idx].Content
break
}
}
}
a.mu.RUnlock()
if a.iceEngine != nil && hasEnoughMessages && userContent != "" {
a.iceEngine.DetectAutoMemory(ctx, userContent, assistantMsg.Content)
}
return
}
type pendingTool struct {
tc llm.ToolCall
isMemoryTool bool
isMCPTool bool
}
var pending []pendingTool
for _, tc := range toolCalls {
if a.memoryStore != nil && a.isMemoryTool(tc.Name) {
pending = append(pending, pendingTool{tc: tc, isMemoryTool: true})
continue
}
if a.isToolsTool(tc.Name) {
out.ToolCallStart(tc.Name, tc.Arguments)
startTime := time.Now()
result, isErr := a.handleToolsTool(tc)
duration := time.Since(startTime)
out.ToolCallResult(tc.Name, result, isErr, duration)
a.AppendMessage(llm.Message{
Role: "tool",
Content: result,
ToolName: tc.Name,
ToolCallID: tc.ID,
})
continue
}
if a.permChecker != nil {
switch a.permChecker.ToCheckResult(tc.Name) {
case permissionPkg.CheckDeny:
errMsg := "tool call blocked by permission policy"
out.ToolCallStart(tc.Name, tc.Arguments)
out.ToolCallResult(tc.Name, errMsg, true, 0)
a.AppendMessage(llm.Message{
Role: "tool",
Content: errMsg,
ToolName: tc.Name,
ToolCallID: tc.ID,
})
continue
case permissionPkg.CheckAsk:
if a.approvalCallback != nil {
allowed, always := permissionPkg.RequestApproval(tc.Name, tc.Arguments, a.approvalCallback)
if always {
a.permChecker.SetPolicy(tc.Name, permissionPkg.PolicyAllow)
}
if !allowed {
errMsg := "tool call denied by user"
out.ToolCallStart(tc.Name, tc.Arguments)
out.ToolCallResult(tc.Name, errMsg, true, 0)
a.AppendMessage(llm.Message{
Role: "tool",
Content: errMsg,
ToolName: tc.Name,
ToolCallID: tc.ID,
})
continue
}
}
}
}
pending = append(pending, pendingTool{tc: tc, isMCPTool: true})
}
if len(pending) > 0 {
var wg sync.WaitGroup
mu := sync.Mutex{}
results := make([]llm.Message, len(pending))
for i, p := range pending {
wg.Add(1)
go func(idx int, tool pendingTool) {
defer wg.Done()
tc := tool.tc
out.ToolCallStart(tc.Name, tc.Arguments)
startTime := time.Now()
var result string
var isErr bool
if tool.isMemoryTool {
result, isErr = a.handleMemoryTool(tc)
} else if tool.isMCPTool {
toolResult, err := a.registry.CallTool(ctx, tc.Name, tc.Arguments)
if err != nil {
result = fmt.Sprintf("ERROR: Tool '%s' failed: %v\nThis tool call failed but you can still complete the task with other available information.", tc.Name, err)
isErr = true
} else {
result = toolResult.Content
isErr = toolResult.IsError
}
}
duration := time.Since(startTime)
out.ToolCallResult(tc.Name, result, isErr, duration)
mu.Lock()
results[idx] = llm.Message{
Role: "tool",
Content: result,
ToolName: tc.Name,
ToolCallID: tc.ID,
}
mu.Unlock()
}(i, p)
}
wg.Wait()
for _, msg := range results {
if msg.ToolName != "" {
a.AppendMessage(msg)
}
}
}
if a.shouldCompact(lastPromptTokens) {
if a.compact(ctx, out) {
system = buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model())
}
}
if i == maxIters-2 {
out.Error(fmt.Sprintf("approaching iteration limit (%d/%d)", i+2, maxIters))
}
}
out.Error(fmt.Sprintf("reached max iterations (%d)", maxIters))
}
func isRetryableError(err error) bool {
msg := err.Error()
return strings.Contains(msg, "parse JSON") || strings.Contains(msg, "unexpected end of JSON")
}
func FormatToolArgs(args map[string]any) string {
if len(args) == 0 {
return ""
}
var parts []string
for key, value := range args {
var valStr string
switch v := value.(type) {
case string:
if len(v) > 47 {
valStr = `"` + v[:44] + `..."`
} else {
valStr = `"` + v + `"`
}
case int, float64, bool:
valStr = fmt.Sprintf("%v", v)
case []any:
valStr = fmt.Sprintf("[%d items]", len(v))
case map[string]any:
valStr = fmt.Sprintf("{%d fields}", len(v))
default:
valStr = fmt.Sprintf("%v", v)
}
parts = append(parts, fmt.Sprintf("%s=%s", key, valStr))
}
sort.Strings(parts)
result := strings.Join(parts, " ")
if len(result) > 60 {
return result[:57] + "..."
}
return result
}

View File

@ -0,0 +1,73 @@
package agent
import (
"strings"
"testing"
)
func TestFormatToolArgs(t *testing.T) {
tests := []struct {
name string
args map[string]any
want string
contains []string
maxLen int
}{
{
name: "empty map",
args: map[string]any{},
want: "",
},
{
name: "simple map",
args: map[string]any{"key": "value"},
contains: []string{"key=", `"value"`},
},
{
name: "long args truncated at 60",
args: map[string]any{"data": strings.Repeat("a", 300)},
maxLen: 60,
},
{
name: "multiple args",
args: map[string]any{"path": "/tmp/test", "command": "ls"},
contains: []string{"path=", "command="},
},
{
name: "numeric args",
args: map[string]any{"count": 42, "ratio": 3.14},
contains: []string{"count=42", "ratio=3.14"},
},
{
name: "array args",
args: map[string]any{"items": []any{1, 2, 3}},
contains: []string{"items=", "[3 items]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatToolArgs(tt.args)
if tt.want != "" && got != tt.want {
t.Errorf("FormatToolArgs() = %q, want %q", got, tt.want)
}
for _, substr := range tt.contains {
if !strings.Contains(got, substr) {
t.Errorf("FormatToolArgs() = %q, missing %q", got, substr)
}
}
if tt.maxLen > 0 {
if len(got) > tt.maxLen {
t.Errorf("FormatToolArgs() len = %d, want <= %d", len(got), tt.maxLen)
}
// Check for truncation indicator (either "..." in value or at end)
if !strings.Contains(got, "...") {
t.Errorf("FormatToolArgs() should contain '...' when truncated, got %q", got)
}
}
})
}
}

171
internal/agent/memory.go Normal file
View File

@ -0,0 +1,171 @@
package agent
import (
"fmt"
"strings"
"ai-agent/internal/llm"
"ai-agent/internal/memory"
)
func (a *Agent) memoryBuiltinToolDefs() []llm.ToolDef {
return memory.BuiltinToolDefs()
}
func (a *Agent) isMemoryTool(name string) bool {
return memory.IsBuiltinTool(name)
}
func (a *Agent) handleMemoryTool(tc llm.ToolCall) (string, bool) {
switch tc.Name {
case "memory_save":
return a.handleMemorySave(tc.Arguments)
case "memory_recall":
return a.handleMemoryRecall(tc.Arguments)
case "memory_delete":
return a.handleMemoryDelete(tc.Arguments)
case "memory_update":
return a.handleMemoryUpdate(tc.Arguments)
case "memory_list":
return a.handleMemoryList(tc.Arguments)
default:
return fmt.Sprintf("unknown memory tool: %s", tc.Name), true
}
}
func (a *Agent) handleMemorySave(args map[string]any) (string, bool) {
content, _ := args["content"].(string)
if content == "" {
return "error: content is required", true
}
var tags []string
if rawTags, ok := args["tags"]; ok {
switch v := rawTags.(type) {
case []any:
for _, t := range v {
if s, ok := t.(string); ok {
tags = append(tags, s)
}
}
case []string:
tags = v
}
}
id, err := a.memoryStore.Save(content, tags)
if err != nil {
return fmt.Sprintf("error saving memory: %v", err), true
}
return fmt.Sprintf("Memory saved (id: %d)", id), false
}
func (a *Agent) handleMemoryRecall(args map[string]any) (string, bool) {
query, _ := args["query"].(string)
if query == "" {
return "error: query is required", true
}
memories := a.memoryStore.Recall(query, 5)
if len(memories) == 0 {
return "No matching memories found.", false
}
var b strings.Builder
fmt.Fprintf(&b, "Found %d matching memories:\n", len(memories))
for _, mem := range memories {
fmt.Fprintf(&b, "- [%d] %s", mem.ID, mem.Content)
if len(mem.Tags) > 0 {
fmt.Fprintf(&b, " (tags: %s)", strings.Join(mem.Tags, ", "))
}
b.WriteString("\n")
}
return b.String(), false
}
func (a *Agent) handleMemoryDelete(args map[string]any) (string, bool) {
idVal, ok := args["id"]
if !ok {
return "error: id is required", true
}
var id int
switch v := idVal.(type) {
case float64:
id = int(v)
case int:
id = v
default:
return "error: id must be a number", true
}
deleted, err := a.memoryStore.Delete(id)
if err != nil {
return fmt.Sprintf("error deleting memory: %v", err), true
}
if !deleted {
return fmt.Sprintf("memory with id %d not found", id), true
}
return fmt.Sprintf("Memory %d deleted", id), false
}
func (a *Agent) handleMemoryUpdate(args map[string]any) (string, bool) {
idVal, ok := args["id"]
if !ok {
return "error: id is required", true
}
var id int
switch v := idVal.(type) {
case float64:
id = int(v)
case int:
id = v
default:
return "error: id must be a number", true
}
content, _ := args["content"].(string)
var tags []string
if rawTags, ok := args["tags"]; ok {
switch v := rawTags.(type) {
case []any:
for _, t := range v {
if s, ok := t.(string); ok {
tags = append(tags, s)
}
}
case []string:
tags = v
}
}
if content == "" && len(tags) == 0 {
return "error: at least one of content or tags is required", true
}
updated, err := a.memoryStore.Update(id, content, tags)
if err != nil {
return fmt.Sprintf("error updating memory: %v", err), true
}
if !updated {
return fmt.Sprintf("memory with id %d not found", id), true
}
return fmt.Sprintf("Memory %d updated", id), false
}
func (a *Agent) handleMemoryList(args map[string]any) (string, bool) {
limit := 20
if rawLimit, ok := args["limit"]; ok {
switch v := rawLimit.(type) {
case float64:
limit = int(v)
case int:
limit = v
}
}
memories := a.memoryStore.Recent(limit)
if len(memories) == 0 {
return "No memories stored.", false
}
var b strings.Builder
fmt.Fprintf(&b, "Stored memories (%d total):\n", a.memoryStore.Count())
for _, mem := range memories {
fmt.Fprintf(&b, "- [%d] %s", mem.ID, mem.Content)
if len(mem.Tags) > 0 {
fmt.Fprintf(&b, " (tags: %s)", strings.Join(mem.Tags, ", "))
}
b.WriteString("\n")
}
return b.String(), false
}

View File

@ -0,0 +1,159 @@
package agent
import (
"path/filepath"
"strings"
"testing"
"ai-agent/internal/llm"
"ai-agent/internal/mcp"
"ai-agent/internal/memory"
)
func newTestAgentWithMemory(t *testing.T) *Agent {
t.Helper()
store := memory.NewStore(filepath.Join(t.TempDir(), "test-memories.json"))
return &Agent{memoryStore: store, registry: mcp.NewRegistry()}
}
func TestHandleMemoryTool(t *testing.T) {
tests := []struct {
name string
toolCall llm.ToolCall
wantSubstr string
wantErr bool
}{
{
name: "dispatch to save",
toolCall: llm.ToolCall{
Name: "memory_save",
Arguments: map[string]any{"content": "test fact", "tags": []any{"tag1"}},
},
wantSubstr: "Memory saved (id:",
wantErr: false,
},
{
name: "dispatch to recall",
toolCall: llm.ToolCall{
Name: "memory_recall",
Arguments: map[string]any{"query": "test"},
},
wantSubstr: "No matching memories found.",
wantErr: false,
},
{
name: "unknown tool",
toolCall: llm.ToolCall{
Name: "unknown",
Arguments: map[string]any{},
},
wantSubstr: "unknown memory tool: unknown",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ag := newTestAgentWithMemory(t)
result, isErr := ag.handleMemoryTool(tt.toolCall)
if isErr != tt.wantErr {
t.Errorf("handleMemoryTool() isErr = %v, want %v", isErr, tt.wantErr)
}
if !strings.Contains(result, tt.wantSubstr) {
t.Errorf("handleMemoryTool() = %q, want substring %q", result, tt.wantSubstr)
}
})
}
}
func TestHandleMemorySave(t *testing.T) {
tests := []struct {
name string
args map[string]any
wantSubstr string
wantErr bool
}{
{
name: "valid with tags as []any",
args: map[string]any{"content": "test fact", "tags": []any{"tag1", "tag2"}},
wantSubstr: "Memory saved (id:",
wantErr: false,
},
{
name: "valid without tags",
args: map[string]any{"content": "another fact"},
wantSubstr: "Memory saved (id:",
wantErr: false,
},
{
name: "missing content",
args: map[string]any{},
wantSubstr: "error: content is required",
wantErr: true,
},
{
name: "empty content",
args: map[string]any{"content": ""},
wantSubstr: "error: content is required",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ag := newTestAgentWithMemory(t)
result, isErr := ag.handleMemorySave(tt.args)
if isErr != tt.wantErr {
t.Errorf("handleMemorySave() isErr = %v, want %v", isErr, tt.wantErr)
}
if !strings.Contains(result, tt.wantSubstr) {
t.Errorf("handleMemorySave() = %q, want substring %q", result, tt.wantSubstr)
}
})
}
}
func TestHandleMemoryRecall(t *testing.T) {
tests := []struct {
name string
setup func(ag *Agent)
args map[string]any
wantSubstr string
wantErr bool
}{
{
name: "valid recall finds saved memory",
setup: func(ag *Agent) {
_, _ = ag.memoryStore.Save("user prefers Go", []string{"language"})
},
args: map[string]any{"query": "Go"},
wantSubstr: "Found 1 matching memories:",
wantErr: false,
},
{
name: "missing query",
setup: func(ag *Agent) {},
args: map[string]any{},
wantSubstr: "error: query is required",
wantErr: true,
},
{
name: "no matches",
setup: func(ag *Agent) {},
args: map[string]any{"query": "nonexistent"},
wantSubstr: "No matching memories found.",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ag := newTestAgentWithMemory(t)
tt.setup(ag)
result, isErr := ag.handleMemoryRecall(tt.args)
if isErr != tt.wantErr {
t.Errorf("handleMemoryRecall() isErr = %v, want %v", isErr, tt.wantErr)
}
if !strings.Contains(result, tt.wantSubstr) {
t.Errorf("handleMemoryRecall() = %q, want substring %q", result, tt.wantSubstr)
}
})
}
}

24
internal/agent/output.go Normal file
View File

@ -0,0 +1,24 @@
package agent
import "time"
// Output is the interface the agent uses to stream results to the UI.
type Output interface {
// StreamText sends incremental text content.
StreamText(text string)
// StreamDone signals that the current response is complete.
StreamDone(evalCount, promptTokens int)
// ToolCallStart signals the beginning of a tool invocation.
ToolCallStart(name string, args map[string]any)
// ToolCallResult delivers the result of a tool invocation.
ToolCallResult(name string, result string, isError bool, duration time.Duration)
// SystemMessage displays a system-level message to the user.
SystemMessage(msg string)
// Error reports a non-fatal error to the user.
Error(msg string)
}

272
internal/agent/system.go Normal file
View File

@ -0,0 +1,272 @@
package agent
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"ai-agent/internal/llm"
"ai-agent/internal/memory"
)
const systemTemplate = `You are a helpful personal assistant running locally on the user's machine.
You have access to tools via MCP servers. You MUST use tools to accomplish tasks do not guess or make up answers when a tool can provide the real information.
%s
Current date: %s
%s%s
%s%s%s
## Available Tools
%s
## Guidelines
- **ALWAYS use your tools** when the user asks you to read, explore, search, or modify files. You have filesystem tools use them.
- When the user says "read this codebase" or similar, use list/read tools starting from the working directory shown above.
- Be concise and direct in your responses.
- When a tool call fails, explain what happened and suggest alternatives.
- For multi-step tasks, explain your plan briefly before executing.
- Format responses in markdown when it improves readability.
- If you're unsure about something, say so rather than guessing.
- Never fabricate tool results always call the actual tool.
- Do NOT claim you cannot access files or the filesystem. You have tools for that use them.
%s`
const smallModelTemplate = `You are a local AI assistant. Use tools to read/write files and run commands.
%sDate: %s
%s%s
%s
## Tools
%s
Guidelines:
- Be concise and direct
- Use tools when needed to complete tasks
- If a tool fails, continue with available information
- Don't guess - use tools to verify
- You can complete tasks even if some tools fail
%s`
func isSmallModel(modelName string) bool {
lower := strings.ToLower(modelName)
if strings.Contains(lower, "0.8b") || strings.Contains(lower, "1b") || strings.Contains(lower, "2b") {
return true
}
return false
}
func buildSystemPrompt(modePrefix string, tools []llm.ToolDef, skillContent, loadedContext string, memStore *memory.Store, iceContext, workDir, ignoreContent string) string {
return buildSystemPromptForModel(modePrefix, tools, skillContent, loadedContext, memStore, iceContext, workDir, ignoreContent, "")
}
func buildSystemPromptForModel(modePrefix string, tools []llm.ToolDef, skillContent, loadedContext string, memStore *memory.Store, iceContext, workDir, ignoreContent string, modelName string) string {
useSmallModel := isSmallModel(modelName)
var toolList string
if len(tools) == 0 {
toolList = "No tools currently available.\n"
} else if useSmallModel {
toolList = simplifyToolsForSmallModel(tools)
} else {
var b strings.Builder
for _, t := range tools {
fmt.Fprintf(&b, "- **%s**: %s\n", t.Name, t.Description)
}
toolList = b.String()
}
envSection := buildEnvironmentSection(workDir)
var skillSection string
if skillContent != "" {
skillSection = fmt.Sprintf("\n## Active Skills\n%s\n", skillContent)
}
var ctxSection string
if loadedContext != "" {
ctxSection = fmt.Sprintf("\n## Loaded Context\n%s\n", loadedContext)
}
var memorySection string
if iceContext != "" {
memorySection = iceContext
} else if memStore != nil {
memorySection = buildMemorySection(memStore)
}
var memoryGuidelines string
if memStore != nil {
memoryGuidelines = `
## Memory Guidelines
- You have access to persistent memory via memory_save and memory_recall tools.
- Proactively save important user preferences, project facts, and key decisions.
- When the user shares personal information (name, preferences, etc.), save it.
- Use memory_recall to look up previously saved information when relevant.
- Don't save trivial or session-specific information.
`
}
var ignoreSection string
if ignoreContent != "" {
ignoreSection = fmt.Sprintf("\n## Ignored Paths\nThe following paths/patterns should be excluded from file operations:\n%s\n", ignoreContent)
}
var modePrefixSection string
if modePrefix != "" {
modePrefixSection = "\n" + modePrefix + "\n"
}
dateStr := time.Now().Format("Monday, January 2, 2006")
if useSmallModel {
return fmt.Sprintf(smallModelTemplate,
modePrefixSection,
dateStr,
envSection,
ignoreSection,
skillSection,
toolList,
memoryGuidelines,
)
}
return fmt.Sprintf(systemTemplate,
modePrefixSection,
dateStr,
envSection,
ignoreSection,
skillSection,
ctxSection,
memorySection,
toolList,
memoryGuidelines,
)
}
func simplifyToolsForSmallModel(tools []llm.ToolDef) string {
var b strings.Builder
for _, t := range tools {
desc := t.Description
if len(desc) > 50 {
desc = desc[:47] + "..."
}
fmt.Fprintf(&b, "- %s: %s\n", t.Name, desc)
}
return b.String()
}
func buildEnvironmentSection(workDir string) string {
if workDir == "" {
return ""
}
var b strings.Builder
b.WriteString("\n## Environment\n")
b.WriteString(fmt.Sprintf("Working directory: %s\n", workDir))
if info := detectProjectInfo(workDir); info != "" {
b.WriteString(info)
}
if gitInfo := detectGitInfo(workDir); gitInfo != "" {
b.WriteString(gitInfo)
}
return b.String()
}
func detectProjectInfo(workDir string) string {
markers := []struct {
file string
desc string
}{
{"go.mod", "Go module"},
{"package.json", "Node.js/JavaScript"},
{"Cargo.toml", "Rust"},
{"pyproject.toml", "Python"},
{"setup.py", "Python"},
{"Makefile", ""},
{"Taskfile.yml", ""},
}
var found []string
for _, m := range markers {
if _, err := os.Stat(filepath.Join(workDir, m.file)); err == nil {
if m.desc != "" {
found = append(found, fmt.Sprintf("%s (%s)", m.file, m.desc))
} else {
found = append(found, m.file)
}
}
}
if len(found) == 0 {
return ""
}
return fmt.Sprintf("Project markers: %s\n", strings.Join(found, ", "))
}
func detectGitInfo(workDir string) string {
gitDir := filepath.Join(workDir, ".git")
if _, err := os.Stat(gitDir); err != nil {
return ""
}
var b strings.Builder
branch := runGitCommand(workDir, "rev-parse", "--abbrev-ref", "HEAD")
if branch != "" {
b.WriteString(fmt.Sprintf("Git branch: %s\n", branch))
}
status := runGitCommand(workDir, "status", "--porcelain")
if status != "" {
lines := strings.Split(strings.TrimSpace(status), "\n")
var modified, added, deleted int
for _, line := range lines {
if len(line) >= 2 {
switch line[0] {
case 'M', 'm':
modified++
case 'A':
added++
case 'D':
deleted++
}
}
}
if modified > 0 || added > 0 || deleted > 0 {
statusParts := []string{}
if modified > 0 {
statusParts = append(statusParts, fmt.Sprintf("%d modified", modified))
}
if added > 0 {
statusParts = append(statusParts, fmt.Sprintf("%d added", added))
}
if deleted > 0 {
statusParts = append(statusParts, fmt.Sprintf("%d deleted", deleted))
}
b.WriteString(fmt.Sprintf("Git status: %s\n", strings.Join(statusParts, ", ")))
}
}
recentLog := runGitCommand(workDir, "log", "-3", "--oneline")
if recentLog != "" {
b.WriteString(fmt.Sprintf("Recent commits:\n"))
for _, line := range strings.Split(strings.TrimSpace(recentLog), "\n") {
b.WriteString(fmt.Sprintf(" - %s\n", line))
}
}
if b.Len() == 0 {
return ""
}
return b.String()
}
func runGitCommand(dir string, args ...string) string {
cmd := exec.Command("git", args...)
cmd.Dir = dir
out, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
func buildMemorySection(store *memory.Store) string {
if store.Count() == 0 {
return ""
}
recent := store.Recent(10)
if len(recent) == 0 {
return ""
}
var b strings.Builder
b.WriteString("\n## Remembered Facts\n")
for _, mem := range recent {
b.WriteString(fmt.Sprintf("- %s", mem.Content))
if len(mem.Tags) > 0 {
b.WriteString(fmt.Sprintf(" [tags: %s]", strings.Join(mem.Tags, ", ")))
}
b.WriteString("\n")
}
return b.String()
}

View File

@ -0,0 +1,186 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
"ai-agent/internal/llm"
"ai-agent/internal/memory"
)
func TestBuildSystemPrompt(t *testing.T) {
tests := []struct {
name string
tools []llm.ToolDef
skillContent string
loadedCtx string
memStore *memory.Store
iceContext string
contains []string
notContains []string
}{
{
name: "no optional sections",
contains: []string{"No tools currently available.", "Current date:"},
notContains: []string{"Active Skills", "Loaded Context", "Remembered Facts"},
},
{
name: "with tools",
tools: []llm.ToolDef{
{Name: "test_tool", Description: "does stuff"},
},
contains: []string{"test_tool", "does stuff"},
notContains: []string{"No tools currently available."},
},
{
name: "with skills",
skillContent: "skill content here",
contains: []string{"Active Skills", "skill content here"},
},
{
name: "with loaded context",
loadedCtx: "some loaded context",
contains: []string{"Loaded Context", "some loaded context"},
},
{
name: "ICE overrides memory",
iceContext: "ice assembled context",
contains: []string{"ice assembled context"},
notContains: []string{"Remembered Facts"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildSystemPrompt("", tt.tools, tt.skillContent, tt.loadedCtx, tt.memStore, tt.iceContext, "", "")
for _, want := range tt.contains {
if !strings.Contains(result, want) {
t.Errorf("buildSystemPrompt() missing %q", want)
}
}
for _, notWant := range tt.notContains {
if strings.Contains(result, notWant) {
t.Errorf("buildSystemPrompt() should not contain %q", notWant)
}
}
})
}
t.Run("with memory store entries", func(t *testing.T) {
store := memory.NewStore(filepath.Join(t.TempDir(), "test-memories.json"))
_, _ = store.Save("user prefers dark mode", []string{"preference"})
result := buildSystemPrompt("", nil, "", "", store, "", "", "")
if !strings.Contains(result, "Remembered Facts") {
t.Error("expected Remembered Facts section")
}
if !strings.Contains(result, "user prefers dark mode") {
t.Error("expected memory content in prompt")
}
})
}
func TestBuildSystemPrompt_WithWorkDir(t *testing.T) {
result := buildSystemPrompt("", nil, "", "", nil, "", "/home/user/myproject", "")
if !strings.Contains(result, "Working directory: /home/user/myproject") {
t.Error("expected working directory in prompt")
}
if !strings.Contains(result, "Environment") {
t.Error("expected Environment section header")
}
}
func TestBuildSystemPrompt_EmptyWorkDir(t *testing.T) {
result := buildSystemPrompt("", nil, "", "", nil, "", "", "")
if strings.Contains(result, "Working directory") {
t.Error("should not include working directory when empty")
}
}
func TestBuildSystemPrompt_WithIgnoreContent(t *testing.T) {
ignoreContent := "- node_modules\n- *.log\n- build/"
result := buildSystemPrompt("", nil, "", "", nil, "", "", ignoreContent)
if !strings.Contains(result, "Ignored Paths") {
t.Error("expected Ignored Paths section header")
}
if !strings.Contains(result, "node_modules") {
t.Error("expected node_modules in ignore section")
}
if !strings.Contains(result, "*.log") {
t.Error("expected *.log in ignore section")
}
}
func TestBuildSystemPrompt_EmptyIgnoreContent(t *testing.T) {
result := buildSystemPrompt("", nil, "", "", nil, "", "", "")
if strings.Contains(result, "Ignored Paths") {
t.Error("should not include Ignored Paths when content is empty")
}
}
func TestDetectProjectInfo_GoProject(t *testing.T) {
dir := t.TempDir()
_ = os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test"), 0o644)
info := detectProjectInfo(dir)
if !strings.Contains(info, "go.mod") {
t.Errorf("expected go.mod in project info, got %q", info)
}
if !strings.Contains(info, "Go module") {
t.Errorf("expected 'Go module' in project info, got %q", info)
}
}
func TestDetectProjectInfo_EmptyDir(t *testing.T) {
dir := t.TempDir()
info := detectProjectInfo(dir)
if info != "" {
t.Errorf("expected empty for dir with no markers, got %q", info)
}
}
func TestBuildMemorySection(t *testing.T) {
tests := []struct {
name string
setup func(s *memory.Store)
contains []string
wantEmpty bool
}{
{
name: "empty store",
setup: func(s *memory.Store) {},
wantEmpty: true,
},
{
name: "store with tagged entry",
setup: func(s *memory.Store) {
_, _ = s.Save("likes Go", []string{"lang", "preference"})
},
contains: []string{"Remembered Facts", "likes Go", "[tags: lang, preference]"},
},
{
name: "store with untagged entry",
setup: func(s *memory.Store) {
_, _ = s.Save("project uses modules", nil)
},
contains: []string{"Remembered Facts", "project uses modules"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := memory.NewStore(filepath.Join(t.TempDir(), "mem.json"))
tt.setup(store)
result := buildMemorySection(store)
if tt.wantEmpty {
if result != "" {
t.Errorf("expected empty string, got %q", result)
}
return
}
for _, want := range tt.contains {
if !strings.Contains(result, want) {
t.Errorf("buildMemorySection() missing %q in:\n%s", want, result)
}
}
})
}
}

688
internal/agent/tools.go Normal file
View File

@ -0,0 +1,688 @@
package agent
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"ai-agent/internal/llm"
"ai-agent/internal/tools"
)
const (
maxTimeout = 120 * time.Second
)
func (a *Agent) toolsBuiltinToolDefs() []llm.ToolDef {
return tools.AllToolDefs()
}
func (a *Agent) isToolsTool(name string) bool {
return tools.IsBuiltinTool(name)
}
func (a *Agent) handleToolsTool(tc llm.ToolCall) (string, bool) {
switch tc.Name {
case "grep":
return a.handleGrep(tc.Arguments)
case "read":
return a.handleRead(tc.Arguments)
case "write":
return a.handleWrite(tc.Arguments)
case "glob":
return a.handleGlob(tc.Arguments)
case "bash":
return a.handleBash(tc.Arguments)
case "ls":
return a.handleLs(tc.Arguments)
case "find":
return a.handleFind(tc.Arguments)
case "diff":
return a.handleDiff(tc.Arguments)
case "edit":
return a.handleEdit(tc.Arguments)
case "mkdir":
return a.handleMkdir(tc.Arguments)
case "remove":
return a.handleRemove(tc.Arguments)
case "copy":
return a.handleCopy(tc.Arguments)
case "move":
return a.handleMove(tc.Arguments)
case "exists":
return a.handleExists(tc.Arguments)
default:
return fmt.Sprintf("unknown tool: %s", tc.Name), true
}
}
func (a *Agent) handleGrep(args map[string]any) (string, bool) {
pattern, _ := args["pattern"].(string)
if pattern == "" {
return "error: pattern is required", true
}
path := a.getArgString(args, "path", a.workDir)
include := a.getArgString(args, "include", "")
context := a.getArgInt(args, "context", 3)
maxResults := a.MaxGrepResults()
if _, err := os.Stat(path); err != nil {
return fmt.Sprintf("error: path does not exist: %s", path), true
}
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Sprintf("error: invalid regex pattern: %v", err), true
}
var results []string
err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if info.IsDir() {
if shouldSkipDir(info.Name()) {
return filepath.SkipDir
}
return nil
}
if include != "" {
matched, err := filepath.Match(include, info.Name())
if err != nil || !matched {
return nil
}
}
if strings.HasPrefix(info.Name(), ".") {
return nil
}
content, err := os.ReadFile(filePath)
if err != nil {
return nil
}
lines := strings.Split(string(content), "\n")
for i, line := range lines {
if re.MatchString(line) {
relPath, _ := filepath.Rel(path, filePath)
ctxStart := i - context
if ctxStart < 0 {
ctxStart = 0
}
ctxEnd := i + context + 1
if ctxEnd > len(lines) {
ctxEnd = len(lines)
}
results = append(results, fmt.Sprintf("%s:%d: %s", relPath, i+1, line))
if context > 0 && ctxStart < i {
for j := ctxStart; j < i; j++ {
if len(results) < maxResults {
results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j]))
}
}
}
if context > 0 && i+1 < ctxEnd {
for j := i + 1; j < ctxEnd; j++ {
if len(results) < maxResults {
results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j]))
}
}
}
if len(results) >= maxResults {
results = append(results, fmt.Sprintf("\n... (truncated, max %d results)", maxResults))
return filepath.SkipAll
}
}
}
return nil
})
if err != nil {
return fmt.Sprintf("error walking directory: %v", err), true
}
if len(results) == 0 {
return fmt.Sprintf("No matches found for pattern: %s", pattern), false
}
return strings.Join(results, "\n"), false
}
func (a *Agent) handleRead(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
data, err := os.ReadFile(path)
if err != nil {
return fmt.Sprintf("error reading file: %v", err), true
}
lines := strings.Split(string(data), "\n")
offset := a.getArgInt(args, "offset", 1)
limit := a.getArgInt(args, "limit", 0)
if offset > len(lines) {
return "error: offset beyond file length", true
}
if offset > 1 {
lines = lines[offset-1:]
}
if limit > 0 && len(lines) > limit {
lines = lines[:limit]
content := strings.Join(lines, "\n")
content += fmt.Sprintf("\n\n... (%d more lines)", len(lines)-limit)
return content, false
}
return strings.Join(lines, "\n"), false
}
func (a *Agent) handleWrite(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
content, _ := args["content"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Sprintf("error creating directory: %v", err), true
}
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
return fmt.Sprintf("error writing file: %v", err), true
}
return fmt.Sprintf("Written to %s (%d bytes)", path, len(content)), false
}
func (a *Agent) handleGlob(args map[string]any) (string, bool) {
pattern, _ := args["pattern"].(string)
if pattern == "" {
return "error: pattern is required", true
}
path := a.getArgString(args, "path", a.workDir)
if _, err := os.Stat(path); err != nil {
return fmt.Sprintf("error: path does not exist: %s", path), true
}
basePattern := filepath.Join(path, pattern)
matches, err := filepath.Glob(basePattern)
if err != nil {
return fmt.Sprintf("error: invalid pattern: %v", err), true
}
if len(matches) == 0 {
return fmt.Sprintf("No files match pattern: %s", pattern), false
}
relMatches := make([]string, 0, len(matches))
for _, m := range matches {
rel, err := filepath.Rel(path, m)
if err != nil {
continue
}
relMatches = append(relMatches, rel)
}
return strings.Join(relMatches, "\n"), false
}
func (a *Agent) handleBash(args map[string]any) (string, bool) {
command, _ := args["command"].(string)
if command == "" {
return "error: command is required", true
}
timeout := a.getArgInt(args, "timeout", int(a.ToolTimeout().Seconds()))
maxTimeoutSecs := int(a.ToolTimeout().Seconds())
if maxTimeoutSecs > 120 {
maxTimeoutSecs = 120
}
if timeout > maxTimeoutSecs {
timeout = maxTimeoutSecs
}
if timeout < 1 {
timeout = 1
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", command)
cmd.Dir = a.workDir
cmd.Env = os.Environ()
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
output := stdout.String()
if stderr.Len() > 0 {
if output != "" {
output += "\n"
}
output += "STDERR:\n" + stderr.String()
}
if ctx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("error: command timed out after %d seconds", timeout), true
}
if err != nil {
if output == "" {
return fmt.Sprintf("error: %v", err), true
}
return fmt.Sprintf("Command exited with error:\n%s", output), true
}
if output == "" {
return "Command completed successfully (no output)", false
}
return output, false
}
func (a *Agent) handleLs(args map[string]any) (string, bool) {
path := a.getArgString(args, "path", a.workDir)
path = a.resolvePath(path)
entries, err := os.ReadDir(path)
if err != nil {
return fmt.Sprintf("error reading directory: %v", err), true
}
if len(entries) == 0 {
return "Directory is empty", false
}
var dirs []string
var files []string
for _, e := range entries {
name := e.Name()
if e.IsDir() {
dirs = append(dirs, name+"/")
} else {
files = append(files, name)
}
}
var result strings.Builder
for _, d := range dirs {
result.WriteString(d + "\n")
}
for _, f := range files {
result.WriteString(f + "\n")
}
return result.String(), false
}
func (a *Agent) handleFind(args map[string]any) (string, bool) {
name, _ := args["name"].(string)
if name == "" {
return "error: name is required", true
}
path := a.getArgString(args, "path", a.workDir)
fileType := a.getArgString(args, "type", "")
if _, err := os.Stat(path); err != nil {
return fmt.Sprintf("error: path does not exist: %s", path), true
}
re, err := regexp.Compile("^" + strings.ReplaceAll(name, "*", ".*") + "$")
if err != nil {
return fmt.Sprintf("error: invalid name pattern: %v", err), true
}
var results []string
err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
if shouldSkipDir(info.Name()) && filePath != path {
if info.IsDir() {
return filepath.SkipDir
}
return nil
}
isDir := info.IsDir()
if fileType == "f" && isDir {
return nil
}
if fileType == "d" && !isDir {
return nil
}
if re.MatchString(info.Name()) {
relPath, _ := filepath.Rel(path, filePath)
if relPath != "." {
if isDir {
results = append(results, relPath+"/")
} else {
results = append(results, relPath)
}
}
}
return nil
})
if err != nil {
return fmt.Sprintf("error walking directory: %v", err), true
}
if len(results) == 0 {
return fmt.Sprintf("No files/directories found matching: %s", name), false
}
return strings.Join(results, "\n"), false
}
func (a *Agent) getArgString(args map[string]any, key, defaultValue string) string {
if v, ok := args[key].(string); ok && v != "" {
return v
}
return defaultValue
}
func (a *Agent) getArgInt(args map[string]any, key string, defaultValue int) int {
if v, ok := args[key]; ok {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
case string:
if n == "" {
return defaultValue
}
if i, err := strconv.Atoi(n); err == nil {
return i
}
}
}
return defaultValue
}
func (a *Agent) resolvePath(path string) string {
if filepath.IsAbs(path) {
return path
}
return filepath.Join(a.workDir, path)
}
func shouldSkipDir(name string) bool {
switch name {
case "node_modules", ".git", "__pycache__", ".venv", "venv",
"dist", "build", "target", ".cache", ".npm",
".svn", "CVS", ".hg", ".bzr":
return true
}
return strings.HasPrefix(name, ".")
}
func (a *Agent) handleDiff(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
newContent, _ := args["new_content"].(string)
if path == "" {
return "error: path is required", true
}
if newContent == "" {
return "error: new_content is required", true
}
path = a.resolvePath(path)
oldContent, err := os.ReadFile(path)
if err != nil {
return fmt.Sprintf("error reading file: %v", err), true
}
oldLines := strings.Split(string(oldContent), "\n")
newLines := strings.Split(newContent, "\n")
diff := computeDiff(oldLines, newLines)
if diff == "" {
return "No changes (files are identical)", false
}
return diff, false
}
func computeDiff(oldLines, newLines []string) string {
var result strings.Builder
oldLen := len(oldLines)
newLen := len(newLines)
lcs := longestCommonSubsequence(oldLines, newLines)
oldIdx := 0
newIdx := 0
lcsIdx := 0
for oldIdx < oldLen || newIdx < newLen {
if lcsIdx < len(lcs) {
for oldIdx < oldLen && oldLines[oldIdx] != lcs[lcsIdx] {
result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx]))
oldIdx++
}
for newIdx < newLen && newLines[newIdx] != lcs[lcsIdx] {
result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx]))
newIdx++
}
if oldIdx < oldLen && newIdx < newLen {
result.WriteString(fmt.Sprintf(" %s\n", lcs[lcsIdx]))
oldIdx++
newIdx++
lcsIdx++
}
} else {
for oldIdx < oldLen {
result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx]))
oldIdx++
}
for newIdx < newLen {
result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx]))
newIdx++
}
}
}
return result.String()
}
func longestCommonSubsequence(a, b []string) []string {
m, n := len(a), len(b)
dp := make([][]int, m+1)
for i := range dp {
dp[i] = make([]int, n+1)
}
for i := 1; i <= m; i++ {
for j := 1; j <= n; j++ {
if a[i-1] == b[j-1] {
dp[i][j] = dp[i-1][j-1] + 1
} else {
if dp[i-1][j] > dp[i][j-1] {
dp[i][j] = dp[i-1][j]
} else {
dp[i][j] = dp[i][j-1]
}
}
}
}
var lcs []string
i, j := m, n
for i > 0 && j > 0 {
if a[i-1] == b[j-1] {
lcs = append([]string{a[i-1]}, lcs...)
i--
j--
} else if dp[i-1][j] > dp[i][j-1] {
i--
} else {
j--
}
}
return lcs
}
func (a *Agent) handleEdit(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
patch, _ := args["patch"].(string)
if path == "" {
return "error: path is required", true
}
if patch == "" {
return "error: patch is required", true
}
path = a.resolvePath(path)
oldContent, err := os.ReadFile(path)
if err != nil {
return fmt.Sprintf("error reading file: %v", err), true
}
newContent, err := applyPatch(string(oldContent), patch)
if err != nil {
return fmt.Sprintf("error applying patch: %v", err), true
}
if err := os.WriteFile(path, []byte(newContent), 0644); err != nil {
return fmt.Sprintf("error writing file: %v", err), true
}
return fmt.Sprintf("Applied patch to %s (%d bytes)", path, len(newContent)), false
}
func (a *Agent) handleMkdir(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Sprintf("error creating directory: %v", err), true
}
return fmt.Sprintf("Created directory: %s", path), false
}
func (a *Agent) handleRemove(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
recursive := a.getArgBool(args, "recursive", false)
force := a.getArgBool(args, "force", false)
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
if force {
return "Removed (ignored nonexistent)", false
}
return fmt.Sprintf("error: path does not exist: %s", path), true
}
return fmt.Sprintf("error: %v", err), true
}
if info.IsDir() {
if recursive {
err = os.RemoveAll(path)
} else {
err = os.Remove(path)
}
} else {
err = os.Remove(path)
}
if err != nil {
return fmt.Sprintf("error removing: %v", err), true
}
return fmt.Sprintf("Removed: %s", path), false
}
func (a *Agent) handleCopy(args map[string]any) (string, bool) {
source, _ := args["source"].(string)
destination, _ := args["destination"].(string)
if source == "" || destination == "" {
return "error: source and destination are required", true
}
source = a.resolvePath(source)
destination = a.resolvePath(destination)
info, err := os.Stat(source)
if err != nil {
return fmt.Sprintf("error: %v", err), true
}
if info.IsDir() {
return "error: copying directories not supported (use bash with cp -r)", true
}
srcData, err := os.ReadFile(source)
if err != nil {
return fmt.Sprintf("error reading source: %v", err), true
}
dir := filepath.Dir(destination)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Sprintf("error creating destination directory: %v", err), true
}
err = os.WriteFile(destination, srcData, info.Mode())
if err != nil {
return fmt.Sprintf("error writing destination: %v", err), true
}
return fmt.Sprintf("Copied: %s -> %s", source, destination), false
}
func (a *Agent) handleMove(args map[string]any) (string, bool) {
source, _ := args["source"].(string)
destination, _ := args["destination"].(string)
if source == "" || destination == "" {
return "error: source and destination are required", true
}
source = a.resolvePath(source)
destination = a.resolvePath(destination)
dir := filepath.Dir(destination)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Sprintf("error creating destination directory: %v", err), true
}
err := os.Rename(source, destination)
if err != nil {
return fmt.Sprintf("error moving: %v", err), true
}
return fmt.Sprintf("Moved: %s -> %s", source, destination), false
}
func (a *Agent) handleExists(args map[string]any) (string, bool) {
path, _ := args["path"].(string)
if path == "" {
return "error: path is required", true
}
path = a.resolvePath(path)
info, err := os.Stat(path)
if os.IsNotExist(err) {
return fmt.Sprintf("false: %s does not exist", path), false
}
if err != nil {
return fmt.Sprintf("error: %v", err), true
}
if info.IsDir() {
return fmt.Sprintf("true: %s (directory)", path), false
}
return fmt.Sprintf("true: %s (file, %d bytes)", path, info.Size()), false
}
func (a *Agent) getArgBool(args map[string]any, key string, defaultValue bool) bool {
if v, ok := args[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return defaultValue
}
func applyPatch(content, patch string) (string, error) {
lines := strings.Split(content, "\n")
patchLines := strings.Split(patch, "\n")
var result []string
i := 0
for i < len(patchLines) {
line := patchLines[i]
if strings.HasPrefix(line, "@@") {
parts := strings.Fields(line)
if len(parts) < 4 {
return "", fmt.Errorf("invalid hunk header: %s", line)
}
oldSpec := strings.TrimPrefix(parts[1], "-")
oldParts := strings.Split(oldSpec, ",")
oldStart, _ := strconv.Atoi(oldParts[0])
newSpec := strings.TrimPrefix(parts[2], "+")
newParts := strings.Split(newSpec, ",")
newStart, _ := strconv.Atoi(newParts[0])
oldIdx := oldStart - 1
newIdx := newStart - 1
i++
for i < len(patchLines) && !strings.HasPrefix(patchLines[i], "@@") {
patchLine := patchLines[i]
if strings.HasPrefix(patchLine, "-") {
if oldIdx < len(lines) {
_ = lines[oldIdx]
oldIdx++
}
} else if strings.HasPrefix(patchLine, "+") {
content := strings.TrimPrefix(patchLine, "+")
result = append(result, content)
newIdx++
} else if strings.HasPrefix(patchLine, " ") || patchLine == "" {
if oldIdx < len(lines) {
result = append(result, lines[oldIdx])
oldIdx++
}
} else {
result = append(result, patchLine)
}
i++
}
continue
}
i++
}
if len(result) == 0 {
return content, nil
}
return strings.Join(result, "\n"), nil
}

View File

@ -0,0 +1,387 @@
package command
import (
"fmt"
"os"
"strings"
)
const maxContextFileSize = 32 * 1024 // 32KB
// RegisterBuiltins adds all built-in slash commands to the registry.
func RegisterBuiltins(r *Registry) {
r.Register(&Command{
Name: "help",
Aliases: []string{"h", "?"},
Description: "Show help overlay with shortcuts and commands",
Handler: func(_ *Context, _ []string) Result {
return Result{Action: ActionShowHelp}
},
})
r.Register(&Command{
Name: "clear",
Description: "Clear conversation history",
Handler: func(_ *Context, _ []string) Result {
return Result{
Text: "Conversation cleared.",
Action: ActionClear,
}
},
})
r.Register(&Command{
Name: "new",
Description: "Start a fresh conversation",
Handler: func(_ *Context, _ []string) Result {
return Result{
Text: "New conversation started.",
Action: ActionClear,
}
},
})
r.Register(&Command{
Name: "model",
Aliases: []string{"m"},
Description: "Show, switch, or list models",
Usage: "/model [name|list|fast|smart]",
Handler: func(ctx *Context, args []string) Result {
if len(args) == 0 {
return Result{Action: ActionShowModelPicker}
}
switch args[0] {
case "list", "ls":
var b strings.Builder
b.WriteString("Available models:\n")
for _, m := range ctx.ModelList {
marker := " "
if m == ctx.Model {
marker = "* "
}
fmt.Fprintf(&b, " %s%s\n", marker, m)
}
b.WriteString("\n* = current")
return Result{Text: b.String()}
case "fast":
if len(ctx.ModelList) > 0 {
return Result{
Text: fmt.Sprintf("Switching to fastest model: %s", ctx.ModelList[0]),
Action: ActionSwitchModel,
Data: ctx.ModelList[0],
}
}
return Result{Error: "No models available"}
case "smart":
if len(ctx.ModelList) > 0 {
smartModel := ctx.ModelList[len(ctx.ModelList)-1]
return Result{
Text: fmt.Sprintf("Switching to smartest model: %s", smartModel),
Action: ActionSwitchModel,
Data: smartModel,
}
}
return Result{Error: "No models available"}
default:
for _, m := range ctx.ModelList {
if m == args[0] {
return Result{
Text: fmt.Sprintf("Switching to model: %s", m),
Action: ActionSwitchModel,
Data: m,
}
}
}
return Result{Error: fmt.Sprintf("Unknown model: %s (use /model list to see available)", args[0])}
}
},
})
r.Register(&Command{
Name: "models",
Aliases: []string{"ml"},
Description: "Open model picker",
Handler: func(_ *Context, _ []string) Result {
return Result{Action: ActionShowModelPicker}
},
})
r.Register(&Command{
Name: "agent",
Aliases: []string{"a"},
Description: "Show or switch agent profile",
Usage: "/agent [name|list]",
Handler: func(ctx *Context, args []string) Result {
if len(args) == 0 || args[0] == "list" {
var b strings.Builder
if len(ctx.AgentList) == 0 {
b.WriteString("No agent profiles found in ~/.agents/agents/")
return Result{Text: b.String()}
}
b.WriteString("Available agent profiles:\n")
for _, a := range ctx.AgentList {
marker := " "
if a == ctx.AgentProfile {
marker = "* "
}
fmt.Fprintf(&b, " %s%s\n", marker, a)
}
b.WriteString("\n* = current")
return Result{Text: b.String()}
}
for _, a := range ctx.AgentList {
if a == args[0] {
return Result{
Text: fmt.Sprintf("Switching to agent: %s", a),
Action: ActionSwitchAgent,
Data: a,
}
}
}
return Result{Error: fmt.Sprintf("Unknown agent: %s (use /agent list to see available)", args[0])}
},
})
r.Register(&Command{
Name: "load",
Aliases: []string{"l"},
Description: "Load a markdown file as context",
Usage: "/load <path>",
Handler: func(_ *Context, args []string) Result {
if len(args) == 0 {
return Result{Error: "Usage: /load <path>"}
}
path := strings.Join(args, " ")
// Expand ~ to home directory.
if strings.HasPrefix(path, "~/") {
if home, err := os.UserHomeDir(); err == nil {
path = home + path[1:]
}
}
info, err := os.Stat(path)
if err != nil {
return Result{Error: fmt.Sprintf("Cannot access %s: %v", path, err)}
}
if info.Size() > maxContextFileSize {
return Result{Error: fmt.Sprintf("File too large (%d bytes, max %d)", info.Size(), maxContextFileSize)}
}
data, err := os.ReadFile(path)
if err != nil {
return Result{Error: fmt.Sprintf("Cannot read %s: %v", path, err)}
}
return Result{
Text: fmt.Sprintf("Loaded context: %s (%d bytes)", path, len(data)),
Action: ActionLoadContext,
Data: path + "\x00" + string(data), // path\0content
}
},
})
r.Register(&Command{
Name: "unload",
Description: "Remove loaded context file",
Handler: func(ctx *Context, _ []string) Result {
if ctx.LoadedFile == "" {
return Result{Text: "No context file loaded."}
}
return Result{
Text: "Context unloaded.",
Action: ActionUnloadContext,
}
},
})
r.Register(&Command{
Name: "skill",
Aliases: []string{"sk"},
Description: "Manage skills (list, activate, deactivate)",
Usage: "/skill [list|activate|deactivate] [name]",
Handler: func(ctx *Context, args []string) Result {
if len(args) == 0 || args[0] == "list" {
return skillList(ctx)
}
if len(args) < 2 {
return Result{Error: "Usage: /skill [list|activate|deactivate] <name>"}
}
switch args[0] {
case "activate", "on":
return Result{
Text: fmt.Sprintf("Activated skill: %s", args[1]),
Action: ActionActivateSkill,
Data: args[1],
}
case "deactivate", "off":
return Result{
Text: fmt.Sprintf("Deactivated skill: %s", args[1]),
Action: ActionDeactivateSkill,
Data: args[1],
}
default:
return Result{Error: fmt.Sprintf("Unknown skill action: %s (use list, activate, or deactivate)", args[0])}
}
},
})
r.Register(&Command{
Name: "servers",
Description: "List connected MCP servers",
Handler: func(ctx *Context, _ []string) Result {
if len(ctx.ServerNames) == 0 {
return Result{Text: "No MCP servers connected."}
}
var b strings.Builder
b.WriteString(fmt.Sprintf("Connected servers (%d):\n", len(ctx.ServerNames)))
for _, name := range ctx.ServerNames {
fmt.Fprintf(&b, " - %s\n", name)
}
b.WriteString(fmt.Sprintf("\nTotal tools: %d", ctx.ToolCount))
return Result{Text: b.String()}
},
})
r.Register(&Command{
Name: "ice",
Description: "Show Infinite Context Engine status",
Handler: func(ctx *Context, _ []string) Result {
if !ctx.ICEEnabled {
return Result{Text: "ICE is not enabled. Add `ice: {enabled: true}` to your config.yaml"}
}
var b strings.Builder
b.WriteString("Infinite Context Engine (ICE)\n")
fmt.Fprintf(&b, " Status: enabled\n")
fmt.Fprintf(&b, " Conversations: %d stored\n", ctx.ICEConversations)
fmt.Fprintf(&b, " Session ID: %s\n", ctx.ICESessionID)
fmt.Fprintf(&b, " Embed model: nomic-embed-text\n")
return Result{Text: b.String()}
},
})
r.Register(&Command{
Name: "sessions",
Aliases: []string{"ss"},
Description: "Browse and restore saved sessions",
Handler: func(_ *Context, _ []string) Result {
return Result{Action: ActionShowSessions}
},
})
r.Register(&Command{
Name: "changes",
Description: "List files modified by the agent this session",
Handler: func(ctx *Context, _ []string) Result {
if len(ctx.FileChanges) == 0 {
return Result{Text: "No files modified this session."}
}
var b strings.Builder
fmt.Fprintf(&b, "Files modified (%d):\n", len(ctx.FileChanges))
for path, count := range ctx.FileChanges {
if count > 1 {
fmt.Fprintf(&b, " ✎ %s (%dx)\n", path, count)
} else {
fmt.Fprintf(&b, " ✎ %s\n", path)
}
}
return Result{Text: b.String()}
},
})
r.Register(&Command{
Name: "commit",
Aliases: []string{"ci"},
Description: "Generate commit message from staged changes and commit",
Handler: func(_ *Context, args []string) Result {
return Result{Action: ActionCommit, Data: strings.Join(args, " ")}
},
})
r.Register(&Command{
Name: "stats",
Description: "Show token usage statistics for this session",
Handler: func(ctx *Context, _ []string) Result {
if ctx.SessionTurnCount == 0 {
return Result{Text: "No token usage recorded yet."}
}
var b strings.Builder
b.WriteString("Session Token Stats\n")
fmt.Fprintf(&b, " Model: %s\n", ctx.CurrentModel)
fmt.Fprintf(&b, " Turns: %d\n", ctx.SessionTurnCount)
fmt.Fprintf(&b, " Output tokens: %d\n", ctx.SessionEvalTotal)
fmt.Fprintf(&b, " Prompt tokens: %d (last turn)\n", ctx.SessionPromptTotal)
if ctx.NumCtx > 0 {
fmt.Fprintf(&b, " Context window: %d\n", ctx.NumCtx)
pct := ctx.SessionPromptTotal * 100 / ctx.NumCtx
fmt.Fprintf(&b, " Context used: %d%%\n", pct)
}
avgOut := ctx.SessionEvalTotal / ctx.SessionTurnCount
fmt.Fprintf(&b, " Avg out/turn: %d\n", avgOut)
return Result{Text: b.String()}
},
})
r.Register(&Command{
Name: "export",
Description: "Export conversation to a markdown file",
Usage: "/export [path]",
Handler: func(_ *Context, args []string) Result {
if len(args) < 1 || args[0] == "" {
return Result{Error: "usage: /export <filepath>"}
}
return Result{
Text: fmt.Sprintf("Exporting conversation to: %s", args[0]),
Action: ActionExport,
Data: args[0],
}
},
})
r.Register(&Command{
Name: "import",
Description: "Import conversation from a markdown file",
Usage: "/import [path]",
Handler: func(_ *Context, args []string) Result {
if len(args) < 1 || args[0] == "" {
return Result{Error: "usage: /import <filepath>"}
}
return Result{
Text: fmt.Sprintf("Importing conversation from: %s", args[0]),
Action: ActionImport,
Data: args[0],
}
},
})
r.Register(&Command{
Name: "exit",
Aliases: []string{"quit", "q"},
Description: "Quit ai-agent",
Handler: func(_ *Context, _ []string) Result {
return Result{Action: ActionQuit}
},
})
}
func skillList(ctx *Context) Result {
if len(ctx.Skills) == 0 {
return Result{Text: "No skills found. Add .md files to ~/.config/ai-agent/skills/"}
}
var b strings.Builder
b.WriteString(fmt.Sprintf("Skills (%d):\n", len(ctx.Skills)))
for _, s := range ctx.Skills {
status := " "
if s.Active {
status = "* "
}
fmt.Fprintf(&b, " %s%s — %s\n", status, s.Name, s.Description)
}
b.WriteString("\n* = active")
return Result{Text: b.String()}
}

View File

@ -0,0 +1,380 @@
package command
import (
"os"
"path/filepath"
"strings"
"testing"
)
func newTestRegistry() *Registry {
r := NewRegistry()
RegisterBuiltins(r)
return r
}
func TestBuiltin_Help(t *testing.T) {
r := newTestRegistry()
result := r.Execute(&Context{}, "help", nil)
if result.Action != ActionShowHelp {
t.Errorf("help action = %d, want %d (ActionShowHelp)", result.Action, ActionShowHelp)
}
}
func TestBuiltin_Clear(t *testing.T) {
r := newTestRegistry()
result := r.Execute(&Context{}, "clear", nil)
if result.Action != ActionClear {
t.Errorf("clear action = %d, want %d (ActionClear)", result.Action, ActionClear)
}
if result.Text == "" {
t.Error("clear should have text")
}
}
func TestBuiltin_New(t *testing.T) {
r := newTestRegistry()
result := r.Execute(&Context{}, "new", nil)
if result.Action != ActionClear {
t.Errorf("new action = %d, want %d (ActionClear)", result.Action, ActionClear)
}
if result.Text == "" {
t.Error("new should have text")
}
}
func TestBuiltin_Model(t *testing.T) {
r := newTestRegistry()
ctx := &Context{
Model: "qwen3.5:0.8b",
ModelList: []string{"qwen3.5:0.8b", "qwen3.5:2b", "qwen3.5:4b", "qwen3.5:9b"},
}
tests := []struct {
name string
args []string
wantAction Action
wantData string
wantErr bool
checkText string
}{
{
name: "no args opens model picker",
args: nil,
wantAction: ActionShowModelPicker,
},
{
name: "list shows models",
args: []string{"list"},
checkText: "Available models",
},
{
name: "fast switches to first",
args: []string{"fast"},
wantAction: ActionSwitchModel,
wantData: "qwen3.5:0.8b",
},
{
name: "smart switches to last",
args: []string{"smart"},
wantAction: ActionSwitchModel,
wantData: "qwen3.5:9b",
},
{
name: "valid name switches",
args: []string{"qwen3.5:2b"},
wantAction: ActionSwitchModel,
wantData: "qwen3.5:2b",
},
{
name: "invalid name errors",
args: []string{"nonexistent"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := r.Execute(ctx, "model", tt.args)
if tt.wantErr {
if result.Error == "" {
t.Error("expected error")
}
return
}
if result.Error != "" {
t.Errorf("unexpected error: %s", result.Error)
return
}
if tt.wantAction != ActionNone && result.Action != tt.wantAction {
t.Errorf("action = %d, want %d", result.Action, tt.wantAction)
}
if tt.wantData != "" && result.Data != tt.wantData {
t.Errorf("data = %q, want %q", result.Data, tt.wantData)
}
if tt.checkText != "" && !strings.Contains(result.Text, tt.checkText) {
t.Errorf("text %q does not contain %q", result.Text, tt.checkText)
}
})
}
}
func TestBuiltin_Models(t *testing.T) {
r := newTestRegistry()
ctx := &Context{
Model: "qwen3.5:0.8b",
ModelList: []string{"qwen3.5:0.8b", "qwen3.5:2b"},
}
result := r.Execute(ctx, "models", nil)
if result.Action != ActionShowModelPicker {
t.Errorf("expected ActionShowModelPicker, got %d", result.Action)
}
}
func TestBuiltin_Agent(t *testing.T) {
r := newTestRegistry()
t.Run("no args lists agents", func(t *testing.T) {
ctx := &Context{AgentList: []string{"coder", "reviewer"}, AgentProfile: "coder"}
result := r.Execute(ctx, "agent", nil)
if !strings.Contains(result.Text, "Available agent profiles") {
t.Errorf("expected agent list, got %q", result.Text)
}
})
t.Run("list subcommand", func(t *testing.T) {
ctx := &Context{AgentList: []string{"coder"}}
result := r.Execute(ctx, "agent", []string{"list"})
if !strings.Contains(result.Text, "Available agent profiles") {
t.Errorf("expected agent list, got %q", result.Text)
}
})
t.Run("valid switch", func(t *testing.T) {
ctx := &Context{AgentList: []string{"coder", "reviewer"}}
result := r.Execute(ctx, "agent", []string{"reviewer"})
if result.Action != ActionSwitchAgent {
t.Errorf("action = %d, want %d", result.Action, ActionSwitchAgent)
}
if result.Data != "reviewer" {
t.Errorf("data = %q, want %q", result.Data, "reviewer")
}
})
t.Run("invalid errors", func(t *testing.T) {
ctx := &Context{AgentList: []string{"coder"}}
result := r.Execute(ctx, "agent", []string{"unknown"})
if result.Error == "" {
t.Error("expected error for unknown agent")
}
})
t.Run("no agents", func(t *testing.T) {
ctx := &Context{AgentList: []string{}}
result := r.Execute(ctx, "agent", nil)
if !strings.Contains(result.Text, "No agent profiles") {
t.Errorf("expected no agents message, got %q", result.Text)
}
})
}
func TestBuiltin_Load(t *testing.T) {
r := newTestRegistry()
t.Run("no args errors", func(t *testing.T) {
result := r.Execute(&Context{}, "load", nil)
if result.Error == "" {
t.Error("expected error for no args")
}
})
t.Run("valid file loads", func(t *testing.T) {
tmp := t.TempDir()
path := filepath.Join(tmp, "test.md")
if err := os.WriteFile(path, []byte("# Hello"), 0644); err != nil {
t.Fatal(err)
}
result := r.Execute(&Context{}, "load", []string{path})
if result.Error != "" {
t.Errorf("unexpected error: %s", result.Error)
}
if result.Action != ActionLoadContext {
t.Errorf("action = %d, want %d", result.Action, ActionLoadContext)
}
// Data should be path\0content
parts := strings.SplitN(result.Data, "\x00", 2)
if len(parts) != 2 {
t.Fatalf("expected path\\0content, got %q", result.Data)
}
if parts[0] != path {
t.Errorf("data path = %q, want %q", parts[0], path)
}
if parts[1] != "# Hello" {
t.Errorf("data content = %q, want %q", parts[1], "# Hello")
}
})
t.Run("too large errors", func(t *testing.T) {
tmp := t.TempDir()
path := filepath.Join(tmp, "big.md")
data := make([]byte, 33*1024) // > 32KB
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(err)
}
result := r.Execute(&Context{}, "load", []string{path})
if result.Error == "" {
t.Error("expected error for oversized file")
}
if !strings.Contains(result.Error, "too large") {
t.Errorf("error = %q, want containing 'too large'", result.Error)
}
})
t.Run("nonexistent errors", func(t *testing.T) {
result := r.Execute(&Context{}, "load", []string{"/nonexistent/file.md"})
if result.Error == "" {
t.Error("expected error for nonexistent file")
}
})
}
func TestBuiltin_Unload(t *testing.T) {
r := newTestRegistry()
t.Run("no loaded file", func(t *testing.T) {
result := r.Execute(&Context{LoadedFile: ""}, "unload", nil)
if !strings.Contains(result.Text, "No context") {
t.Errorf("expected 'No context' message, got %q", result.Text)
}
})
t.Run("loaded file unloads", func(t *testing.T) {
result := r.Execute(&Context{LoadedFile: "something.md"}, "unload", nil)
if result.Action != ActionUnloadContext {
t.Errorf("action = %d, want %d", result.Action, ActionUnloadContext)
}
})
}
func TestBuiltin_Skill(t *testing.T) {
r := newTestRegistry()
ctx := &Context{
Skills: []SkillInfo{
{Name: "coder", Description: "Code generation", Active: true},
{Name: "reviewer", Description: "Code review", Active: false},
},
}
t.Run("no args lists skills", func(t *testing.T) {
result := r.Execute(ctx, "skill", nil)
if !strings.Contains(result.Text, "Skills") {
t.Errorf("expected skills list, got %q", result.Text)
}
})
t.Run("list subcommand", func(t *testing.T) {
result := r.Execute(ctx, "skill", []string{"list"})
if !strings.Contains(result.Text, "Skills") {
t.Errorf("expected skills list, got %q", result.Text)
}
})
t.Run("activate", func(t *testing.T) {
result := r.Execute(ctx, "skill", []string{"activate", "reviewer"})
if result.Action != ActionActivateSkill {
t.Errorf("action = %d, want %d", result.Action, ActionActivateSkill)
}
if result.Data != "reviewer" {
t.Errorf("data = %q, want %q", result.Data, "reviewer")
}
})
t.Run("deactivate", func(t *testing.T) {
result := r.Execute(ctx, "skill", []string{"deactivate", "coder"})
if result.Action != ActionDeactivateSkill {
t.Errorf("action = %d, want %d", result.Action, ActionDeactivateSkill)
}
if result.Data != "coder" {
t.Errorf("data = %q, want %q", result.Data, "coder")
}
})
t.Run("unknown action errors", func(t *testing.T) {
result := r.Execute(ctx, "skill", []string{"unknown", "foo"})
if result.Error == "" {
t.Error("expected error for unknown skill action")
}
})
t.Run("missing name errors", func(t *testing.T) {
result := r.Execute(ctx, "skill", []string{"activate"})
if result.Error == "" {
t.Error("expected error for missing skill name")
}
})
}
func TestBuiltin_Servers(t *testing.T) {
r := newTestRegistry()
t.Run("no servers", func(t *testing.T) {
result := r.Execute(&Context{ServerNames: nil}, "servers", nil)
if !strings.Contains(result.Text, "No MCP servers") {
t.Errorf("expected no servers message, got %q", result.Text)
}
})
t.Run("with servers", func(t *testing.T) {
ctx := &Context{
ServerNames: []string{"server-a", "server-b"},
ToolCount: 10,
}
result := r.Execute(ctx, "servers", nil)
if !strings.Contains(result.Text, "server-a") {
t.Errorf("expected server-a in output, got %q", result.Text)
}
if !strings.Contains(result.Text, "server-b") {
t.Errorf("expected server-b in output, got %q", result.Text)
}
if !strings.Contains(result.Text, "10") {
t.Errorf("expected tool count in output, got %q", result.Text)
}
})
}
func TestBuiltin_ICE(t *testing.T) {
r := newTestRegistry()
t.Run("disabled", func(t *testing.T) {
result := r.Execute(&Context{ICEEnabled: false}, "ice", nil)
if !strings.Contains(result.Text, "not enabled") {
t.Errorf("expected disabled message, got %q", result.Text)
}
})
t.Run("enabled shows status", func(t *testing.T) {
ctx := &Context{
ICEEnabled: true,
ICEConversations: 5,
ICESessionID: "abc-123",
}
result := r.Execute(ctx, "ice", nil)
if !strings.Contains(result.Text, "enabled") {
t.Errorf("expected enabled status, got %q", result.Text)
}
if !strings.Contains(result.Text, "5") {
t.Errorf("expected conversation count, got %q", result.Text)
}
if !strings.Contains(result.Text, "abc-123") {
t.Errorf("expected session ID, got %q", result.Text)
}
})
}
func TestBuiltin_Exit(t *testing.T) {
r := newTestRegistry()
result := r.Execute(&Context{}, "exit", nil)
if result.Action != ActionQuit {
t.Errorf("exit action = %d, want %d (ActionQuit)", result.Action, ActionQuit)
}
}

116
internal/command/custom.go Normal file
View File

@ -0,0 +1,116 @@
package command
import (
"os"
"path/filepath"
"strings"
)
// CustomCommand represents a user-defined command loaded from a markdown file.
type CustomCommand struct {
Name string
Description string
Template string // prompt template with {{input}} placeholder
}
// LoadCustomCommands reads .md files from the commands directory and returns
// parsed custom commands. Each file should have YAML-like frontmatter:
//
// ---
// name: review
// description: Code review prompt
// ---
// Review this code: {{input}}
func LoadCustomCommands(dir string) []CustomCommand {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var cmds []CustomCommand
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
continue
}
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
if err != nil {
continue
}
if cmd, ok := parseCustomCommand(string(data)); ok {
cmds = append(cmds, cmd)
}
}
return cmds
}
// parseCustomCommand parses a markdown file with YAML frontmatter.
func parseCustomCommand(content string) (CustomCommand, bool) {
content = strings.TrimSpace(content)
if !strings.HasPrefix(content, "---") {
return CustomCommand{}, false
}
// Find end of frontmatter.
rest := content[3:]
idx := strings.Index(rest, "---")
if idx < 0 {
return CustomCommand{}, false
}
frontmatter := rest[:idx]
body := strings.TrimSpace(rest[idx+3:])
cmd := CustomCommand{Template: body}
// Parse simple key: value pairs from frontmatter.
for _, line := range strings.Split(frontmatter, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
key := strings.TrimSpace(parts[0])
val := strings.TrimSpace(parts[1])
switch key {
case "name":
cmd.Name = val
case "description":
cmd.Description = val
}
}
if cmd.Name == "" || cmd.Template == "" {
return CustomCommand{}, false
}
return cmd, true
}
// RegisterCustomCommands loads and registers custom commands from the given directory.
func RegisterCustomCommands(r *Registry, dir string) {
cmds := LoadCustomCommands(dir)
for _, cc := range cmds {
// Capture for closure.
tmpl := cc.Template
desc := cc.Description
if desc == "" {
desc = "Custom command"
}
r.Register(&Command{
Name: cc.Name,
Description: desc,
Handler: func(_ *Context, args []string) Result {
input := strings.Join(args, " ")
prompt := strings.ReplaceAll(tmpl, "{{input}}", input)
return Result{
Action: ActionSendPrompt,
Data: prompt,
}
},
})
}
}

View File

@ -0,0 +1,148 @@
package command
import (
"os"
"path/filepath"
"testing"
)
func TestParseCustomCommand(t *testing.T) {
tests := []struct {
name string
content string
wantOK bool
wantCmd CustomCommand
}{
{
name: "valid command",
content: `---
name: review
description: Code review prompt
---
Review this code: {{input}}`,
wantOK: true,
wantCmd: CustomCommand{
Name: "review",
Description: "Code review prompt",
Template: "Review this code: {{input}}",
},
},
{
name: "no description",
content: `---
name: explain
---
Explain this: {{input}}`,
wantOK: true,
wantCmd: CustomCommand{
Name: "explain",
Template: "Explain this: {{input}}",
},
},
{
name: "no frontmatter",
content: "just some text",
wantOK: false,
},
{
name: "no name",
content: `---
description: something
---
body`,
wantOK: false,
},
{
name: "no body",
content: `---
name: empty
---`,
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd, ok := parseCustomCommand(tt.content)
if ok != tt.wantOK {
t.Fatalf("parseCustomCommand() ok = %v, want %v", ok, tt.wantOK)
}
if !ok {
return
}
if cmd.Name != tt.wantCmd.Name {
t.Errorf("Name = %q, want %q", cmd.Name, tt.wantCmd.Name)
}
if cmd.Description != tt.wantCmd.Description {
t.Errorf("Description = %q, want %q", cmd.Description, tt.wantCmd.Description)
}
if cmd.Template != tt.wantCmd.Template {
t.Errorf("Template = %q, want %q", cmd.Template, tt.wantCmd.Template)
}
})
}
}
func TestLoadCustomCommands(t *testing.T) {
dir := t.TempDir()
// Write a valid command file.
err := os.WriteFile(filepath.Join(dir, "review.md"), []byte(`---
name: review
description: Review code
---
Review: {{input}}`), 0o644)
if err != nil {
t.Fatal(err)
}
// Write an invalid file (no frontmatter).
err = os.WriteFile(filepath.Join(dir, "invalid.md"), []byte("just text"), 0o644)
if err != nil {
t.Fatal(err)
}
// Write a non-md file (should be ignored).
err = os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("not a command"), 0o644)
if err != nil {
t.Fatal(err)
}
cmds := LoadCustomCommands(dir)
if len(cmds) != 1 {
t.Fatalf("LoadCustomCommands() returned %d commands, want 1", len(cmds))
}
if cmds[0].Name != "review" {
t.Errorf("Name = %q, want %q", cmds[0].Name, "review")
}
}
func TestLoadCustomCommands_MissingDir(t *testing.T) {
cmds := LoadCustomCommands("/nonexistent/path")
if len(cmds) != 0 {
t.Errorf("expected empty result for missing dir, got %d", len(cmds))
}
}
func TestRegisterCustomCommands(t *testing.T) {
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "test.md"), []byte(`---
name: testcmd
description: A test command
---
Do this: {{input}}`), 0o644)
if err != nil {
t.Fatal(err)
}
reg := NewRegistry()
RegisterCustomCommands(reg, dir)
result := reg.Execute(&Context{}, "testcmd", []string{"hello", "world"})
if result.Action != ActionSendPrompt {
t.Errorf("Action = %v, want ActionSendPrompt", result.Action)
}
if result.Data != "Do this: hello world" {
t.Errorf("Data = %q, want %q", result.Data, "Do this: hello world")
}
}

View File

@ -0,0 +1,129 @@
package command
import (
"fmt"
"sort"
"strings"
)
// Command represents a slash command.
type Command struct {
Name string
Aliases []string
Description string
Usage string
Handler func(ctx *Context, args []string) Result
}
// Context provides commands with read access to application state.
type Context struct {
Model string
ModelList []string
AgentProfile string
AgentList []string
ToolCount int
ServerCount int
ServerNames []string
Skills []SkillInfo
LoadedFile string
ICEEnabled bool
ICEConversations int
ICESessionID string
// Token stats
SessionEvalTotal int
SessionPromptTotal int
SessionTurnCount int
NumCtx int
CurrentModel string
// File changes
FileChanges map[string]int // path → modification count
}
// SkillInfo is a read-only view of a skill for command display.
type SkillInfo struct {
Name string
Description string
Active bool
}
// Result is returned by command handlers to describe what to do.
type Result struct {
Text string // Display text (shown as system message)
Action Action // Side effect for the TUI to execute
Data string // Optional payload (e.g. file path, model name)
Error string // Error text (takes priority over Text)
}
// Action describes a side effect the TUI should perform.
type Action int
const (
ActionNone Action = iota
ActionShowHelp // Show help overlay
ActionClear // Clear conversation history
ActionQuit // Exit the application
ActionLoadContext // Load markdown context (Data = path)
ActionUnloadContext // Remove loaded context
ActionActivateSkill // Activate skill (Data = name)
ActionDeactivateSkill // Deactivate skill (Data = name)
ActionSwitchModel // Switch model (Data = model name)
ActionSwitchAgent // Switch agent profile (Data = agent name)
ActionShowSessions // Open sessions picker
ActionShowModelPicker // Open model picker overlay
ActionCommit // Generate commit message and commit
ActionSendPrompt // Send Data as a message to the agent
ActionExport // Export conversation (Data = path)
ActionImport // Import conversation (Data = path)
)
// Registry holds all registered slash commands.
type Registry struct {
commands map[string]*Command // name/alias → command
all []*Command // ordered list
}
// NewRegistry creates an empty command registry.
func NewRegistry() *Registry {
return &Registry{
commands: make(map[string]*Command),
}
}
// Register adds a command to the registry.
func (r *Registry) Register(cmd *Command) {
r.all = append(r.all, cmd)
r.commands[cmd.Name] = cmd
for _, alias := range cmd.Aliases {
r.commands[alias] = cmd
}
}
// Execute dispatches a slash command by name and returns the result.
func (r *Registry) Execute(ctx *Context, name string, args []string) Result {
cmd, ok := r.commands[name]
if !ok {
return Result{Error: fmt.Sprintf("unknown command: /%s — type /help for available commands", name)}
}
return cmd.Handler(ctx, args)
}
// All returns all registered commands in registration order.
func (r *Registry) All() []*Command {
return r.all
}
// Match returns commands whose name starts with the given prefix.
func (r *Registry) Match(prefix string) []*Command {
var matches []*Command
seen := make(map[string]bool)
for _, cmd := range r.all {
if strings.HasPrefix(cmd.Name, prefix) && !seen[cmd.Name] {
matches = append(matches, cmd)
seen[cmd.Name] = true
}
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].Name < matches[j].Name
})
return matches
}

View File

@ -0,0 +1,164 @@
package command
import "testing"
func TestRegistry_Register(t *testing.T) {
r := NewRegistry()
cmd := &Command{
Name: "test",
Description: "A test command",
Handler: func(_ *Context, _ []string) Result {
return Result{Text: "ok"}
},
}
r.Register(cmd)
all := r.All()
if len(all) != 1 {
t.Fatalf("expected 1 command, got %d", len(all))
}
if all[0].Name != "test" {
t.Errorf("command name = %q, want %q", all[0].Name, "test")
}
// Execute to verify it was registered correctly
result := r.Execute(&Context{}, "test", nil)
if result.Text != "ok" {
t.Errorf("result text = %q, want %q", result.Text, "ok")
}
}
func TestRegistry_Execute(t *testing.T) {
r := NewRegistry()
called := false
r.Register(&Command{
Name: "run",
Handler: func(_ *Context, _ []string) Result {
called = true
return Result{Text: "executed"}
},
})
t.Run("found command executes handler", func(t *testing.T) {
result := r.Execute(&Context{}, "run", nil)
if !called {
t.Error("handler was not called")
}
if result.Text != "executed" {
t.Errorf("result text = %q, want %q", result.Text, "executed")
}
})
t.Run("not found returns error", func(t *testing.T) {
result := r.Execute(&Context{}, "nonexistent", nil)
if result.Error == "" {
t.Error("expected error for unknown command")
}
})
}
func TestRegistry_ExecuteByAlias(t *testing.T) {
r := NewRegistry()
r.Register(&Command{
Name: "mycommand",
Aliases: []string{"mc", "m"},
Handler: func(_ *Context, _ []string) Result {
return Result{Text: "alias works"}
},
})
tests := []struct {
name string
cmdName string
wantOk bool
}{
{name: "by name", cmdName: "mycommand", wantOk: true},
{name: "by alias mc", cmdName: "mc", wantOk: true},
{name: "by alias m", cmdName: "m", wantOk: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := r.Execute(&Context{}, tt.cmdName, nil)
if tt.wantOk && result.Error != "" {
t.Errorf("unexpected error: %s", result.Error)
}
if tt.wantOk && result.Text != "alias works" {
t.Errorf("result text = %q, want %q", result.Text, "alias works")
}
})
}
}
func TestRegistry_All(t *testing.T) {
r := NewRegistry()
names := []string{"alpha", "beta", "gamma"}
for _, name := range names {
n := name // capture
r.Register(&Command{
Name: n,
Handler: func(_ *Context, _ []string) Result { return Result{} },
})
}
all := r.All()
if len(all) != len(names) {
t.Fatalf("expected %d commands, got %d", len(names), len(all))
}
for i, cmd := range all {
if cmd.Name != names[i] {
t.Errorf("All()[%d].Name = %q, want %q", i, cmd.Name, names[i])
}
}
}
func TestRegistry_Match(t *testing.T) {
r := NewRegistry()
r.Register(&Command{
Name: "model",
Aliases: []string{"m"},
Handler: func(_ *Context, _ []string) Result { return Result{} },
})
r.Register(&Command{
Name: "models",
Aliases: []string{"ml"},
Handler: func(_ *Context, _ []string) Result { return Result{} },
})
r.Register(&Command{
Name: "help",
Handler: func(_ *Context, _ []string) Result { return Result{} },
})
tests := []struct {
name string
prefix string
want int
}{
{name: "prefix mo matches model and models", prefix: "mo", want: 2},
{name: "prefix model matches model and models", prefix: "model", want: 2},
{name: "prefix models matches only models", prefix: "models", want: 1},
{name: "prefix h matches help", prefix: "h", want: 1},
{name: "no match", prefix: "z", want: 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := r.Match(tt.prefix)
if len(matches) != tt.want {
t.Errorf("Match(%q) returned %d results, want %d", tt.prefix, len(matches), tt.want)
}
})
}
// Verify no duplicates from aliases
t.Run("aliases dont create dupes", func(t *testing.T) {
matches := r.Match("model")
seen := make(map[string]bool)
for _, m := range matches {
if seen[m.Name] {
t.Errorf("duplicate match for %q", m.Name)
}
seen[m.Name] = true
}
})
}

366
internal/config/agents.go Normal file
View File

@ -0,0 +1,366 @@
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
type AgentsDir struct {
Path string
Agents map[string]AgentProfile
MCPServers []ServerConfig
GlobalInstructions string
Skills []SkillDef
}
type AgentProfile struct {
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Model string `yaml:"model" json:"model"`
Skills []string `yaml:"skills" json:"skills"`
MCPServers []string `yaml:"mcp_servers" json:"mcp_servers"`
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
UseCases []string `yaml:"use_cases" json:"use_cases"`
}
type SkillDef struct {
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Path string `yaml:"path" json:"path"`
}
type MCPConfig struct {
Servers []ServerConfig `json:"servers,omitempty"`
}
type ModelsConfig struct {
Models []Model `yaml:"models,omitempty"`
DefaultModel string `yaml:"default_model,omitempty"`
FallbackChain []string `yaml:"fallback_chain,omitempty"`
AutoSelect bool `yaml:"auto_select,omitempty"`
EmbedModel string `yaml:"embed_model,omitempty"`
}
func FindAgentsDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
candidates := []string{
filepath.Join(home, ".agents"),
filepath.Join(home, ".config", "agents"),
}
for _, dir := range candidates {
if _, err := os.Stat(dir); err == nil {
return dir
}
}
return ""
}
func FindAgentsDirWithCreate() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("get home dir: %w", err)
}
dirs := []string{
filepath.Join(home, ".agents"),
filepath.Join(home, ".config", "agents"),
}
for _, dir := range dirs {
if _, err := os.Stat(dir); err == nil {
return dir, nil
}
}
if err := os.MkdirAll(dirs[0], 0755); err != nil {
return "", fmt.Errorf("create agents dir: %w", err)
}
return dirs[0], nil
}
func LoadAgentsDir(path string) (*AgentsDir, error) {
if path == "" {
path = FindAgentsDir()
if path == "" {
return &AgentsDir{
Path: "",
Agents: make(map[string]AgentProfile),
}, nil
}
}
dir := &AgentsDir{
Path: path,
Agents: make(map[string]AgentProfile),
}
if err := dir.loadAgents(path); err != nil {
return nil, fmt.Errorf("load agents: %w", err)
}
if err := dir.loadMCP(path); err != nil {
return nil, fmt.Errorf("load MCP: %w", err)
}
if err := dir.loadGlobalInstructions(path); err != nil {
return nil, fmt.Errorf("load instructions: %w", err)
}
if err := dir.loadSkills(path); err != nil {
return nil, fmt.Errorf("load skills: %w", err)
}
return dir, nil
}
func (d *AgentsDir) loadAgents(path string) error {
agentsDir := filepath.Join(path, "agents")
entries, err := os.ReadDir(agentsDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
for _, entry := range entries {
if entry.IsDir() {
agentPath := filepath.Join(agentsDir, entry.Name(), "agent.yaml")
if _, err := os.Stat(agentPath); err != nil {
agentPath = filepath.Join(agentsDir, entry.Name(), "agent.md")
}
if _, err := os.Stat(agentPath); err != nil {
continue
}
data, err := os.ReadFile(agentPath)
if err != nil {
continue
}
var profile AgentProfile
if err := yaml.Unmarshal(data, &profile); err != nil {
continue
}
if profile.Name == "" {
profile.Name = entry.Name()
}
d.Agents[profile.Name] = profile
}
}
return nil
}
func (d *AgentsDir) loadMCP(path string) error {
mcpPath := filepath.Join(path, "mcp.json")
data, err := os.ReadFile(mcpPath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var mcpCfg MCPConfig
if err := json.Unmarshal(data, &mcpCfg); err != nil {
return fmt.Errorf("parse mcp.json: %w", err)
}
d.MCPServers = mcpCfg.Servers
return nil
}
func (d *AgentsDir) loadGlobalInstructions(path string) error {
paths := []string{
filepath.Join(path, "agents.md"),
filepath.Join(path, "instructions.md"),
}
for _, p := range paths {
data, err := os.ReadFile(p)
if err == nil {
d.GlobalInstructions = string(data)
return nil
}
}
return nil
}
func (d *AgentsDir) loadSkills(path string) error {
skillsDir := filepath.Join(path, "skills")
entries, err := os.ReadDir(skillsDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
skillDir := filepath.Join(skillsDir, entry.Name())
// Try both SKILL.md and skill.md (case insensitive check)
skillPath := ""
for _, name := range []string{"SKILL.md", "skill.md"} {
path := filepath.Join(skillDir, name)
if info, err := os.Stat(path); err == nil && !info.IsDir() {
skillPath = path
break
}
}
if skillPath == "" {
continue
}
data, err := os.ReadFile(skillPath)
if err != nil {
continue
}
d.Skills = append(d.Skills, SkillDef{
Name: entry.Name(),
Description: extractDescription(string(data)),
Path: skillPath,
})
}
return nil
}
func extractDescription(content string) string {
for _, line := range splitLines(content) {
line = trimWhitespace(line)
if line == "" || startsWith(line, "#") {
continue
}
return line
}
return ""
}
func splitLines(s string) []string {
var lines []string
start := 0
for i, r := range s {
if r == '\n' {
lines = append(lines, s[start:i])
start = i + 1
}
}
lines = append(lines, s[start:])
return lines
}
func trimWhitespace(s string) string {
start := 0
end := len(s)
for start < end && (s[start] == ' ' || s[start] == '\t') {
start++
}
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
end--
}
return s[start:end]
}
func startsWith(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
func (d *AgentsDir) GetAgent(name string) *AgentProfile {
if agent, ok := d.Agents[name]; ok {
return &agent
}
return nil
}
func (d *AgentsDir) ListAgents() []AgentProfile {
agents := make([]AgentProfile, 0, len(d.Agents))
for _, agent := range d.Agents {
agents = append(agents, agent)
}
return agents
}
func (d *AgentsDir) GetSkills() []SkillDef {
return d.Skills
}
func (d *AgentsDir) HasMCP() bool {
return len(d.MCPServers) > 0
}
func (d *AgentsDir) GetMCPServers() []ServerConfig {
return d.MCPServers
}
func (d *AgentsDir) GetGlobalInstructions() string {
return d.GlobalInstructions
}
func CreateDefaultAgentsDir() error {
dir, err := FindAgentsDirWithCreate()
if err != nil {
return err
}
subdirs := []string{"agents", "skills", "tasks", "memories"}
for _, sub := range subdirs {
path := filepath.Join(dir, sub)
if _, err := os.Stat(path); err != nil {
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("create %s: %w", sub, err)
}
}
}
mcpPath := filepath.Join(dir, "mcp.json")
if _, err := os.Stat(mcpPath); err != nil {
defaultMCP := MCPConfig{
Servers: []ServerConfig{},
}
data, _ := json.MarshalIndent(defaultMCP, "", " ")
if err := os.WriteFile(mcpPath, data, 0644); err != nil {
return fmt.Errorf("write mcp.json: %w", err)
}
}
agentsPath := filepath.Join(dir, "agents.md")
if _, err := os.Stat(agentsPath); err != nil {
defaultContent := `# Global Agent Instructions
You are a helpful local AI coding assistant.
## Guidelines
- Be concise and direct
- Explain your reasoning
- Ask for clarification when needed
- Never fabricate information
`
if err := os.WriteFile(agentsPath, []byte(defaultContent), 0644); err != nil {
return fmt.Errorf("write agents.md: %w", err)
}
}
return nil
}

View File

@ -0,0 +1,176 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestExtractDescription(t *testing.T) {
tests := []struct {
name string
content string
want string
}{
{
name: "first non-header non-empty line",
content: "# Title\n\nThis is the description.\nMore text.",
want: "This is the description.",
},
{
name: "header only content",
content: "# Title\n## Subtitle\n### Another",
want: "",
},
{
name: "empty content",
content: "",
want: "",
},
{
name: "whitespace around description",
content: "# Title\n\n Indented description \n",
want: "Indented description",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractDescription(tt.content)
if got != tt.want {
t.Errorf("extractDescription() = %q, want %q", got, tt.want)
}
})
}
}
func TestSplitLines(t *testing.T) {
tests := []struct {
name string
s string
want int // expected number of lines
}{
{name: "normal lines", s: "a\nb\nc", want: 3},
{name: "empty string", s: "", want: 1},
{name: "trailing newline", s: "a\nb\n", want: 3},
{name: "single line", s: "hello", want: 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitLines(tt.s)
if len(got) != tt.want {
t.Errorf("splitLines(%q) returned %d lines, want %d (lines: %v)", tt.s, len(got), tt.want, got)
}
})
}
}
func TestTrimWhitespace(t *testing.T) {
tests := []struct {
name string
s string
want string
}{
{name: "tabs", s: "\thello\t", want: "hello"},
{name: "spaces", s: " hello ", want: "hello"},
{name: "mixed", s: "\t hello \t", want: "hello"},
{name: "already trimmed", s: "hello", want: "hello"},
{name: "empty", s: "", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := trimWhitespace(tt.s)
if got != tt.want {
t.Errorf("trimWhitespace(%q) = %q, want %q", tt.s, got, tt.want)
}
})
}
}
func TestStartsWith(t *testing.T) {
tests := []struct {
name string
s string
prefix string
want bool
}{
{name: "match", s: "hello world", prefix: "hello", want: true},
{name: "no match", s: "hello world", prefix: "world", want: false},
{name: "empty prefix", s: "hello", prefix: "", want: true},
{name: "longer prefix", s: "hi", prefix: "hello", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := startsWith(tt.s, tt.prefix)
if got != tt.want {
t.Errorf("startsWith(%q, %q) = %v, want %v", tt.s, tt.prefix, got, tt.want)
}
})
}
}
func TestLoadAgentsDir(t *testing.T) {
t.Run("valid temp structure with agent", func(t *testing.T) {
tmp := t.TempDir()
// Create agents/test-agent/agent.yaml
agentDir := filepath.Join(tmp, "agents", "test-agent")
if err := os.MkdirAll(agentDir, 0755); err != nil {
t.Fatal(err)
}
agentYAML := `name: test-agent
description: A test agent
model: qwen3.5:0.8b
`
if err := os.WriteFile(filepath.Join(agentDir, "agent.yaml"), []byte(agentYAML), 0644); err != nil {
t.Fatal(err)
}
dir, err := LoadAgentsDir(tmp)
if err != nil {
t.Fatalf("LoadAgentsDir() error: %v", err)
}
if dir.Path != tmp {
t.Errorf("Path = %q, want %q", dir.Path, tmp)
}
if len(dir.Agents) != 1 {
t.Errorf("expected 1 agent, got %d", len(dir.Agents))
}
agent, ok := dir.Agents["test-agent"]
if !ok {
t.Fatal("expected agent 'test-agent' to exist")
}
if agent.Description != "A test agent" {
t.Errorf("agent description = %q, want %q", agent.Description, "A test agent")
}
})
t.Run("empty path uses FindAgentsDir", func(t *testing.T) {
dir, err := LoadAgentsDir("")
if err != nil {
t.Fatalf("LoadAgentsDir('') error: %v", err)
}
// Should return a valid AgentsDir (possibly with no agents)
if dir == nil {
t.Fatal("expected non-nil AgentsDir")
}
if dir.Agents == nil {
t.Error("expected Agents map to be initialized")
}
})
t.Run("nonexistent subdirs dont error", func(t *testing.T) {
tmp := t.TempDir()
// Empty temp dir — no agents/, skills/, mcp.json, etc.
dir, err := LoadAgentsDir(tmp)
if err != nil {
t.Fatalf("LoadAgentsDir() error: %v", err)
}
if len(dir.Agents) != 0 {
t.Errorf("expected 0 agents, got %d", len(dir.Agents))
}
})
}

186
internal/config/config.go Normal file
View File

@ -0,0 +1,186 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strconv"
"gopkg.in/yaml.v3"
)
type Config struct {
Ollama OllamaConfig `yaml:"ollama"`
Model ModelConfig `yaml:"model,omitempty"`
Agents AgentsConfig `yaml:"agents,omitempty"`
Servers []ServerConfig `yaml:"servers,omitempty"`
SkillsDir string `yaml:"skills_dir,omitempty"`
ICE ICEConfig `yaml:"ice,omitempty"`
AgentProfile string `yaml:"agent_profile,omitempty"`
Tools ToolsConfig `yaml:"tools,omitempty"`
}
type AgentsConfig struct {
Dir string `yaml:"dir,omitempty"`
AutoLoad bool `yaml:"auto_load"`
}
type ToolsConfig struct {
Timeout string `yaml:"timeout,omitempty"` // e.g., "30s", "2m"
MaxGrepResults int `yaml:"max_grep_results,omitempty"`
MaxIterations int `yaml:"max_iterations,omitempty"`
}
type ICEConfig struct {
Enabled bool `yaml:"enabled"`
EmbedModel string `yaml:"embed_model,omitempty"`
StorePath string `yaml:"store_path,omitempty"`
}
type OllamaConfig struct {
Model string `yaml:"model"`
BaseURL string `yaml:"base_url"`
NumCtx int `yaml:"num_ctx"`
}
type ServerConfig struct {
Name string `yaml:"name"`
Command string `yaml:"command,omitempty"`
Args []string `yaml:"args,omitempty"`
Env []string `yaml:"env,omitempty"`
Transport string `yaml:"transport,omitempty"`
URL string `yaml:"url,omitempty"`
}
func defaults() Config {
modelCfg := DefaultModelConfig()
return Config{
Ollama: OllamaConfig{
Model: "qwen3.5:2b",
BaseURL: "http://localhost:11434",
NumCtx: 262144,
},
Model: modelCfg,
Agents: AgentsConfig{
Dir: "",
AutoLoad: true,
},
Tools: ToolsConfig{
Timeout: "30s",
MaxGrepResults: 500,
MaxIterations: 10,
},
}
}
func Load() (*Config, error) {
cfg := defaults()
localPath := findConfigFile()
if localPath != "" {
data, err := os.ReadFile(localPath)
if err != nil {
return nil, fmt.Errorf("read config %s: %w", localPath, err)
}
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config %s: %w", localPath, err)
}
}
agentsDir := cfg.Agents.Dir
if agentsDir == "" {
agentsDir = FindAgentsDir()
}
var agentsData *AgentsDir
if agentsDir != "" && cfg.Agents.AutoLoad {
var err error
agentsData, err = LoadAgentsDir(agentsDir)
if err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to load .agents directory: %v\n", err)
} else {
if agentsData != nil {
if cfg.Ollama.Model == "" {
cfg.Ollama.Model = cfg.Model.DefaultModel
}
if len(cfg.Servers) == 0 && agentsData.HasMCP() {
cfg.Servers = agentsData.GetMCPServers()
}
}
}
}
applyEnvOverrides(&cfg)
return &cfg, nil
}
func LoadWithAgentsDir() (*Config, *AgentsDir, error) {
cfg, err := Load()
if err != nil {
return nil, nil, err
}
agentsDir := cfg.Agents.Dir
if agentsDir == "" {
agentsDir = FindAgentsDir()
}
var agents *AgentsDir
if agentsDir != "" && cfg.Agents.AutoLoad {
agents, _ = LoadAgentsDir(agentsDir)
}
return cfg, agents, nil
}
func findConfigFile() string {
candidates := []string{
"ai-agent.yaml",
"ai-agent.yml",
"config.yaml",
"config.yml",
}
if home, err := os.UserHomeDir(); err == nil {
candidates = append(candidates,
filepath.Join(home, ".config", "ai-agent", "config.yaml"),
filepath.Join(home, ".config", "ai-agent", "config.yml"),
)
}
for _, path := range candidates {
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func applyEnvOverrides(cfg *Config) {
if v := os.Getenv("OLLAMA_HOST"); v != "" {
cfg.Ollama.BaseURL = v
}
if v := os.Getenv("LOCAL_AGENT_MODEL"); v != "" {
cfg.Ollama.Model = v
}
if v := os.Getenv("LOCAL_AGENT_AGENTS_DIR"); v != "" {
cfg.Agents.Dir = v
}
if v := os.Getenv("LOCAL_AGENT_TOOLS_TIMEOUT"); v != "" {
cfg.Tools.Timeout = v
}
if v := os.Getenv("LOCAL_AGENT_TOOLS_MAX_GREP"); v != "" {
cfg.Tools.MaxGrepResults = parseEnvInt(v, cfg.Tools.MaxGrepResults)
}
if v := os.Getenv("LOCAL_AGENT_TOOLS_MAX_ITER"); v != "" {
cfg.Tools.MaxIterations = parseEnvInt(v, cfg.Tools.MaxIterations)
}
if v := os.Getenv("LOCAL_AGENT_ICE_EMBED_MODEL"); v != "" {
cfg.ICE.EmbedModel = v
}
}
func parseEnvInt(v string, defaultVal int) int {
if i, err := strconv.Atoi(v); err == nil {
return i
}
return defaultVal
}

View File

@ -0,0 +1,82 @@
package config
import "testing"
func TestDefaults(t *testing.T) {
cfg := defaults()
tests := []struct {
name string
got string
want string
}{
{name: "Ollama.Model", got: cfg.Ollama.Model, want: "qwen3.5:2b"},
{name: "Ollama.BaseURL", got: cfg.Ollama.BaseURL, want: "http://localhost:11434"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.want {
t.Errorf("%s = %q, want %q", tt.name, tt.got, tt.want)
}
})
}
if cfg.Ollama.NumCtx != 262144 {
t.Errorf("Ollama.NumCtx = %d, want %d", cfg.Ollama.NumCtx, 262144)
}
if !cfg.Model.AutoSelect {
t.Error("Model.AutoSelect should be true by default")
}
}
func TestApplyEnvOverrides(t *testing.T) {
tests := []struct {
name string
envKey string
envVal string
checkFn func(cfg *Config) string
want string
}{
{
name: "OLLAMA_HOST overrides BaseURL",
envKey: "OLLAMA_HOST",
envVal: "http://custom:1234",
checkFn: func(cfg *Config) string {
return cfg.Ollama.BaseURL
},
want: "http://custom:1234",
},
{
name: "LOCAL_AGENT_MODEL overrides Model",
envKey: "LOCAL_AGENT_MODEL",
envVal: "custom-model",
checkFn: func(cfg *Config) string {
return cfg.Ollama.Model
},
want: "custom-model",
},
{
name: "LOCAL_AGENT_AGENTS_DIR overrides AgentsDir",
envKey: "LOCAL_AGENT_AGENTS_DIR",
envVal: "/custom/agents",
checkFn: func(cfg *Config) string {
return cfg.Agents.Dir
},
want: "/custom/agents",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv(tt.envKey, tt.envVal)
cfg := defaults()
applyEnvOverrides(&cfg)
got := tt.checkFn(&cfg)
if got != tt.want {
t.Errorf("after setting %s=%q, got %q, want %q", tt.envKey, tt.envVal, got, tt.want)
}
})
}
}

103
internal/config/ignore.go Normal file
View File

@ -0,0 +1,103 @@
package config
import (
"bufio"
"os"
"path/filepath"
"strings"
)
// IgnorePatterns holds parsed .agentignore patterns.
type IgnorePatterns struct {
patterns []string
raw string // original file content for injection into system prompt
}
// LoadIgnoreFile reads and parses an .agentignore file from the given directory.
// Returns nil if no .agentignore file exists (not an error).
func LoadIgnoreFile(dir string) *IgnorePatterns {
path := filepath.Join(dir, ".agentignore")
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var patterns []string
var rawLines []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
rawLines = append(rawLines, line)
trimmed := strings.TrimSpace(line)
// Skip empty lines and comments.
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
patterns = append(patterns, trimmed)
}
return &IgnorePatterns{
patterns: patterns,
raw: strings.Join(rawLines, "\n"),
}
}
// Match returns true if the given path should be ignored.
// Returns false if the receiver is nil.
func (ip *IgnorePatterns) Match(path string) bool {
if ip == nil || len(ip.patterns) == 0 {
return false
}
// Normalise the path separators and remove trailing slashes for comparison.
path = filepath.ToSlash(path)
cleanPath := strings.TrimSuffix(path, "/")
for _, pattern := range ip.patterns {
pat := strings.TrimSuffix(pattern, "/")
// Check each component of the path against the pattern.
// e.g. "node_modules" should match "node_modules", "node_modules/foo",
// and "src/node_modules/bar".
parts := strings.Split(cleanPath, "/")
for _, part := range parts {
if matched, _ := filepath.Match(pat, part); matched {
return true
}
}
// Also try matching the full path with the pattern (for glob patterns
// that include path separators like "build/output").
if matched, _ := filepath.Match(pat, cleanPath); matched {
return true
}
// Prefix match: path starts with the pattern directory.
if strings.HasPrefix(cleanPath, pat+"/") || cleanPath == pat {
return true
}
}
return false
}
// Raw returns the raw file content for system prompt injection.
// Returns an empty string if the receiver is nil.
func (ip *IgnorePatterns) Raw() string {
if ip == nil {
return ""
}
return ip.raw
}
// Patterns returns the list of patterns.
// Returns nil if the receiver is nil.
func (ip *IgnorePatterns) Patterns() []string {
if ip == nil {
return nil
}
return ip.patterns
}

View File

@ -0,0 +1,183 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadIgnoreFile_Valid(t *testing.T) {
dir := t.TempDir()
content := `# Build artifacts
node_modules
*.log
.git
build/
dist/
vendor/
`
if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(content), 0o644); err != nil {
t.Fatal(err)
}
ip := LoadIgnoreFile(dir)
if ip == nil {
t.Fatal("expected non-nil IgnorePatterns")
}
wantPatterns := []string{"node_modules", "*.log", ".git", "build/", "dist/", "vendor/"}
if len(ip.Patterns()) != len(wantPatterns) {
t.Fatalf("got %d patterns, want %d", len(ip.Patterns()), len(wantPatterns))
}
for i, p := range ip.Patterns() {
if p != wantPatterns[i] {
t.Errorf("pattern[%d] = %q, want %q", i, p, wantPatterns[i])
}
}
if ip.Raw() != content[:len(content)-1] { // raw joins lines without trailing newline from Join
// Just check it contains the comment and patterns
if ip.Raw() == "" {
t.Error("Raw() should not be empty")
}
}
}
func TestLoadIgnoreFile_Missing(t *testing.T) {
dir := t.TempDir()
ip := LoadIgnoreFile(dir)
if ip != nil {
t.Error("expected nil for missing .agentignore")
}
}
func TestLoadIgnoreFile_Empty(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(""), 0o644); err != nil {
t.Fatal(err)
}
ip := LoadIgnoreFile(dir)
if ip == nil {
t.Fatal("expected non-nil IgnorePatterns for empty file")
}
if len(ip.Patterns()) != 0 {
t.Errorf("expected 0 patterns, got %d", len(ip.Patterns()))
}
}
func TestLoadIgnoreFile_CommentsOnly(t *testing.T) {
dir := t.TempDir()
content := "# This is a comment\n# Another comment\n\n"
if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(content), 0o644); err != nil {
t.Fatal(err)
}
ip := LoadIgnoreFile(dir)
if ip == nil {
t.Fatal("expected non-nil IgnorePatterns")
}
if len(ip.Patterns()) != 0 {
t.Errorf("expected 0 patterns for comments-only file, got %d", len(ip.Patterns()))
}
}
func TestIgnorePatterns_Match_Exact(t *testing.T) {
ip := &IgnorePatterns{
patterns: []string{"node_modules", ".git", "vendor"},
}
tests := []struct {
path string
want bool
}{
{"node_modules", true},
{"node_modules/package/index.js", true},
{".git", true},
{".git/config", true},
{"vendor", true},
{"vendor/lib/foo.go", true},
{"src/main.go", false},
{"README.md", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := ip.Match(tt.path); got != tt.want {
t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func TestIgnorePatterns_Match_Glob(t *testing.T) {
ip := &IgnorePatterns{
patterns: []string{"*.log", "*.tmp"},
}
tests := []struct {
path string
want bool
}{
{"app.log", true},
{"debug.log", true},
{"temp.tmp", true},
{"logs/app.log", true},
{"main.go", false},
{"log.txt", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := ip.Match(tt.path); got != tt.want {
t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func TestIgnorePatterns_Match_DirectoryPattern(t *testing.T) {
ip := &IgnorePatterns{
patterns: []string{"build/", "dist/"},
}
tests := []struct {
path string
want bool
}{
{"build", true},
{"build/output.js", true},
{"dist", true},
{"dist/bundle.js", true},
{"src/build.go", false},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := ip.Match(tt.path); got != tt.want {
t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func TestIgnorePatterns_Match_NilReceiver(t *testing.T) {
var ip *IgnorePatterns
if ip.Match("anything") {
t.Error("nil IgnorePatterns should not match anything")
}
}
func TestIgnorePatterns_Raw_NilReceiver(t *testing.T) {
var ip *IgnorePatterns
if ip.Raw() != "" {
t.Error("nil IgnorePatterns Raw() should return empty string")
}
}
func TestIgnorePatterns_Patterns_NilReceiver(t *testing.T) {
var ip *IgnorePatterns
if ip.Patterns() != nil {
t.Error("nil IgnorePatterns Patterns() should return nil")
}
}

167
internal/config/models.go Normal file
View File

@ -0,0 +1,167 @@
package config
import "fmt"
type ModelFamily string
const (
FamilyQwen3 ModelFamily = "qwen3"
FamilyQwen35 ModelFamily = "qwen3.5"
FamilyLlama ModelFamily = "llama"
FamilyMistral ModelFamily = "mistral"
)
type ModelCapability int
const (
CapabilitySimple ModelCapability = iota
CapabilityMedium
CapabilityComplex
CapabilityAdvanced
)
type Model struct {
Name string `yaml:"name"`
Family ModelFamily `yaml:"family"`
DisplayName string `yaml:"display_name"`
Size string `yaml:"size"`
Parameters string `yaml:"parameters"`
ContextSize int `yaml:"context_size"`
Capability ModelCapability `yaml:"capability"`
Speed float64 `yaml:"speed"` // 1.0 = baseline
UseCases []string `yaml:"use_cases"`
Description string `yaml:"description"`
Default bool `yaml:"default,omitempty"`
}
type ModelConfig struct {
Models []Model `yaml:"models"`
DefaultModel string `yaml:"default_model"`
FallbackChain []string `yaml:"fallback_chain"`
AutoSelect bool `yaml:"auto_select"`
EmbedModel string `yaml:"embed_model,omitempty"`
}
func DefaultModels() []Model {
return []Model{
{
Name: "qwen3.5:0.8b",
Family: FamilyQwen35,
DisplayName: "Qwen 3.5 0.8B",
Size: "0.8B",
Parameters: "0.8 billion",
ContextSize: 262144,
Capability: CapabilitySimple,
Speed: 4.0,
UseCases: []string{"quick_answers", "simple_tools", "single_file_edits"},
Description: "Fast, lightweight model for simple tasks and quick answers",
Default: false,
},
{
Name: "qwen3.5:2b",
Family: FamilyQwen35,
DisplayName: "Qwen 3.5 2B",
Size: "2B",
Parameters: "2 billion",
ContextSize: 262144,
Capability: CapabilityMedium,
Speed: 2.5,
UseCases: []string{"code_completion", "simple_refactoring", "explanations"},
Description: "Balanced model for medium complexity tasks",
Default: true,
},
{
Name: "qwen3.5:4b",
Family: FamilyQwen35,
DisplayName: "Qwen 3.5 4B",
Size: "4B",
Parameters: "4 billion",
ContextSize: 262144,
Capability: CapabilityComplex,
Speed: 1.5,
UseCases: []string{"multi_step_reasoning", "code_review", "debugging", "refactoring"},
Description: "Capable model for complex reasoning and code analysis",
Default: false,
},
{
Name: "qwen3.5:9b",
Family: FamilyQwen35,
DisplayName: "Qwen 3.5 9B",
Size: "9B",
Parameters: "9 billion",
ContextSize: 262144,
Capability: CapabilityAdvanced,
Speed: 1.0,
UseCases: []string{"complex_reasoning", "architecture", "full_stack", "advanced_debugging"},
Description: "Full capability model for advanced tasks",
Default: false,
},
}
}
func DefaultModelConfig() ModelConfig {
models := DefaultModels()
return ModelConfig{
Models: models,
DefaultModel: "qwen3.5:2b",
FallbackChain: []string{"qwen3.5:2b", "qwen3.5:0.8b", "qwen3.5:4b", "qwen3.5:9b"},
AutoSelect: true,
EmbedModel: "nomic-embed-text",
}
}
func (m *Model) IsSimpleTask() bool {
return m.Capability <= CapabilityMedium
}
func (m *Model) IsComplexTask() bool {
return m.Capability >= CapabilityComplex
}
func (mc *ModelConfig) GetModel(name string) (*Model, error) {
for _, m := range mc.Models {
if m.Name == name {
return &m, nil
}
}
return nil, fmt.Errorf("model not found: %s", name)
}
func (mc *ModelConfig) GetDefaultModel() *Model {
for _, m := range mc.Models {
if m.Default {
return &m
}
}
if len(mc.Models) > 0 {
return &mc.Models[len(mc.Models)-1]
}
return nil
}
func (mc *ModelConfig) SelectModelForTask(taskComplexity string) string {
if !mc.AutoSelect {
return mc.DefaultModel
}
switch taskComplexity {
case "simple":
return mc.Models[0].Name
case "medium":
for _, m := range mc.Models {
if m.Capability == CapabilityMedium {
return m.Name
}
}
case "complex":
for _, m := range mc.Models {
if m.Capability == CapabilityComplex {
return m.Name
}
}
case "advanced":
return mc.DefaultModel
}
return mc.DefaultModel
}

View File

@ -0,0 +1,159 @@
package config
import "testing"
func TestModel_IsSimpleTask(t *testing.T) {
tests := []struct {
name string
capability ModelCapability
want bool
}{
{name: "CapabilitySimple is simple", capability: CapabilitySimple, want: true},
{name: "CapabilityMedium is simple", capability: CapabilityMedium, want: true},
{name: "CapabilityComplex is not simple", capability: CapabilityComplex, want: false},
{name: "CapabilityAdvanced is not simple", capability: CapabilityAdvanced, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &Model{Capability: tt.capability}
if got := m.IsSimpleTask(); got != tt.want {
t.Errorf("Model{Capability: %d}.IsSimpleTask() = %v, want %v", tt.capability, got, tt.want)
}
})
}
}
func TestModel_IsComplexTask(t *testing.T) {
tests := []struct {
name string
capability ModelCapability
want bool
}{
{name: "CapabilitySimple is not complex", capability: CapabilitySimple, want: false},
{name: "CapabilityMedium is not complex", capability: CapabilityMedium, want: false},
{name: "CapabilityComplex is complex", capability: CapabilityComplex, want: true},
{name: "CapabilityAdvanced is complex", capability: CapabilityAdvanced, want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &Model{Capability: tt.capability}
if got := m.IsComplexTask(); got != tt.want {
t.Errorf("Model{Capability: %d}.IsComplexTask() = %v, want %v", tt.capability, got, tt.want)
}
})
}
}
func TestModelConfig_GetModel(t *testing.T) {
cfg := DefaultModelConfig()
tests := []struct {
name string
model string
wantErr bool
}{
{name: "found model", model: "qwen3.5:0.8b", wantErr: false},
{name: "not found", model: "nonexistent", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := cfg.GetModel(tt.model)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got.Name != tt.model {
t.Errorf("GetModel(%q).Name = %q, want %q", tt.model, got.Name, tt.model)
}
}
})
}
}
func TestModelConfig_GetDefaultModel(t *testing.T) {
tests := []struct {
name string
cfg ModelConfig
want string // empty means nil expected
}{
{
name: "model with Default=true",
cfg: ModelConfig{
Models: []Model{
{Name: "a", Default: false},
{Name: "b", Default: true},
{Name: "c", Default: false},
},
},
want: "b",
},
{
name: "no default returns last",
cfg: ModelConfig{
Models: []Model{
{Name: "a", Default: false},
{Name: "b", Default: false},
},
},
want: "b",
},
{
name: "empty slice returns nil",
cfg: ModelConfig{Models: []Model{}},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.cfg.GetDefaultModel()
if tt.want == "" {
if got != nil {
t.Errorf("expected nil, got %+v", got)
}
} else {
if got == nil {
t.Fatal("expected non-nil model, got nil")
}
if got.Name != tt.want {
t.Errorf("GetDefaultModel().Name = %q, want %q", got.Name, tt.want)
}
}
})
}
}
func TestModelConfig_SelectModelForTask(t *testing.T) {
cfg := DefaultModelConfig()
tests := []struct {
name string
complexity string
autoSelect bool
want string
}{
{name: "auto simple", complexity: "simple", autoSelect: true, want: "qwen3.5:0.8b"},
{name: "auto medium", complexity: "medium", autoSelect: true, want: "qwen3.5:2b"},
{name: "auto complex", complexity: "complex", autoSelect: true, want: "qwen3.5:4b"},
{name: "auto advanced", complexity: "advanced", autoSelect: true, want: cfg.DefaultModel},
{name: "no autoselect simple", complexity: "simple", autoSelect: false, want: cfg.DefaultModel},
{name: "no autoselect complex", complexity: "complex", autoSelect: false, want: cfg.DefaultModel},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg.AutoSelect = tt.autoSelect
got := cfg.SelectModelForTask(tt.complexity)
if got != tt.want {
t.Errorf("SelectModelForTask(%q) = %q, want %q", tt.complexity, got, tt.want)
}
})
}
}

View File

@ -0,0 +1,364 @@
package config
import (
"context"
"strings"
"sync"
"time"
)
type QwenModelRouter struct {
config *ModelConfig
overrideLog []ModelOverride
modeContext ModeContext
mu sync.RWMutex
}
type ModeContext int
const (
ModeAskContext ModeContext = iota
ModePlanContext
ModeBuildContext
)
type QwenComplexity string
const (
QwenTrivial QwenComplexity = "trivial"
QwenSimple QwenComplexity = "simple"
QwenModerate QwenComplexity = "moderate"
QwenAdvanced QwenComplexity = "advanced"
)
var (
qwenTrivialIndicators = []string{
"what is", "who is", "when is", "where is",
"define", "meaning of", "synonym", "antonym",
"list files", "show me", "display",
"yes", "no", "ok", "thanks",
"hello", "hi", "hey",
}
qwenSimpleIndicators = []string{
"how do i", "explain", "what does", "why does",
"find", "search", "get", "read",
"print", "echo", "cat", "ls", "grep",
"simple", "quick", "fast", "brief",
"check", "verify", "test",
"create file", "write file", "save",
}
qwenModerateIndicators = []string{
"create", "generate", "add", "modify", "update",
"fix", "debug", "refactor", "optimize",
"function", "class", "method", "interface",
"test", "unit test", "integration test",
"script", "command", "pipeline",
"compare", "analyze", "review",
"multiple", "several", "across",
}
qwenAdvancedIndicators = []string{
"architecture", "design pattern", "system design",
"infrastructure", "deployment", "scaling",
"security audit", "performance optimization",
"multi-step", "complex", "comprehensive",
"build a", "implement", "develop", "engineer",
"full stack", "end-to-end", "production",
"migration", "refactor entire", "rewrite",
}
qwenCodePatterns = map[string]QwenComplexity{
"variable": QwenSimple,
"constant": QwenSimple,
"function": QwenSimple,
"loop": QwenSimple,
"condition": QwenSimple,
"array": QwenSimple,
"slice": QwenSimple,
"map": QwenSimple,
"struct": QwenModerate,
"interface": QwenModerate,
"generics": QwenModerate,
"concurrency": QwenModerate,
"goroutine": QwenModerate,
"channel": QwenModerate,
"mutex": QwenModerate,
"architecture": QwenAdvanced,
"pattern": QwenAdvanced,
"microservice": QwenAdvanced,
"distributed": QwenAdvanced,
"kubernetes": QwenAdvanced,
}
)
func NewQwenModelRouter(cfg *ModelConfig) *QwenModelRouter {
return &QwenModelRouter{
config: cfg,
overrideLog: make([]ModelOverride, 0),
modeContext: ModeAskContext,
}
}
func (r *QwenModelRouter) SetModeContext(mode ModeContext) {
r.mu.Lock()
defer r.mu.Unlock()
r.modeContext = mode
}
func (r *QwenModelRouter) ClassifyTaskComplexity(query string) QwenComplexity {
return classifyQwenTask(query, r.modeContext)
}
func (r *QwenModelRouter) SelectModel(query string) string {
complexity := r.ClassifyTaskComplexity(query)
return r.config.SelectModelForTask(string(complexity))
}
func (r *QwenModelRouter) SelectModelForMode(query string, mode ModeContext) string {
switch mode {
case ModeAskContext:
return r.selectAskModel(query)
case ModePlanContext:
return r.selectPlanModel(query)
case ModeBuildContext:
return r.selectBuildModel(query)
}
return r.SelectModel(query)
}
func (r *QwenModelRouter) selectAskModel(query string) string {
complexity := classifyQwenTask(query, ModeAskContext)
switch complexity {
case QwenTrivial, QwenSimple:
if r.isModelAvailable("qwen3.5:0.8b") {
return "qwen3.5:0.8b"
}
return "qwen3.5:2b"
case QwenModerate:
return "qwen3.5:2b"
case QwenAdvanced:
return "qwen3.5:4b"
default:
return "qwen3.5:2b"
}
}
func (r *QwenModelRouter) selectPlanModel(query string) string {
complexity := classifyQwenTask(query, ModePlanContext)
switch complexity {
case QwenTrivial, QwenSimple:
return "qwen3.5:2b"
case QwenModerate:
return "qwen3.5:4b"
case QwenAdvanced:
return "qwen3.5:9b"
default:
return "qwen3.5:4b"
}
}
func (r *QwenModelRouter) selectBuildModel(query string) string {
complexity := classifyQwenTask(query, ModeBuildContext)
switch complexity {
case QwenTrivial, QwenSimple:
return "qwen3.5:2b"
case QwenModerate:
return "qwen3.5:4b"
case QwenAdvanced:
return "qwen3.5:9b"
default:
return "qwen3.5:4b"
}
}
func (r *QwenModelRouter) isModelAvailable(name string) bool {
for _, m := range r.config.Models {
if m.Name == name {
return true
}
}
return false
}
func classifyQwenTask(query string, mode ModeContext) QwenComplexity {
lowerQuery := strings.ToLower(query)
words := strings.Fields(lowerQuery)
wordCount := len(words)
score := 0
for _, indicator := range qwenTrivialIndicators {
if strings.Contains(lowerQuery, indicator) {
score -= 4
}
}
for _, indicator := range qwenSimpleIndicators {
if strings.Contains(lowerQuery, indicator) {
score -= 1
}
}
for _, indicator := range qwenModerateIndicators {
if strings.Contains(lowerQuery, indicator) {
score += 2
}
}
for _, indicator := range qwenAdvancedIndicators {
if strings.Contains(lowerQuery, indicator) {
score += 4
}
}
for pattern, complexity := range qwenCodePatterns {
if strings.Contains(lowerQuery, pattern) {
switch complexity {
case QwenSimple:
score -= 1
case QwenModerate:
score += 2
case QwenAdvanced:
score += 4
}
}
}
if wordCount > 50 {
score += 3
} else if wordCount > 30 {
score += 1
} else if wordCount < 5 && score <= 0 {
score -= 2
}
if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") {
score += 2
}
if strings.Contains(lowerQuery, "how") && wordCount > 10 {
score += 1
}
if strings.Contains(lowerQuery, "?") && wordCount < 10 {
score -= 1
}
switch mode {
case ModeAskContext:
score -= 1
case ModeBuildContext:
score += 1
}
switch {
case score <= -3:
return QwenTrivial
case score <= 1:
return QwenSimple
case score <= 5:
return QwenModerate
default:
return QwenAdvanced
}
}
func (r *QwenModelRouter) RecordOverride(query, userModel string) {
r.mu.Lock()
defer r.mu.Unlock()
routerModel := r.SelectModel(query)
r.overrideLog = append(r.overrideLog, ModelOverride{
Query: query,
UserModel: userModel,
RouterModel: routerModel,
Timestamp: time.Now(),
})
if len(r.overrideLog) > 100 {
r.overrideLog = r.overrideLog[len(r.overrideLog)-100:]
}
}
func (r *QwenModelRouter) GetLearnedPatterns() map[string]QwenComplexity {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.overrideLog) < 3 {
return nil
}
wordCounts := make(map[string]map[QwenComplexity]int)
for _, o := range r.overrideLog {
if o.Query == "" || o.UserModel == "" {
continue
}
var complexity QwenComplexity
switch {
case strings.Contains(o.UserModel, "0.8b"):
complexity = QwenTrivial
case strings.Contains(o.UserModel, "2b"):
complexity = QwenSimple
case strings.Contains(o.UserModel, "4b"):
complexity = QwenModerate
case strings.Contains(o.UserModel, "9b"):
complexity = QwenAdvanced
default:
continue
}
words := strings.Fields(strings.ToLower(o.Query))
for _, w := range words {
if len(w) < 3 {
continue
}
if _, ok := wordCounts[w]; !ok {
wordCounts[w] = make(map[QwenComplexity]int)
}
wordCounts[w][complexity]++
}
}
wordComplexity := make(map[string]QwenComplexity)
for word, counts := range wordCounts {
var maxCount int
var dominant QwenComplexity
for c, cnt := range counts {
if cnt > maxCount {
maxCount = cnt
dominant = c
}
}
if maxCount >= 2 {
wordComplexity[word] = dominant
}
}
return wordComplexity
}
func (r *QwenModelRouter) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string {
preferred := r.SelectModel(query)
fallbackOrder := []string{
preferred,
"qwen3.5:2b",
"qwen3.5:0.8b",
"qwen3.5:4b",
"qwen3.5:9b",
}
for _, model := range fallbackOrder {
if err := pinger.PingModel(ctx, model); err == nil {
return model
}
}
return r.config.DefaultModel
}
func (r *QwenModelRouter) GetRecommendedModel(query string) (model string, reason string, complexity QwenComplexity) {
r.mu.RLock()
mode := r.modeContext
r.mu.RUnlock()
complexity = classifyQwenTask(query, mode)
switch complexity {
case QwenTrivial:
model = "qwen3.5:0.8b"
reason = "trivial task - ultra-fast response"
case QwenSimple:
model = "qwen3.5:2b"
reason = "simple task - balanced speed/capability"
case QwenModerate:
model = "qwen3.5:4b"
reason = "moderate complexity - multi-step reasoning"
case QwenAdvanced:
model = "qwen3.5:9b"
reason = "advanced task - complex reasoning required"
}
switch mode {
case ModeAskContext:
reason += " (ASK mode - prefer speed)"
case ModePlanContext:
reason += " (PLAN mode - prefer reasoning)"
case ModeBuildContext:
reason += " (BUILD mode - prefer capability)"
}
return model, reason, complexity
}

View File

@ -0,0 +1,254 @@
package config
import (
"testing"
)
func TestQwenRouter_ClassifyTrivial(t *testing.T) {
tests := []struct {
name string
query string
maxComplexity QwenComplexity
}{
{"simple what", "what is go", QwenTrivial},
{"simple who", "who created go", QwenSimple},
{"simple define", "define interface", QwenTrivial},
{"simple greeting", "hello", QwenTrivial},
{"simple thanks", "thanks", QwenTrivial},
{"simple list", "list files", QwenTrivial},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyQwenTask(tt.query, ModeAskContext)
if got > tt.maxComplexity {
t.Errorf("classifyQwenTask(%q) = %v, want <= %v", tt.query, got, tt.maxComplexity)
}
})
}
}
func TestQwenRouter_ClassifySimple(t *testing.T) {
tests := []struct {
name string
query string
}{
{"simple how", "how do i create a file"},
{"simple explain", "explain this code"},
{"simple find", "find all go files"},
{"simple check", "check if file exists"},
{"simple read", "read config file"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyQwenTask(tt.query, ModeAskContext)
t.Logf("%s: %v", tt.query, got)
})
}
}
func TestQwenRouter_ClassifyModerate(t *testing.T) {
tests := []struct {
name string
query string
}{
{"create function", "create a function to parse json"},
{"debug issue", "debug this nil pointer error"},
{"refactor code", "refactor this function"},
{"add test", "add unit tests for handler"},
{"optimize query", "optimize this database query"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyQwenTask(tt.query, ModeBuildContext)
t.Logf("%s: %v", tt.query, got)
})
}
}
func TestQwenRouter_ClassifyAdvanced(t *testing.T) {
tests := []struct {
name string
query string
}{
{"architecture", "design microservice architecture"},
{"system design", "system design for high traffic"},
{"security audit", "security audit of api"},
{"full stack", "build a full stack application"},
{"migration", "migration from mysql to postgres"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyQwenTask(tt.query, ModeBuildContext)
t.Logf("%s: %v", tt.query, got)
})
}
}
func TestQwenRouter_ModeAffectsClassification(t *testing.T) {
query := "how do i fix this bug"
ask := classifyQwenTask(query, ModeAskContext)
build := classifyQwenTask(query, ModeBuildContext)
// BUILD mode should generally prefer equal or larger models than ASK
// Note: This is a soft requirement - the mode adjustment is subtle
t.Logf("ASK mode: %v, BUILD mode: %v", ask, build)
}
func TestQwenRouter_WordCountAffectsClassification(t *testing.T) {
short := "what is go"
long := "what is the go programming language and how does it compare to rust and what are its main features and use cases in modern software development"
shortComplexity := classifyQwenTask(short, ModeAskContext)
longComplexity := classifyQwenTask(long, ModeAskContext)
// Long query should ideally be more complex, but at minimum not less
// Note: This test documents the behavior - word count does affect scoring
t.Logf("short (%d chars): %v, long (%d chars): %v", len(short), shortComplexity, len(long), longComplexity)
}
func TestQwenRouter_CodePatterns(t *testing.T) {
tests := []struct {
name string
query string
maxComplexity QwenComplexity
}{
{"simple variable", "declare a variable", QwenModerate},
{"simple function", "write a function", QwenAdvanced},
{"moderate struct", "define a struct", QwenAdvanced},
{"moderate interface", "implement an interface", QwenAdvanced},
{"moderate concurrency", "add concurrency with goroutines", QwenAdvanced},
{"advanced architecture", "design the architecture", QwenAdvanced},
{"advanced distributed", "distributed system design", QwenAdvanced},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyQwenTask(tt.query, ModeBuildContext)
// All code patterns should classify as something (not panic)
t.Logf("%s: %v", tt.query, got)
})
}
}
func TestQwenRouter_SelectAskModel(t *testing.T) {
cfg := DefaultModelConfig()
router := NewQwenModelRouter(&cfg)
router.SetModeContext(ModeAskContext)
// Simple question should get small model
model := router.SelectModelForMode("what is go", ModeAskContext)
if model != "qwen3.5:0.8b" && model != "qwen3.5:2b" {
t.Errorf("ASK mode simple query should get small model, got %s", model)
}
// Complex question should get capable model (2B or higher)
model = router.SelectModelForMode("design a distributed system", ModeAskContext)
if model == "qwen3.5:0.8b" {
t.Errorf("ASK mode complex query should not get 0.8B model, got %s", model)
}
}
func TestQwenRouter_SelectPlanModel(t *testing.T) {
cfg := DefaultModelConfig()
router := NewQwenModelRouter(&cfg)
router.SetModeContext(ModePlanContext)
// Planning should prefer 4B for reasoning
model := router.SelectModelForMode("plan the architecture", ModePlanContext)
if model != "qwen3.5:4b" && model != "qwen3.5:9b" {
t.Errorf("PLAN mode should prefer 4B or 9B, got %s", model)
}
}
func TestQwenRouter_SelectBuildModel(t *testing.T) {
cfg := DefaultModelConfig()
router := NewQwenModelRouter(&cfg)
router.SetModeContext(ModeBuildContext)
// Building should prefer capable models
model := router.SelectModelForMode("implement the feature", ModeBuildContext)
if model != "qwen3.5:4b" && model != "qwen3.5:9b" {
t.Errorf("BUILD mode should prefer 4B or 9B, got %s", model)
}
}
func TestQwenRouter_GetRecommendedModel(t *testing.T) {
cfg := DefaultModelConfig()
router := NewQwenModelRouter(&cfg)
model, reason, complexity := router.GetRecommendedModel("what is go")
if model == "" {
t.Error("GetRecommendedModel should return a model")
}
if reason == "" {
t.Error("GetRecommendedModel should return a reason")
}
if complexity == "" {
t.Error("GetRecommendedModel should return a complexity")
}
}
func TestQwenRouter_QuestionMarkHandling(t *testing.T) {
// Short questions with ? should be simpler
short := "what is go?"
long := "can you explain what the go programming language is and how it works?"
shortComplexity := classifyQwenTask(short, ModeAskContext)
longComplexity := classifyQwenTask(long, ModeAskContext)
if shortComplexity >= longComplexity {
t.Logf("Note: short question complexity (%v) vs long (%v)", shortComplexity, longComplexity)
}
}
func TestQwenRouter_WhyQuestions(t *testing.T) {
// Why questions need reasoning
why := "why does this code fail"
what := "what does this code do"
whyComplexity := classifyQwenTask(why, ModeAskContext)
whatComplexity := classifyQwenTask(what, ModeAskContext)
if whyComplexity < whatComplexity {
t.Errorf("why questions should be more complex: why=%v, what=%v", whyComplexity, whatComplexity)
}
}
func BenchmarkQwenRouter_ClassifyTask(b *testing.B) {
queries := []string{
"what is go",
"how do i create a file",
"debug this nil pointer error",
"design microservice architecture",
}
for i := 0; i < b.N; i++ {
for _, q := range queries {
_ = classifyQwenTask(q, ModeAskContext)
}
}
}
func BenchmarkQwenRouter_SelectModel(b *testing.B) {
cfg := DefaultModelConfig()
router := NewQwenModelRouter(&cfg)
queries := []string{
"what is go",
"how do i create a file",
"debug this nil pointer error",
"design microservice architecture",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, q := range queries {
_ = router.SelectModel(q)
}
}
}

318
internal/config/router.go Normal file
View File

@ -0,0 +1,318 @@
package config
import (
"context"
"strings"
"sync"
"time"
)
type TaskComplexity string
const (
ComplexitySimple TaskComplexity = "simple"
ComplexityMedium TaskComplexity = "medium"
ComplexityComplex TaskComplexity = "complex"
ComplexityAdvanced TaskComplexity = "advanced"
)
var simpleIndicators = []string{
"what is", "how do i", "explain", "what does",
"find", "search", "list", "show", "get",
"print", "echo", "read", "cat", "ls",
"simple", "quick", "fast",
}
var mediumIndicators = []string{
"create", "write", "generate", "add", "modify",
"change", "update", "fix", "refactor",
"function", "class", "variable", "test",
"script", "command", "file", "directory",
}
var complexIndicators = []string{
"debug", "error", "bug", "issue", "problem",
"refactor", "architecture", "design", "review",
"multiple", "several", "across", "migrate",
"optimize", "performance", "security",
"explain why", "analyze", "compare",
}
var advancedIndicators = []string{
"build a", "create a", "implement", "develop",
"full stack", "system", "infrastructure",
"multi-step", "complex", "comprehensive",
"security audit", "architecture design",
}
// ModelPinger is an interface for checking if a model is available.
type ModelPinger interface {
PingModel(ctx context.Context, model string) error
}
// ModelOverride records when a user explicitly selects a model.
type ModelOverride struct {
Query string
UserModel string
RouterModel string
Timestamp time.Time
}
type Router struct {
config *ModelConfig
overrideLog []ModelOverride
mu sync.RWMutex
}
func NewRouter(cfg *ModelConfig) *Router {
return &Router{
config: cfg,
overrideLog: make([]ModelOverride, 0),
}
}
func (r *Router) ClassifyTaskComplexity(query string) TaskComplexity {
return ClassifyTask(query)
}
func (r *Router) SelectModel(query string) string {
complexity := r.ClassifyTaskComplexity(query)
// Check learned patterns if we have enough data
wordComplexity := r.getLearnedPatterns()
if len(wordComplexity) > 0 {
words := strings.Fields(strings.ToLower(query))
// Count votes from learned patterns
complexityVotes := make(map[TaskComplexity]int)
for _, w := range words {
if len(w) >= 3 { // Skip short words
if c, ok := wordComplexity[w]; ok {
complexityVotes[c]++
}
}
}
// If strong learned signal (>30% words match a pattern), use it
if len(words) > 0 {
matchRatio := float64(complexityVotes[ComplexitySimple]+complexityVotes[ComplexityAdvanced]) / float64(len(words))
if matchRatio > 0.3 {
if complexityVotes[ComplexityAdvanced] > complexityVotes[ComplexitySimple] {
complexity = ComplexityAdvanced
} else if complexityVotes[ComplexitySimple] > complexityVotes[ComplexityAdvanced] {
complexity = ComplexitySimple
}
}
}
}
return r.config.SelectModelForTask(string(complexity))
}
// RecordOverride logs when a user explicitly selects a model.
// This helps the router learn from user preferences.
func (r *Router) RecordOverride(query, userModel string) {
r.mu.Lock()
defer r.mu.Unlock()
routerModel := r.SelectModel(query)
r.overrideLog = append(r.overrideLog, ModelOverride{
Query: query,
UserModel: userModel,
RouterModel: routerModel,
Timestamp: time.Now(),
})
// Keep last 100 overrides
if len(r.overrideLog) > 100 {
r.overrideLog = r.overrideLog[len(r.overrideLog)-100:]
}
}
// getLearnedPatterns analyzes override history to find word->complexity mappings.
func (r *Router) getLearnedPatterns() map[string]TaskComplexity {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.overrideLog) < 3 {
return nil // Not enough data
}
wordCounts := make(map[string]map[TaskComplexity]int)
for _, o := range r.overrideLog {
if o.Query == "" || o.UserModel == "" {
continue
}
// Determine complexity from user-selected model
var complexity TaskComplexity
switch {
case strings.Contains(o.UserModel, "0.8") || strings.Contains(o.UserModel, "2b"):
complexity = ComplexitySimple
case strings.Contains(o.UserModel, "4b"):
complexity = ComplexityMedium
case strings.Contains(o.UserModel, "9b"):
complexity = ComplexityAdvanced
default:
continue
}
words := strings.Fields(strings.ToLower(o.Query))
for _, w := range words {
if len(w) < 3 {
continue // Skip short words
}
if _, ok := wordCounts[w]; !ok {
wordCounts[w] = make(map[TaskComplexity]int)
}
wordCounts[w][complexity]++
}
}
// For each word, find dominant complexity
wordComplexity := make(map[string]TaskComplexity)
for word, counts := range wordCounts {
var maxCount int
var dominant TaskComplexity
for c, cnt := range counts {
if cnt > maxCount {
maxCount = cnt
dominant = c
}
}
// Only use if we have enough samples (at least 2 overrides)
if maxCount >= 2 {
wordComplexity[word] = dominant
}
}
return wordComplexity
}
func (r *Router) GetFallbackChain(currentModel string) []string {
chain := r.config.FallbackChain
for i, model := range chain {
if model == currentModel {
return chain[i:]
}
}
return chain
}
func (r *Router) GetModelForCapability(capability ModelCapability) string {
for _, m := range r.config.Models {
if m.Capability == capability {
return m.Name
}
}
return r.config.DefaultModel
}
// SelectAvailableModel returns the first available model from the fallback chain.
// It checks each model in order and returns the first one that responds to a ping.
// If no models are available, returns the default model.
func (r *Router) SelectAvailableModel(ctx context.Context, pinger ModelPinger) string {
chain := r.config.FallbackChain
for _, model := range chain {
if err := pinger.PingModel(ctx, model); err == nil {
return model
}
}
// Fallback to default if none available
return r.config.DefaultModel
}
// SelectAvailableModelForTask returns the first available model for the given task complexity.
// It prioritizes models appropriate for the task, then falls back to larger models if unavailable.
func (r *Router) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string {
// First, get the preferred model for this task
preferred := r.SelectModel(query)
// Check if preferred model is available
if err := pinger.PingModel(ctx, preferred); err == nil {
return preferred
}
// Try fallback chain
chain := r.GetFallbackChain(preferred)
for _, model := range chain {
if err := pinger.PingModel(ctx, model); err == nil {
return model
}
}
// Last resort: default model
return r.config.DefaultModel
}
func (r *Router) ForceModel(name string) (*Model, error) {
return r.config.GetModel(name)
}
func (r *Router) ListModels() []Model {
return r.config.Models
}
func (r *Router) GetDefaultModel() string {
return r.config.DefaultModel
}
func ClassifyTask(query string) TaskComplexity {
lowerQuery := strings.ToLower(query)
wordCount := len(strings.Fields(query))
score := 0
for _, indicator := range simpleIndicators {
if strings.Contains(lowerQuery, indicator) {
score -= 2
}
}
for _, indicator := range mediumIndicators {
if strings.Contains(lowerQuery, indicator) {
score += 1
}
}
for _, indicator := range complexIndicators {
if strings.Contains(lowerQuery, indicator) {
score += 2
}
}
for _, indicator := range advancedIndicators {
if strings.Contains(lowerQuery, indicator) {
score += 3
}
}
if wordCount > 50 {
score += 2
}
if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") {
score += 1
}
if strings.Contains(lowerQuery, "how") && wordCount > 10 {
score += 1
}
switch {
case score <= -2:
return ComplexitySimple
case score <= 1:
return ComplexityMedium
case score <= 4:
return ComplexityComplex
default:
return ComplexityAdvanced
}
}

View File

@ -0,0 +1,166 @@
package config
import (
"strings"
"testing"
)
func TestClassifyTask(t *testing.T) {
tests := []struct {
name string
query string
want TaskComplexity
}{
{name: "empty query", query: "", want: ComplexityMedium},
{name: "simple what is", query: "what is Go", want: ComplexitySimple},
// "create a function": medium "create" +1, "function" +1, advanced "create a" +3 = 5 → advanced
{name: "create a function is advanced due to overlaps", query: "create a function", want: ComplexityAdvanced},
// "debug this error across multiple files": complex "debug" +2, "error" +2, "bug" +2 (substring of debug),
// "multiple" +2, "across" +2 = 10, medium "file" +1 = 11 → advanced
{name: "debug across files is advanced", query: "debug this error across multiple files", want: ComplexityAdvanced},
// "implement a full stack system with infrastructure": advanced "implement" +3, "full stack" +3, "system" +3,
// "infrastructure" +3 = 12 → advanced
{name: "advanced full stack system", query: "implement a full stack system with infrastructure", want: ComplexityAdvanced},
// Boundary: "explain" → simple -2 → score -2 → simple
{name: "boundary simple score -2", query: "explain", want: ComplexitySimple},
// No indicators → score 0 → medium
{name: "boundary medium score 0", query: "hello world", want: ComplexityMedium},
// "create" → medium +1, but also matches advanced "create a"? No, "create" doesn't contain "create a".
// So just +1 → medium
{name: "boundary medium score 1", query: "create", want: ComplexityMedium},
// "debug" alone: complex "debug" +2, "bug" +2 (substring) = 4 → complex
{name: "debug alone is complex", query: "debug", want: ComplexityComplex},
// "debug error": "debug" +2, "error" +2, "bug" +2 (substring of debug) = 6 → advanced
{name: "debug error is advanced", query: "debug error", want: ComplexityAdvanced},
// Word count >50 bonus (+2) with "debug": "debug" +2, "bug" +2 = 4, +2 word bonus = 6 → advanced
{
name: "word count bonus over 50 with debug",
query: strings.Repeat("word ", 51) + "debug",
want: ComplexityAdvanced,
},
// "why does this happen": "why" +1 = 1 → medium
{name: "why bonus", query: "why does this happen", want: ComplexityMedium},
// "reason for the crash": "reason" +1 = 1 → medium
{name: "reason bonus", query: "reason for the crash", want: ComplexityMedium},
// "how about we think...": no indicators, "how" + >10 words +1 = 1 → medium
{name: "how with many words", query: "how about we think about the things that are happening right now in the code base", want: ComplexityMedium},
// Case insensitivity
{name: "case insensitive WHAT IS", query: "WHAT IS Go", want: ComplexitySimple},
{name: "case insensitive EXPLAIN", query: "EXPLAIN this code", want: ComplexitySimple},
// Pure simple: multiple simple indicators
{name: "multiple simple indicators", query: "what is this simple quick search", want: ComplexitySimple},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ClassifyTask(tt.query)
if got != tt.want {
t.Errorf("ClassifyTask(%q) = %q, want %q", tt.query, got, tt.want)
}
})
}
}
func TestRouter_GetFallbackChain(t *testing.T) {
cfg := &ModelConfig{
FallbackChain: []string{"a", "b", "c", "d"},
}
r := NewRouter(cfg)
tests := []struct {
name string
model string
wantLen int
wantAll bool // true means expect full chain
}{
{name: "found at start", model: "a", wantLen: 4},
{name: "found in middle", model: "c", wantLen: 2},
{name: "found at end", model: "d", wantLen: 1},
{name: "not found returns full chain", model: "unknown", wantLen: 4, wantAll: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := r.GetFallbackChain(tt.model)
if len(got) != tt.wantLen {
t.Errorf("GetFallbackChain(%q) returned %d items, want %d", tt.model, len(got), tt.wantLen)
}
if tt.wantAll && got[0] != "a" {
t.Errorf("GetFallbackChain(%q) first element = %q, want %q", tt.model, got[0], "a")
}
})
}
}
func TestRouter_GetModelForCapability(t *testing.T) {
cfg := &ModelConfig{
Models: []Model{
{Name: "fast", Capability: CapabilitySimple},
{Name: "mid", Capability: CapabilityMedium},
{Name: "big", Capability: CapabilityComplex},
},
DefaultModel: "fallback",
}
r := NewRouter(cfg)
tests := []struct {
name string
capability ModelCapability
want string
}{
{name: "match simple", capability: CapabilitySimple, want: "fast"},
{name: "match medium", capability: CapabilityMedium, want: "mid"},
{name: "match complex", capability: CapabilityComplex, want: "big"},
{name: "no match returns default", capability: CapabilityAdvanced, want: "fallback"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := r.GetModelForCapability(tt.capability)
if got != tt.want {
t.Errorf("GetModelForCapability(%d) = %q, want %q", tt.capability, got, tt.want)
}
})
}
}
func TestRouter_SelectModel(t *testing.T) {
cfg := DefaultModelConfig()
r := NewRouter(&cfg)
tests := []struct {
name string
query string
want string
}{
// "what is Go" → simple → first model
{name: "simple query selects first model", query: "what is Go", want: cfg.Models[0].Name},
// "debug" → complex → complex-capable model
{name: "complex query selects complex model", query: "debug", want: "qwen3.5:4b"},
// "implement a system" → advanced → DefaultModel
{name: "advanced query selects default model", query: "implement a full stack system", want: cfg.DefaultModel},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := r.SelectModel(tt.query)
if got != tt.want {
t.Errorf("SelectModel(%q) = %q, want %q", tt.query, got, tt.want)
}
})
}
}

31
internal/db/db.go Normal file
View File

@ -0,0 +1,31 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package db
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@ -0,0 +1,58 @@
-- Sessions table: replaces noted CLI dependency for session persistence.
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL DEFAULT '',
model TEXT NOT NULL DEFAULT '',
mode TEXT NOT NULL DEFAULT 'BUILD',
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);
-- Session messages: individual chat entries within a session.
CREATE TABLE IF NOT EXISTS session_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
role TEXT NOT NULL, -- 'user', 'assistant', 'tool', 'system', 'error'
content TEXT NOT NULL DEFAULT '',
tool_name TEXT NOT NULL DEFAULT '',
tool_args TEXT NOT NULL DEFAULT '',
is_error INTEGER NOT NULL DEFAULT 0,
thinking TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_session_messages_session_id ON session_messages(session_id);
-- Tool permissions: per-tool allow/deny/always-allow.
CREATE TABLE IF NOT EXISTS tool_permissions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tool_name TEXT NOT NULL UNIQUE,
policy TEXT NOT NULL DEFAULT 'ask', -- 'allow', 'deny', 'ask'
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);
-- Token usage stats: per-turn tracking.
CREATE TABLE IF NOT EXISTS token_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
turn INTEGER NOT NULL DEFAULT 0,
eval_count INTEGER NOT NULL DEFAULT 0,
prompt_tokens INTEGER NOT NULL DEFAULT 0,
model TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_token_stats_session_id ON token_stats(session_id);
-- File changes: files modified by agent during a session.
CREATE TABLE IF NOT EXISTS file_changes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
file_path TEXT NOT NULL,
tool_name TEXT NOT NULL DEFAULT '',
added INTEGER NOT NULL DEFAULT 0,
removed INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);
CREATE INDEX IF NOT EXISTS idx_file_changes_session_id ON file_changes(session_id);

53
internal/db/models.go Normal file
View File

@ -0,0 +1,53 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package db
type FileChange struct {
ID int64 `json:"id"`
SessionID int64 `json:"session_id"`
FilePath string `json:"file_path"`
ToolName string `json:"tool_name"`
Added int64 `json:"added"`
Removed int64 `json:"removed"`
CreatedAt string `json:"created_at"`
}
type Session struct {
ID int64 `json:"id"`
Title string `json:"title"`
Model string `json:"model"`
Mode string `json:"mode"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type SessionMessage struct {
ID int64 `json:"id"`
SessionID int64 `json:"session_id"`
Role string `json:"role"`
Content string `json:"content"`
ToolName string `json:"tool_name"`
ToolArgs string `json:"tool_args"`
IsError int64 `json:"is_error"`
Thinking string `json:"thinking"`
CreatedAt string `json:"created_at"`
}
type TokenStat struct {
ID int64 `json:"id"`
SessionID int64 `json:"session_id"`
Turn int64 `json:"turn"`
EvalCount int64 `json:"eval_count"`
PromptTokens int64 `json:"prompt_tokens"`
Model string `json:"model"`
CreatedAt string `json:"created_at"`
}
type ToolPermission struct {
ID int64 `json:"id"`
ToolName string `json:"tool_name"`
Policy string `json:"policy"`
UpdatedAt string `json:"updated_at"`
}

View File

@ -0,0 +1,100 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: permissions.sql
package db
import (
"context"
)
const deleteToolPermission = `-- name: DeleteToolPermission :exec
DELETE FROM tool_permissions WHERE tool_name = ?
`
func (q *Queries) DeleteToolPermission(ctx context.Context, toolName string) error {
_, err := q.db.ExecContext(ctx, deleteToolPermission, toolName)
return err
}
const getToolPermission = `-- name: GetToolPermission :one
SELECT id, tool_name, policy, updated_at FROM tool_permissions WHERE tool_name = ?
`
func (q *Queries) GetToolPermission(ctx context.Context, toolName string) (ToolPermission, error) {
row := q.db.QueryRowContext(ctx, getToolPermission, toolName)
var i ToolPermission
err := row.Scan(
&i.ID,
&i.ToolName,
&i.Policy,
&i.UpdatedAt,
)
return i, err
}
const listToolPermissions = `-- name: ListToolPermissions :many
SELECT id, tool_name, policy, updated_at FROM tool_permissions ORDER BY tool_name ASC
`
func (q *Queries) ListToolPermissions(ctx context.Context) ([]ToolPermission, error) {
rows, err := q.db.QueryContext(ctx, listToolPermissions)
if err != nil {
return nil, err
}
defer rows.Close()
items := []ToolPermission{}
for rows.Next() {
var i ToolPermission
if err := rows.Scan(
&i.ID,
&i.ToolName,
&i.Policy,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const resetToolPermissions = `-- name: ResetToolPermissions :exec
DELETE FROM tool_permissions
`
func (q *Queries) ResetToolPermissions(ctx context.Context) error {
_, err := q.db.ExecContext(ctx, resetToolPermissions)
return err
}
const upsertToolPermission = `-- name: UpsertToolPermission :one
INSERT INTO tool_permissions (tool_name, policy)
VALUES (?, ?)
ON CONFLICT(tool_name) DO UPDATE SET policy = excluded.policy, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
RETURNING id, tool_name, policy, updated_at
`
type UpsertToolPermissionParams struct {
ToolName string `json:"tool_name"`
Policy string `json:"policy"`
}
func (q *Queries) UpsertToolPermission(ctx context.Context, arg UpsertToolPermissionParams) (ToolPermission, error) {
row := q.db.QueryRowContext(ctx, upsertToolPermission, arg.ToolName, arg.Policy)
var i ToolPermission
err := row.Scan(
&i.ID,
&i.ToolName,
&i.Policy,
&i.UpdatedAt,
)
return i, err
}

View File

@ -0,0 +1,17 @@
-- name: GetToolPermission :one
SELECT * FROM tool_permissions WHERE tool_name = ?;
-- name: UpsertToolPermission :one
INSERT INTO tool_permissions (tool_name, policy)
VALUES (?, ?)
ON CONFLICT(tool_name) DO UPDATE SET policy = excluded.policy, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
RETURNING *;
-- name: ListToolPermissions :many
SELECT * FROM tool_permissions ORDER BY tool_name ASC;
-- name: DeleteToolPermission :exec
DELETE FROM tool_permissions WHERE tool_name = ?;
-- name: ResetToolPermissions :exec
DELETE FROM tool_permissions;

View File

@ -0,0 +1,27 @@
-- name: CreateSession :one
INSERT INTO sessions (title, model, mode) VALUES (?, ?, ?) RETURNING *;
-- name: GetSession :one
SELECT * FROM sessions WHERE id = ?;
-- name: ListSessions :many
SELECT * FROM sessions ORDER BY updated_at DESC LIMIT ?;
-- name: UpdateSessionTitle :exec
UPDATE sessions SET title = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?;
-- name: UpdateSessionTimestamp :exec
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?;
-- name: DeleteSession :exec
DELETE FROM sessions WHERE id = ?;
-- name: CreateSessionMessage :one
INSERT INTO session_messages (session_id, role, content, tool_name, tool_args, is_error, thinking)
VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *;
-- name: GetSessionMessages :many
SELECT * FROM session_messages WHERE session_id = ? ORDER BY id ASC;
-- name: CountSessions :one
SELECT COUNT(*) FROM sessions;

View File

@ -0,0 +1,31 @@
-- name: RecordTokenUsage :one
INSERT INTO token_stats (session_id, turn, eval_count, prompt_tokens, model)
VALUES (?, ?, ?, ?, ?) RETURNING *;
-- name: GetSessionTokenStats :many
SELECT * FROM token_stats WHERE session_id = ? ORDER BY turn ASC;
-- name: GetSessionTotalTokens :one
SELECT
CAST(COALESCE(SUM(eval_count), 0) AS INTEGER) AS total_eval,
CAST(COALESCE(SUM(prompt_tokens), 0) AS INTEGER) AS total_prompt,
CAST(COUNT(*) AS INTEGER) AS turn_count
FROM token_stats WHERE session_id = ?;
-- name: RecordFileChange :one
INSERT INTO file_changes (session_id, file_path, tool_name, added, removed)
VALUES (?, ?, ?, ?, ?) RETURNING *;
-- name: GetSessionFileChanges :many
SELECT * FROM file_changes WHERE session_id = ? ORDER BY created_at ASC;
-- name: GetSessionFileChangeSummary :many
SELECT
file_path,
CAST(COALESCE(SUM(added), 0) AS INTEGER) AS total_added,
CAST(COALESCE(SUM(removed), 0) AS INTEGER) AS total_removed,
CAST(COUNT(*) AS INTEGER) AS change_count
FROM file_changes
WHERE session_id = ?
GROUP BY file_path
ORDER BY file_path ASC;

206
internal/db/sessions.sql.go Normal file
View File

@ -0,0 +1,206 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: sessions.sql
package db
import (
"context"
)
const countSessions = `-- name: CountSessions :one
SELECT COUNT(*) FROM sessions
`
func (q *Queries) CountSessions(ctx context.Context) (int64, error) {
row := q.db.QueryRowContext(ctx, countSessions)
var count int64
err := row.Scan(&count)
return count, err
}
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (title, model, mode) VALUES (?, ?, ?) RETURNING id, title, model, mode, created_at, updated_at
`
type CreateSessionParams struct {
Title string `json:"title"`
Model string `json:"model"`
Mode string `json:"mode"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, createSession, arg.Title, arg.Model, arg.Mode)
var i Session
err := row.Scan(
&i.ID,
&i.Title,
&i.Model,
&i.Mode,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createSessionMessage = `-- name: CreateSessionMessage :one
INSERT INTO session_messages (session_id, role, content, tool_name, tool_args, is_error, thinking)
VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id, session_id, role, content, tool_name, tool_args, is_error, thinking, created_at
`
type CreateSessionMessageParams struct {
SessionID int64 `json:"session_id"`
Role string `json:"role"`
Content string `json:"content"`
ToolName string `json:"tool_name"`
ToolArgs string `json:"tool_args"`
IsError int64 `json:"is_error"`
Thinking string `json:"thinking"`
}
func (q *Queries) CreateSessionMessage(ctx context.Context, arg CreateSessionMessageParams) (SessionMessage, error) {
row := q.db.QueryRowContext(ctx, createSessionMessage,
arg.SessionID,
arg.Role,
arg.Content,
arg.ToolName,
arg.ToolArgs,
arg.IsError,
arg.Thinking,
)
var i SessionMessage
err := row.Scan(
&i.ID,
&i.SessionID,
&i.Role,
&i.Content,
&i.ToolName,
&i.ToolArgs,
&i.IsError,
&i.Thinking,
&i.CreatedAt,
)
return i, err
}
const deleteSession = `-- name: DeleteSession :exec
DELETE FROM sessions WHERE id = ?
`
func (q *Queries) DeleteSession(ctx context.Context, id int64) error {
_, err := q.db.ExecContext(ctx, deleteSession, id)
return err
}
const getSession = `-- name: GetSession :one
SELECT id, title, model, mode, created_at, updated_at FROM sessions WHERE id = ?
`
func (q *Queries) GetSession(ctx context.Context, id int64) (Session, error) {
row := q.db.QueryRowContext(ctx, getSession, id)
var i Session
err := row.Scan(
&i.ID,
&i.Title,
&i.Model,
&i.Mode,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getSessionMessages = `-- name: GetSessionMessages :many
SELECT id, session_id, role, content, tool_name, tool_args, is_error, thinking, created_at FROM session_messages WHERE session_id = ? ORDER BY id ASC
`
func (q *Queries) GetSessionMessages(ctx context.Context, sessionID int64) ([]SessionMessage, error) {
rows, err := q.db.QueryContext(ctx, getSessionMessages, sessionID)
if err != nil {
return nil, err
}
defer rows.Close()
items := []SessionMessage{}
for rows.Next() {
var i SessionMessage
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.Role,
&i.Content,
&i.ToolName,
&i.ToolArgs,
&i.IsError,
&i.Thinking,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSessions = `-- name: ListSessions :many
SELECT id, title, model, mode, created_at, updated_at FROM sessions ORDER BY updated_at DESC LIMIT ?
`
func (q *Queries) ListSessions(ctx context.Context, limit int64) ([]Session, error) {
rows, err := q.db.QueryContext(ctx, listSessions, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := []Session{}
for rows.Next() {
var i Session
if err := rows.Scan(
&i.ID,
&i.Title,
&i.Model,
&i.Mode,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateSessionTimestamp = `-- name: UpdateSessionTimestamp :exec
UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?
`
func (q *Queries) UpdateSessionTimestamp(ctx context.Context, id int64) error {
_, err := q.db.ExecContext(ctx, updateSessionTimestamp, id)
return err
}
const updateSessionTitle = `-- name: UpdateSessionTitle :exec
UPDATE sessions SET title = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?
`
type UpdateSessionTitleParams struct {
Title string `json:"title"`
ID int64 `json:"id"`
}
func (q *Queries) UpdateSessionTitle(ctx context.Context, arg UpdateSessionTitleParams) error {
_, err := q.db.ExecContext(ctx, updateSessionTitle, arg.Title, arg.ID)
return err
}

11
internal/db/sqlc.yaml Normal file
View File

@ -0,0 +1,11 @@
version: "2"
sql:
- engine: "sqlite"
queries: "queries"
schema: "migrations"
gen:
go:
package: "db"
out: "."
emit_json_tags: true
emit_empty_slices: true

216
internal/db/stats.sql.go Normal file
View File

@ -0,0 +1,216 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: stats.sql
package db
import (
"context"
)
const getSessionFileChangeSummary = `-- name: GetSessionFileChangeSummary :many
SELECT
file_path,
CAST(COALESCE(SUM(added), 0) AS INTEGER) AS total_added,
CAST(COALESCE(SUM(removed), 0) AS INTEGER) AS total_removed,
CAST(COUNT(*) AS INTEGER) AS change_count
FROM file_changes
WHERE session_id = ?
GROUP BY file_path
ORDER BY file_path ASC
`
type GetSessionFileChangeSummaryRow struct {
FilePath string `json:"file_path"`
TotalAdded int64 `json:"total_added"`
TotalRemoved int64 `json:"total_removed"`
ChangeCount int64 `json:"change_count"`
}
func (q *Queries) GetSessionFileChangeSummary(ctx context.Context, sessionID int64) ([]GetSessionFileChangeSummaryRow, error) {
rows, err := q.db.QueryContext(ctx, getSessionFileChangeSummary, sessionID)
if err != nil {
return nil, err
}
defer rows.Close()
items := []GetSessionFileChangeSummaryRow{}
for rows.Next() {
var i GetSessionFileChangeSummaryRow
if err := rows.Scan(
&i.FilePath,
&i.TotalAdded,
&i.TotalRemoved,
&i.ChangeCount,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getSessionFileChanges = `-- name: GetSessionFileChanges :many
SELECT id, session_id, file_path, tool_name, added, removed, created_at FROM file_changes WHERE session_id = ? ORDER BY created_at ASC
`
func (q *Queries) GetSessionFileChanges(ctx context.Context, sessionID int64) ([]FileChange, error) {
rows, err := q.db.QueryContext(ctx, getSessionFileChanges, sessionID)
if err != nil {
return nil, err
}
defer rows.Close()
items := []FileChange{}
for rows.Next() {
var i FileChange
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.FilePath,
&i.ToolName,
&i.Added,
&i.Removed,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getSessionTokenStats = `-- name: GetSessionTokenStats :many
SELECT id, session_id, turn, eval_count, prompt_tokens, model, created_at FROM token_stats WHERE session_id = ? ORDER BY turn ASC
`
func (q *Queries) GetSessionTokenStats(ctx context.Context, sessionID int64) ([]TokenStat, error) {
rows, err := q.db.QueryContext(ctx, getSessionTokenStats, sessionID)
if err != nil {
return nil, err
}
defer rows.Close()
items := []TokenStat{}
for rows.Next() {
var i TokenStat
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.Turn,
&i.EvalCount,
&i.PromptTokens,
&i.Model,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getSessionTotalTokens = `-- name: GetSessionTotalTokens :one
SELECT
CAST(COALESCE(SUM(eval_count), 0) AS INTEGER) AS total_eval,
CAST(COALESCE(SUM(prompt_tokens), 0) AS INTEGER) AS total_prompt,
CAST(COUNT(*) AS INTEGER) AS turn_count
FROM token_stats WHERE session_id = ?
`
type GetSessionTotalTokensRow struct {
TotalEval int64 `json:"total_eval"`
TotalPrompt int64 `json:"total_prompt"`
TurnCount int64 `json:"turn_count"`
}
func (q *Queries) GetSessionTotalTokens(ctx context.Context, sessionID int64) (GetSessionTotalTokensRow, error) {
row := q.db.QueryRowContext(ctx, getSessionTotalTokens, sessionID)
var i GetSessionTotalTokensRow
err := row.Scan(&i.TotalEval, &i.TotalPrompt, &i.TurnCount)
return i, err
}
const recordFileChange = `-- name: RecordFileChange :one
INSERT INTO file_changes (session_id, file_path, tool_name, added, removed)
VALUES (?, ?, ?, ?, ?) RETURNING id, session_id, file_path, tool_name, added, removed, created_at
`
type RecordFileChangeParams struct {
SessionID int64 `json:"session_id"`
FilePath string `json:"file_path"`
ToolName string `json:"tool_name"`
Added int64 `json:"added"`
Removed int64 `json:"removed"`
}
func (q *Queries) RecordFileChange(ctx context.Context, arg RecordFileChangeParams) (FileChange, error) {
row := q.db.QueryRowContext(ctx, recordFileChange,
arg.SessionID,
arg.FilePath,
arg.ToolName,
arg.Added,
arg.Removed,
)
var i FileChange
err := row.Scan(
&i.ID,
&i.SessionID,
&i.FilePath,
&i.ToolName,
&i.Added,
&i.Removed,
&i.CreatedAt,
)
return i, err
}
const recordTokenUsage = `-- name: RecordTokenUsage :one
INSERT INTO token_stats (session_id, turn, eval_count, prompt_tokens, model)
VALUES (?, ?, ?, ?, ?) RETURNING id, session_id, turn, eval_count, prompt_tokens, model, created_at
`
type RecordTokenUsageParams struct {
SessionID int64 `json:"session_id"`
Turn int64 `json:"turn"`
EvalCount int64 `json:"eval_count"`
PromptTokens int64 `json:"prompt_tokens"`
Model string `json:"model"`
}
func (q *Queries) RecordTokenUsage(ctx context.Context, arg RecordTokenUsageParams) (TokenStat, error) {
row := q.db.QueryRowContext(ctx, recordTokenUsage,
arg.SessionID,
arg.Turn,
arg.EvalCount,
arg.PromptTokens,
arg.Model,
)
var i TokenStat
err := row.Scan(
&i.ID,
&i.SessionID,
&i.Turn,
&i.EvalCount,
&i.PromptTokens,
&i.Model,
&i.CreatedAt,
)
return i, err
}

71
internal/db/store.go Normal file
View File

@ -0,0 +1,71 @@
package db
import (
"database/sql"
"embed"
"fmt"
"os"
"path/filepath"
_ "modernc.org/sqlite"
)
//go:embed migrations/*.sql
var migrations embed.FS
type Store struct {
*Queries
db *sql.DB
}
func Open() (*Store, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("home dir: %w", err)
}
dir := filepath.Join(home, ".config", "ai-agent")
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create config dir: %w", err)
}
return OpenPath(filepath.Join(dir, "ai-agent.db"))
}
func OpenPath(path string) (*Store, error) {
conn, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=ON")
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
if err := runMigrations(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("migrations: %w", err)
}
return &Store{Queries: New(conn), db: conn}, nil
}
func (s *Store) Close() error {
return s.db.Close()
}
func (s *Store) DB() *sql.DB {
return s.db
}
func runMigrations(conn *sql.DB) error {
entries, err := migrations.ReadDir("migrations")
if err != nil {
return fmt.Errorf("read migrations dir: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
data, err := migrations.ReadFile("migrations/" + entry.Name())
if err != nil {
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
}
if _, err := conn.Exec(string(data)); err != nil {
return fmt.Errorf("exec migration %s: %w", entry.Name(), err)
}
}
return nil
}

271
internal/db/store_test.go Normal file
View File

@ -0,0 +1,271 @@
package db
import (
"context"
"os"
"path/filepath"
"testing"
)
func testStore(t *testing.T) *Store {
t.Helper()
dir := t.TempDir()
s, err := OpenPath(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open store: %v", err)
}
t.Cleanup(func() { s.Close() })
return s
}
func TestOpenAndMigrate(t *testing.T) {
s := testStore(t)
// Verify the store is functional by counting sessions.
ctx := context.Background()
count, err := s.CountSessions(ctx)
if err != nil {
t.Fatalf("count sessions: %v", err)
}
if count != 0 {
t.Fatalf("expected 0 sessions, got %d", count)
}
}
func TestSessionCRUD(t *testing.T) {
s := testStore(t)
ctx := context.Background()
// Create.
sess, err := s.CreateSession(ctx, CreateSessionParams{
Title: "Test Session",
Model: "qwen3.5:4b",
Mode: "BUILD",
})
if err != nil {
t.Fatalf("create session: %v", err)
}
if sess.Title != "Test Session" {
t.Fatalf("expected title 'Test Session', got %q", sess.Title)
}
// Read.
got, err := s.GetSession(ctx, sess.ID)
if err != nil {
t.Fatalf("get session: %v", err)
}
if got.Model != "qwen3.5:4b" {
t.Fatalf("expected model 'qwen3.5:4b', got %q", got.Model)
}
// Create message.
msg, err := s.CreateSessionMessage(ctx, CreateSessionMessageParams{
SessionID: sess.ID,
Role: "user",
Content: "Hello world",
})
if err != nil {
t.Fatalf("create message: %v", err)
}
if msg.Content != "Hello world" {
t.Fatalf("expected content 'Hello world', got %q", msg.Content)
}
// List messages.
msgs, err := s.GetSessionMessages(ctx, sess.ID)
if err != nil {
t.Fatalf("get messages: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
// Delete.
if err := s.DeleteSession(ctx, sess.ID); err != nil {
t.Fatalf("delete session: %v", err)
}
count, err := s.CountSessions(ctx)
if err != nil {
t.Fatalf("count: %v", err)
}
if count != 0 {
t.Fatalf("expected 0 after delete, got %d", count)
}
}
func TestToolPermissions(t *testing.T) {
s := testStore(t)
ctx := context.Background()
// Upsert.
perm, err := s.UpsertToolPermission(ctx, UpsertToolPermissionParams{
ToolName: "bash",
Policy: "allow",
})
if err != nil {
t.Fatalf("upsert permission: %v", err)
}
if perm.Policy != "allow" {
t.Fatalf("expected policy 'allow', got %q", perm.Policy)
}
// Update via upsert.
perm2, err := s.UpsertToolPermission(ctx, UpsertToolPermissionParams{
ToolName: "bash",
Policy: "deny",
})
if err != nil {
t.Fatalf("upsert update: %v", err)
}
if perm2.Policy != "deny" {
t.Fatalf("expected policy 'deny', got %q", perm2.Policy)
}
// List.
perms, err := s.ListToolPermissions(ctx)
if err != nil {
t.Fatalf("list permissions: %v", err)
}
if len(perms) != 1 {
t.Fatalf("expected 1 permission, got %d", len(perms))
}
// Reset.
if err := s.ResetToolPermissions(ctx); err != nil {
t.Fatalf("reset: %v", err)
}
perms, err = s.ListToolPermissions(ctx)
if err != nil {
t.Fatalf("list after reset: %v", err)
}
if len(perms) != 0 {
t.Fatalf("expected 0 after reset, got %d", len(perms))
}
}
func TestTokenStats(t *testing.T) {
s := testStore(t)
ctx := context.Background()
sess, err := s.CreateSession(ctx, CreateSessionParams{
Title: "Stats Test",
Model: "qwen3.5:4b",
Mode: "ASK",
})
if err != nil {
t.Fatalf("create session: %v", err)
}
// Record usage.
_, err = s.RecordTokenUsage(ctx, RecordTokenUsageParams{
SessionID: sess.ID,
Turn: 1,
EvalCount: 100,
PromptTokens: 500,
Model: "qwen3.5:4b",
})
if err != nil {
t.Fatalf("record usage: %v", err)
}
_, err = s.RecordTokenUsage(ctx, RecordTokenUsageParams{
SessionID: sess.ID,
Turn: 2,
EvalCount: 200,
PromptTokens: 600,
Model: "qwen3.5:4b",
})
if err != nil {
t.Fatalf("record usage 2: %v", err)
}
// Get totals.
totals, err := s.GetSessionTotalTokens(ctx, sess.ID)
if err != nil {
t.Fatalf("get totals: %v", err)
}
if totals.TotalEval != 300 {
t.Fatalf("expected total_eval 300, got %v", totals.TotalEval)
}
if totals.TotalPrompt != 1100 {
t.Fatalf("expected total_prompt 1100, got %v", totals.TotalPrompt)
}
if totals.TurnCount != 2 {
t.Fatalf("expected turn_count 2, got %v", totals.TurnCount)
}
}
func TestFileChanges(t *testing.T) {
s := testStore(t)
ctx := context.Background()
sess, err := s.CreateSession(ctx, CreateSessionParams{
Title: "Changes Test",
Model: "qwen3.5:4b",
Mode: "BUILD",
})
if err != nil {
t.Fatalf("create session: %v", err)
}
_, err = s.RecordFileChange(ctx, RecordFileChangeParams{
SessionID: sess.ID,
FilePath: "main.go",
ToolName: "write_file",
Added: 10,
Removed: 3,
})
if err != nil {
t.Fatalf("record change: %v", err)
}
_, err = s.RecordFileChange(ctx, RecordFileChangeParams{
SessionID: sess.ID,
FilePath: "main.go",
ToolName: "write_file",
Added: 5,
Removed: 2,
})
if err != nil {
t.Fatalf("record change 2: %v", err)
}
summary, err := s.GetSessionFileChangeSummary(ctx, sess.ID)
if err != nil {
t.Fatalf("get summary: %v", err)
}
if len(summary) != 1 {
t.Fatalf("expected 1 file, got %d", len(summary))
}
if summary[0].TotalAdded != 15 {
t.Fatalf("expected 15 added, got %d", summary[0].TotalAdded)
}
}
func TestDoubleOpen(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.db")
s1, err := OpenPath(path)
if err != nil {
t.Fatalf("first open: %v", err)
}
defer s1.Close()
// Running migrations again should be idempotent (IF NOT EXISTS).
s2, err := OpenPath(path)
if err != nil {
t.Fatalf("second open: %v", err)
}
defer s2.Close()
}
func TestOpenDefault(t *testing.T) {
if os.Getenv("CI") != "" {
t.Skip("skip in CI to avoid side effects")
}
s, err := Open()
if err != nil {
t.Fatalf("open default: %v", err)
}
s.Close()
}

126
internal/ice/assembler.go Normal file
View File

@ -0,0 +1,126 @@
package ice
import (
"context"
"fmt"
"strings"
"sync"
"ai-agent/internal/memory"
)
type Assembler struct {
embedder *Embedder
convStore *Store
memStore *memory.Store
budgetCfg BudgetConfig
sessionID string
}
func (a *Assembler) Assemble(ctx context.Context, query string) (string, error) {
budget := a.budgetCfg.Calculate(0)
type convResult struct {
chunks []ContextChunk
err error
}
type memResult struct {
chunks []ContextChunk
}
convCh := make(chan convResult, 1)
memCh := make(chan memResult, 1)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
chunks, err := a.retrieveConversations(ctx, query, budget.Conversation)
convCh <- convResult{chunks: chunks, err: err}
}()
go func() {
defer wg.Done()
chunks := a.retrieveMemories(query, budget.Memory)
memCh <- memResult{chunks: chunks}
}()
wg.Wait()
close(convCh)
close(memCh)
cr := <-convCh
mr := <-memCh
if cr.err != nil {
return "", fmt.Errorf("conversation retrieval: %w", cr.err)
}
return formatContext(cr.chunks, mr.chunks), nil
}
func (a *Assembler) retrieveConversations(ctx context.Context, query string, tokenBudget int) ([]ContextChunk, error) {
if tokenBudget <= 0 {
return nil, nil
}
queryEmb, err := a.embedder.Embed(ctx, query)
if err != nil {
return nil, err
}
results := a.convStore.Search(queryEmb, a.sessionID, 20)
var chunks []ContextChunk
usedTokens := 0
for _, r := range results {
tokens := estimateTokens(r.Entry.Content)
if usedTokens+tokens > tokenBudget {
continue
}
chunks = append(chunks, ContextChunk{
Source: SourceConversation,
Content: r.Entry.Content,
Score: r.Score,
Tokens: tokens,
})
usedTokens += tokens
}
return chunks, nil
}
func (a *Assembler) retrieveMemories(query string, tokenBudget int) []ContextChunk {
if a.memStore == nil || tokenBudget <= 0 {
return nil
}
memories := a.memStore.Recall(query, 10)
var chunks []ContextChunk
usedTokens := 0
for _, m := range memories {
tokens := estimateTokens(m.Content)
if usedTokens+tokens > tokenBudget {
continue
}
content := m.Content
if len(m.Tags) > 0 {
content += " [" + strings.Join(m.Tags, ", ") + "]"
}
chunks = append(chunks, ContextChunk{
Source: SourceMemory,
Content: content,
Tokens: tokens,
})
usedTokens += tokens
}
return chunks
}
func formatContext(convChunks, memChunks []ContextChunk) string {
var sb strings.Builder
if len(convChunks) > 0 {
sb.WriteString("\n## Relevant Past Conversations\n\n")
for _, c := range convChunks {
sb.WriteString("- ")
sb.WriteString(c.Content)
sb.WriteString("\n")
}
}
if len(memChunks) > 0 {
sb.WriteString("\n## Remembered Facts\n\n")
for _, c := range memChunks {
sb.WriteString("- ")
sb.WriteString(c.Content)
sb.WriteString("\n")
}
}
return sb.String()
}

View File

@ -0,0 +1,88 @@
package ice
import (
"strings"
"testing"
)
func TestFormatContext(t *testing.T) {
tests := []struct {
name string
convChunks []ContextChunk
memChunks []ContextChunk
wantConv bool // should contain "Relevant Past Conversations"
wantMem bool // should contain "Remembered Facts"
wantEmpty bool
}{
{
name: "both conversation and memory chunks",
convChunks: []ContextChunk{
{Source: SourceConversation, Content: "past chat about Go"},
},
memChunks: []ContextChunk{
{Source: SourceMemory, Content: "user prefers dark mode"},
},
wantConv: true,
wantMem: true,
},
{
name: "conversations only",
convChunks: []ContextChunk{
{Source: SourceConversation, Content: "previous discussion"},
},
memChunks: nil,
wantConv: true,
wantMem: false,
},
{
name: "memories only",
convChunks: nil,
memChunks: []ContextChunk{
{Source: SourceMemory, Content: "user name is Alice"},
},
wantConv: false,
wantMem: true,
},
{
name: "both empty",
convChunks: nil,
memChunks: nil,
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := formatContext(tt.convChunks, tt.memChunks)
if tt.wantEmpty {
if got != "" {
t.Errorf("expected empty string, got %q", got)
}
return
}
hasConv := strings.Contains(got, "Relevant Past Conversations")
hasMem := strings.Contains(got, "Remembered Facts")
if hasConv != tt.wantConv {
t.Errorf("has conversations section = %v, want %v", hasConv, tt.wantConv)
}
if hasMem != tt.wantMem {
t.Errorf("has memories section = %v, want %v", hasMem, tt.wantMem)
}
// Verify content is present in output.
for _, c := range tt.convChunks {
if !strings.Contains(got, c.Content) {
t.Errorf("output missing conversation content %q", c.Content)
}
}
for _, c := range tt.memChunks {
if !strings.Contains(got, c.Content) {
t.Errorf("output missing memory content %q", c.Content)
}
}
})
}
}

View File

@ -0,0 +1,76 @@
package ice
import (
"context"
"fmt"
"strings"
"ai-agent/internal/llm"
"ai-agent/internal/memory"
)
var autoMemorySystemPrompt = "Extract any important facts, user preferences, decisions, or action items from this exchange.\n" +
"Output one item per line in the format: TYPE: content\n" +
"Where TYPE is one of: FACT, DECISION, PREFERENCE, TODO\n" +
"If there is nothing worth remembering, output exactly: NONE"
var autoMemoryUserTemplate = "User: %s\nAssistant: %s"
type AutoMemory struct {
client llm.Client
memStore *memory.Store
}
func (am *AutoMemory) Detect(ctx context.Context, userMsg, assistantMsg string) error {
if am.memStore == nil {
return nil
}
if len(userMsg) < 20 && len(assistantMsg) < 50 {
return nil
}
prompt := fmt.Sprintf(autoMemoryUserTemplate, userMsg, assistantMsg)
var response strings.Builder
err := am.client.ChatStream(ctx, llm.ChatOptions{
System: autoMemorySystemPrompt,
Messages: []llm.Message{
{Role: "user", Content: prompt},
},
}, func(chunk llm.StreamChunk) error {
response.WriteString(chunk.Text)
return nil
})
if err != nil {
return fmt.Errorf("auto-memory LLM call: %w", err)
}
return am.parseAndSave(response.String())
}
func (am *AutoMemory) parseAndSave(response string) error {
lines := strings.Split(strings.TrimSpace(response), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.EqualFold(line, "NONE") {
continue
}
parts := strings.SplitN(line, ": ", 2)
if len(parts) != 2 {
continue
}
typeName := strings.TrimSpace(parts[0])
content := strings.TrimSpace(parts[1])
if content == "" {
continue
}
tag := strings.ToLower(typeName)
switch tag {
case "fact", "decision", "preference", "todo":
// Valid type.
default:
continue
}
if _, err := am.memStore.Save(content, []string{tag, "auto"}); err != nil {
return fmt.Errorf("save auto-memory: %w", err)
}
}
return nil
}

View File

@ -0,0 +1,104 @@
package ice
import (
"path/filepath"
"testing"
"ai-agent/internal/memory"
)
func TestParseAndSave(t *testing.T) {
tests := []struct {
name string
input string
wantCount int
wantTags [][]string
}{
{
name: "valid FACT and DECISION lines",
input: "FACT: user likes Go\nDECISION: use postgres",
wantCount: 2,
wantTags: [][]string{{"fact", "auto"}, {"decision", "auto"}},
},
{
name: "NONE saves nothing",
input: "NONE",
wantCount: 0,
},
{
name: "empty lines are skipped",
input: "\n\n\n",
wantCount: 0,
},
{
name: "invalid type is skipped",
input: "UNKNOWN: something",
wantCount: 0,
},
{
name: "missing colon format is skipped",
input: "this has no colon",
wantCount: 0,
},
{
name: "PREFERENCE type",
input: "PREFERENCE: dark mode",
wantCount: 1,
wantTags: [][]string{{"preference", "auto"}},
},
{
name: "TODO type",
input: "TODO: fix the bug",
wantCount: 1,
wantTags: [][]string{{"todo", "auto"}},
},
{
name: "mixed valid and invalid",
input: "FACT: real fact\nBAD: not valid\nTODO: real todo",
wantCount: 2,
wantTags: [][]string{{"fact", "auto"}, {"todo", "auto"}},
},
{
name: "empty content after type is skipped",
input: "FACT: ",
wantCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
memPath := filepath.Join(dir, "memories.json")
ms := memory.NewStore(memPath)
am := &AutoMemory{memStore: ms}
err := am.parseAndSave(tt.input)
if err != nil {
t.Fatalf("parseAndSave returned error: %v", err)
}
if ms.Count() != tt.wantCount {
t.Errorf("memory count = %d, want %d", ms.Count(), tt.wantCount)
}
if tt.wantTags != nil {
recent := ms.Recent(tt.wantCount)
for i, j := 0, len(recent)-1; i < j; i, j = i+1, j-1 {
recent[i], recent[j] = recent[j], recent[i]
}
for i, wantTags := range tt.wantTags {
if i >= len(recent) {
t.Errorf("missing memory at index %d", i)
continue
}
got := recent[i].Tags
if len(got) != len(wantTags) {
t.Errorf("memory[%d] tags = %v, want %v", i, got, wantTags)
continue
}
for j := range wantTags {
if got[j] != wantTags[j] {
t.Errorf("memory[%d] tag[%d] = %q, want %q", i, j, got[j], wantTags[j])
}
}
}
}
})
}
}

54
internal/ice/budget.go Normal file
View File

@ -0,0 +1,54 @@
package ice
// BudgetConfig controls how the context window is divided among sources.
type BudgetConfig struct {
NumCtx int
SystemReserve int // tokens reserved for system prompt
RecentReserve int // tokens reserved for recent conversation
ConversationPct float64 // fraction of remaining budget for past conversations
MemoryPct float64 // fraction of remaining budget for memories
CodePct float64 // fraction of remaining budget for code context
}
// DefaultBudgetConfig returns sensible defaults for a given context window.
func DefaultBudgetConfig(numCtx int) BudgetConfig {
return BudgetConfig{
NumCtx: numCtx,
SystemReserve: 1500,
RecentReserve: 2000,
ConversationPct: 0.40,
MemoryPct: 0.20,
CodePct: 0.40,
}
}
// Calculate allocates token budgets given how many tokens the current prompt uses.
func (bc BudgetConfig) Calculate(promptTokens int) Budget {
// Use 75% of numCtx as total available.
available := int(float64(bc.NumCtx) * 0.75)
available -= bc.SystemReserve
available -= bc.RecentReserve
available -= promptTokens
if available < 0 {
available = 0
}
return Budget{
Total: available,
System: bc.SystemReserve,
Recent: bc.RecentReserve,
Conversation: int(float64(available) * bc.ConversationPct),
Memory: int(float64(available) * bc.MemoryPct),
Code: int(float64(available) * bc.CodePct),
}
}
// estimateTokens returns a rough token count for a string (chars / 4).
func estimateTokens(s string) int {
n := len(s) / 4
if n == 0 && len(s) > 0 {
n = 1
}
return n
}

133
internal/ice/budget_test.go Normal file
View File

@ -0,0 +1,133 @@
package ice
import "testing"
func TestBudgetConfig_Calculate(t *testing.T) {
tests := []struct {
name string
cfg BudgetConfig
promptTokens int
wantTotal int
wantConv int
wantMemory int
wantCode int
}{
{
name: "normal allocation",
cfg: BudgetConfig{
NumCtx: 8192,
SystemReserve: 1500,
RecentReserve: 2000,
ConversationPct: 0.40,
MemoryPct: 0.20,
CodePct: 0.40,
},
promptTokens: 500,
// available = int(8192*0.75) - 1500 - 2000 - 500 = 6144 - 4000 = 2144
wantTotal: 2144,
wantConv: 857, // int(2144 * 0.40) = 857
wantMemory: 428, // int(2144 * 0.20) = 428
wantCode: 857, // int(2144 * 0.40) = 857
},
{
name: "large prompt clamps to zero",
cfg: BudgetConfig{
NumCtx: 8192,
SystemReserve: 1500,
RecentReserve: 2000,
ConversationPct: 0.40,
MemoryPct: 0.20,
CodePct: 0.40,
},
promptTokens: 99999,
wantTotal: 0,
wantConv: 0,
wantMemory: 0,
wantCode: 0,
},
{
name: "exact boundary available is zero",
cfg: BudgetConfig{
NumCtx: 8192,
SystemReserve: 1500,
RecentReserve: 2000,
ConversationPct: 0.40,
MemoryPct: 0.20,
CodePct: 0.40,
},
// int(8192*0.75) - 1500 - 2000 = 2644
promptTokens: 2644,
wantTotal: 0,
wantConv: 0,
wantMemory: 0,
wantCode: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := tt.cfg.Calculate(tt.promptTokens)
if b.Total != tt.wantTotal {
t.Errorf("Total = %d, want %d", b.Total, tt.wantTotal)
}
if b.Conversation != tt.wantConv {
t.Errorf("Conversation = %d, want %d", b.Conversation, tt.wantConv)
}
if b.Memory != tt.wantMemory {
t.Errorf("Memory = %d, want %d", b.Memory, tt.wantMemory)
}
if b.Code != tt.wantCode {
t.Errorf("Code = %d, want %d", b.Code, tt.wantCode)
}
if b.System != tt.cfg.SystemReserve {
t.Errorf("System = %d, want %d", b.System, tt.cfg.SystemReserve)
}
if b.Recent != tt.cfg.RecentReserve {
t.Errorf("Recent = %d, want %d", b.Recent, tt.cfg.RecentReserve)
}
})
}
}
func TestEstimateTokens(t *testing.T) {
tests := []struct {
name string
input string
want int
}{
{
name: "len/4 heuristic",
input: "hello world",
want: 2, // 11/4 = 2
},
{
name: "single char clamps to 1",
input: "a",
want: 1, // 1/4 = 0, clamp to 1
},
{
name: "empty string",
input: "",
want: 0,
},
{
name: "exactly 4 chars",
input: "abcd",
want: 1, // 4/4 = 1
},
{
name: "three chars clamps to 1",
input: "abc",
want: 1, // 3/4 = 0, clamp to 1
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateTokens(tt.input)
if got != tt.want {
t.Errorf("estimateTokens(%q) = %d, want %d", tt.input, got, tt.want)
}
})
}
}

56
internal/ice/embed.go Normal file
View File

@ -0,0 +1,56 @@
package ice
import (
"context"
"fmt"
"ai-agent/internal/llm"
)
const (
defaultEmbedModel = "nomic-embed-text"
maxBatchSize = 32
)
type Embedder struct {
client llm.Client
model string
}
func NewEmbedder(client llm.Client, model string) *Embedder {
if model == "" {
model = defaultEmbedModel
}
return &Embedder{client: client, model: model}
}
func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) {
vecs, err := e.EmbedBatch(ctx, []string{text})
if err != nil {
return nil, err
}
if len(vecs) == 0 {
return nil, fmt.Errorf("empty embedding response")
}
return vecs[0], nil
}
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
var all [][]float32
for i := 0; i < len(texts); i += maxBatchSize {
end := i + maxBatchSize
if end > len(texts) {
end = len(texts)
}
batch := texts[i:end]
vecs, err := e.client.Embed(ctx, e.model, batch)
if err != nil {
return nil, fmt.Errorf("embed batch [%d:%d]: %w", i, end, err)
}
all = append(all, vecs...)
}
return all, nil
}

113
internal/ice/engine.go Normal file
View File

@ -0,0 +1,113 @@
package ice
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"ai-agent/internal/llm"
"ai-agent/internal/memory"
)
type EngineConfig struct {
EmbedModel string
StorePath string
NumCtx int
}
type Engine struct {
embedder *Embedder
store *Store
memStore *memory.Store
budgetCfg BudgetConfig
sessionID string
turnIndex int
autoMemory *AutoMemory
}
func NewEngine(client llm.Client, memStore *memory.Store, cfg EngineConfig) (*Engine, error) {
storePath := cfg.StorePath
if storePath == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("determine home dir: %w", err)
}
storePath = filepath.Join(home, ".config", "ai-agent", "conversations.json")
}
embedModel := cfg.EmbedModel
if embedModel == "" {
embedModel = defaultEmbedModel
}
sessionID := fmt.Sprintf("s_%d", time.Now().UnixNano())
return &Engine{
embedder: NewEmbedder(client, embedModel),
store: NewStore(storePath),
memStore: memStore,
budgetCfg: DefaultBudgetConfig(cfg.NumCtx),
sessionID: sessionID,
autoMemory: &AutoMemory{client: client, memStore: memStore},
}, nil
}
func (e *Engine) AssembleContext(ctx context.Context, query string) (string, error) {
a := &Assembler{
embedder: e.embedder,
convStore: e.store,
memStore: e.memStore,
budgetCfg: e.budgetCfg,
sessionID: e.sessionID,
}
return a.Assemble(ctx, query)
}
func (e *Engine) IndexMessage(ctx context.Context, role, content string) error {
if content == "" {
return nil
}
text := content
if len(text) > 2000 {
text = text[:2000]
}
emb, err := e.embedder.Embed(ctx, text)
if err != nil {
return fmt.Errorf("embed message: %w", err)
}
e.turnIndex++
_, err = e.store.Add(e.sessionID, role, text, emb, e.turnIndex)
return err
}
func (e *Engine) IndexSummary(ctx context.Context, summary string) error {
if summary == "" {
return nil
}
emb, err := e.embedder.Embed(ctx, summary)
if err != nil {
return fmt.Errorf("embed summary: %w", err)
}
_, err = e.store.Add(e.sessionID, "summary", summary, emb, e.turnIndex)
return err
}
func (e *Engine) DetectAutoMemory(ctx context.Context, userMsg, assistantMsg string) {
if e.autoMemory == nil {
return
}
go func() {
_ = e.autoMemory.Detect(ctx, userMsg, assistantMsg)
}()
}
func (e *Engine) Flush() error {
return e.store.Flush()
}
func (e *Engine) Store() *Store {
return e.store
}
func (e *Engine) SessionID() string {
return e.sessionID
}

View File

@ -0,0 +1,73 @@
package ice
import (
"testing"
)
func TestEngineConfigDefaults(t *testing.T) {
// Test embed model default
embedModel := ""
if embedModel == "" {
embedModel = defaultEmbedModel
}
if embedModel != defaultEmbedModel {
t.Errorf("embedModel = %q, want %q", embedModel, defaultEmbedModel)
}
// Test custom embed model
cfg := EngineConfig{
EmbedModel: "custom-model",
}
if cfg.EmbedModel != "custom-model" {
t.Errorf("EmbedModel = %q, want %q", cfg.EmbedModel, "custom-model")
}
}
func TestBudgetConfigCalculate(t *testing.T) {
cfg := DefaultBudgetConfig(16384)
budget := cfg.Calculate(100)
// 16384 * 0.75 = 12288
// 12288 - 1500 - 2000 - 100 = 8688
if budget.Total != 8688 {
t.Errorf("Total = %d, want %d", budget.Total, 8688)
}
if budget.System != 1500 {
t.Errorf("System = %d, want %d", budget.System, 1500)
}
if budget.Recent != 2000 {
t.Errorf("Recent = %d, want %d", budget.Recent, 2000)
}
}
func TestBudgetConfigCalculateNegative(t *testing.T) {
// With small context, should not panic and return zeros
cfg := DefaultBudgetConfig(1000)
budget := cfg.Calculate(500)
// 1000 * 0.75 = 750
// 750 - 1500 - 2000 - 500 = -3250 -> clamped to 0
if budget.Total != 0 {
t.Errorf("Total should be 0 when budget is negative, got %d", budget.Total)
}
}
func TestBudgetConfigPercentages(t *testing.T) {
cfg := DefaultBudgetConfig(16384)
budget := cfg.Calculate(100)
// Check percentages: ConversationPct=0.40, MemoryPct=0.20, CodePct=0.40
// available = 12288 - 1500 - 2000 - 100 = 8688
// Conversation = 8688 * 0.40 = 3475
// Memory = 8688 * 0.20 = 1737
// Code = 8688 * 0.40 = 3475
if budget.Conversation != 3475 {
t.Errorf("Conversation = %d, want %d", budget.Conversation, 3475)
}
if budget.Memory != 1737 {
t.Errorf("Memory = %d, want %d", budget.Memory, 1737)
}
if budget.Code != 3475 {
t.Errorf("Code = %d, want %d", budget.Code, 3475)
}
}

167
internal/ice/store.go Normal file
View File

@ -0,0 +1,167 @@
package ice
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"sort"
"sync"
"time"
)
// timeNow is a variable for testing.
var timeNow = time.Now
const minSimilarityThreshold = 0.3
// Store is a flat-file vector store for conversation history.
// It holds all entries in memory and persists to a JSON file.
type Store struct {
mu sync.Mutex
path string
entries []ConversationEntry
nextID int
dirty bool
}
// NewStore loads an existing store from path or creates an empty one.
func NewStore(path string) *Store {
s := &Store{path: path}
s.load()
return s
}
// Add appends a new conversation entry and returns its ID.
func (s *Store) Add(sessionID, role, content string, embedding []float32, turnIndex int) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.nextID++
entry := ConversationEntry{
ID: s.nextID,
SessionID: sessionID,
Role: role,
Content: content,
Embedding: embedding,
TurnIndex: turnIndex,
}
// Use a zero-value check to set CreatedAt (avoids importing time in every call site).
entry.CreatedAt = timeNow()
s.entries = append(s.entries, entry)
s.dirty = true
return s.nextID, nil
}
// Search returns the top-K entries most similar to queryEmbedding.
// Entries from excludeSession are skipped. Results are sorted by score descending.
func (s *Store) Search(queryEmbedding []float32, excludeSession string, topK int) []ScoredEntry {
s.mu.Lock()
defer s.mu.Unlock()
if len(queryEmbedding) == 0 || len(s.entries) == 0 {
return nil
}
var scored []ScoredEntry
for _, e := range s.entries {
if e.SessionID == excludeSession {
continue
}
if len(e.Embedding) == 0 {
continue
}
sim := cosineSimilarity(queryEmbedding, e.Embedding)
if sim >= minSimilarityThreshold {
scored = append(scored, ScoredEntry{Entry: e, Score: sim})
}
}
sort.Slice(scored, func(i, j int) bool {
return scored[i].Score > scored[j].Score
})
if len(scored) > topK {
scored = scored[:topK]
}
return scored
}
// Flush persists any pending changes to disk.
func (s *Store) Flush() error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.dirty {
return nil
}
return s.persist()
}
// Count returns the total number of stored entries.
func (s *Store) Count() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.entries)
}
// load reads entries from the JSON file.
func (s *Store) load() {
data, err := os.ReadFile(s.path)
if err != nil {
return // File doesn't exist yet.
}
var entries []ConversationEntry
if err := json.Unmarshal(data, &entries); err != nil {
return // Corrupt file, start empty.
}
s.entries = entries
for _, e := range s.entries {
if e.ID > s.nextID {
s.nextID = e.ID
}
}
}
// persist writes all entries to the JSON file.
func (s *Store) persist() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("create ice store dir: %w", err)
}
data, err := json.Marshal(s.entries)
if err != nil {
return fmt.Errorf("marshal ice store: %w", err)
}
if err := os.WriteFile(s.path, data, 0o644); err != nil {
return fmt.Errorf("write ice store: %w", err)
}
s.dirty = false
return nil
}
// cosineSimilarity computes the cosine similarity between two vectors.
func cosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) || len(a) == 0 {
return 0
}
var dot, normA, normB float64
for i := range a {
dot += float64(a[i]) * float64(b[i])
normA += float64(a[i]) * float64(a[i])
normB += float64(b[i]) * float64(b[i])
}
denom := math.Sqrt(normA) * math.Sqrt(normB)
if denom == 0 {
return 0
}
return float32(dot / denom)
}

232
internal/ice/store_test.go Normal file
View File

@ -0,0 +1,232 @@
package ice
import (
"math"
"os"
"path/filepath"
"testing"
)
func TestCosineSimilarity(t *testing.T) {
tests := []struct {
name string
a, b []float32
want float32
tol float32
}{
{
name: "identical vectors",
a: []float32{1, 2, 3},
b: []float32{1, 2, 3},
want: 1.0,
tol: 1e-6,
},
{
name: "orthogonal vectors",
a: []float32{1, 0},
b: []float32{0, 1},
want: 0.0,
tol: 1e-6,
},
{
name: "opposite vectors",
a: []float32{1, 0},
b: []float32{-1, 0},
want: -1.0,
tol: 1e-6,
},
{
name: "different lengths returns 0",
a: []float32{1, 0},
b: []float32{1, 0, 0},
want: 0,
tol: 0,
},
{
name: "zero vector returns 0",
a: []float32{0, 0},
b: []float32{1, 1},
want: 0,
tol: 0,
},
{
name: "known value with tolerance",
a: []float32{1, 1},
b: []float32{1, 0},
// 1/(sqrt(2)*1) ≈ 0.7071
want: float32(1.0 / math.Sqrt(2)),
tol: 1e-4,
},
{
name: "empty vectors",
a: []float32{},
b: []float32{},
want: 0,
tol: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cosineSimilarity(tt.a, tt.b)
diff := got - tt.want
if diff < 0 {
diff = -diff
}
if diff > tt.tol {
t.Errorf("cosineSimilarity(%v, %v) = %f, want %f (±%f)",
tt.a, tt.b, got, tt.want, tt.tol)
}
})
}
}
func TestStore_Add_And_Count(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s := NewStore(path)
if s.Count() != 0 {
t.Fatalf("new store Count = %d, want 0", s.Count())
}
id1, err := s.Add("sess1", "user", "hello", []float32{1, 0, 0}, 0)
if err != nil {
t.Fatalf("Add returned error: %v", err)
}
if id1 != 1 {
t.Errorf("first Add returned id=%d, want 1", id1)
}
if s.Count() != 1 {
t.Errorf("Count after first Add = %d, want 1", s.Count())
}
id2, err := s.Add("sess1", "assistant", "world", []float32{0, 1, 0}, 1)
if err != nil {
t.Fatalf("Add returned error: %v", err)
}
if id2 != 2 {
t.Errorf("second Add returned id=%d, want 2", id2)
}
if s.Count() != 2 {
t.Errorf("Count after second Add = %d, want 2", s.Count())
}
}
func TestStore_Search(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s := NewStore(path)
// Add entries with known embeddings.
s.Add("sess1", "user", "entry A", []float32{1, 0, 0}, 0)
s.Add("sess1", "user", "entry B", []float32{0, 1, 0}, 1)
s.Add("sess2", "user", "entry C", []float32{0.9, 0.1, 0}, 0)
s.Add("sess2", "user", "entry D", []float32{0, 0, 1}, 1) // orthogonal to query
t.Run("similarity filtering and sorting", func(t *testing.T) {
// Query similar to entries A and C, exclude no session.
results := s.Search([]float32{1, 0, 0}, "", 10)
if len(results) == 0 {
t.Fatal("expected results, got 0")
}
// Entry A should be highest (identical to query).
if results[0].Entry.Content != "entry A" {
t.Errorf("top result = %q, want 'entry A'", results[0].Entry.Content)
}
})
t.Run("session exclusion", func(t *testing.T) {
results := s.Search([]float32{1, 0, 0}, "sess1", 10)
for _, r := range results {
if r.Entry.SessionID == "sess1" {
t.Errorf("excluded session sess1 should not appear in results")
}
}
})
t.Run("min threshold 0.3", func(t *testing.T) {
// Entry D: [0,0,1] is orthogonal to [1,0,0] → similarity 0.
results := s.Search([]float32{1, 0, 0}, "", 10)
for _, r := range results {
if r.Score < minSimilarityThreshold {
t.Errorf("result %q has score %f below threshold %f",
r.Entry.Content, r.Score, minSimilarityThreshold)
}
}
})
t.Run("topK limit", func(t *testing.T) {
results := s.Search([]float32{1, 0, 0}, "", 1)
if len(results) > 1 {
t.Errorf("topK=1 but got %d results", len(results))
}
})
t.Run("empty store returns nil", func(t *testing.T) {
emptyPath := filepath.Join(dir, "empty.json")
empty := NewStore(emptyPath)
results := empty.Search([]float32{1, 0}, "", 5)
if results != nil {
t.Errorf("empty store search should return nil, got %v", results)
}
})
t.Run("empty query embedding returns nil", func(t *testing.T) {
results := s.Search([]float32{}, "", 5)
if results != nil {
t.Errorf("empty query should return nil, got %v", results)
}
})
}
func TestStore_Flush_Persistence(t *testing.T) {
t.Run("round trip", func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s1 := NewStore(path)
s1.Add("sess1", "user", "hello", []float32{1, 0}, 0)
s1.Add("sess1", "assistant", "world", []float32{0, 1}, 1)
if err := s1.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
// Reload from same path.
s2 := NewStore(path)
if s2.Count() != 2 {
t.Errorf("reloaded store Count = %d, want 2", s2.Count())
}
})
t.Run("corrupt JSON recovery", func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
// Write corrupt JSON.
os.WriteFile(path, []byte("not valid json{{{"), 0o644)
s := NewStore(path)
if s.Count() != 0 {
t.Errorf("corrupt store Count = %d, want 0", s.Count())
}
})
t.Run("nextID restoration", func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s1 := NewStore(path)
s1.Add("sess1", "user", "first", []float32{1}, 0)
s1.Add("sess1", "user", "second", []float32{1}, 1)
s1.Flush()
s2 := NewStore(path)
id, _ := s2.Add("sess1", "user", "third", []float32{1}, 2)
if id != 3 {
t.Errorf("continued id = %d, want 3", id)
}
})
}

46
internal/ice/types.go Normal file
View File

@ -0,0 +1,46 @@
package ice
import "time"
// SourceKind identifies where a context chunk came from.
type SourceKind int
const (
SourceConversation SourceKind = iota
SourceMemory
)
// ConversationEntry is a single stored message with its embedding.
type ConversationEntry struct {
ID int `json:"id"`
SessionID string `json:"session_id"`
Role string `json:"role"` // "user", "assistant", "summary"
Content string `json:"content"`
Embedding []float32 `json:"embedding"`
CreatedAt time.Time `json:"created_at"`
TurnIndex int `json:"turn_index"`
}
// ScoredEntry pairs a conversation entry with its similarity score.
type ScoredEntry struct {
Entry ConversationEntry
Score float32
}
// ContextChunk is a piece of assembled context ready for the prompt.
type ContextChunk struct {
Source SourceKind
Content string
Score float32
Tokens int
}
// Budget holds the token allocation for each context source.
type Budget struct {
Total int
System int
Conversation int
Memory int
Code int
Recent int
}

184
internal/initcmd/initcmd.go Normal file
View File

@ -0,0 +1,184 @@
package initcmd
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
)
// projectMarker maps a marker file name to its detected project type.
var projectMarkers = map[string]string{
"go.mod": "Go",
"go.sum": "Go",
"package.json": "Node.js",
"Cargo.toml": "Rust",
"pyproject.toml": "Python",
"requirements.txt": "Python",
"setup.py": "Python",
"Pipfile": "Python",
"Gemfile": "Ruby",
"pom.xml": "Java (Maven)",
"build.gradle": "Java (Gradle)",
"build.gradle.kts": "Kotlin (Gradle)",
"CMakeLists.txt": "C/C++ (CMake)",
"Makefile": "Make",
"Taskfile.yml": "Taskfile",
"Taskfile.yaml": "Taskfile",
"docker-compose.yml": "Docker Compose",
"docker-compose.yaml": "Docker Compose",
"Dockerfile": "Docker",
".gitignore": "Git",
}
// Options configures the behaviour of Run.
type Options struct {
// Force overwrites an existing AGENT.md.
Force bool
}
// Run scans dir for project markers and generates an AGENT.md file.
// It returns an error if AGENT.md already exists unless opts.Force is true.
func Run(dir string, opts Options) error {
agentPath := filepath.Join(dir, "AGENT.md")
if !opts.Force {
if _, err := os.Stat(agentPath); err == nil {
return fmt.Errorf("AGENT.md already exists in %s (use --force to overwrite)", dir)
}
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("reading directory: %w", err)
}
// Detect project types from marker files.
detectedTypes := detectProjectTypes(entries)
// Build directory listing.
listing := buildDirectoryListing(dir, entries)
// Generate AGENT.md content.
content := generateAgentMD(detectedTypes, listing)
if err := os.WriteFile(agentPath, []byte(content), 0644); err != nil {
return fmt.Errorf("writing AGENT.md: %w", err)
}
return nil
}
// detectProjectTypes returns a deduplicated, sorted list of project types
// found based on marker files in the directory entries.
func detectProjectTypes(entries []os.DirEntry) []string {
seen := make(map[string]bool)
for _, e := range entries {
if e.IsDir() {
continue
}
if pt, ok := projectMarkers[e.Name()]; ok {
seen[pt] = true
}
}
types := make([]string, 0, len(seen))
for t := range seen {
types = append(types, t)
}
sort.Strings(types)
return types
}
// buildDirectoryListing returns a formatted string of top-level files and
// first-level subdirectory contents.
func buildDirectoryListing(dir string, entries []os.DirEntry) string {
var b strings.Builder
var files []string
var dirs []string
for _, e := range entries {
name := e.Name()
// Skip hidden files/dirs except well-known ones.
if strings.HasPrefix(name, ".") && name != ".gitignore" {
continue
}
if e.IsDir() {
dirs = append(dirs, name)
} else {
files = append(files, name)
}
}
sort.Strings(files)
sort.Strings(dirs)
for _, f := range files {
b.WriteString(f)
b.WriteByte('\n')
}
for _, d := range dirs {
b.WriteString(d + "/\n")
subEntries, err := os.ReadDir(filepath.Join(dir, d))
if err != nil {
continue
}
var subNames []string
for _, se := range subEntries {
n := se.Name()
if strings.HasPrefix(n, ".") {
continue
}
if se.IsDir() {
subNames = append(subNames, n+"/")
} else {
subNames = append(subNames, n)
}
}
sort.Strings(subNames)
for _, sn := range subNames {
b.WriteString(" " + sn + "\n")
}
}
return b.String()
}
// generateAgentMD produces the Markdown content for the AGENT.md file.
func generateAgentMD(projectTypes []string, listing string) string {
var b strings.Builder
b.WriteString("# AGENT.md\n\n")
// Project type section.
b.WriteString("## Project Type\n\n")
if len(projectTypes) == 0 {
b.WriteString("Unknown\n")
} else {
b.WriteString(strings.Join(projectTypes, ", ") + "\n")
}
// Directory structure.
b.WriteString("\n## Directory Structure\n\n")
b.WriteString("```\n")
b.WriteString(listing)
b.WriteString("```\n")
// Placeholder sections.
b.WriteString("\n## Build Commands\n\n")
b.WriteString("<!-- Add build, test, and run commands here -->\n")
b.WriteString("\n## Architecture\n\n")
b.WriteString("<!-- Describe the high-level architecture here -->\n")
b.WriteString("\n## Key Files\n\n")
b.WriteString("<!-- List important files and their purposes here -->\n")
b.WriteString("\n## Notes\n\n")
b.WriteString("<!-- Any additional notes for the agent -->\n")
return b.String()
}

View File

@ -0,0 +1,141 @@
package initcmd
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestRun_GoProject(t *testing.T) {
dir := t.TempDir()
// Create a go.mod marker file.
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com/test\n"), 0644); err != nil {
t.Fatal(err)
}
// Create a source directory with a file.
if err := os.Mkdir(filepath.Join(dir, "cmd"), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "cmd", "main.go"), []byte("package main\n"), 0644); err != nil {
t.Fatal(err)
}
if err := Run(dir, Options{}); err != nil {
t.Fatalf("Run() returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(dir, "AGENT.md"))
if err != nil {
t.Fatalf("reading AGENT.md: %v", err)
}
content := string(data)
// Check project type detection.
if !strings.Contains(content, "Go") {
t.Error("expected AGENT.md to contain 'Go' project type")
}
// Check directory listing includes go.mod.
if !strings.Contains(content, "go.mod") {
t.Error("expected AGENT.md to list go.mod")
}
// Check directory listing includes cmd/ subdirectory.
if !strings.Contains(content, "cmd/") {
t.Error("expected AGENT.md to list cmd/ directory")
}
// Check that main.go appears under cmd/.
if !strings.Contains(content, "main.go") {
t.Error("expected AGENT.md to list main.go inside cmd/")
}
// Check placeholder sections exist.
for _, section := range []string{"## Build Commands", "## Architecture", "## Key Files", "## Notes"} {
if !strings.Contains(content, section) {
t.Errorf("expected AGENT.md to contain section %q", section)
}
}
}
func TestRun_EmptyDirectory(t *testing.T) {
dir := t.TempDir()
if err := Run(dir, Options{}); err != nil {
t.Fatalf("Run() returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(dir, "AGENT.md"))
if err != nil {
t.Fatalf("reading AGENT.md: %v", err)
}
content := string(data)
// With no marker files, project type should be "Unknown".
if !strings.Contains(content, "Unknown") {
t.Error("expected AGENT.md to contain 'Unknown' project type for empty dir")
}
}
func TestRun_ExistingAgentMD_NoOverwrite(t *testing.T) {
dir := t.TempDir()
agentPath := filepath.Join(dir, "AGENT.md")
original := "# Original content\n"
if err := os.WriteFile(agentPath, []byte(original), 0644); err != nil {
t.Fatal(err)
}
err := Run(dir, Options{})
if err == nil {
t.Fatal("expected error when AGENT.md already exists")
}
if !strings.Contains(err.Error(), "already exists") {
t.Errorf("expected 'already exists' in error, got: %v", err)
}
// Verify file was not modified.
data, err := os.ReadFile(agentPath)
if err != nil {
t.Fatal(err)
}
if string(data) != original {
t.Error("AGENT.md was unexpectedly modified")
}
}
func TestRun_Force(t *testing.T) {
dir := t.TempDir()
agentPath := filepath.Join(dir, "AGENT.md")
if err := os.WriteFile(agentPath, []byte("# Old content\n"), 0644); err != nil {
t.Fatal(err)
}
// Create a go.mod so the new content is distinguishable.
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test\n"), 0644); err != nil {
t.Fatal(err)
}
if err := Run(dir, Options{Force: true}); err != nil {
t.Fatalf("Run() with Force=true returned error: %v", err)
}
data, err := os.ReadFile(agentPath)
if err != nil {
t.Fatal(err)
}
content := string(data)
if strings.Contains(content, "Old content") {
t.Error("AGENT.md should have been overwritten with Force=true")
}
if !strings.Contains(content, "Go") {
t.Error("expected new AGENT.md to detect Go project type")
}
}

View File

@ -0,0 +1,230 @@
//go:build integration
// +build integration
package integration
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"ai-agent/internal/agent"
"ai-agent/internal/command"
"ai-agent/internal/config"
"ai-agent/internal/llm"
"ai-agent/internal/mcp"
"ai-agent/internal/tui"
tea "charm.land/bubbletea/v2"
)
func skipIfNoOllama(t *testing.T) {
if os.Getenv("OLLAMA_HOST") == "" {
os.Setenv("OLLAMA_HOST", "http://localhost:11434")
}
client := llm.NewClient(llm.Config{
BaseURL: os.Getenv("OLLAMA_HOST"),
Model: "qwen3.5:2b",
NumCtx: 262144,
})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := client.Ping(ctx); err != nil {
t.Skip("Ollama not available: skipping integration test")
}
}
func TestTUI_Initialization(t *testing.T) {
skipIfNoOllama(t)
reg := command.NewRegistry()
command.RegisterBuiltins(reg)
cfg := config.DefaultModelConfig()
router := config.NewRouter(&cfg)
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
modelManager.SetCurrentModel("qwen3.5:2b")
ag := agent.New(modelManager, mcp.NewRegistry(), cfg.Ollama.NumCtx)
ag.SetRouter(router)
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
m = updated.(*tui.Model)
if !m.Ready() {
t.Error("TUI should be ready after WindowSizeMsg")
}
}
func TestTUI_ScrollAnchorDuringStreaming(t *testing.T) {
skipIfNoOllama(t)
reg := command.NewRegistry()
command.RegisterBuiltins(reg)
cfg := config.DefaultModelConfig()
router := config.NewRouter(&cfg)
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
ag.SetRouter(router)
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
m = updated.(*tui.Model)
if !m.AnchorActive() {
t.Error("anchorActive should be true after initialization")
}
updated, _ = m.Update(tui.StreamTextMsg{Text: "Hello"})
m = updated.(*tui.Model)
if !m.AnchorActive() {
t.Error("anchorActive should remain true during streaming")
}
}
func TestTUI_OverlayRendering(t *testing.T) {
reg := command.NewRegistry()
command.RegisterBuiltins(reg)
cfg := config.DefaultModelConfig()
router := config.NewRouter(&cfg)
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
ag.SetRouter(router)
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
m = updated.(*tui.Model)
updated, _ = m.Update(tui.KeyPressMsg{Code: '?'})
m = updated.(*tui.Model)
view := m.View()
if view == nil {
t.Error("View should not be nil")
}
updated, _ = m.Update(tui.KeyPressMsg{Code: tea.KeyEscape})
m = updated.(*tui.Model)
}
func TestTUI_ToolCardRendering(t *testing.T) {
reg := command.NewRegistry()
command.RegisterBuiltins(reg)
cfg := config.DefaultModelConfig()
router := config.NewRouter(&cfg)
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
ag.SetRouter(router)
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
m = updated.(*tui.Model)
startTime := time.Now()
updated, _ = m.Update(tui.ToolCallStartMsg{
Name: "read_file",
Args: map[string]any{"path": "test.go"},
StartTime: startTime,
})
m = updated.(*tui.Model)
updated, _ = m.Update(tui.ToolCallResultMsg{
Name: "read_file",
Result: "file content",
IsError: false,
Duration: 100 * time.Millisecond,
})
m = updated.(*tui.Model)
view := m.View()
if view == nil {
t.Error("View should not be nil after tool execution")
}
}
func TestQwenRouter_Integration(t *testing.T) {
cfg := config.DefaultModelConfig()
router := config.NewQwenModelRouter(&cfg)
tests := []struct {
query string
mode config.ModeContext
expectSmaller string
expectLarger string
}{
{"what is go?", config.ModeAskContext, "qwen3.5:2b", ""},
{"design architecture", config.ModeBuildContext, "", "qwen3.5:4b"},
{"plan the system", config.ModePlanContext, "", "qwen3.5:4b"},
}
for _, tt := range tests {
t.Run(tt.query, func(t *testing.T) {
model := router.SelectModelForMode(tt.query, tt.mode)
if tt.expectSmaller != "" {
if modelRank(model) > modelRank(tt.expectSmaller) {
t.Errorf("model %s is larger than expected %s", model, tt.expectSmaller)
}
}
if tt.expectLarger != "" {
if modelRank(model) < modelRank(tt.expectLarger) {
t.Errorf("model %s is smaller than expected %s", model, tt.expectLarger)
}
}
})
}
}
func modelRank(model string) int {
switch {
case strings.Contains(model, "0.8b"):
return 1
case strings.Contains(model, "2b"):
return 2
case strings.Contains(model, "4b"):
return 3
case strings.Contains(model, "9b"):
return 4
default:
return 0
}
}
func TestFileOperations_Integration(t *testing.T) {
skipIfNoOllama(t)
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
if err := os.WriteFile(testFile, []byte("hello world"), 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
}
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("failed to read test file: %v", err)
}
if string(content) != "hello world" {
t.Errorf("unexpected file content: %q", string(content))
}
}
func BenchmarkTUI_Render(b *testing.B) {
reg := command.NewRegistry()
command.RegisterBuiltins(reg)
cfg := config.DefaultModelConfig()
router := config.NewRouter(&cfg)
modelManager := llm.NewModelManager("http://localhost:11434", 262144)
ag := agent.New(modelManager, mcp.NewRegistry(), 262144)
ag.SetRouter(router)
completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil)
m := tui.New(ag, reg, nil, completer, modelManager, router, nil)
updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
m = updated.(*tui.Model)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = m.View()
}
}
func BenchmarkQwenRouter_Classification(b *testing.B) {
cfg := config.DefaultModelConfig()
router := config.NewQwenModelRouter(&cfg)
queries := []string{
"what is go",
"how do i create a file",
"debug this nil pointer error",
"design microservice architecture",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, q := range queries {
_ = router.SelectModelForMode(q, config.ModeAskContext)
}
}
}

58
internal/llm/client.go Normal file
View File

@ -0,0 +1,58 @@
package llm
import "context"
// Client is the interface for LLM providers.
type Client interface {
// ChatStream sends messages to the LLM and streams the response.
// The callback is called for each chunk. Return a non-nil error to abort.
ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error
// Ping checks if the LLM is reachable and the model is available.
Ping() error
// Model returns the current model name.
Model() string
// Embed generates embeddings for the given texts using the specified model.
Embed(ctx context.Context, model string, texts []string) ([][]float32, error)
}
// ChatOptions holds parameters for a chat request.
type ChatOptions struct {
Messages []Message
Tools []ToolDef
System string
}
// Message represents a conversation message.
type Message struct {
Role string `json:"role"` // system, user, assistant, tool
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// StreamChunk is a piece of a streaming response.
type StreamChunk struct {
Text string // incremental text content
ToolCalls []ToolCall // tool calls (usually in final chunk)
Done bool // true on the last chunk
EvalCount int // tokens generated (only on Done)
PromptEvalCount int // prompt tokens evaluated (only on Done)
}
// ToolCall represents a tool invocation requested by the LLM.
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
// ToolDef defines a tool the LLM can call.
type ToolDef struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]any `json:"parameters"` // JSON Schema
}

163
internal/llm/manager.go Normal file
View File

@ -0,0 +1,163 @@
package llm
import (
"context"
"fmt"
"sync"
)
type ModelManager struct {
baseURL string
numCtx int
clients map[string]*OllamaClient
currentModel string
mu sync.RWMutex
}
var _ Client = (*ModelManager)(nil)
func NewModelManager(baseURL string, numCtx int) *ModelManager {
return &ModelManager{
baseURL: baseURL,
numCtx: numCtx,
clients: make(map[string]*OllamaClient),
}
}
func (m *ModelManager) GetClient(modelName string) (*OllamaClient, error) {
m.mu.RLock()
client, exists := m.clients[modelName]
m.mu.RUnlock()
if exists {
return client, nil
}
m.mu.Lock()
defer m.mu.Unlock()
if client, exists := m.clients[modelName]; exists {
return client, nil
}
client, err := NewOllamaClient(m.baseURL, modelName, m.numCtx)
if err != nil {
return nil, fmt.Errorf("create client for %s: %w", modelName, err)
}
m.clients[modelName] = client
return client, nil
}
func (m *ModelManager) SetCurrentModel(model string) error {
m.mu.Lock()
defer m.mu.Unlock()
client, err := NewOllamaClient(m.baseURL, model, m.numCtx)
if err != nil {
return fmt.Errorf("create client for %s: %w", model, err)
}
m.clients[model] = client
m.currentModel = model
return nil
}
func (m *ModelManager) CurrentModel() string {
m.mu.RLock()
defer m.mu.RUnlock()
return m.currentModel
}
func (m *ModelManager) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error {
m.mu.RLock()
model := m.currentModel
m.mu.RUnlock()
if model == "" {
return fmt.Errorf("no model selected")
}
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.ChatStream(ctx, opts, fn)
}
func (m *ModelManager) ChatStreamForModel(ctx context.Context, model string, opts ChatOptions, fn func(StreamChunk) error) error {
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.ChatStream(ctx, opts, fn)
}
func (m *ModelManager) Ping() error {
m.mu.RLock()
model := m.currentModel
m.mu.RUnlock()
if model == "" {
return fmt.Errorf("no model selected")
}
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.Ping()
}
func (m *ModelManager) PingModel(model string) error {
client, err := m.GetClient(model)
if err != nil {
return err
}
return client.Ping()
}
func (m *ModelManager) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) {
client, err := m.GetClient(model)
if err != nil {
return nil, err
}
return client.Embed(ctx, model, texts)
}
func (m *ModelManager) EmbedWithCurrentModel(ctx context.Context, texts []string) ([][]float32, error) {
m.mu.RLock()
model := m.currentModel
m.mu.RUnlock()
if model == "" {
return nil, fmt.Errorf("no model selected")
}
return m.Embed(ctx, model, texts)
}
func (m *ModelManager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
for range m.clients {
}
m.clients = make(map[string]*OllamaClient)
}
func (m *ModelManager) BaseURL() string {
return m.baseURL
}
func (m *ModelManager) NumCtx() int {
return m.numCtx
}
func (m *ModelManager) Model() string {
return m.CurrentModel()
}
// ListModels returns model names available in Ollama at the manager's base URL.
func (m *ModelManager) ListModels(ctx context.Context) ([]string, error) {
return ListModels(ctx, m.baseURL)
}

View File

@ -0,0 +1,92 @@
package llm
import (
"testing"
)
func TestNewModelManager(t *testing.T) {
m := NewModelManager("http://localhost:11434", 4096)
if m.baseURL != "http://localhost:11434" {
t.Errorf("baseURL = %q, want %q", m.baseURL, "http://localhost:11434")
}
if m.numCtx != 4096 {
t.Errorf("numCtx = %d, want %d", m.numCtx, 4096)
}
if m.clients == nil {
t.Error("clients map should be initialized")
}
}
func TestModelManagerBaseURL(t *testing.T) {
m := NewModelManager("http://custom:9999", 2048)
if m.BaseURL() != "http://custom:9999" {
t.Errorf("BaseURL() = %q, want %q", m.BaseURL(), "http://custom:9999")
}
}
func TestModelManagerNumCtx(t *testing.T) {
m := NewModelManager("http://localhost:11434", 8192)
if m.NumCtx() != 8192 {
t.Errorf("NumCtx() = %d, want %d", m.NumCtx(), 8192)
}
}
func TestModelManagerCurrentModel(t *testing.T) {
m := NewModelManager("http://localhost:11434", 4096)
// Should return empty when no model set
if m.CurrentModel() != "" {
t.Errorf("CurrentModel() = %q, want %q", m.CurrentModel(), "")
}
// Set a model
m.SetCurrentModel("llama3")
if m.CurrentModel() != "llama3" {
t.Errorf("CurrentModel() = %q, want %q", m.CurrentModel(), "llama3")
}
}
func TestModelManagerChatStreamNoModel(t *testing.T) {
m := NewModelManager("http://localhost:11434", 4096)
err := m.ChatStream(nil, ChatOptions{}, func(chunk StreamChunk) error {
return nil
})
if err == nil {
t.Error("ChatStream should fail when no model is set")
}
}
func TestModelManagerPingNoModel(t *testing.T) {
m := NewModelManager("http://localhost:11434", 4096)
err := m.Ping()
if err == nil {
t.Error("Ping should fail when no model is set")
}
}
func TestModelManagerEmbedWithCurrentModelNoModel(t *testing.T) {
m := NewModelManager("http://localhost:11434", 4096)
_, err := m.EmbedWithCurrentModel(nil, []string{"test"})
if err == nil {
t.Error("EmbedWithCurrentModel should fail when no model is set")
}
}
func TestModelManagerClose(t *testing.T) {
m := NewModelManager("http://localhost:11434", 4096)
// Should not panic
m.Close()
if len(m.clients) != 0 {
t.Errorf("after Close, clients map should be empty, got %d", len(m.clients))
}
}

222
internal/llm/ollama.go Normal file
View File

@ -0,0 +1,222 @@
package llm
import (
"context"
"fmt"
"net/url"
"os"
ollamaapi "github.com/ollama/ollama/api"
)
// OllamaClient implements Client using the official Ollama Go library.
type OllamaClient struct {
client *ollamaapi.Client
model string
numCtx int
}
// NewOllamaClient creates a new Ollama client.
func NewOllamaClient(baseURL, model string, numCtx int) (*OllamaClient, error) {
// The official client reads OLLAMA_HOST, but we want to support our config too.
if baseURL != "" {
os.Setenv("OLLAMA_HOST", baseURL)
}
client, err := ollamaapi.ClientFromEnvironment()
if err != nil {
return nil, fmt.Errorf("create ollama client: %w", err)
}
return &OllamaClient{
client: client,
model: model,
numCtx: numCtx,
}, nil
}
func (o *OllamaClient) Model() string { return o.model }
// Ping checks Ollama is running and the model exists.
func (o *OllamaClient) Ping() error {
ctx := context.Background()
// Check the model is available by requesting a show.
req := &ollamaapi.ShowRequest{Model: o.model}
_, err := o.client.Show(ctx, req)
if err != nil {
return fmt.Errorf("model %q not available: %w", o.model, err)
}
return nil
}
// ChatStream sends a chat request and streams the response via callback.
func (o *OllamaClient) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error {
messages := make([]ollamaapi.Message, 0, len(opts.Messages)+1)
if opts.System != "" {
messages = append(messages, ollamaapi.Message{
Role: "system",
Content: opts.System,
})
}
for _, m := range opts.Messages {
msg := ollamaapi.Message{
Role: m.Role,
Content: m.Content,
ToolName: m.ToolName,
ToolCallID: m.ToolCallID,
}
// Convert tool calls for assistant messages.
for _, tc := range m.ToolCalls {
args := ollamaapi.NewToolCallFunctionArguments()
for k, v := range tc.Arguments {
args.Set(k, v)
}
msg.ToolCalls = append(msg.ToolCalls, ollamaapi.ToolCall{
ID: tc.ID,
Function: ollamaapi.ToolCallFunction{
Name: tc.Name,
Arguments: args,
},
})
}
messages = append(messages, msg)
}
tools := convertTools(opts.Tools)
req := &ollamaapi.ChatRequest{
Model: o.model,
Messages: messages,
Tools: tools,
Options: map[string]any{
"num_ctx": o.numCtx,
},
}
return o.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error {
chunk := StreamChunk{
Text: resp.Message.Content,
Done: resp.Done,
}
if resp.Done {
chunk.EvalCount = resp.EvalCount
chunk.PromptEvalCount = resp.PromptEvalCount
}
// Collect tool calls from the response.
for _, tc := range resp.Message.ToolCalls {
chunk.ToolCalls = append(chunk.ToolCalls, ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments.ToMap(),
})
}
return fn(chunk)
})
}
// Embed generates embeddings for the given texts using the specified model.
func (o *OllamaClient) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) {
resp, err := o.client.Embed(ctx, &ollamaapi.EmbedRequest{
Model: model,
Input: texts,
})
if err != nil {
return nil, fmt.Errorf("embedding failed: %w", err)
}
return resp.Embeddings, nil
}
// convertTools transforms our ToolDef slice into Ollama's Tools format.
func convertTools(defs []ToolDef) ollamaapi.Tools {
if len(defs) == 0 {
return nil
}
tools := make(ollamaapi.Tools, 0, len(defs))
for _, d := range defs {
props := ollamaapi.NewToolPropertiesMap()
var required []string
// Extract properties from JSON Schema.
if propsRaw, ok := d.Parameters["properties"].(map[string]any); ok {
for name, schema := range propsRaw {
schemaMap, _ := schema.(map[string]any)
prop := ollamaapi.ToolProperty{
Description: strFromMap(schemaMap, "description"),
}
if t, ok := schemaMap["type"].(string); ok {
prop.Type = ollamaapi.PropertyType{t}
}
if enumRaw, ok := schemaMap["enum"].([]any); ok {
prop.Enum = enumRaw
}
props.Set(name, prop)
}
}
// Extract required fields.
if reqRaw, ok := d.Parameters["required"].([]any); ok {
for _, r := range reqRaw {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
}
tools = append(tools, ollamaapi.Tool{
Type: "function",
Function: ollamaapi.ToolFunction{
Name: d.Name,
Description: d.Description,
Parameters: ollamaapi.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: required,
},
},
})
}
return tools
}
func strFromMap(m map[string]any, key string) string {
if m == nil {
return ""
}
s, _ := m[key].(string)
return s
}
// BaseURL returns the configured Ollama base URL for display.
func (o *OllamaClient) BaseURL() string {
if v := os.Getenv("OLLAMA_HOST"); v != "" {
return v
}
return "http://localhost:11434"
}
// ParseBaseURL validates the Ollama URL.
func ParseBaseURL(rawURL string) (*url.URL, error) {
return url.Parse(rawURL)
}
// ListModels returns model names available in Ollama at baseURL.
func ListModels(ctx context.Context, baseURL string) ([]string, error) {
if baseURL != "" {
os.Setenv("OLLAMA_HOST", baseURL)
}
client, err := ollamaapi.ClientFromEnvironment()
if err != nil {
return nil, fmt.Errorf("ollama client: %w", err)
}
resp, err := client.List(ctx)
if err != nil {
return nil, fmt.Errorf("ollama list: %w", err)
}
names := make([]string, 0, len(resp.Models))
for _, m := range resp.Models {
names = append(names, m.Name)
}
return names, nil
}

121
internal/llm/ollama_test.go Normal file
View File

@ -0,0 +1,121 @@
package llm
import (
"testing"
)
func TestConvertTools(t *testing.T) {
tests := []struct {
name string
input []ToolDef
wantNil bool
wantCount int
}{
{
name: "nil input",
input: nil,
wantNil: true,
},
{
name: "single tool with properties and required",
input: []ToolDef{
{
Name: "read_file",
Description: "Read a file",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "file path",
},
},
"required": []any{"path"},
},
},
},
wantCount: 1,
},
{
name: "tool without properties in parameters",
input: []ToolDef{
{
Name: "noop",
Description: "Does nothing",
Parameters: map[string]any{"type": "object"},
},
},
wantCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertTools(tt.input)
if tt.wantNil {
if result != nil {
t.Errorf("convertTools() = %v, want nil", result)
}
return
}
if len(result) != tt.wantCount {
t.Errorf("convertTools() returned %d tools, want %d", len(result), tt.wantCount)
}
if tt.wantCount > 0 {
tool := result[0]
if tool.Function.Name != tt.input[0].Name {
t.Errorf("tool name = %q, want %q", tool.Function.Name, tt.input[0].Name)
}
if tool.Function.Description != tt.input[0].Description {
t.Errorf("tool description = %q, want %q", tool.Function.Description, tt.input[0].Description)
}
if tool.Type != "function" {
t.Errorf("tool type = %q, want %q", tool.Type, "function")
}
}
})
}
}
func TestStrFromMap(t *testing.T) {
tests := []struct {
name string
m map[string]any
key string
want string
}{
{
name: "key present",
m: map[string]any{"description": "a desc"},
key: "description",
want: "a desc",
},
{
name: "key missing",
m: map[string]any{"other": "value"},
key: "description",
want: "",
},
{
name: "nil map",
m: nil,
key: "description",
want: "",
},
{
name: "non-string value",
m: map[string]any{"count": 42},
key: "count",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := strFromMap(tt.m, tt.key)
if got != tt.want {
t.Errorf("strFromMap() = %q, want %q", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,33 @@
package logging
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/charmbracelet/log"
)
func NewSessionLogger() (*log.Logger, *os.File, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, nil, fmt.Errorf("home dir: %w", err)
}
logDir := filepath.Join(home, ".config", "ai-agent", "logs")
if err := os.MkdirAll(logDir, 0o755); err != nil {
return nil, nil, fmt.Errorf("create log dir: %w", err)
}
filename := time.Now().Format("2006-01-02_15-04-05") + ".log"
f, err := os.Create(filepath.Join(logDir, filename))
if err != nil {
return nil, nil, fmt.Errorf("create log file: %w", err)
}
logger := log.NewWithOptions(f, log.Options{
ReportTimestamp: true,
TimeFormat: time.RFC3339,
Prefix: "ai-agent",
Level: log.DebugLevel,
})
return logger, f, nil
}

View File

@ -0,0 +1,42 @@
package logging
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestNewSessionLogger(t *testing.T) {
logger, f, err := NewSessionLogger()
if err != nil {
t.Fatalf("NewSessionLogger() error: %v", err)
}
if f != nil {
defer f.Close()
defer os.Remove(f.Name())
}
if logger == nil {
t.Fatal("logger should not be nil")
}
if f == nil {
t.Fatal("file should not be nil")
}
dir := filepath.Dir(f.Name())
if !strings.Contains(dir, filepath.Join(".config", "ai-agent", "logs")) {
t.Errorf("log file should be in ~/.config/ai-agent/logs/, got %q", dir)
}
logger.Info("test message", "key", "value")
}
func TestNilLoggerNoPanic(t *testing.T) {
var called bool
logger, f, err := NewSessionLogger()
if err == nil && f != nil {
defer f.Close()
defer os.Remove(f.Name())
logger.Info("test")
called = true
}
_ = called
}

View File

@ -0,0 +1,92 @@
package logging
import (
"bufio"
"fmt"
"os"
"path/filepath"
"sort"
"time"
)
type LogEntry struct {
Path string
ModTime time.Time
Size int64
}
func LogDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".config", "ai-agent", "logs")
}
func ListLogs(n int) ([]LogEntry, error) {
return listLogsIn(LogDir(), n)
}
func listLogsIn(dir string, n int) ([]LogEntry, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("read log dir: %w", err)
}
var logs []LogEntry
for _, e := range entries {
if e.IsDir() {
continue
}
info, err := e.Info()
if err != nil {
continue
}
logs = append(logs, LogEntry{
Path: filepath.Join(dir, e.Name()),
ModTime: info.ModTime(),
Size: info.Size(),
})
}
sort.Slice(logs, func(i, j int) bool {
return logs[i].ModTime.After(logs[j].ModTime)
})
if n > 0 && n < len(logs) {
logs = logs[:n]
}
return logs, nil
}
func LatestLogPath() (string, error) {
return latestLogPathIn(LogDir())
}
func latestLogPathIn(dir string) (string, error) {
logs, err := listLogsIn(dir, 1)
if err != nil {
return "", err
}
if len(logs) == 0 {
return "", fmt.Errorf("no log files found in %s", dir)
}
return logs[0].Path, nil
}
func TailLog(path string, n int) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open log: %w", err)
}
defer f.Close()
var lines []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("read log: %w", err)
}
if n > 0 && n < len(lines) {
lines = lines[len(lines)-n:]
}
return lines, nil
}

View File

@ -0,0 +1,157 @@
package logging
import (
"os"
"path/filepath"
"testing"
"time"
)
// helper creates n temp log files in dir with distinct mod times.
func createFakeLogs(t *testing.T, dir string, n int) []string {
t.Helper()
var paths []string
for i := range n {
name := filepath.Join(dir, "2025-01-01_00-00-0"+string(rune('0'+i))+".log")
if err := os.WriteFile(name, []byte("line "+string(rune('0'+i))+"\n"), 0o644); err != nil {
t.Fatal(err)
}
// Stagger mod times so ordering is deterministic.
ts := time.Now().Add(time.Duration(i) * time.Second)
if err := os.Chtimes(name, ts, ts); err != nil {
t.Fatal(err)
}
paths = append(paths, name)
}
return paths
}
func TestListLogs(t *testing.T) {
dir := t.TempDir()
createFakeLogs(t, dir, 5)
logs, err := listLogsIn(dir, 3)
if err != nil {
t.Fatalf("listLogsIn error: %v", err)
}
if len(logs) != 3 {
t.Fatalf("expected 3 entries, got %d", len(logs))
}
// Verify newest-first ordering.
for i := 1; i < len(logs); i++ {
if logs[i].ModTime.After(logs[i-1].ModTime) {
t.Errorf("entry %d (%v) is newer than entry %d (%v)", i, logs[i].ModTime, i-1, logs[i-1].ModTime)
}
}
}
func TestListLogs_All(t *testing.T) {
dir := t.TempDir()
createFakeLogs(t, dir, 4)
logs, err := listLogsIn(dir, 0)
if err != nil {
t.Fatalf("listLogsIn error: %v", err)
}
if len(logs) != 4 {
t.Fatalf("expected 4 entries, got %d", len(logs))
}
}
func TestListLogs_EmptyDir(t *testing.T) {
dir := t.TempDir()
logs, err := listLogsIn(dir, 5)
if err != nil {
t.Fatalf("listLogsIn error: %v", err)
}
if len(logs) != 0 {
t.Fatalf("expected 0 entries, got %d", len(logs))
}
}
func TestListLogs_MissingDir(t *testing.T) {
_, err := listLogsIn("/tmp/nonexistent-log-dir-test-xyz", 5)
if err == nil {
t.Fatal("expected error for missing dir")
}
}
func TestTailLog(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.log")
content := "line1\nline2\nline3\nline4\nline5\n"
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
t.Fatal(err)
}
lines, err := TailLog(path, 3)
if err != nil {
t.Fatalf("TailLog error: %v", err)
}
if len(lines) != 3 {
t.Fatalf("expected 3 lines, got %d", len(lines))
}
if lines[0] != "line3" {
t.Errorf("expected 'line3', got %q", lines[0])
}
if lines[2] != "line5" {
t.Errorf("expected 'line5', got %q", lines[2])
}
}
func TestTailLog_FewerLines(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "short.log")
if err := os.WriteFile(path, []byte("only\n"), 0o644); err != nil {
t.Fatal(err)
}
lines, err := TailLog(path, 100)
if err != nil {
t.Fatalf("TailLog error: %v", err)
}
if len(lines) != 1 {
t.Fatalf("expected 1 line, got %d", len(lines))
}
}
func TestTailLog_MissingFile(t *testing.T) {
_, err := TailLog("/tmp/nonexistent-file-test-xyz.log", 10)
if err == nil {
t.Fatal("expected error for missing file")
}
}
func TestLatestLogPath(t *testing.T) {
dir := t.TempDir()
paths := createFakeLogs(t, dir, 3)
latest, err := latestLogPathIn(dir)
if err != nil {
t.Fatalf("latestLogPathIn error: %v", err)
}
// The last created file has the newest mod time.
expected := paths[len(paths)-1]
if latest != expected {
t.Errorf("expected %q, got %q", expected, latest)
}
}
func TestLatestLogPath_EmptyDir(t *testing.T) {
dir := t.TempDir()
_, err := latestLogPathIn(dir)
if err == nil {
t.Fatal("expected error for empty dir")
}
}
func TestLogDir(t *testing.T) {
dir := LogDir()
if dir == "" {
t.Fatal("LogDir should not be empty")
}
if filepath.Base(dir) != "logs" {
t.Errorf("expected dir to end in 'logs', got %q", dir)
}
}

110
internal/mcp/client.go Normal file
View File

@ -0,0 +1,110 @@
package mcp
import (
"context"
"fmt"
"os/exec"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
)
type MCPClient struct {
name string
client *sdkmcp.Client
session *sdkmcp.ClientSession
cmd *exec.Cmd
}
func Connect(ctx context.Context, name, command string, args []string, env []string, transport, url string) (*MCPClient, error) {
client := sdkmcp.NewClient(
&sdkmcp.Implementation{Name: "ai-agent", Version: "0.2.0"},
nil,
)
var t sdkmcp.Transport
switch transport {
case "sse":
if url == "" {
return nil, fmt.Errorf("sse transport requires url for %s", name)
}
t = &sdkmcp.SSEClientTransport{Endpoint: url}
case "streamable-http":
if url == "" {
return nil, fmt.Errorf("streamable-http transport requires url for %s", name)
}
t = &sdkmcp.StreamableClientTransport{Endpoint: url}
default:
if command == "" {
return nil, fmt.Errorf("stdio transport requires command for %s", name)
}
cmd := exec.Command(command, args...)
if len(env) > 0 {
cmd.Env = append(cmd.Environ(), env...)
}
t = &sdkmcp.CommandTransport{Command: cmd}
}
session, err := client.Connect(ctx, t, nil)
if err != nil {
return nil, fmt.Errorf("connect to %s: %w", name, err)
}
return &MCPClient{
name: name,
client: client,
session: session,
}, nil
}
func (c *MCPClient) Name() string { return c.name }
func (c *MCPClient) ListTools(ctx context.Context) ([]*sdkmcp.Tool, error) {
caps := c.session.InitializeResult()
if caps == nil || caps.Capabilities.Tools == nil {
return nil, nil
}
var tools []*sdkmcp.Tool
for tool, err := range c.session.Tools(ctx, nil) {
if err != nil {
return tools, fmt.Errorf("list tools from %s: %w", c.name, err)
}
tools = append(tools, tool)
}
return tools, nil
}
func (c *MCPClient) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
result, err := c.session.CallTool(ctx, &sdkmcp.CallToolParams{
Name: name,
Arguments: args,
})
if err != nil {
return nil, fmt.Errorf("call tool %s on %s: %w", name, c.name, err)
}
var text string
for _, ct := range result.Content {
if tc, ok := ct.(*sdkmcp.TextContent); ok {
if text != "" {
text += "\n"
}
text += tc.Text
}
}
return &ToolResult{Content: text, IsError: result.IsError}, nil
}
func (c *MCPClient) Close() error {
if c.session != nil {
return c.session.Close()
}
return nil
}
func (c *MCPClient) IsConnected() bool {
return c.session != nil
}
func (c *MCPClient) Ping(ctx context.Context) error {
if c.session == nil {
return fmt.Errorf("no session")
}
_, err := c.ListTools(ctx)
return err
}

241
internal/mcp/registry.go Normal file
View File

@ -0,0 +1,241 @@
package mcp
import (
"context"
"fmt"
"sync"
"time"
"ai-agent/internal/config"
"ai-agent/internal/llm"
)
type FailedServer struct {
Name string
Reason string
}
type ServerStatus struct {
Name string
Connected bool
LastError string
LastPing time.Time
}
type Registry struct {
mu sync.RWMutex
clients []*MCPClient
toolMap map[string]*MCPClient
toolDefs []llm.ToolDef
failedServers []FailedServer
serverConfigs map[string]config.ServerConfig
}
func NewRegistry() *Registry {
return &Registry{toolMap: make(map[string]*MCPClient), serverConfigs: make(map[string]config.ServerConfig)}
}
const connectTimeout = 5 * time.Second
func (r *Registry) ConnectServer(ctx context.Context, srv config.ServerConfig) (int, error) {
connCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
client, err := Connect(connCtx, srv.Name, srv.Command, srv.Args, srv.Env, srv.Transport, srv.URL)
if err != nil {
r.mu.Lock()
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
r.mu.Unlock()
return 0, fmt.Errorf("connect to %s: %w", srv.Name, err)
}
tools, err := client.ListTools(connCtx)
if err != nil {
client.Close()
r.mu.Lock()
r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()})
r.mu.Unlock()
return 0, fmt.Errorf("%s tools: %w", srv.Name, err)
}
r.mu.Lock()
r.clients = append(r.clients, client)
for _, tool := range tools {
r.toolMap[tool.Name] = client
r.toolDefs = append(r.toolDefs, ToLLMToolDef(tool.Name, tool.Description, tool.InputSchema))
}
r.serverConfigs[srv.Name] = srv
r.mu.Unlock()
return len(tools), nil
}
func (r *Registry) ConnectAll(ctx context.Context, servers []config.ServerConfig, logFn func(string)) {
for _, srv := range servers {
toolCount, err := r.ConnectServer(ctx, srv)
if err != nil {
logFn(fmt.Sprintf("skip %s: %v", srv.Name, err))
continue
}
logFn(fmt.Sprintf("connected %s (%d tools)", srv.Name, toolCount))
}
}
func (r *Registry) Tools() []llm.ToolDef {
r.mu.RLock()
defer r.mu.RUnlock()
return r.toolDefs
}
func (r *Registry) ToolCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.toolDefs)
}
func (r *Registry) ServerCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.clients)
}
func (r *Registry) ServerNames() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, len(r.clients))
for i, c := range r.clients {
names[i] = c.Name()
}
return names
}
func (r *Registry) FailedServers() []FailedServer {
r.mu.RLock()
defer r.mu.RUnlock()
return r.failedServers
}
func (r *Registry) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) {
r.mu.RLock()
client, ok := r.toolMap[name]
r.mu.RUnlock()
if !ok {
return &ToolResult{
Content: fmt.Sprintf("unknown tool: %s", name),
IsError: true,
}, nil
}
return client.CallTool(ctx, name, args)
}
func (r *Registry) Close() {
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.clients {
c.Close()
}
r.clients = nil
r.toolMap = make(map[string]*MCPClient)
r.toolDefs = nil
}
func (r *Registry) HealthCheck(ctx context.Context) []ServerStatus {
r.mu.RLock()
defer r.mu.RUnlock()
var results []ServerStatus
for _, client := range r.clients {
status := ServerStatus{Name: client.Name()}
if client.IsConnected() {
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := client.Ping(pingCtx)
cancel()
status.Connected = err == nil
if err != nil {
status.LastError = err.Error()
}
status.LastPing = time.Now()
}
results = append(results, status)
}
for _, failed := range r.failedServers {
results = append(results, ServerStatus{
Name: failed.Name,
Connected: false,
LastError: failed.Reason,
})
}
return results
}
func (r *Registry) ReconnectServer(ctx context.Context, name string) (int, error) {
r.mu.RLock()
srv, ok := r.serverConfigs[name]
r.mu.RUnlock()
if !ok {
return 0, fmt.Errorf("no config found for server: %s", name)
}
r.mu.Lock()
var remainingFailed []FailedServer
for _, f := range r.failedServers {
if f.Name != name {
remainingFailed = append(remainingFailed, f)
}
}
r.failedServers = remainingFailed
r.mu.Unlock()
return r.ConnectServer(ctx, srv)
}
type MonitorConfig struct {
Interval time.Duration
MaxRetries int
BackoffBase time.Duration
}
var defaultMonitorConfig = MonitorConfig{
Interval: 30 * time.Second,
MaxRetries: 3,
BackoffBase: 5 * time.Second,
}
func (r *Registry) StartHealthMonitor(ctx context.Context, cfg MonitorConfig, logFn func(string)) context.CancelFunc {
if cfg.Interval == 0 {
cfg = defaultMonitorConfig
}
monitorCtx, cancel := context.WithCancel(ctx)
go func() {
ticker := time.NewTicker(cfg.Interval)
defer ticker.Stop()
for {
select {
case <-monitorCtx.Done():
return
case <-ticker.C:
r.healthCheckRound(monitorCtx, cfg, logFn)
}
}
}()
return cancel
}
func (r *Registry) healthCheckRound(ctx context.Context, cfg MonitorConfig, logFn func(string)) {
statuses := r.HealthCheck(ctx)
for _, status := range statuses {
if status.Connected {
continue
}
logFn(fmt.Sprintf("server %s unhealthy, attempting reconnect...", status.Name))
for attempt := 1; attempt <= cfg.MaxRetries; attempt++ {
backoff := cfg.BackoffBase * time.Duration(attempt)
select {
case <-ctx.Done():
return
case <-time.After(backoff):
}
_, err := r.ReconnectServer(ctx, status.Name)
if err == nil {
logFn(fmt.Sprintf("server %s reconnected", status.Name))
break
}
if attempt == cfg.MaxRetries {
logFn(fmt.Sprintf("server %s reconnection failed after %d attempts: %v", status.Name, cfg.MaxRetries, err))
}
}
}
}

View File

@ -0,0 +1,73 @@
package mcp
import (
"context"
"strings"
"testing"
)
func TestNewRegistry(t *testing.T) {
r := NewRegistry()
if r.ToolCount() != 0 {
t.Errorf("ToolCount() = %d, want 0", r.ToolCount())
}
if r.ServerCount() != 0 {
t.Errorf("ServerCount() = %d, want 0", r.ServerCount())
}
if tools := r.Tools(); len(tools) != 0 {
t.Errorf("Tools() = %v, want empty", tools)
}
}
func TestRegistry_CallTool_Unknown(t *testing.T) {
r := NewRegistry()
result, err := r.CallTool(context.Background(), "nonexistent_tool", nil)
if err != nil {
t.Fatalf("CallTool() unexpected error: %v", err)
}
if !result.IsError {
t.Error("CallTool() IsError = false, want true for unknown tool")
}
if !strings.Contains(result.Content, "unknown tool") {
t.Errorf("CallTool() Content = %q, want to contain 'unknown tool'", result.Content)
}
}
func TestRegistry_HealthCheck_Empty(t *testing.T) {
r := NewRegistry()
statuses := r.HealthCheck(context.Background())
if len(statuses) != 0 {
t.Errorf("HealthCheck() returned %d statuses, want 0", len(statuses))
}
}
func TestRegistry_HealthCheck_TracksFailedServers(t *testing.T) {
r := NewRegistry()
// Simulate a failed server by directly adding to failedServers
r.mu.Lock()
r.failedServers = append(r.failedServers, FailedServer{
Name: "failed-server",
Reason: "connection refused",
})
r.mu.Unlock()
statuses := r.HealthCheck(context.Background())
if len(statuses) != 1 {
t.Fatalf("HealthCheck() returned %d statuses, want 1", len(statuses))
}
status := statuses[0]
if status.Name != "failed-server" {
t.Errorf("status.Name = %q, want 'failed-server'", status.Name)
}
if status.Connected {
t.Error("status.Connected = true, want false")
}
if status.LastError != "connection refused" {
t.Errorf("status.LastError = %q, want 'connection refused'", status.LastError)
}
}

27
internal/mcp/types.go Normal file
View File

@ -0,0 +1,27 @@
package mcp
import (
"ai-agent/internal/llm"
)
type ServerInfo struct {
Name string
ToolCount int
}
type ToolResult struct {
Content string
IsError bool
}
func ToLLMToolDef(name, description string, inputSchema any) llm.ToolDef {
params, _ := inputSchema.(map[string]any)
if params == nil {
params = map[string]any{"type": "object", "properties": map[string]any{}}
}
return llm.ToolDef{
Name: name,
Description: description,
Parameters: params,
}
}

View File

@ -0,0 +1,71 @@
package mcp
import (
"testing"
)
func TestToLLMToolDef(t *testing.T) {
tests := []struct {
name string
toolName string
description string
inputSchema any
wantName string
wantDesc string
wantParams bool // true = should have non-nil params
}{
{
name: "normal with valid schema",
toolName: "read_file",
description: "Read a file",
inputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{"type": "string"},
},
},
wantName: "read_file",
wantDesc: "Read a file",
wantParams: true,
},
{
name: "nil schema uses default",
toolName: "noop",
description: "No-op tool",
inputSchema: nil,
wantName: "noop",
wantDesc: "No-op tool",
wantParams: true,
},
{
name: "non-map schema uses default",
toolName: "bad_schema",
description: "Bad schema tool",
inputSchema: "not a map",
wantName: "bad_schema",
wantDesc: "Bad schema tool",
wantParams: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ToLLMToolDef(tt.toolName, tt.description, tt.inputSchema)
if result.Name != tt.wantName {
t.Errorf("Name = %q, want %q", result.Name, tt.wantName)
}
if result.Description != tt.wantDesc {
t.Errorf("Description = %q, want %q", result.Description, tt.wantDesc)
}
if tt.wantParams && result.Parameters == nil {
t.Error("Parameters should not be nil")
}
// Nil and non-map schemas should get the default object schema.
if tt.inputSchema == nil || func() bool { _, ok := tt.inputSchema.(map[string]any); return !ok }() {
if result.Parameters["type"] != "object" {
t.Errorf("default schema type = %v, want 'object'", result.Parameters["type"])
}
}
})
}
}

259
internal/memory/store.go Normal file
View File

@ -0,0 +1,259 @@
package memory
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
type Memory struct {
ID int `json:"id"`
Content string `json:"content"`
Tags []string `json:"tags,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsed time.Time `json:"last_used"`
}
type Store struct {
mu sync.Mutex
path string
memories []Memory
nextID int
}
func NewStore(path string) *Store {
if path == "" {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
path = filepath.Join(home, ".config", "ai-agent", "memories.json")
}
s := &Store{path: path}
s.load()
return s
}
func (s *Store) Save(content string, tags []string) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.nextID++
mem := Memory{
ID: s.nextID,
Content: content,
Tags: tags,
CreatedAt: time.Now(),
LastUsed: time.Now(),
}
s.memories = append(s.memories, mem)
if err := s.persist(); err != nil {
return 0, err
}
return mem.ID, nil
}
func (s *Store) Recall(query string, maxResults int) []Memory {
s.mu.Lock()
defer s.mu.Unlock()
if maxResults <= 0 {
maxResults = 5
}
queryLower := strings.ToLower(query)
words := strings.Fields(queryLower)
type scored struct {
mem Memory
score int
}
var results []scored
for i := range s.memories {
mem := s.memories[i]
score := 0
contentLower := strings.ToLower(mem.Content)
for _, w := range words {
if strings.Contains(contentLower, w) {
score += 2
}
}
for _, tag := range mem.Tags {
tagLower := strings.ToLower(tag)
for _, w := range words {
if strings.Contains(tagLower, w) {
score += 3
}
}
}
if score > 0 {
results = append(results, scored{mem: mem, score: score})
}
}
sort.Slice(results, func(i, j int) bool {
if results[i].score != results[j].score {
return results[i].score > results[j].score
}
return results[i].mem.LastUsed.After(results[j].mem.LastUsed)
})
if len(results) > maxResults {
results = results[:maxResults]
}
now := time.Now()
out := make([]Memory, len(results))
for i, r := range results {
out[i] = r.mem
for j := range s.memories {
if s.memories[j].ID == r.mem.ID {
s.memories[j].LastUsed = now
break
}
}
}
_ = s.persist()
return out
}
func (s *Store) Recent(n int) []Memory {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.memories) == 0 {
return nil
}
sorted := make([]Memory, len(s.memories))
copy(sorted, s.memories)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].LastUsed.After(sorted[j].LastUsed)
})
if n > len(sorted) {
n = len(sorted)
}
return sorted[:n]
}
func (s *Store) Count() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.memories)
}
func (s *Store) Delete(id int) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
for i, mem := range s.memories {
if mem.ID == id {
s.memories = append(s.memories[:i], s.memories[i+1:]...)
return true, s.persist()
}
}
return false, nil
}
func (s *Store) DeleteByTag(tag string) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
tagLower := strings.ToLower(tag)
var remaining []Memory
deleted := 0
for _, mem := range s.memories {
found := false
for _, t := range mem.Tags {
if strings.ToLower(t) == tagLower {
found = true
break
}
}
if found {
deleted++
} else {
remaining = append(remaining, mem)
}
}
s.memories = remaining
if deleted > 0 {
return deleted, s.persist()
}
return deleted, nil
}
func (s *Store) Update(id int, content string, tags []string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
for i, mem := range s.memories {
if mem.ID == id {
if content != "" {
s.memories[i].Content = content
}
if tags != nil {
s.memories[i].Tags = tags
}
s.memories[i].LastUsed = time.Now()
return true, s.persist()
}
}
return false, nil
}
func (s *Store) Prune(olderThan time.Duration) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
cutoff := time.Now().Add(-olderThan)
var remaining []Memory
deleted := 0
for _, mem := range s.memories {
if mem.CreatedAt.Before(cutoff) {
deleted++
} else {
remaining = append(remaining, mem)
}
}
s.memories = remaining
if deleted > 0 {
return deleted, s.persist()
}
return deleted, nil
}
func (s *Store) Get(id int) (Memory, bool) {
s.mu.Lock()
defer s.mu.Unlock()
for _, mem := range s.memories {
if mem.ID == id {
return mem, true
}
}
return Memory{}, false
}
func (s *Store) load() {
data, err := os.ReadFile(s.path)
if err != nil {
return
}
var memories []Memory
if err := json.Unmarshal(data, &memories); err != nil {
return
}
s.memories = memories
for _, m := range s.memories {
if m.ID > s.nextID {
s.nextID = m.ID
}
}
}
func (s *Store) persist() error {
dir := filepath.Dir(s.path)
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("create memory dir: %w", err)
}
data, err := json.MarshalIndent(s.memories, "", " ")
if err != nil {
return fmt.Errorf("marshal memories: %w", err)
}
if err := os.WriteFile(s.path, data, 0o644); err != nil {
return fmt.Errorf("write memories: %w", err)
}
return nil
}

View File

@ -0,0 +1,381 @@
package memory
import (
"path/filepath"
"testing"
"time"
)
func TestStore_Save_And_Count(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
if s.Count() != 0 {
t.Fatalf("new store Count = %d, want 0", s.Count())
}
id1, err := s.Save("first memory", []string{"tag1"})
if err != nil {
t.Fatalf("Save returned error: %v", err)
}
if id1 != 1 {
t.Errorf("first Save id = %d, want 1", id1)
}
if s.Count() != 1 {
t.Errorf("Count after first Save = %d, want 1", s.Count())
}
id2, err := s.Save("second memory", []string{"tag2"})
if err != nil {
t.Fatalf("Save returned error: %v", err)
}
if id2 != 2 {
t.Errorf("second Save id = %d, want 2", id2)
}
if s.Count() != 2 {
t.Errorf("Count after second Save = %d, want 2", s.Count())
}
}
func TestStore_Recall(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
s.Save("the user prefers Go language", []string{"preference", "golang"})
s.Save("project uses PostgreSQL database", []string{"tech", "database"})
s.Save("user name is Alice", []string{"name"})
t.Run("content match", func(t *testing.T) {
results := s.Recall("Go", 10)
if len(results) == 0 {
t.Fatal("expected results for 'Go' query")
}
found := false
for _, r := range results {
if r.Content == "the user prefers Go language" {
found = true
}
}
if !found {
t.Error("expected to find 'the user prefers Go language'")
}
})
t.Run("tag match", func(t *testing.T) {
results := s.Recall("golang", 10)
if len(results) == 0 {
t.Fatal("expected results for 'golang' tag query")
}
if results[0].Content != "the user prefers Go language" {
t.Errorf("top result = %q, want 'the user prefers Go language'", results[0].Content)
}
})
t.Run("combined scoring", func(t *testing.T) {
// "database" matches both content and tag for PostgreSQL entry.
results := s.Recall("database", 10)
if len(results) == 0 {
t.Fatal("expected results for 'database' query")
}
if results[0].Content != "project uses PostgreSQL database" {
t.Errorf("top result = %q, want 'project uses PostgreSQL database'",
results[0].Content)
}
})
t.Run("maxResults limit", func(t *testing.T) {
results := s.Recall("user", 1)
if len(results) > 1 {
t.Errorf("maxResults=1 but got %d results", len(results))
}
})
t.Run("default maxResults 5 when 0", func(t *testing.T) {
// With 3 memories, should return all 3 (default limit is 5).
results := s.Recall("user", 0)
if len(results) > 5 {
t.Errorf("default maxResults should be 5, got %d results", len(results))
}
})
t.Run("case insensitive", func(t *testing.T) {
results := s.Recall("ALICE", 10)
if len(results) == 0 {
t.Fatal("expected case-insensitive match for 'ALICE'")
}
if results[0].Content != "user name is Alice" {
t.Errorf("result = %q, want 'user name is Alice'", results[0].Content)
}
})
t.Run("no matches", func(t *testing.T) {
results := s.Recall("xyzzyzxyz", 10)
if len(results) != 0 {
t.Errorf("expected no results for nonsense query, got %d", len(results))
}
})
}
func TestStore_Recall_TieBreakByRecency(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
// Save two memories with the same scoring potential.
s.Save("alpha topic info", []string{"info"})
// Small delay so LastUsed differs.
time.Sleep(10 * time.Millisecond)
s.Save("beta topic info", []string{"info"})
results := s.Recall("info", 10)
if len(results) < 2 {
t.Fatalf("expected at least 2 results, got %d", len(results))
}
// Both match tag "info" equally (+3), so more recent (beta) should come first.
if results[0].Content != "beta topic info" {
t.Errorf("expected more recent 'beta topic info' first, got %q", results[0].Content)
}
}
func TestStore_Recent(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
s.Save("old memory", nil)
time.Sleep(10 * time.Millisecond)
s.Save("new memory", nil)
t.Run("ordering by LastUsed", func(t *testing.T) {
recent := s.Recent(2)
if len(recent) != 2 {
t.Fatalf("Recent(2) returned %d, want 2", len(recent))
}
if recent[0].Content != "new memory" {
t.Errorf("first recent = %q, want 'new memory'", recent[0].Content)
}
if recent[1].Content != "old memory" {
t.Errorf("second recent = %q, want 'old memory'", recent[1].Content)
}
})
t.Run("limit exceeds count returns all", func(t *testing.T) {
recent := s.Recent(100)
if len(recent) != 2 {
t.Errorf("Recent(100) returned %d, want 2", len(recent))
}
})
t.Run("empty store", func(t *testing.T) {
emptyPath := filepath.Join(dir, "empty.json")
empty := NewStore(emptyPath)
recent := empty.Recent(5)
if recent != nil {
t.Errorf("empty Recent should return nil, got %v", recent)
}
})
}
func TestStore_Persistence_RoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s1 := NewStore(path)
s1.Save("persistent memory", []string{"test"})
s1.Save("another memory", []string{"test2"})
// Create new store from same path.
s2 := NewStore(path)
if s2.Count() != 2 {
t.Errorf("reloaded Count = %d, want 2", s2.Count())
}
// Verify data is intact.
recent := s2.Recent(2)
contents := map[string]bool{}
for _, m := range recent {
contents[m.Content] = true
}
if !contents["persistent memory"] {
t.Error("missing 'persistent memory' after reload")
}
if !contents["another memory"] {
t.Error("missing 'another memory' after reload")
}
// Verify IDs continue.
id, err := s2.Save("third", nil)
if err != nil {
t.Fatalf("Save after reload: %v", err)
}
if id != 3 {
t.Errorf("continued id = %d, want 3", id)
}
}
func TestStore_Delete(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
id, _ := s.Save("to be deleted", []string{"temp"})
if s.Count() != 1 {
t.Fatalf("expected 1 memory, got %d", s.Count())
}
deleted, err := s.Delete(id)
if err != nil {
t.Fatalf("Delete returned error: %v", err)
}
if !deleted {
t.Error("Delete returned false for existing memory")
}
if s.Count() != 0 {
t.Errorf("Count after delete = %d, want 0", s.Count())
}
// Try deleting non-existent.
deleted, err = s.Delete(999)
if err != nil {
t.Fatalf("Delete returned error: %v", err)
}
if deleted {
t.Error("Delete should return false for non-existent memory")
}
}
func TestStore_Update(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
id, _ := s.Save("original content", []string{"original"})
updated, err := s.Update(id, "updated content", []string{"updated"})
if err != nil {
t.Fatalf("Update returned error: %v", err)
}
if !updated {
t.Error("Update returned false for existing memory")
}
// Verify update.
mem, found := s.Get(id)
if !found {
t.Fatal("memory not found after update")
}
if mem.Content != "updated content" {
t.Errorf("Content = %q, want 'updated content'", mem.Content)
}
if len(mem.Tags) != 1 || mem.Tags[0] != "updated" {
t.Errorf("Tags = %v, want ['updated']", mem.Tags)
}
// Try updating non-existent.
updated, err = s.Update(999, "test", nil)
if err != nil {
t.Fatalf("Update returned error: %v", err)
}
if updated {
t.Error("Update should return false for non-existent memory")
}
}
func TestStore_DeleteByTag(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
s.Save("keep this 1", []string{"keep"})
s.Save("delete this", []string{"temp"})
s.Save("keep this 2", []string{"keep"})
s.Save("delete this too", []string{"temp"})
s.Save("also keep", []string{"permanent"})
deleted, err := s.DeleteByTag("temp")
if err != nil {
t.Fatalf("DeleteByTag returned error: %v", err)
}
if deleted != 2 {
t.Errorf("DeleteByTag deleted = %d, want 2", deleted)
}
if s.Count() != 3 {
t.Errorf("Count after delete = %d, want 3", s.Count())
}
// Verify only temp memories are gone.
results := s.Recall("keep", 10)
if len(results) != 3 {
t.Errorf("Recall returned %d, want 3", len(results))
}
}
func TestStore_Get(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
id, _ := s.Save("test memory", []string{"tag"})
mem, found := s.Get(id)
if !found {
t.Fatal("Get returned false for existing memory")
}
if mem.Content != "test memory" {
t.Errorf("Content = %q, want 'test memory'", mem.Content)
}
if len(mem.Tags) != 1 || mem.Tags[0] != "tag" {
t.Errorf("Tags = %v, want ['tag']", mem.Tags)
}
// Try getting non-existent.
_, found = s.Get(999)
if found {
t.Error("Get should return false for non-existent memory")
}
}
func TestStore_UpdatePartial(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "memories.json")
s := NewStore(path)
id, _ := s.Save("original content", []string{"original", "tags"})
// Update only content, keep tags.
updated, err := s.Update(id, "new content", nil)
if err != nil {
t.Fatalf("Update returned error: %v", err)
}
if !updated {
t.Error("Update returned false")
}
mem, _ := s.Get(id)
if mem.Content != "new content" {
t.Errorf("Content = %q, want 'new content'", mem.Content)
}
// Tags should remain unchanged when nil is passed.
if len(mem.Tags) != 2 {
t.Errorf("Tags = %v, want 2 tags", mem.Tags)
}
// Update only tags, keep content.
updated, err = s.Update(id, "", []string{"only", "tags"})
if err != nil {
t.Fatalf("Update returned error: %v", err)
}
if !updated {
t.Error("Update returned false")
}
mem, _ = s.Get(id)
if mem.Content != "new content" {
t.Errorf("Content changed unexpectedly to %q", mem.Content)
}
if len(mem.Tags) != 2 || mem.Tags[0] != "only" || mem.Tags[1] != "tags" {
t.Errorf("Tags = %v, want ['only', 'tags']", mem.Tags)
}
}

102
internal/memory/tools.go Normal file
View File

@ -0,0 +1,102 @@
package memory
import (
"ai-agent/internal/llm"
)
func BuiltinToolDefs() []llm.ToolDef {
return []llm.ToolDef{
{
Name: "memory_save",
Description: "Save an important fact, user preference, or piece of context to persistent memory. Use this proactively when the user shares information worth remembering across sessions.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"content": map[string]any{
"type": "string",
"description": "The fact or information to remember.",
},
"tags": map[string]any{
"type": "array",
"items": map[string]any{"type": "string"},
"description": "Optional tags for categorization (e.g., 'preference', 'project', 'name').",
},
},
"required": []string{"content"},
},
},
{
Name: "memory_recall",
Description: "Search persistent memory for previously saved facts. Use this when you need to recall user preferences, project details, or other saved context.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "Search query to find relevant memories.",
},
},
"required": []string{"query"},
},
},
{
Name: "memory_delete",
Description: "Delete a memory by its ID. Use memory_recall or memory_list first to find the ID of the memory you want to delete.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"id": map[string]any{
"type": "integer",
"description": "The ID of the memory to delete (use memory_list or memory_recall to find IDs).",
},
},
"required": []string{"id"},
},
},
{
Name: "memory_update",
Description: "Update an existing memory's content or tags. Use memory_recall or memory_list first to find the ID.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"id": map[string]any{
"type": "integer",
"description": "The ID of the memory to update (use memory_list or memory_recall to find IDs).",
},
"content": map[string]any{
"type": "string",
"description": "New content for the memory.",
},
"tags": map[string]any{
"type": "array",
"items": map[string]any{"type": "string"},
"description": "New tags for the memory.",
},
},
"required": []string{"id"},
},
},
{
Name: "memory_list",
Description: "List all stored memories with their IDs, content, and tags. Use this to see what has been saved.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of memories to return (default: 20).",
},
},
},
},
}
}
func IsBuiltinTool(name string) bool {
switch name {
case "memory_save", "memory_recall", "memory_delete", "memory_update", "memory_list":
return true
default:
return false
}
}

View File

@ -0,0 +1,48 @@
package memory
import "testing"
func TestBuiltinToolDefs(t *testing.T) {
defs := BuiltinToolDefs()
if len(defs) != 5 {
t.Fatalf("BuiltinToolDefs() returned %d defs, want 5", len(defs))
}
names := map[string]bool{}
for _, d := range defs {
names[d.Name] = true
}
expected := []string{"memory_save", "memory_recall", "memory_delete", "memory_update", "memory_list"}
for _, name := range expected {
if !names[name] {
t.Errorf("missing %s tool definition", name)
}
}
}
func TestIsBuiltinTool(t *testing.T) {
tests := []struct {
name string
tool string
want bool
}{
{name: "memory_save", tool: "memory_save", want: true},
{name: "memory_recall", tool: "memory_recall", want: true},
{name: "memory_delete", tool: "memory_delete", want: true},
{name: "memory_update", tool: "memory_update", want: true},
{name: "memory_list", tool: "memory_list", want: true},
{name: "unknown tool", tool: "unknown", want: false},
{name: "empty string", tool: "", want: false},
{name: "partial match", tool: "memory_", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsBuiltinTool(tt.tool)
if got != tt.want {
t.Errorf("IsBuiltinTool(%q) = %v, want %v", tt.tool, got, tt.want)
}
})
}
}

View File

@ -0,0 +1,158 @@
package permission
import (
"context"
"sync"
"ai-agent/internal/db"
)
type Policy string
const (
PolicyAllow Policy = "allow"
PolicyDeny Policy = "deny"
PolicyAsk Policy = "ask"
)
type Checker struct {
store *db.Store
cache map[string]Policy
mu sync.RWMutex
yolo bool
}
func NewChecker(store *db.Store, yolo bool) *Checker {
c := &Checker{
store: store,
cache: make(map[string]Policy),
yolo: yolo,
}
if store != nil {
c.loadFromDB()
}
return c
}
func (c *Checker) Check(toolName string) Policy {
if c.yolo {
return PolicyAllow
}
c.mu.RLock()
defer c.mu.RUnlock()
if p, ok := c.cache[toolName]; ok {
return p
}
return PolicyAsk
}
func (c *Checker) SetPolicy(toolName string, policy Policy) {
c.mu.Lock()
c.cache[toolName] = policy
c.mu.Unlock()
if c.store != nil {
c.store.UpsertToolPermission(context.Background(), db.UpsertToolPermissionParams{
ToolName: toolName,
Policy: string(policy),
})
}
}
func (c *Checker) IsYolo() bool {
return c.yolo
}
func (c *Checker) AllPolicies() map[string]Policy {
c.mu.RLock()
defer c.mu.RUnlock()
result := make(map[string]Policy, len(c.cache))
for k, v := range c.cache {
result[k] = v
}
return result
}
func (c *Checker) Reset() {
c.mu.Lock()
c.cache = make(map[string]Policy)
c.mu.Unlock()
if c.store != nil {
c.store.ResetToolPermissions(context.Background())
}
}
func (c *Checker) loadFromDB() {
perms, err := c.store.ListToolPermissions(context.Background())
if err != nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
for _, p := range perms {
switch Policy(p.Policy) {
case PolicyAllow, PolicyDeny, PolicyAsk:
c.cache[p.ToolName] = Policy(p.Policy)
}
}
}
type ApprovalRequest struct {
ToolName string
Args map[string]any
Response chan ApprovalResponse
}
type ApprovalResponse struct {
Allowed bool
Always bool
}
func RequestApproval(toolName string, args map[string]any, callback func(ApprovalRequest)) (bool, bool) {
if callback == nil {
return true, false
}
ch := make(chan ApprovalResponse, 1)
callback(ApprovalRequest{
ToolName: toolName,
Args: args,
Response: ch,
})
resp := <-ch
return resp.Allowed, resp.Always
}
type CheckResult int
const (
CheckAllow CheckResult = iota
CheckDeny
CheckAsk
)
func (c *Checker) ToCheckResult(toolName string) CheckResult {
if c == nil || c.yolo {
return CheckAllow
}
switch c.Check(toolName) {
case PolicyAllow:
return CheckAllow
case PolicyDeny:
return CheckDeny
default:
return CheckAsk
}
}
func NilSafe(store *db.Store, yolo bool) *Checker {
return NewChecker(store, yolo)
}
var AlwaysAllow = func(_ ApprovalRequest) {}
type ErrDenied struct {
ToolName string
}
func (e *ErrDenied) Error() string {
return "tool call denied by permission policy: " + e.ToolName
}

View File

@ -0,0 +1,98 @@
package permission
import (
"path/filepath"
"testing"
"ai-agent/internal/db"
)
func TestChecker_DefaultPolicy(t *testing.T) {
c := NewChecker(nil, false)
if got := c.Check("some_tool"); got != PolicyAsk {
t.Errorf("Check() = %q, want %q", got, PolicyAsk)
}
}
func TestChecker_Yolo(t *testing.T) {
c := NewChecker(nil, true)
if got := c.Check("any_tool"); got != PolicyAllow {
t.Errorf("Check() = %q, want %q", got, PolicyAllow)
}
if !c.IsYolo() {
t.Error("expected IsYolo() = true")
}
}
func TestChecker_SetPolicy(t *testing.T) {
c := NewChecker(nil, false)
c.SetPolicy("bash", PolicyAllow)
if got := c.Check("bash"); got != PolicyAllow {
t.Errorf("Check() = %q, want %q", got, PolicyAllow)
}
c.SetPolicy("bash", PolicyDeny)
if got := c.Check("bash"); got != PolicyDeny {
t.Errorf("Check() = %q, want %q", got, PolicyDeny)
}
}
func TestChecker_WithDB(t *testing.T) {
store, err := db.OpenPath(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatal(err)
}
defer store.Close()
c := NewChecker(store, false)
c.SetPolicy("file_write", PolicyAllow)
c2 := NewChecker(store, false)
if got := c2.Check("file_write"); got != PolicyAllow {
t.Errorf("persisted Check() = %q, want %q", got, PolicyAllow)
}
}
func TestChecker_Reset(t *testing.T) {
c := NewChecker(nil, false)
c.SetPolicy("tool1", PolicyAllow)
c.SetPolicy("tool2", PolicyDeny)
c.Reset()
if got := c.Check("tool1"); got != PolicyAsk {
t.Errorf("after reset Check() = %q, want %q", got, PolicyAsk)
}
}
func TestChecker_AllPolicies(t *testing.T) {
c := NewChecker(nil, false)
c.SetPolicy("a", PolicyAllow)
c.SetPolicy("b", PolicyDeny)
policies := c.AllPolicies()
if len(policies) != 2 {
t.Errorf("AllPolicies() len = %d, want 2", len(policies))
}
if policies["a"] != PolicyAllow {
t.Errorf("policies[a] = %q, want %q", policies["a"], PolicyAllow)
}
}
func TestToCheckResult(t *testing.T) {
c := NewChecker(nil, false)
c.SetPolicy("allowed", PolicyAllow)
c.SetPolicy("denied", PolicyDeny)
if c.ToCheckResult("allowed") != CheckAllow {
t.Error("expected CheckAllow for allowed tool")
}
if c.ToCheckResult("denied") != CheckDeny {
t.Error("expected CheckDeny for denied tool")
}
if c.ToCheckResult("unknown") != CheckAsk {
t.Error("expected CheckAsk for unknown tool")
}
}
func TestToCheckResult_Nil(t *testing.T) {
var c *Checker
if c.ToCheckResult("anything") != CheckAllow {
t.Error("nil checker should return CheckAllow")
}
}

118
internal/skill/manager.go Normal file
View File

@ -0,0 +1,118 @@
package skill
import (
"fmt"
"os"
"path/filepath"
"strings"
)
type Manager struct {
skills []*Skill
dirs []string
}
func NewManager(dir string) *Manager {
dirs := []string{}
if dir != "" {
dirs = append(dirs, dir)
} else {
if home, err := os.UserHomeDir(); err == nil {
dirs = append(dirs, filepath.Join(home, ".config", "ai-agent", "skills"))
}
}
return &Manager{dirs: dirs}
}
func (m *Manager) AddSearchPath(dir string) {
for _, d := range m.dirs {
if d == dir {
return
}
}
m.dirs = append(m.dirs, dir)
}
func (m *Manager) Names() []string {
var names []string
for _, s := range m.skills {
names = append(names, s.Name)
}
return names
}
func (m *Manager) LoadAll() error {
for _, dir := range m.dirs {
if err := m.loadFromDir(dir); err != nil {
return err
}
}
return nil
}
func (m *Manager) loadFromDir(dir string) error {
if dir == "" {
return nil
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read skills dir: %w", err)
}
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
continue
}
path := filepath.Join(dir, entry.Name())
data, err := os.ReadFile(path)
if err != nil {
continue
}
skill, err := parseFrontmatter(string(data))
if err != nil {
continue
}
skill.Path = path
if skill.Name == "" {
skill.Name = strings.TrimSuffix(entry.Name(), ".md")
}
m.skills = append(m.skills, skill)
}
return nil
}
func (m *Manager) All() []*Skill {
return m.skills
}
func (m *Manager) Activate(name string) error {
for _, s := range m.skills {
if s.Name == name {
s.Active = true
return nil
}
}
return fmt.Errorf("skill not found: %s", name)
}
func (m *Manager) Deactivate(name string) error {
for _, s := range m.skills {
if s.Name == name {
s.Active = false
return nil
}
}
return fmt.Errorf("skill not found: %s", name)
}
func (m *Manager) ActiveContent() string {
var parts []string
for _, s := range m.skills {
if s.Active && s.Content != "" {
parts = append(parts, fmt.Sprintf("### %s\n%s", s.Name, s.Content))
}
}
return strings.Join(parts, "\n\n")
}

View File

@ -0,0 +1,169 @@
package skill
import (
"os"
"path/filepath"
"testing"
)
func TestManager_LoadAll(t *testing.T) {
dir := t.TempDir()
// Create valid skill files.
os.WriteFile(filepath.Join(dir, "greeting.md"), []byte("---\nname: greeting\ndescription: Say hello\n---\nHello!"), 0o644)
os.WriteFile(filepath.Join(dir, "farewell.md"), []byte("---\nname: farewell\ndescription: Say bye\n---\nGoodbye!"), 0o644)
// Create a non-.md file (should be skipped).
os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("not a skill"), 0o644)
// Create a subdirectory (should be skipped).
os.MkdirAll(filepath.Join(dir, "subdir"), 0o755)
m := NewManager(dir)
if err := m.LoadAll(); err != nil {
t.Fatalf("LoadAll: %v", err)
}
skills := m.All()
if len(skills) != 2 {
t.Fatalf("loaded %d skills, want 2", len(skills))
}
names := map[string]bool{}
for _, s := range skills {
names[s.Name] = true
}
if !names["greeting"] {
t.Error("missing 'greeting' skill")
}
if !names["farewell"] {
t.Error("missing 'farewell' skill")
}
}
func TestManager_LoadAll_NoFrontmatter(t *testing.T) {
dir := t.TempDir()
// File without frontmatter uses filename as name.
os.WriteFile(filepath.Join(dir, "plain.md"), []byte("Just content, no frontmatter"), 0o644)
m := NewManager(dir)
if err := m.LoadAll(); err != nil {
t.Fatalf("LoadAll: %v", err)
}
skills := m.All()
if len(skills) != 1 {
t.Fatalf("loaded %d skills, want 1", len(skills))
}
if skills[0].Name != "plain" {
t.Errorf("Name = %q, want 'plain'", skills[0].Name)
}
}
func TestManager_LoadAll_NonexistentDir(t *testing.T) {
m := NewManager("/nonexistent/path/that/does/not/exist")
if err := m.LoadAll(); err != nil {
t.Fatalf("LoadAll on nonexistent dir should not error, got: %v", err)
}
if len(m.All()) != 0 {
t.Errorf("expected 0 skills from nonexistent dir, got %d", len(m.All()))
}
}
func TestManager_Activate_Deactivate(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "test.md"), []byte("---\nname: test\n---\nTest content"), 0o644)
m := NewManager(dir)
m.LoadAll()
t.Run("activate found", func(t *testing.T) {
err := m.Activate("test")
if err != nil {
t.Fatalf("Activate: %v", err)
}
skill := m.All()[0]
if !skill.Active {
t.Error("skill should be active after Activate")
}
})
t.Run("activate not found", func(t *testing.T) {
err := m.Activate("nonexistent")
if err == nil {
t.Error("expected error for nonexistent skill")
}
})
t.Run("deactivate found", func(t *testing.T) {
err := m.Deactivate("test")
if err != nil {
t.Fatalf("Deactivate: %v", err)
}
skill := m.All()[0]
if skill.Active {
t.Error("skill should be inactive after Deactivate")
}
})
t.Run("deactivate not found", func(t *testing.T) {
err := m.Deactivate("nonexistent")
if err == nil {
t.Error("expected error for nonexistent skill")
}
})
}
func TestManager_ActiveContent(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "alpha.md"), []byte("---\nname: alpha\n---\nAlpha content"), 0o644)
os.WriteFile(filepath.Join(dir, "beta.md"), []byte("---\nname: beta\n---\nBeta content"), 0o644)
m := NewManager(dir)
m.LoadAll()
t.Run("none active returns empty", func(t *testing.T) {
content := m.ActiveContent()
if content != "" {
t.Errorf("expected empty content, got %q", content)
}
})
t.Run("one active returns its content", func(t *testing.T) {
m.Activate("alpha")
content := m.ActiveContent()
if content == "" {
t.Fatal("expected non-empty content")
}
if !contains(content, "Alpha content") {
t.Errorf("content missing 'Alpha content': %q", content)
}
if contains(content, "Beta content") {
t.Errorf("content should not contain inactive 'Beta content': %q", content)
}
m.Deactivate("alpha")
})
t.Run("multiple active returns combined", func(t *testing.T) {
m.Activate("alpha")
m.Activate("beta")
content := m.ActiveContent()
if !contains(content, "Alpha content") || !contains(content, "Beta content") {
t.Errorf("combined content missing expected parts: %q", content)
}
})
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchString(s, substr)
}
func searchString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

65
internal/skill/types.go Normal file
View File

@ -0,0 +1,65 @@
package skill
import (
"bufio"
"strings"
"gopkg.in/yaml.v3"
)
// Skill represents a loadable skill definition.
type Skill struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
Active bool `yaml:"-"`
Content string `yaml:"-"` // markdown body after frontmatter
Path string `yaml:"-"` // file path
}
// parseFrontmatter extracts YAML frontmatter and markdown body from a skill file.
// Frontmatter is delimited by "---" on the first and closing lines.
func parseFrontmatter(data string) (*Skill, error) {
scanner := bufio.NewScanner(strings.NewReader(data))
// Check for opening "---".
if !scanner.Scan() || strings.TrimSpace(scanner.Text()) != "---" {
// No frontmatter — treat entire content as body.
return &Skill{Content: data}, nil
}
// Read YAML lines until closing "---".
var yamlBuf strings.Builder
foundEnd := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "---" {
foundEnd = true
break
}
yamlBuf.WriteString(line)
yamlBuf.WriteString("\n")
}
if !foundEnd {
// No closing delimiter — treat as body only.
return &Skill{Content: data}, nil
}
// Parse YAML frontmatter.
s := &Skill{}
if err := yaml.Unmarshal([]byte(yamlBuf.String()), s); err != nil {
return nil, err
}
// Remaining content is the markdown body.
var bodyBuf strings.Builder
for scanner.Scan() {
if bodyBuf.Len() > 0 {
bodyBuf.WriteString("\n")
}
bodyBuf.WriteString(scanner.Text())
}
s.Content = strings.TrimSpace(bodyBuf.String())
return s, nil
}

View File

@ -0,0 +1,78 @@
package skill
import "testing"
func TestParseFrontmatter(t *testing.T) {
tests := []struct {
name string
input string
wantName string
wantDesc string
wantContent string
wantErr bool
}{
{
name: "valid frontmatter",
input: "---\nname: test\ndescription: desc\n---\nBody content",
wantName: "test",
wantDesc: "desc",
wantContent: "Body content",
},
{
name: "no frontmatter",
input: "Just body",
wantContent: "Just body",
},
{
name: "missing closing delimiter",
input: "---\nname: test\nBody",
wantContent: "---\nname: test\nBody",
},
{
name: "invalid YAML",
input: "---\n: :\n---\nbody",
wantErr: true,
},
{
name: "empty body",
input: "---\nname: test\n---\n",
wantName: "test",
wantContent: "",
},
{
name: "empty input",
input: "",
wantContent: "",
},
{
name: "multiline body",
input: "---\nname: multi\n---\nline 1\nline 2\nline 3",
wantName: "multi",
wantContent: "line 1\nline 2\nline 3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
skill, err := parseFrontmatter(tt.input)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if skill.Name != tt.wantName {
t.Errorf("Name = %q, want %q", skill.Name, tt.wantName)
}
if skill.Description != tt.wantDesc {
t.Errorf("Description = %q, want %q", skill.Description, tt.wantDesc)
}
if skill.Content != tt.wantContent {
t.Errorf("Content = %q, want %q", skill.Content, tt.wantContent)
}
})
}
}

View File

@ -0,0 +1,45 @@
package tools
import (
"ai-agent/internal/llm"
)
var builtinToolNames = map[string]bool{
"grep": true,
"read": true,
"write": true,
"glob": true,
"bash": true,
"ls": true,
"find": true,
"diff": true,
"edit": true,
"mkdir": true,
"remove": true,
"copy": true,
"move": true,
"exists": true,
}
func AllToolDefs() []llm.ToolDef {
return []llm.ToolDef{
GrepToolDef(),
ReadToolDef(),
WriteToolDef(),
GlobToolDef(),
BashToolDef(),
LsToolDef(),
FindToolDef(),
DiffToolDef(),
EditToolDef(),
MkdirToolDef(),
RemoveToolDef(),
CopyToolDef(),
MoveToolDef(),
ExistsToolDef(),
}
}
func IsBuiltinTool(name string) bool {
return builtinToolNames[name]
}

306
internal/tools/tools.go Normal file
View File

@ -0,0 +1,306 @@
package tools
import (
"ai-agent/internal/llm"
)
func GrepToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "grep",
Description: "Search for a pattern in files. Use this to find code, text, or values across multiple files.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "The regex pattern to search for.",
},
"path": map[string]any{
"type": "string",
"description": "Directory path to search in (defaults to current directory).",
},
"include": map[string]any{
"type": "string",
"description": "File pattern to include (e.g., '*.go', '*.ts').",
},
"context": map[string]any{
"type": "integer",
"description": "Number of lines of context to show around matches (default: 3).",
},
},
"required": []string{"pattern"},
},
}
}
func ReadToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "read",
Description: "Read the contents of a file. Use this to view source code, configuration files, or any text file.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to read.",
},
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of lines to read (optional).",
},
"offset": map[string]any{
"type": "integer",
"description": "Line number to start reading from (optional, 1-indexed).",
},
},
"required": []string{"path"},
},
}
}
func WriteToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "write",
Description: "Write content to a file. Use this to create new files or overwrite existing ones. Creates parent directories if needed.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to write.",
},
"content": map[string]any{
"type": "string",
"description": "Content to write to the file.",
},
},
"required": []string{"path", "content"},
},
}
}
func GlobToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "glob",
Description: "Find files matching a pattern. Use this to discover files by name patterns like '*.go', '**/*.ts', etc.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "Glob pattern to match (e.g., '**/*.go', 'src/**/*.ts').",
},
"path": map[string]any{
"type": "string",
"description": "Directory to search in (defaults to current directory).",
},
},
"required": []string{"pattern"},
},
}
}
func BashToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "bash",
Description: "Execute a shell command. Use this to run git, npm, go, or other command-line tools. Output is returned after completion.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"command": map[string]any{
"type": "string",
"description": "The shell command to execute.",
},
"timeout": map[string]any{
"type": "integer",
"description": "Timeout in seconds (default: 30, max: 120).",
},
},
"required": []string{"command"},
},
}
}
func LsToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "ls",
Description: "List files and directories. Use this to see what's in a directory.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Directory path to list (defaults to current directory).",
},
},
},
}
}
func FindToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "find",
Description: "Find files or directories by name. Use this to locate specific files when you know all or part of the filename.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"description": "Name or pattern to search for (supports * and ? wildcards).",
},
"path": map[string]any{
"type": "string",
"description": "Directory to search in (defaults to current directory).",
},
"type": map[string]any{
"type": "string",
"description": "Type to find: 'f' for files, 'd' for directories (default: both).",
},
},
"required": []string{"name"},
},
}
}
func DiffToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "diff",
Description: "Show the differences between the current file content and new content. Use this to preview changes before writing.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to diff.",
},
"new_content": map[string]any{
"type": "string",
"description": "The new content to compare against the current file.",
},
},
"required": []string{"path", "new_content"},
},
}
}
func EditToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "edit",
Description: "Apply a patch to a file. Use this to make targeted edits to specific lines without overwriting the entire file. The patch format is: @@ -start,count +new_start,new_count @@\nfollowed by lines starting with - (remove), + (add), or (context).",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to edit.",
},
"patch": map[string]any{
"type": "string",
"description": "Unified diff patch to apply. Format: @@ -start,count +new_start,new_count @@ followed by -line (remove), +line (add), or context line.",
},
},
"required": []string{"path", "patch"},
},
}
}
func MkdirToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "mkdir",
Description: "Create one or more directories. Creates parent directories as needed.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the directory to create.",
},
},
"required": []string{"path"},
},
}
}
func RemoveToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "remove",
Description: "Remove files or directories. Use with caution - this permanently deletes files.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to remove (file or directory).",
},
"recursive": map[string]any{
"type": "boolean",
"description": "Remove directories recursively (default: false).",
},
"force": map[string]any{
"type": "boolean",
"description": "Ignore nonexistent files (default: false).",
},
},
"required": []string{"path"},
},
}
}
func CopyToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "copy",
Description: "Copy a file from source to destination.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"source": map[string]any{
"type": "string",
"description": "Source path to copy from.",
},
"destination": map[string]any{
"type": "string",
"description": "Destination path to copy to.",
},
},
"required": []string{"source", "destination"},
},
}
}
func MoveToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "move",
Description: "Move or rename a file or directory.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"source": map[string]any{
"type": "string",
"description": "Source path to move from.",
},
"destination": map[string]any{
"type": "string",
"description": "Destination path to move to.",
},
},
"required": []string{"source", "destination"},
},
}
}
func ExistsToolDef() llm.ToolDef {
return llm.ToolDef{
Name: "exists",
Description: "Check if a file or directory exists and get information about it.",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to check.",
},
},
"required": []string{"path"},
},
}
}

View File

@ -0,0 +1,156 @@
package tools
import (
"testing"
)
func TestGrepToolDef(t *testing.T) {
tool := GrepToolDef()
if tool.Name != "grep" {
t.Errorf("Name = %q, want %q", tool.Name, "grep")
}
if tool.Description == "" {
t.Error("Description should not be empty")
}
if tool.Parameters == nil {
t.Error("Parameters should not be nil")
}
}
func TestReadToolDef(t *testing.T) {
tool := ReadToolDef()
if tool.Name != "read" {
t.Errorf("Name = %q, want %q", tool.Name, "read")
}
props := tool.Parameters["properties"].(map[string]any)
if _, ok := props["path"]; !ok {
t.Error("should have path property")
}
}
func TestWriteToolDef(t *testing.T) {
tool := WriteToolDef()
if tool.Name != "write" {
t.Errorf("Name = %q, want %q", tool.Name, "write")
}
props := tool.Parameters["properties"].(map[string]any)
if _, ok := props["path"]; !ok {
t.Error("should have path property")
}
if _, ok := props["content"]; !ok {
t.Error("should have content property")
}
}
func TestGlobToolDef(t *testing.T) {
tool := GlobToolDef()
if tool.Name != "glob" {
t.Errorf("Name = %q, want %q", tool.Name, "glob")
}
}
func TestBashToolDef(t *testing.T) {
tool := BashToolDef()
if tool.Name != "bash" {
t.Errorf("Name = %q, want %q", tool.Name, "bash")
}
props := tool.Parameters["properties"].(map[string]any)
if _, ok := props["command"]; !ok {
t.Error("should have command property")
}
}
func TestLsToolDef(t *testing.T) {
tool := LsToolDef()
if tool.Name != "ls" {
t.Errorf("Name = %q, want %q", tool.Name, "ls")
}
}
func TestFindToolDef(t *testing.T) {
tool := FindToolDef()
if tool.Name != "find" {
t.Errorf("Name = %q, want %q", tool.Name, "find")
}
props := tool.Parameters["properties"].(map[string]any)
if _, ok := props["name"]; !ok {
t.Error("should have name property")
}
}
func TestDiffToolDef(t *testing.T) {
tool := DiffToolDef()
if tool.Name != "diff" {
t.Errorf("Name = %q, want %q", tool.Name, "diff")
}
}
func TestEditToolDef(t *testing.T) {
tool := EditToolDef()
if tool.Name != "edit" {
t.Errorf("Name = %q, want %q", tool.Name, "edit")
}
}
func TestMkdirToolDef(t *testing.T) {
tool := MkdirToolDef()
if tool.Name != "mkdir" {
t.Errorf("Name = %q, want %q", tool.Name, "mkdir")
}
}
func TestRemoveToolDef(t *testing.T) {
tool := RemoveToolDef()
if tool.Name != "remove" {
t.Errorf("Name = %q, want %q", tool.Name, "remove")
}
props := tool.Parameters["properties"].(map[string]any)
if _, ok := props["recursive"]; !ok {
t.Error("should have recursive property")
}
if _, ok := props["force"]; !ok {
t.Error("should have force property")
}
}
func TestCopyToolDef(t *testing.T) {
tool := CopyToolDef()
if tool.Name != "copy" {
t.Errorf("Name = %q, want %q", tool.Name, "copy")
}
props := tool.Parameters["properties"].(map[string]any)
if _, ok := props["source"]; !ok {
t.Error("should have source property")
}
if _, ok := props["destination"]; !ok {
t.Error("should have destination property")
}
}
func TestMoveToolDef(t *testing.T) {
tool := MoveToolDef()
if tool.Name != "move" {
t.Errorf("Name = %q, want %q", tool.Name, "move")
}
}
func TestExistsToolDef(t *testing.T) {
tool := ExistsToolDef()
if tool.Name != "exists" {
t.Errorf("Name = %q, want %q", tool.Name, "exists")
}
}

View File

@ -0,0 +1,162 @@
# Responsive Width Implementation
## Overview
This document describes the responsive width calculations implemented to prevent horizontal scrolling in the TUI chat interface.
## Width Calculation Hierarchy
### 1. Viewport Width (Primary Constraint)
The viewport is the main container for chat content. All other widths derive from this.
**Formula** (from `model.go:373-380`):
```go
viewportWidth := screenWidth - 1
if sidePanel.IsVisible() {
viewportWidth = screenWidth - panelWidth - 2
}
if viewportWidth < 20 {
viewportWidth = 20 // minimum width
}
```
**Breakdown**:
- `screenWidth - 1`: Full width minus right edge padding (when panel hidden)
- `screenWidth - panelWidth - 2`: Width minus panel and separator line (when panel visible)
- Minimum 20 characters to ensure readability
### 2. Content Width (Text Wrapping)
Used for wrapping text in `renderEntries()`, `renderUserMsg()`, `renderAssistantMsg()`, etc.
**Formula** (from `view.go:422-429`):
```go
contentW := screenWidth - 4
if sidePanel.IsVisible() {
contentW = screenWidth - panelWidth - 5
}
if contentW < 20 {
contentW = 20
}
```
**Breakdown**:
- `screenWidth - 4`: Full width with 2-char padding on each side
- `screenWidth - panelWidth - 5`: Accounts for panel, separator, and padding
- Minimum 20 characters
### 3. Markdown Width (Glamour Rendering)
Used for rendering markdown content via Glamour.
**Formula** (from `model.go:382-386`):
```go
markdownWidth := viewportWidth - 3
if markdownWidth < 20 {
markdownWidth = 20
}
```
**Breakdown**:
- Derived from viewport width minus 3 chars for padding/indentation
- Minimum 20 characters
### 4. Input Width
Matches viewport width exactly for unified appearance.
**Formula** (from `model.go:431`):
```go
input.SetWidth(viewportWidth)
```
## Panel Width Calculation
Panel width is dynamic based on screen size (from `model.go:365-371`):
```go
panelWidth := 30 // default
if screenWidth < 100 {
panelWidth = 25
} else if screenWidth > 160 {
panelWidth = 40
}
```
## Layout Constraints
### With Panel Visible
```
┌─────────────────────────────────────────────────┐
│ Panel (25-40) ││ Chat Viewport │
│ ││ (screen - panel - 2) │
│ ││ │
│ ││ Content wrapped to: │
│ ││ (screen - panel - 5) │
└─────────────────────────────────────────────────┘
```
### Without Panel
```
┌─────────────────────────────────────────────────┐
│ Chat Viewport (screen - 1) │
│ │
│ Content wrapped to: (screen - 4) │
└─────────────────────────────────────────────────┘
```
## Critical Invariants
The following invariants are enforced to prevent horizontal scrolling:
1. **viewportWidth ≤ screenWidth - 1** (or `screenWidth - panelWidth - 1` when panel visible)
2. **contentWidth ≤ viewportWidth**
3. **markdownWidth ≤ viewportWidth**
4. **All widths ≥ 20** (minimum readability)
## Test Coverage
Comprehensive tests in `width_test.go` verify:
- `TestViewportWidthCalculation`: Validates width calculations for various screen sizes
- `TestResponsiveWidthToggle`: Ensures widths adjust correctly when panel is toggled
- `TestMinimumWidthConstraints`: Verifies minimum width enforcement on small screens
- `TestRenderedTextWidth`: Tests actual text wrapping behavior
- `TestLayoutConsistency`: Exhaustive testing across screen sizes 40-200 chars
## Example Calculations
### 120-char screen with panel (30 chars)
```
Viewport: 120 - 30 - 2 = 88 chars
Content: 120 - 30 - 5 = 85 chars
Markdown: 88 - 3 = 85 chars
Input: 88 chars
Total: 30 (panel) + 1 (separator) + 88 (viewport) = 119 ✓
```
### 80-char screen without panel
```
Viewport: 80 - 1 = 79 chars
Content: 80 - 4 = 76 chars
Markdown: 79 - 3 = 76 chars
Input: 79 chars
Total: 79 chars ✓
```
### 40-char screen with panel (25 chars) - Edge Case
```
Viewport: 40 - 25 - 2 = 13 → 20 (minimum enforced)
Content: 40 - 25 - 5 = 10 → 20 (minimum enforced)
Markdown: 20 - 3 = 17 → 20 (minimum enforced)
Input: 20 chars
Total: 25 + 1 + 20 = 46 (exceeds screen, but minimum width takes priority)
```
**Note**: On very small screens (< 46 chars with panel), the minimum width constraints take precedence. Users should be advised to use larger terminal windows for optimal experience.
## Responsive Behavior
When the side panel is toggled:
1. Viewport width recalculates immediately
2. Content is re-wrapped to new width via `invalidateRenderedCache()`
3. Markdown renderer is recreated with new width
4. Input field resizes to match viewport
This ensures seamless responsive behavior without horizontal scrolling.

View File

@ -0,0 +1,240 @@
package tui
import (
"fmt"
"strings"
"charm.land/lipgloss/v2"
)
// AccessibilityHelper provides accessibility features like screen reader support.
type AccessibilityHelper struct {
isDark bool
styles AccessibilityStyles
speakFunc func(string) // Function to speak text (for screen readers)
announceFunc func(string) // Function to announce changes
}
// AccessibilityStyles holds styling.
type AccessibilityStyles struct {
Announce lipgloss.Style
}
// DefaultAccessibilityStyles returns default styles.
func DefaultAccessibilityStyles(isDark bool) AccessibilityStyles {
return AccessibilityStyles{
Announce: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")),
}
}
// NewAccessibilityHelper creates a new accessibility helper.
func NewAccessibilityHelper(isDark bool) *AccessibilityHelper {
return &AccessibilityHelper{
isDark: isDark,
styles: DefaultAccessibilityStyles(isDark),
}
}
// SetDark updates theme.
func (ah *AccessibilityHelper) SetDark(isDark bool) {
ah.isDark = isDark
ah.styles = DefaultAccessibilityStyles(isDark)
}
// SetSpeakFunc sets the function to speak text.
func (ah *AccessibilityHelper) SetSpeakFunc(f func(string)) {
ah.speakFunc = f
}
// SetAnnounceFunc sets the function to announce changes.
func (ah *AccessibilityHelper) SetAnnounceFunc(f func(string)) {
ah.announceFunc = f
}
// Announce announces a message to the user.
func (ah *AccessibilityHelper) Announce(format string, args ...string) {
if ah.announceFunc != nil {
msg := format
if len(args) > 0 {
msg = fmt.Sprintf(format, args)
}
ah.announceFunc(msg)
}
}
// Speak speaks text directly.
func (ah *AccessibilityHelper) Speak(text string) {
if ah.speakFunc != nil {
ah.speakFunc(text)
}
}
// DescribeEntry creates an accessibility description for a chat entry.
func (ah *AccessibilityHelper) DescribeEntry(entry ChatEntry, index int, toolCount int) string {
var desc strings.Builder
switch entry.Kind {
case "user":
desc.WriteString("User message")
case "assistant":
desc.WriteString("Assistant response")
if entry.ThinkingContent != "" {
desc.WriteString(", has thinking")
}
case "tool_group":
desc.WriteString("Tool execution")
if index >= 0 && index < toolCount {
desc.WriteString(", tool result")
}
case "system":
desc.WriteString("System message")
case "error":
desc.WriteString("Error")
}
// Add content preview
if entry.Content != "" {
preview := truncateStr(entry.Content, 50)
desc.WriteString(": ")
desc.WriteString(preview)
}
return desc.String()
}
// DescribeState creates an accessibility description of the current state.
func (ah *AccessibilityHelper) DescribeState(state State, model, mode string) string {
var desc string
switch state {
case StateIdle:
desc = "Ready"
case StateWaiting:
desc = "Waiting for response"
case StateStreaming:
desc = "Receiving response"
}
if model != "" {
desc += ", model: " + model
}
if mode != "" {
desc += ", mode: " + mode
}
return desc
}
// DescribeOverlay creates an accessibility description of the current overlay.
func (ah *AccessibilityHelper) DescribeOverlay(overlay OverlayKind) string {
switch overlay {
case OverlayNone:
return ""
case OverlayHelp:
return "Help overlay open"
case OverlayCompletion:
return "Completion menu open"
case OverlayModelPicker:
return "Model picker open"
case OverlayPlanForm:
return "Plan form open"
case OverlaySessionsPicker:
return "Sessions picker open"
default:
return "Overlay open"
}
}
// DescribeTools creates an accessibility description of tool status.
func (ah *AccessibilityHelper) DescribeTools(pending, total int) string {
if pending == 0 && total == 0 {
return "No tools running"
}
if pending > 0 {
return fmt.Sprintf("%d tool running", pending)
}
return fmt.Sprintf("%d tools completed", total)
}
// truncate truncates a string to maxLength.
func truncateStr(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
return s[:maxLength-3] + "..."
}
// AccessibilityLabel returns an accessibility label for a view element.
func AccessibilityLabel(role, name string, props ...string) string {
var b strings.Builder
b.WriteString(role)
b.WriteString(": ")
b.WriteString(name)
for _, p := range props {
b.WriteString(", ")
b.WriteString(p)
}
return b.String()
}
// FocusOrder represents the focus order for keyboard navigation.
type FocusOrder struct {
Current int
Items []Focusable
}
// Focusable is an interface for focusable elements.
type Focusable interface {
Focus() error
Blur() error
IsFocused() bool
}
// NewFocusOrder creates a new focus order.
func NewFocusOrder(items []Focusable) *FocusOrder {
return &FocusOrder{
Current: 0,
Items: items,
}
}
// Next moves focus to the next item.
func (fo *FocusOrder) Next() {
if len(fo.Items) == 0 {
return
}
fo.Current = (fo.Current + 1) % len(fo.Items)
fo.focusCurrent()
}
// Prev moves focus to the previous item.
func (fo *FocusOrder) Prev() {
if len(fo.Items) == 0 {
return
}
fo.Current--
if fo.Current < 0 {
fo.Current = len(fo.Items) - 1
}
fo.focusCurrent()
}
// Current returns the currently focused item.
func (fo *FocusOrder) CurrentItem() Focusable {
if fo.Current >= 0 && fo.Current < len(fo.Items) {
return fo.Items[fo.Current]
}
return nil
}
func (fo *FocusOrder) focusCurrent() {
for i, item := range fo.Items {
if i == fo.Current {
item.Focus()
} else {
item.Blur()
}
}
}

50
internal/tui/adapter.go Normal file
View File

@ -0,0 +1,50 @@
package tui
import (
"time"
tea "charm.land/bubbletea/v2"
)
// Adapter bridges the agent.Output interface to BubbleTea messages.
type Adapter struct {
program *tea.Program
}
// NewAdapter creates an Adapter that sends messages to the given program.
func NewAdapter(p *tea.Program) *Adapter {
return &Adapter{program: p}
}
func (a *Adapter) StreamText(text string) {
sendMsg(a.program, StreamTextMsg{Text: text})
}
func (a *Adapter) StreamDone(evalCount, promptTokens int) {
sendMsg(a.program, StreamDoneMsg{EvalCount: evalCount, PromptTokens: promptTokens})
}
func (a *Adapter) ToolCallStart(name string, args map[string]any) {
sendMsg(a.program, ToolCallStartMsg{Name: name, Args: args, StartTime: time.Now()})
}
func (a *Adapter) ToolCallResult(name string, result string, isError bool, duration time.Duration) {
sendMsg(a.program, ToolCallResultMsg{Name: name, Result: result, IsError: isError, Duration: duration})
}
func (a *Adapter) SystemMessage(msg string) {
sendMsg(a.program, SystemMessageMsg{Msg: msg})
}
func (a *Adapter) Error(msg string) {
// Log error for debugging
if len(msg) > 100 {
msg = msg[:97] + "..."
}
sendMsg(a.program, ErrorMsg{Msg: msg})
}
// Done sends the final completion message.
func (a *Adapter) Done() {
sendMsg(a.program, AgentDoneMsg{})
}

View File

@ -0,0 +1,127 @@
package tui
import (
"testing"
)
func TestLastAssistantContent(t *testing.T) {
t.Run("found", func(t *testing.T) {
m := newTestModel(t)
m.entries = []ChatEntry{
{Kind: "user", Content: "hello"},
{Kind: "assistant", Content: "world"},
}
got := m.lastAssistantContent()
if got != "world" {
t.Errorf("expected 'world', got %q", got)
}
})
t.Run("not_found", func(t *testing.T) {
m := newTestModel(t)
m.entries = []ChatEntry{
{Kind: "user", Content: "hello"},
{Kind: "system", Content: "info"},
}
got := m.lastAssistantContent()
if got != "" {
t.Errorf("expected empty string, got %q", got)
}
})
t.Run("returns_last", func(t *testing.T) {
m := newTestModel(t)
m.entries = []ChatEntry{
{Kind: "assistant", Content: "first"},
{Kind: "user", Content: "question"},
{Kind: "assistant", Content: "second"},
}
got := m.lastAssistantContent()
if got != "second" {
t.Errorf("expected 'second', got %q", got)
}
})
t.Run("empty_entries", func(t *testing.T) {
m := newTestModel(t)
m.entries = nil
got := m.lastAssistantContent()
if got != "" {
t.Errorf("expected empty string, got %q", got)
}
})
}
func TestCopyLast_OnlyWhenIdleAndEmpty(t *testing.T) {
t.Run("idle_empty_with_assistant", func(t *testing.T) {
m := newTestModel(t)
m.state = StateIdle
m.entries = []ChatEntry{
{Kind: "assistant", Content: "response text"},
}
m.input.SetValue("")
_, cmd := m.Update(ctrlKey('y'))
if cmd == nil {
t.Error("expected a command to be returned for copy")
}
})
t.Run("non_empty_input_no_trigger", func(t *testing.T) {
m := newTestModel(t)
m.state = StateIdle
m.entries = []ChatEntry{
{Kind: "assistant", Content: "response text"},
}
m.input.SetValue("some text")
_, cmd := m.Update(ctrlKey('y'))
// When input is non-empty, ctrl+y should not trigger copy.
// The cmd may be non-nil (textarea update), but no copy should occur.
// Verify no system message about clipboard appears.
if cmd != nil {
msg := cmd()
if sysMsg, ok := msg.(SystemMessageMsg); ok {
if sysMsg.Msg == "Copied to clipboard." {
t.Error("should not trigger copy when input is non-empty")
}
}
}
})
t.Run("non_idle_no_trigger", func(t *testing.T) {
m := newTestModel(t)
m.state = StateStreaming
m.entries = []ChatEntry{
{Kind: "assistant", Content: "response text"},
}
m.input.SetValue("")
initialEntryCount := len(m.entries)
m.Update(ctrlKey('y'))
// Should not add any system message about clipboard
if len(m.entries) > initialEntryCount {
t.Error("should not trigger copy when not idle")
}
})
t.Run("no_assistant_entries", func(t *testing.T) {
m := newTestModel(t)
m.state = StateIdle
m.entries = []ChatEntry{
{Kind: "user", Content: "hello"},
}
m.input.SetValue("")
_, cmd := m.Update(ctrlKey('y'))
// Should not return a copy command when there's no assistant content
if cmd != nil {
msg := cmd()
if sysMsg, ok := msg.(SystemMessageMsg); ok {
if sysMsg.Msg == "Copied to clipboard." {
t.Error("should not trigger copy when no assistant content")
}
}
}
})
}

79
internal/tui/commit.go Normal file
View File

@ -0,0 +1,79 @@
package tui
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"ai-agent/internal/llm"
tea "charm.land/bubbletea/v2"
)
func runCommit(client llm.Client, model string, extraMsg string) tea.Cmd {
return func() tea.Msg {
diff, err := gitDiff()
if err != nil {
return CommitResultMsg{Err: fmt.Errorf("git diff: %w", err)}
}
if strings.TrimSpace(diff) == "" {
return CommitResultMsg{Err: fmt.Errorf("no staged changes (use `git add` first)")}
}
if len(diff) > 8000 {
diff = diff[:8000] + "\n... (truncated)"
}
prompt := "Write a concise git commit message for the following staged diff. " +
"Return ONLY the commit message, no explanation or markdown. " +
"Use conventional commit style (e.g. feat:, fix:, refactor:). " +
"Keep the first line under 72 characters."
if extraMsg != "" {
prompt += "\n\nAdditional context: " + extraMsg
}
prompt += "\n\nDiff:\n" + diff
var msgBuf strings.Builder
err = client.ChatStream(context.Background(), llm.ChatOptions{
Messages: []llm.Message{{Role: "user", Content: prompt}},
System: "You are a helpful assistant that writes git commit messages.",
}, func(chunk llm.StreamChunk) error {
if chunk.Text != "" {
msgBuf.WriteString(chunk.Text)
}
return nil
})
if err != nil {
return CommitResultMsg{Err: fmt.Errorf("LLM error: %w", err)}
}
commitMsg := strings.TrimSpace(msgBuf.String())
if commitMsg == "" {
return CommitResultMsg{Err: fmt.Errorf("LLM returned empty commit message")}
}
commitMsg += fmt.Sprintf("\n\nAssisted-by: ai-agent (%s)", model)
if err := gitCommit(commitMsg); err != nil {
return CommitResultMsg{Err: fmt.Errorf("git commit: %w", err)}
}
return CommitResultMsg{Message: commitMsg}
}
}
func gitDiff() (string, error) {
cmd := exec.Command("git", "diff", "--cached", "--stat")
stat, _ := cmd.Output()
cmd = exec.Command("git", "diff", "--cached")
out, err := cmd.Output()
if err != nil {
return "", err
}
return string(stat) + "\n" + string(out), nil
}
func gitCommit(msg string) error {
cmd := exec.Command("git", "commit", "-m", msg)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("%s: %s", err, stderr.String())
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More