commit 8dc496b62616d6e24f45ab00a0b788fc64a27572 Author: admin Date: Sun Mar 8 15:40:34 2026 +0700 first commit diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..67b09be --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,31 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.25" + cache: true + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + + - name: Run go vet + run: go vet ./... + + - name: Run tests + run: go test ./... -count=1 -race diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml new file mode 100644 index 0000000..e424fdb --- /dev/null +++ b/.gitea/workflows/release.yml @@ -0,0 +1,37 @@ +name: Release + +on: + push: + tags: + - "v*" + +permissions: + contents: write + +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.25" + cache: true + + - name: Run tests + run: go test ./... -count=1 + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: latest + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GIT_TOKEN }} + HOMEBREW_TAP_TOKEN: ${{ secrets.HOMEBREW_TAP_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..203ad47 --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# Binaries +bin/ +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +ai-agent + +# Vector embeddings +.vecgrep/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Temporary files +tmp/ +temp/ +*.tmp diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..e11b9a9 --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,71 @@ +version: 2 + +project_name: ai-agent + +before: + hooks: + - go mod tidy + +builds: + - id: ai-agent + main: ./cmd/ai-agent + binary: ai-agent + goos: + - linux + - darwin + - windows + goarch: + - amd64 + - arm64 + env: + - CGO_ENABLED=0 + ldflags: + - -s -w -X main.version={{.Version}} + +archives: + - id: default + name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}" + format_overrides: + - goos: windows + formats: [zip] + +checksum: + name_template: "checksums.txt" + +snapshot: + version_template: "{{ incpatch .Version }}-next" + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" + - "^ci:" + - "^chore:" + - Merge pull request + - Merge branch + +release: + github: + owner: abdul-hamid-achik + name: ai-agent + draft: false + prerelease: auto + name_template: "v{{.Version}}" + +brews: + - name: ai-agent + homepage: https://github.com/abdul-hamid-achik/ai-agent + description: "Local AI agent with TUI, powered by Ollama and MCP servers" + license: MIT + directory: Formula + repository: + owner: abdul-hamid-achik + name: homebrew-tap + token: "{{ .Env.HOMEBREW_TAP_TOKEN }}" + install: | + bin.install "ai-agent" + test: | + system "#{bin}/ai-agent", "--version" + skip_upload: "{{ if .Env.HOMEBREW_TAP_TOKEN }}false{{ else }}true{{ end }}" diff --git a/README.md b/README.md new file mode 100644 index 0000000..597929c --- /dev/null +++ b/README.md @@ -0,0 +1,375 @@ +# ai-agent + +A fully local AI coding agent for the terminal -- powered by Ollama and small models, with intelligent routing, cross-session memory, and MCP tool integration. + +``` + ╭──────────────────────────────────────────╮ + │ ai-agent │ + │ 100% local. Your data never leaves. │ + │ │ + │ ASK -- PLAN -- BUILD │ + │ 0.8B 4B 9B │ + ╰──────────────────────────────────────────╯ +``` + +--- + +## What is ai-agent? + +- **100% local** -- runs entirely on your machine via Ollama. No API keys, no cloud, no data leaving your device. +- **Small model optimized** -- intelligent routing across Qwen 3.5 variants (0.8B / 2B / 4B / 9B) based on task complexity. +- **Three operational modes** -- ASK for quick answers, PLAN for design and reasoning, BUILD for full execution with tools. +- **MCP native** -- first-class Model Context Protocol support (STDIO, SSE, Streamable HTTP) for extensible tool integration. +- **Beautiful TUI** -- built with Charm's BubbleTea v2, Lip Gloss v2, and Glamour for rich markdown rendering in the terminal. +- **Infinite Context Engine (ICE)** -- cross-session vector retrieval that surfaces relevant past conversations automatically. +- **Auto-Memory Detection** -- the LLM extracts facts, decisions, preferences, and TODOs from conversations and persists them. +- **Thinking/CoT extraction** -- chain-of-thought reasoning is captured and displayed in collapsible blocks. +- **Skills system** -- load `.md` skill files with YAML frontmatter to inject domain-specific instructions into the system prompt. +- **Agent profiles** -- configure per-project agents with custom system prompts, skills, and MCP servers. + +--- + +## Quick Start + +### Prerequisites + +- [Go 1.25+](https://go.dev/dl/) +- [Ollama](https://ollama.ai/) running locally +- [Task](https://taskfile.dev/) (optional, for build commands) + +### Install + +Pull the required model, then install: + +```bash +ollama pull qwen3.5:2b + +go install github.com/abdul-hamid-achik/ai-agent/cmd/ai-agent@latest +``` + +For the full model routing suite (optional): + +```bash +ollama pull qwen3.5:0.8b +ollama pull qwen3.5:4b +ollama pull qwen3.5:9b +ollama pull nomic-embed-text # for ICE vector embeddings +``` + +### Configure + +Create a config file (optional -- defaults work out of the box): + +```bash +mkdir -p ~/.config/ai-agent +cp config.example.yaml ~/.config/ai-agent/config.yaml +``` + +### Run + +```bash +ai-agent +``` + +Or from source: + +```bash +task dev +``` + +--- + +## Features + +### Model Routing + +ai-agent automatically selects the right model size for the task at hand. Simple questions go to the fast 2B model; complex multi-step reasoning escalates to the 9B model. The router analyzes query complexity using keyword heuristics and word count. + +| Complexity | Model | Speed | Use Cases | +|------------|---------------|--------|----------------------------------------------| +| Simple | qwen3.5:2b | 2.5x | Quick answers, simple tool use, single edits | +| Medium | qwen3.5:4b | 1.5x | Code completion, refactoring, explanations | +| Complex | qwen3.5:9b | 1.0x | Multi-step reasoning, debugging, code review | + +The fallback chain ensures graceful degradation if a model is not available: `2b -> 4b -> 9b`. + +### Three Modes: ASK / PLAN / BUILD + +Cycle between modes with `shift+tab`. Each mode configures a different system prompt and preferred model tier. + +- **ASK** -- Direct, concise answers. Routes to the fastest available model. Tools available for file reads and searches. +- **PLAN** -- Design and planning. Breaks tasks into steps. Reads and explores with tools but does not modify files. +- **BUILD** -- Full execution mode. Uses the most capable model. All tools enabled including writes and modifications. + +### MCP Tool Integration + +Connect any MCP-compatible tool server. Supports all three transport protocols: + +- **STDIO** -- Launch tools as subprocesses (default). +- **SSE** -- Connect to Server-Sent Events endpoints. +- **Streamable HTTP** -- Connect to HTTP-based MCP servers. + +Tool calls execute in parallel when possible. The registry handles graceful failure if a server becomes unavailable. + +### Infinite Context Engine (ICE) + +ICE embeds each conversation turn using `nomic-embed-text` and stores them persistently. On every new message, it retrieves the most relevant past conversations via cosine similarity and injects them into the system prompt -- giving the agent memory that spans across sessions. + +### Auto-Memory Detection + +After each conversation turn, a background process analyzes the exchange and extracts structured memories: + +- **FACT** -- objective information the user shared +- **DECISION** -- choices made during the conversation +- **PREFERENCE** -- user preferences and working styles +- **TODO** -- action items and follow-ups + +Memories are stored in `~/.config/ai-agent/memories.json` with tag-weighted search scoring (tags weighted 3x over content). + +### Thinking/CoT Display + +When the model produces chain-of-thought reasoning, ai-agent captures it and renders it in collapsible blocks. Toggle the display with `ctrl+t`. + +### Skills System + +Drop `.md` files with YAML frontmatter into the skills directory to inject domain-specific instructions: + +``` +~/.config/ai-agent/skills/ +``` + +Manage active skills with `/skill list`, `/skill activate `, and `/skill deactivate `. + +### Agent Profiles + +Create per-project or per-domain agent profiles: + +``` +~/.agents// + AGENT.md # System prompt additions + SKILL.md # Agent-specific skills + mcp.yaml # Agent-specific MCP servers +``` + +Switch profiles with `/agent ` or `/agent list`. + +--- + +## Configuration + +### File Locations + +Config is searched in order (first match wins): + +1. `./ai-agent.yaml` (project-local) +2. `~/.config/ai-agent/config.yaml` (user-global) + +### Annotated Example + +```yaml +ollama: + model: "qwen3.5:2b" # Default model + base_url: "http://localhost:11434" # Ollama API endpoint + num_ctx: 262144 # Context window size + +# Skills directory (default: ~/.config/ai-agent/skills/) +# skills_dir: "/path/to/custom/skills" + +# MCP tool servers +servers: + # STDIO transport (default) + - name: noted + command: noted + args: [mcp] + + # SSE transport + # - name: remote-server + # transport: sse + # url: "http://localhost:8811" + + # Streamable HTTP transport + # - name: streamable-server + # transport: streamable-http + # url: "http://localhost:8812/mcp" + +# ICE configuration +# ice: +# enabled: true +# embed_model: "nomic-embed-text" +# store_path: "~/.config/ai-agent/conversations.json" +``` + +### Environment Variables + +| Variable | Description | Overrides | +|--------------------------|------------------------------|----------------------| +| `OLLAMA_HOST` | Ollama API base URL | `ollama.base_url` | +| `LOCAL_AGENT_MODEL` | Default model name | `ollama.model` | +| `LOCAL_AGENT_AGENTS_DIR` | Path to agents directory | `agents.dir` | + +--- + +## Keyboard Shortcuts + +### Input + +| Key | Action | +|-----------------|-------------------------------| +| `enter` | Send message | +| `shift+enter` | Insert new line | +| `up` / `down` | Browse input history | +| `shift+tab` | Cycle mode (ASK/PLAN/BUILD) | +| `ctrl+m` | Quick model switch | + +### Navigation + +| Key | Action | +|------------------|------------------------------| +| `pgup` / `pgdown`| Scroll conversation | +| `ctrl+u` | Half-page scroll up | +| `ctrl+d` | Half-page scroll down | + +### Display + +| Key | Action | +|-----------------|-------------------------------| +| `?` | Toggle help overlay | +| `t` | Expand/collapse tool calls | +| `space` | Toggle last tool details | +| `ctrl+t` | Toggle thinking/CoT display | +| `ctrl+y` | Copy last response | + +### Control + +| Key | Action | +|-----------------|-------------------------------| +| `esc` | Cancel streaming / close overlay | +| `ctrl+c` | Quit | +| `ctrl+l` | Clear screen | +| `ctrl+n` | New conversation | + +--- + +## Slash Commands + +| Command | Description | +|--------------------------------------|-----------------------------------| +| `/help` | Show help overlay | +| `/clear` | Clear conversation history | +| `/new` | Start a fresh conversation | +| `/model [name\|list\|fast\|smart]` | Show or switch models | +| `/models` | Open model picker | +| `/agent [name\|list]` | Show or switch agent profile | +| `/load ` | Load markdown file as context | +| `/unload` | Remove loaded context | +| `/skill [list\|activate\|deactivate] [name]` | Manage skills | +| `/servers` | List connected MCP servers | +| `/ice` | Show ICE engine status | +| `/sessions` | Browse saved sessions | +| `/exit` | Quit | + +--- + +## Architecture + +``` +cmd/ai-agent/ Entry point +internal/ + agent/ ReAct loop orchestration + llm/ LLM abstraction (OllamaClient, ModelManager) + mcp/ MCP server registry + config/ YAML config, env overrides, Router + ice/ Infinite Context Engine + memory/ Persistent key-value store + skill/ Skill file loader + command/ Slash command registry + tui/ BubbleTea v2 terminal UI + logging/ Structured logging +``` + +### Request Flow + +``` +User Input + | + v +agent.AddUserMessage() + | + v +ICE embeds message, retrieves relevant past context + | + v +System prompt assembled (tools + skills + context + ICE + memory) + | + v +Router selects model based on task complexity + | + v +LLM streams response via ChatStream() + | + v +Tool calls routed through MCP registry (parallel execution) + | + v +ReAct loop continues (up to 10 iterations) until final text + | + v +Conversation compacted if token budget exceeded +Auto-memory detection runs in background +``` + +### Key Interfaces + +- `llm.Client` -- pluggable LLM provider (`ChatStream`, `Ping`, `Embed`) +- `agent.Output` -- streaming callbacks for TUI rendering +- `command.Registry` -- extensible slash command dispatch + +### Concurrency + +`sync.RWMutex` protects shared state in `ModelManager`, `mcp.Registry`, and `memory.Store`. Auto-memory detection and MCP connections run as background goroutines. Tool calls execute in parallel when independent. + +--- + +## Comparison + +| Feature | ai-agent | opencode | crush | +|----------------------------------|:-----------:|:--------:|:-----:| +| 100% local (no API keys) | Yes | No | Yes | +| Model routing by task complexity | Yes | No | No | +| Operational modes (ASK/PLAN/BUILD)| Yes | No | No | +| Cross-session memory (ICE) | Yes | No | No | +| Auto-memory detection | Yes | No | No | +| Thinking/CoT extraction | Yes | Yes | No | +| MCP tool support | Yes | Yes | Yes | +| Skills system | Yes | No | No | +| Plan form overlay | Yes | No | No | +| Small model optimized | Yes | No | No | +| TUI chat interface | Yes | Yes | Yes | +| Language | Go | TypeScript| Go | + +--- + +## Building + +This project uses [Task](https://taskfile.dev/) as its build tool. + +```bash +task build # Compile to bin/ai-agent +task run # Build and run +task dev # Quick run via go run ./cmd/ai-agent +task test # Run all tests: go test ./... +task lint # Run golangci-lint run ./... +task clean # Remove bin/ directory +``` + +Run a single test: + +```bash +go test ./internal/agent/ -run TestFunctionName +``` + +--- + +## License + +MIT diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..dbc0cf8 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,33 @@ +version: '3' + +tasks: + build: + desc: Build the binary + cmds: + - go build -o bin/ai-agent ./cmd/ai-agent + + run: + desc: Build and run + deps: [build] + cmds: + - ./bin/ai-agent + + dev: + desc: Run with go run + cmds: + - go run ./cmd/ai-agent + + lint: + desc: Run linter + cmds: + - golangci-lint run ./... + + test: + desc: Run tests + cmds: + - go test ./... + + clean: + desc: Clean build artifacts + cmds: + - rm -rf bin/ diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..712267e --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,70 @@ +# ai-agent configuration +# Place at ~/.config/ai-agent/config.yaml or ./ai-agent.yaml + +ollama: + model: "qwen3.5:2b" + base_url: "http://localhost:11434" + # num_ctx: контекст в токенах. Большие значения (например 262144) требуют много RAM/VRAM; + # при ошибке "requires more system memory" уменьшите до 32768 или 8192. + num_ctx: 262144 + +# Model routing suite (optional - for automatic model tier selection) +# models: +# - name: "qwen3.5:0.8b" +# size: "0.8B" +# capability: "simple" +# - name: "qwen3.5:2b" +# size: "2B" +# capability: "medium" +# - name: "qwen3.5:4b" +# size: "4B" +# capability: "complex" +# - name: "qwen3.5:9b" +# size: "9B" +# capability: "advanced" + +# Embedding model for ICE (Infinite Context Engine) +# embed_model: "nomic-embed-text" + +# UI Configuration (optional) +# ui: +# # Theme: "dark", "light", or "auto" (default: auto - detects system theme) +# # The Nord theme will be applied automatically: +# # - Dark mode: Nord Polar Night (dark blues) + Frost (light text) + Aurora (colorful accents) +# # - Light mode: Nord Aurora Light (white background + same colorful accents) +# theme: "auto" +# # Syntax highlighting theme for code blocks (via Chroma) +# # Options: monokai, github, dracula, one-dark, solarized-dark, etc. +# # Default: monokai (dark), github (light) +# code_theme: "monokai" +# # Show line numbers in code blocks (default: false) +# code_line_numbers: false +# # Compact mode threshold (terminal width < this value triggers compact mode) +# compact_threshold: 80 + +# MCP tool servers +servers: + # STDIO transport (default) + - name: noted + command: noted + args: [mcp] + + # tinyvault — requires `tvault init` before first use + # - name: tinyvault + # command: tvault + # args: [mcp-server] + + # Docker MCP Gateway (requires Docker Desktop with MCP enabled) + # - name: docker-gateway + # command: docker + # args: ["mcp", "gateway", "run"] + + # SSE transport + # - name: remote-server + # transport: sse + # url: "http://localhost:8811" + + # Streamable HTTP transport + # - name: streamable-server + # transport: streamable-http + # url: "http://localhost:8812/mcp" diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..a75e5aa --- /dev/null +++ b/config.yaml @@ -0,0 +1,65 @@ +ollama: + model: "qwen3.5:27b" + base_url: "http://10.2.18.188:11434" + # 262144 требует ~22.6 GiB; при 20.4 GiB доступно — уменьшаем контекст (32768 укладывается в память) + num_ctx: 32768 + +# Model routing suite (optional - for automatic model tier selection) +# models: +# - name: "qwen3.5:0.8b" +# size: "0.8B" +# capability: "simple" +# - name: "qwen3.5:2b" +# size: "2B" +# capability: "medium" +# - name: "qwen3.5:4b" +# size: "4B" +# capability: "complex" +# - name: "qwen3.5:9b" +# size: "9B" +# capability: "advanced" + +# Embedding model for ICE (Infinite Context Engine) +# embed_model: "nomic-embed-text" + +# UI Configuration (optional) +# ui: +# # Theme: "dark", "light", or "auto" (default: auto - detects system theme) +# # The Nord theme will be applied automatically: +# # - Dark mode: Nord Polar Night (dark blues) + Frost (light text) + Aurora (colorful accents) +# # - Light mode: Nord Aurora Light (white background + same colorful accents) +# theme: "auto" +# # Syntax highlighting theme for code blocks (via Chroma) +# # Options: monokai, github, dracula, one-dark, solarized-dark, etc. +# # Default: monokai (dark), github (light) +# code_theme: "monokai" +# # Show line numbers in code blocks (default: false) +# code_line_numbers: false +# # Compact mode threshold (terminal width < this value triggers compact mode) +# compact_threshold: 80 + +# MCP tool servers +servers: + - name: noted + command: noted + args: [mcp] + + # tinyvault — requires `tvault init` before first use + # - name: tinyvault + # command: tvault + # args: [mcp-server] + + # Docker MCP Gateway (requires Docker Desktop with MCP enabled) + # - name: docker-gateway + # command: docker + # args: ["mcp", "gateway", "run"] + + # SSE transport + # - name: remote-server + # transport: sse + # url: "http://localhost:8811" + + # Streamable HTTP transport + # - name: streamable-server + # transport: streamable-http + # url: "http://localhost:8812/mcp" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3994c08 --- /dev/null +++ b/go.mod @@ -0,0 +1,72 @@ +module ai-agent + +go 1.25.5 + +require ( + charm.land/bubbles/v2 v2.0.0 + charm.land/bubbletea/v2 v2.0.1 + charm.land/lipgloss/v2 v2.0.0 + github.com/atotto/clipboard v0.1.4 + github.com/charmbracelet/glamour v0.10.0 + github.com/charmbracelet/log v0.4.2 + github.com/lucasb-eyer/go-colorful v1.3.0 + github.com/modelcontextprotocol/go-sdk v1.3.1 + github.com/ollama/ollama v0.17.4 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/alecthomas/chroma/v2 v2.14.0 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/aymerick/douceur v0.2.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/charmbracelet/colorprofile v0.4.2 // indirect + github.com/charmbracelet/harmonica v0.2.0 // indirect + github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect + github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 // indirect + github.com/charmbracelet/x/ansi v0.11.6 // indirect + github.com/charmbracelet/x/cellbuf v0.0.15 // indirect + github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/charmbracelet/x/termios v0.1.1 // indirect + github.com/charmbracelet/x/windows v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.11.0 // indirect + github.com/clipperhouse/uax29/v2 v2.7.0 // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/css v1.0.1 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.20 // indirect + github.com/microcosm-cc/bluemonday v1.0.27 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/reflow v0.3.0 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.3 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + github.com/yuin/goldmark v1.7.8 // indirect + github.com/yuin/goldmark-emoji v1.0.5 // indirect + golang.org/x/crypto v0.43.0 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/net v0.46.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/term v0.36.0 // indirect + golang.org/x/text v0.30.0 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.46.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4b3597f --- /dev/null +++ b/go.sum @@ -0,0 +1,164 @@ +charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s= +charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI= +charm.land/bubbletea/v2 v2.0.0 h1:p0d6CtWyJXJ9GfzMpUUqbP/XUUhhlk06+vCKWmox1wQ= +charm.land/bubbletea/v2 v2.0.0/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ= +charm.land/bubbletea/v2 v2.0.1 h1:B8e9zzK7x9JJ+XvHGF4xnYu9Xa0E0y0MyggY6dbaCfQ= +charm.land/bubbletea/v2 v2.0.1/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ= +charm.land/lipgloss/v2 v2.0.0 h1:sd8N/B3x892oiOjFfBQdXBQp3cAkvjGaU5TvVZC3ivo= +charm.land/lipgloss/v2 v2.0.0/go.mod h1:w6SnmsBFBmEFBodiEDurGS/sdUY/u1+v72DqUzc6J14= +github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= +github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= +github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= +github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.4.0 h1:TKnLPh7IbnizJIBKFWa9mKayRUBQ9Kh1BPCk6w2PnYM= +github.com/aymanbagabas/go-udiff v0.4.0/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w= +github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= +github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/charmbracelet/colorprofile v0.4.2 h1:BdSNuMjRbotnxHSfxy+PCSa4xAmz7szw70ktAtWRYrY= +github.com/charmbracelet/colorprofile v0.4.2/go.mod h1:0rTi81QpwDElInthtrQ6Ni7cG0sDtwAd4C4le060fT8= +github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY= +github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk= +github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ= +github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao= +github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE= +github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA= +github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig= +github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw= +github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8 h1:eyFRbAmexyt43hVfeyBofiGSEmJ7krjLOYt/9CF5NKA= +github.com/charmbracelet/ultraviolet v0.0.0-20260205113103-524a6607adb8/go.mod h1:SQpCTRNBtzJkwku5ye4S3HEuthAlGy2n9VXZnWkEW98= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA= +github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I= +github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI= +github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= +github.com/charmbracelet/x/termios v0.1.1/go.mod h1:rB7fnv1TgOPOyyKRJ9o+AsTU/vK5WHJ2ivHeut/Pcwo= +github.com/charmbracelet/x/windows v0.2.2 h1:IofanmuvaxnKHuV04sC0eBy/smG6kIKrWG2/jYn2GuM= +github.com/charmbracelet/x/windows v0.2.2/go.mod h1:/8XtdKZzedat74NQFn0NGlGL4soHB0YQZrETF96h75k= +github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8= +github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= +github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= +github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= +github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= +github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjcQQaQ= +github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk= +github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= +github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI= +github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= +github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/ollama/ollama v0.17.4 h1:X3KNm9x4BlHqk/AXMGtC7pBAFd46nmJmy8yDaBZLo9s= +github.com/ollama/ollama v0.17.4/go.mod h1:tCX4IMV8DHjl3zY0THxuEkpWDZSOchJpzTuLACpMwFw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= +github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w= +github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= +github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk= +github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= +golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= diff --git a/internal/agent/agent.go b/internal/agent/agent.go new file mode 100644 index 0000000..c4fb2c8 --- /dev/null +++ b/internal/agent/agent.go @@ -0,0 +1,194 @@ +package agent + +import ( + "sync" + "time" + + "ai-agent/internal/config" + "ai-agent/internal/ice" + "ai-agent/internal/llm" + "ai-agent/internal/mcp" + "ai-agent/internal/memory" + "ai-agent/internal/permission" +) + +type Agent struct { + mu sync.RWMutex + llmClient llm.Client + registry *mcp.Registry + messages []llm.Message + skillContent string + loadedCtx string + numCtx int + memoryStore *memory.Store + iceEngine *ice.Engine + router *config.Router + modePrefix string + toolsEnabled bool + workDir string + ignoreContent string + permChecker *permission.Checker + approvalCallback func(permission.ApprovalRequest) + toolsConfig config.ToolsConfig +} + +func New(llmClient llm.Client, registry *mcp.Registry, numCtx int) *Agent { + return &Agent{ + llmClient: llmClient, + registry: registry, + numCtx: numCtx, + toolsEnabled: true, + } +} + +func (a *Agent) SetRouter(router *config.Router) { + a.router = router +} + +func (a *Agent) SetModeContext(prefix string, allowTools bool) { + a.modePrefix = prefix + a.toolsEnabled = allowTools +} + +func (a *Agent) AppendLoadedContext(content string) { + if a.loadedCtx == "" { + a.loadedCtx = content + } else { + a.loadedCtx += content + } +} + +func (a *Agent) Router() *config.Router { + return a.router +} + +func (a *Agent) NumCtx() int { + return a.numCtx +} + +func (a *Agent) SetMemoryStore(store *memory.Store) { + a.memoryStore = store +} + +func (a *Agent) AddUserMessage(content string) { + a.mu.Lock() + defer a.mu.Unlock() + a.messages = append(a.messages, llm.Message{ + Role: "user", + Content: content, + }) +} + +func (a *Agent) Messages() []llm.Message { + a.mu.RLock() + defer a.mu.RUnlock() + return a.messages +} + +func (a *Agent) ClearHistory() { + a.mu.Lock() + defer a.mu.Unlock() + a.messages = nil +} + +func (a *Agent) AppendMessage(msg llm.Message) { + a.mu.Lock() + defer a.mu.Unlock() + a.messages = append(a.messages, msg) +} + +func (a *Agent) ReplaceMessages(msgs []llm.Message) { + a.mu.Lock() + defer a.mu.Unlock() + a.messages = msgs +} + +func (a *Agent) SetSkillContent(content string) { + a.skillContent = content +} + +func (a *Agent) SetLoadedContext(content string) { + a.loadedCtx = content +} + +func (a *Agent) Model() string { + return a.llmClient.Model() +} + +func (a *Agent) LLMClient() llm.Client { + return a.llmClient +} + +func (a *Agent) ToolCount() int { + count := a.registry.ToolCount() + if a.memoryStore != nil { + count += 2 + } + return count +} + +func (a *Agent) ServerCount() int { + return a.registry.ServerCount() +} + +func (a *Agent) ServerNames() []string { + return a.registry.ServerNames() +} + +func (a *Agent) SetWorkDir(dir string) { + a.workDir = dir +} + +func (a *Agent) SetIgnoreContent(content string) { + a.ignoreContent = content +} + +func (a *Agent) SetPermissionChecker(checker *permission.Checker) { + a.permChecker = checker +} + +func (a *Agent) SetApprovalCallback(cb func(permission.ApprovalRequest)) { + a.approvalCallback = cb +} + +func (a *Agent) SetICEEngine(engine *ice.Engine) { + a.iceEngine = engine +} + +func (a *Agent) ICEEngine() *ice.Engine { + return a.iceEngine +} + +func (a *Agent) SetToolsConfig(cfg config.ToolsConfig) { + a.toolsConfig = cfg +} + +func (a *Agent) MaxIterations() int { + if a.toolsConfig.MaxIterations > 0 { + return a.toolsConfig.MaxIterations + } + return 10 +} + +func (a *Agent) ToolTimeout() time.Duration { + if a.toolsConfig.Timeout != "" { + if d, err := time.ParseDuration(a.toolsConfig.Timeout); err == nil { + return d + } + } + return 30 * time.Second +} + +func (a *Agent) MaxGrepResults() int { + if a.toolsConfig.MaxGrepResults > 0 { + return a.toolsConfig.MaxGrepResults + } + return 500 +} + +func (a *Agent) Close() { + if a.iceEngine != nil { + _ = a.iceEngine.Flush() + } + a.registry.Close() +} diff --git a/internal/agent/compact.go b/internal/agent/compact.go new file mode 100644 index 0000000..30bb6e8 --- /dev/null +++ b/internal/agent/compact.go @@ -0,0 +1,95 @@ +package agent + +import ( + "context" + "fmt" + "strings" + + "ai-agent/internal/llm" +) + +const compactThreshold = 0.75 +const keepMessages = 4 + +func (a *Agent) shouldCompact(promptTokens int) bool { + if a.numCtx <= 0 || promptTokens <= 0 { + return false + } + return float64(promptTokens) > float64(a.numCtx)*compactThreshold +} + +func (a *Agent) compact(ctx context.Context, out Output) bool { + a.mu.RLock() + msgCount := len(a.messages) + a.mu.RUnlock() + if msgCount <= keepMessages+1 { + return false + } + a.mu.RLock() + splitAt := msgCount - keepMessages + older := make([]llm.Message, splitAt) + copy(older, a.messages[:splitAt]) + recent := make([]llm.Message, keepMessages) + copy(recent, a.messages[splitAt:]) + a.mu.RUnlock() + summary := summarizeMessages(older) + var summaryBuf strings.Builder + err := a.llmClient.ChatStream(ctx, llm.ChatOptions{ + Messages: []llm.Message{ + {Role: "user", Content: summary}, + }, + System: "You are a conversation summarizer. Produce a concise summary of the conversation so far, capturing all key facts, decisions, tool results, and user requests. Keep it under 500 words. Output only the summary, no preamble.", + }, func(chunk llm.StreamChunk) error { + if chunk.Text != "" { + summaryBuf.WriteString(chunk.Text) + } + return nil + }) + if err != nil { + out.Error(fmt.Sprintf("compaction failed: %v", err)) + return false + } + summaryText := summaryBuf.String() + if summaryText == "" { + return false + } + if a.iceEngine != nil { + if err := a.iceEngine.IndexSummary(ctx, summaryText); err != nil { + out.Error(fmt.Sprintf("ICE summary indexing failed: %v", err)) + } + } + compacted := make([]llm.Message, 0, 1+len(recent)) + compacted = append(compacted, llm.Message{ + Role: "user", + Content: fmt.Sprintf("[Conversation summary: %s]", summaryText), + }) + compacted = append(compacted, recent...) + a.ReplaceMessages(compacted) + out.SystemMessage(fmt.Sprintf("Context compacted: %d messages summarized, %d kept", len(older), len(recent))) + return true +} + +func summarizeMessages(msgs []llm.Message) string { + var b strings.Builder + b.WriteString("Summarize this conversation:\n\n") + for _, msg := range msgs { + switch msg.Role { + case "user": + fmt.Fprintf(&b, "User: %s\n", msg.Content) + case "assistant": + if msg.Content != "" { + fmt.Fprintf(&b, "Assistant: %s\n", msg.Content) + } + for _, tc := range msg.ToolCalls { + fmt.Fprintf(&b, "Assistant called tool %s(%s)\n", tc.Name, FormatToolArgs(tc.Arguments)) + } + case "tool": + content := msg.Content + if len(content) > 300 { + content = content[:297] + "..." + } + fmt.Fprintf(&b, "Tool %s result: %s\n", msg.ToolName, content) + } + } + return b.String() +} diff --git a/internal/agent/compact_test.go b/internal/agent/compact_test.go new file mode 100644 index 0000000..1c24d96 --- /dev/null +++ b/internal/agent/compact_test.go @@ -0,0 +1,143 @@ +package agent + +import ( + "strings" + "testing" + "time" + + "ai-agent/internal/llm" + "ai-agent/internal/mcp" +) + +type mockOutput struct { + texts []string + errors []string + sysMsgs []string +} + +func (m *mockOutput) StreamText(text string) { + m.texts = append(m.texts, text) +} + +func (m *mockOutput) StreamDone(_, _ int) {} + +func (m *mockOutput) ToolCallStart(_ string, _ map[string]any) {} + +func (m *mockOutput) ToolCallResult(_ string, _ string, _ bool, _ time.Duration) {} + +func (m *mockOutput) SystemMessage(msg string) { + m.sysMsgs = append(m.sysMsgs, msg) +} + +func (m *mockOutput) Error(msg string) { + m.errors = append(m.errors, msg) +} + +func TestShouldCompact(t *testing.T) { + tests := []struct { + name string + numCtx int + promptTokens int + want bool + }{ + {"below 75%", 1000, 749, false}, + {"above 75%", 1000, 751, true}, + {"exactly 75% (strict >)", 1000, 750, false}, + {"numCtx zero", 0, 500, false}, + {"promptTokens zero", 1000, 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ag := &Agent{ + numCtx: tt.numCtx, + registry: mcp.NewRegistry(), + } + got := ag.shouldCompact(tt.promptTokens) + if got != tt.want { + t.Errorf("shouldCompact(%d) with numCtx=%d = %v, want %v", + tt.promptTokens, tt.numCtx, got, tt.want) + } + }) + } +} + +func TestSummarizeMessages(t *testing.T) { + tests := []struct { + name string + msgs []llm.Message + contains []string + }{ + { + name: "user message", + msgs: []llm.Message{ + {Role: "user", Content: "hello"}, + }, + contains: []string{"User: hello"}, + }, + { + name: "assistant message", + msgs: []llm.Message{ + {Role: "assistant", Content: "hi there"}, + }, + contains: []string{"Assistant: hi there"}, + }, + { + name: "tool message", + msgs: []llm.Message{ + {Role: "tool", Content: "result data", ToolName: "read_file"}, + }, + contains: []string{"Tool read_file result: result data"}, + }, + { + name: "tool content truncation at 300 chars", + msgs: []llm.Message{ + {Role: "tool", Content: strings.Repeat("x", 400), ToolName: "big_tool"}, + }, + contains: []string{"Tool big_tool result: " + strings.Repeat("x", 297) + "..."}, + }, + { + name: "empty slice", + msgs: []llm.Message{}, + contains: []string{"Summarize this conversation:"}, + }, + { + name: "assistant with tool calls", + msgs: []llm.Message{ + { + Role: "assistant", + ToolCalls: []llm.ToolCall{ + {Name: "search", Arguments: map[string]any{"q": "test"}}, + }, + }, + }, + contains: []string{"Assistant called tool search("}, + }, + { + name: "mixed messages", + msgs: []llm.Message{ + {Role: "user", Content: "find files"}, + {Role: "assistant", Content: "", ToolCalls: []llm.ToolCall{ + {Name: "glob", Arguments: map[string]any{"pattern": "*.go"}}, + }}, + {Role: "tool", Content: "file1.go\nfile2.go", ToolName: "glob"}, + {Role: "assistant", Content: "Found 2 files"}, + }, + contains: []string{ + "User: find files", + "Assistant called tool glob(", + "Tool glob result:", + "Assistant: Found 2 files", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := summarizeMessages(tt.msgs) + for _, want := range tt.contains { + if !strings.Contains(result, want) { + t.Errorf("summarizeMessages() missing %q in:\n%s", want, result) + } + } + }) + } +} diff --git a/internal/agent/headless_output.go b/internal/agent/headless_output.go new file mode 100644 index 0000000..c4f8276 --- /dev/null +++ b/internal/agent/headless_output.go @@ -0,0 +1,71 @@ +package agent + +import ( + "fmt" + "io" + "os" + "time" +) + +// HeadlessOutput implements the Output interface for non-interactive / pipe mode. +// Text is written to stdout; tool calls, system messages, and errors go to stderr. +type HeadlessOutput struct { + stdout io.Writer + stderr io.Writer +} + +// NewHeadlessOutput creates a HeadlessOutput that writes text to os.Stdout +// and diagnostics to os.Stderr. +func NewHeadlessOutput() *HeadlessOutput { + return &HeadlessOutput{ + stdout: os.Stdout, + stderr: os.Stderr, + } +} + +// newHeadlessOutput creates a HeadlessOutput with custom writers (for testing). +func newHeadlessOutput(stdout, stderr io.Writer) *HeadlessOutput { + return &HeadlessOutput{ + stdout: stdout, + stderr: stderr, + } +} + +// StreamText writes incremental text content to stdout. +func (h *HeadlessOutput) StreamText(text string) { + fmt.Fprint(h.stdout, text) +} + +// StreamDone writes a trailing newline to ensure output is terminated. +func (h *HeadlessOutput) StreamDone(evalCount, promptTokens int) { + fmt.Fprintln(h.stdout) +} + +// ToolCallStart writes a brief tool invocation notice to stderr. +func (h *HeadlessOutput) ToolCallStart(name string, args map[string]any) { + fmt.Fprintf(h.stderr, "→ %s %s\n", name, FormatToolArgs(args)) +} + +// ToolCallResult writes the tool result summary to stderr. +func (h *HeadlessOutput) ToolCallResult(name string, result string, isError bool, duration time.Duration) { + status := "ok" + if isError { + status = "ERROR" + } + // Truncate long results for stderr display. + display := result + if len(display) > 200 { + display = display[:197] + "..." + } + fmt.Fprintf(h.stderr, "← %s [%s %s] %s\n", name, status, duration.Round(time.Millisecond), display) +} + +// SystemMessage writes a system message to stderr. +func (h *HeadlessOutput) SystemMessage(msg string) { + fmt.Fprintf(h.stderr, "[system] %s\n", msg) +} + +// Error writes an error message to stderr. +func (h *HeadlessOutput) Error(msg string) { + fmt.Fprintf(h.stderr, "[error] %s\n", msg) +} diff --git a/internal/agent/headless_output_test.go b/internal/agent/headless_output_test.go new file mode 100644 index 0000000..8d6b76f --- /dev/null +++ b/internal/agent/headless_output_test.go @@ -0,0 +1,154 @@ +package agent + +import ( + "bytes" + "strings" + "testing" + "time" +) + +// Verify HeadlessOutput satisfies the Output interface at compile time. +var _ Output = (*HeadlessOutput)(nil) + +func TestHeadlessOutput_StreamText(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.StreamText("hello ") + out.StreamText("world") + + if got := stdout.String(); got != "hello world" { + t.Errorf("StreamText: stdout = %q, want %q", got, "hello world") + } + if stderr.Len() != 0 { + t.Errorf("StreamText: unexpected stderr output: %q", stderr.String()) + } +} + +func TestHeadlessOutput_StreamDone(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.StreamText("response") + out.StreamDone(100, 50) + + if got := stdout.String(); got != "response\n" { + t.Errorf("StreamDone: stdout = %q, want %q", got, "response\n") + } + if stderr.Len() != 0 { + t.Errorf("StreamDone: unexpected stderr output: %q", stderr.String()) + } +} + +func TestHeadlessOutput_ToolCallStart(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.ToolCallStart("read_file", map[string]any{"path": "/tmp/test.go"}) + + if stdout.Len() != 0 { + t.Errorf("ToolCallStart: unexpected stdout output: %q", stdout.String()) + } + got := stderr.String() + if !strings.Contains(got, "read_file") { + t.Errorf("ToolCallStart: stderr = %q, missing tool name", got) + } + if !strings.HasPrefix(got, "→ ") { + t.Errorf("ToolCallStart: stderr = %q, missing arrow prefix", got) + } +} + +func TestHeadlessOutput_ToolCallResult(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.ToolCallResult("read_file", "file contents here", false, 150*time.Millisecond) + + if stdout.Len() != 0 { + t.Errorf("ToolCallResult: unexpected stdout output: %q", stdout.String()) + } + got := stderr.String() + if !strings.Contains(got, "read_file") { + t.Errorf("ToolCallResult: stderr = %q, missing tool name", got) + } + if !strings.Contains(got, "ok") { + t.Errorf("ToolCallResult: stderr = %q, missing ok status", got) + } + if !strings.Contains(got, "file contents here") { + t.Errorf("ToolCallResult: stderr = %q, missing result content", got) + } +} + +func TestHeadlessOutput_ToolCallResult_Error(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.ToolCallResult("write_file", "permission denied", true, 50*time.Millisecond) + + got := stderr.String() + if !strings.Contains(got, "ERROR") { + t.Errorf("ToolCallResult error: stderr = %q, missing ERROR status", got) + } +} + +func TestHeadlessOutput_ToolCallResult_LongResult(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + longResult := strings.Repeat("x", 300) + out.ToolCallResult("search", longResult, false, 100*time.Millisecond) + + got := stderr.String() + if strings.Contains(got, strings.Repeat("x", 300)) { + t.Error("ToolCallResult: long result should be truncated") + } + if !strings.Contains(got, "...") { + t.Error("ToolCallResult: truncated result should end with ...") + } +} + +func TestHeadlessOutput_SystemMessage(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.SystemMessage("compacting conversation") + + if stdout.Len() != 0 { + t.Errorf("SystemMessage: unexpected stdout output: %q", stdout.String()) + } + got := stderr.String() + if !strings.Contains(got, "[system]") { + t.Errorf("SystemMessage: stderr = %q, missing [system] prefix", got) + } + if !strings.Contains(got, "compacting conversation") { + t.Errorf("SystemMessage: stderr = %q, missing message", got) + } +} + +func TestHeadlessOutput_Error(t *testing.T) { + var stdout, stderr bytes.Buffer + out := newHeadlessOutput(&stdout, &stderr) + + out.Error("something went wrong") + + if stdout.Len() != 0 { + t.Errorf("Error: unexpected stdout output: %q", stdout.String()) + } + got := stderr.String() + if !strings.Contains(got, "[error]") { + t.Errorf("Error: stderr = %q, missing [error] prefix", got) + } + if !strings.Contains(got, "something went wrong") { + t.Errorf("Error: stderr = %q, missing message", got) + } +} + +func TestNewHeadlessOutput(t *testing.T) { + out := NewHeadlessOutput() + if out == nil { + t.Fatal("NewHeadlessOutput returned nil") + } + if out.stdout == nil || out.stderr == nil { + t.Error("NewHeadlessOutput: writers should not be nil") + } +} diff --git a/internal/agent/loop.go b/internal/agent/loop.go new file mode 100644 index 0000000..6dd5861 --- /dev/null +++ b/internal/agent/loop.go @@ -0,0 +1,277 @@ +package agent + +import ( + "context" + "fmt" + "sort" + "strings" + "sync" + "time" + + "ai-agent/internal/llm" + permissionPkg "ai-agent/internal/permission" +) + +func (a *Agent) Run(ctx context.Context, out Output) { + var tools []llm.ToolDef + if a.toolsEnabled { + tools = a.registry.Tools() + if a.memoryStore != nil { + tools = append(tools, a.memoryBuiltinToolDefs()...) + } + tools = append(tools, a.toolsBuiltinToolDefs()...) + } + var iceContext string + a.mu.RLock() + hasMessages := len(a.messages) > 0 + var lastMsg llm.Message + if hasMessages { + lastMsg = a.messages[len(a.messages)-1] + } + a.mu.RUnlock() + if a.iceEngine != nil && hasMessages { + if lastMsg.Role == "user" { + if err := a.iceEngine.IndexMessage(ctx, "user", lastMsg.Content); err != nil { + out.Error(fmt.Sprintf("ICE indexing failed: %v", err)) + } + if assembled, err := a.iceEngine.AssembleContext(ctx, lastMsg.Content); err == nil { + iceContext = assembled + } + } + } + system := buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model()) + const maxRetries = 2 + var lastPromptTokens int + var retryCount int + maxIters := a.MaxIterations() + for i := 0; i < maxIters; i++ { + select { + case <-ctx.Done(): + return + default: + } + var textBuf strings.Builder + var toolCalls []llm.ToolCall + err := a.llmClient.ChatStream(ctx, llm.ChatOptions{ + Messages: a.messages, + Tools: tools, + System: system, + }, func(chunk llm.StreamChunk) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if chunk.Text != "" { + textBuf.WriteString(chunk.Text) + out.StreamText(chunk.Text) + } + if len(chunk.ToolCalls) > 0 { + toolCalls = append(toolCalls, chunk.ToolCalls...) + } + if chunk.Done { + lastPromptTokens = chunk.PromptEvalCount + out.StreamDone(chunk.EvalCount, chunk.PromptEvalCount) + } + return nil + }) + if err != nil { + if ctx.Err() != nil { + return + } + if retryCount < maxRetries && isRetryableError(err) { + retryCount++ + out.Error(fmt.Sprintf("LLM produced malformed output, retrying (%d/%d)...", retryCount, maxRetries)) + textBuf.Reset() + toolCalls = nil + continue + } + out.Error(fmt.Sprintf("LLM error: %v", err)) + out.SystemMessage(fmt.Sprintf("⚠️ Model response failed: %v\n\nYou can try:\n- Checking if Ollama is running (`ollama ps`)\n- Switching to a different model (ctrl+m)\n- Reducing context size\n\nTool results are still available above.", err)) + return + } + retryCount = 0 + assistantMsg := llm.Message{ + Role: "assistant", + Content: textBuf.String(), + ToolCalls: toolCalls, + } + a.AppendMessage(assistantMsg) + if a.iceEngine != nil && assistantMsg.Content != "" { + if err := a.iceEngine.IndexMessage(ctx, "assistant", assistantMsg.Content); err != nil { + out.Error(fmt.Sprintf("ICE indexing failed: %v", err)) + } + } + if len(toolCalls) == 0 { + a.mu.RLock() + hasEnoughMessages := len(a.messages) >= 2 + var userContent string + if hasEnoughMessages { + for idx := len(a.messages) - 2; idx >= 0; idx-- { + if a.messages[idx].Role == "user" { + userContent = a.messages[idx].Content + break + } + } + } + a.mu.RUnlock() + if a.iceEngine != nil && hasEnoughMessages && userContent != "" { + a.iceEngine.DetectAutoMemory(ctx, userContent, assistantMsg.Content) + } + return + } + type pendingTool struct { + tc llm.ToolCall + isMemoryTool bool + isMCPTool bool + } + var pending []pendingTool + for _, tc := range toolCalls { + if a.memoryStore != nil && a.isMemoryTool(tc.Name) { + pending = append(pending, pendingTool{tc: tc, isMemoryTool: true}) + continue + } + if a.isToolsTool(tc.Name) { + out.ToolCallStart(tc.Name, tc.Arguments) + startTime := time.Now() + result, isErr := a.handleToolsTool(tc) + duration := time.Since(startTime) + out.ToolCallResult(tc.Name, result, isErr, duration) + a.AppendMessage(llm.Message{ + Role: "tool", + Content: result, + ToolName: tc.Name, + ToolCallID: tc.ID, + }) + continue + } + if a.permChecker != nil { + switch a.permChecker.ToCheckResult(tc.Name) { + case permissionPkg.CheckDeny: + errMsg := "tool call blocked by permission policy" + out.ToolCallStart(tc.Name, tc.Arguments) + out.ToolCallResult(tc.Name, errMsg, true, 0) + a.AppendMessage(llm.Message{ + Role: "tool", + Content: errMsg, + ToolName: tc.Name, + ToolCallID: tc.ID, + }) + continue + case permissionPkg.CheckAsk: + if a.approvalCallback != nil { + allowed, always := permissionPkg.RequestApproval(tc.Name, tc.Arguments, a.approvalCallback) + if always { + a.permChecker.SetPolicy(tc.Name, permissionPkg.PolicyAllow) + } + if !allowed { + errMsg := "tool call denied by user" + out.ToolCallStart(tc.Name, tc.Arguments) + out.ToolCallResult(tc.Name, errMsg, true, 0) + a.AppendMessage(llm.Message{ + Role: "tool", + Content: errMsg, + ToolName: tc.Name, + ToolCallID: tc.ID, + }) + continue + } + } + } + } + pending = append(pending, pendingTool{tc: tc, isMCPTool: true}) + } + if len(pending) > 0 { + var wg sync.WaitGroup + mu := sync.Mutex{} + results := make([]llm.Message, len(pending)) + for i, p := range pending { + wg.Add(1) + go func(idx int, tool pendingTool) { + defer wg.Done() + tc := tool.tc + out.ToolCallStart(tc.Name, tc.Arguments) + startTime := time.Now() + var result string + var isErr bool + if tool.isMemoryTool { + result, isErr = a.handleMemoryTool(tc) + } else if tool.isMCPTool { + toolResult, err := a.registry.CallTool(ctx, tc.Name, tc.Arguments) + if err != nil { + result = fmt.Sprintf("ERROR: Tool '%s' failed: %v\nThis tool call failed but you can still complete the task with other available information.", tc.Name, err) + isErr = true + } else { + result = toolResult.Content + isErr = toolResult.IsError + } + } + duration := time.Since(startTime) + out.ToolCallResult(tc.Name, result, isErr, duration) + mu.Lock() + results[idx] = llm.Message{ + Role: "tool", + Content: result, + ToolName: tc.Name, + ToolCallID: tc.ID, + } + mu.Unlock() + }(i, p) + } + wg.Wait() + for _, msg := range results { + if msg.ToolName != "" { + a.AppendMessage(msg) + } + } + } + if a.shouldCompact(lastPromptTokens) { + if a.compact(ctx, out) { + system = buildSystemPromptForModel(a.modePrefix, tools, a.skillContent, a.loadedCtx, a.memoryStore, iceContext, a.workDir, a.ignoreContent, a.llmClient.Model()) + } + } + if i == maxIters-2 { + out.Error(fmt.Sprintf("approaching iteration limit (%d/%d)", i+2, maxIters)) + } + } + out.Error(fmt.Sprintf("reached max iterations (%d)", maxIters)) +} + +func isRetryableError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "parse JSON") || strings.Contains(msg, "unexpected end of JSON") +} + +func FormatToolArgs(args map[string]any) string { + if len(args) == 0 { + return "" + } + var parts []string + for key, value := range args { + var valStr string + switch v := value.(type) { + case string: + if len(v) > 47 { + valStr = `"` + v[:44] + `..."` + } else { + valStr = `"` + v + `"` + } + case int, float64, bool: + valStr = fmt.Sprintf("%v", v) + case []any: + valStr = fmt.Sprintf("[%d items]", len(v)) + case map[string]any: + valStr = fmt.Sprintf("{%d fields}", len(v)) + default: + valStr = fmt.Sprintf("%v", v) + } + parts = append(parts, fmt.Sprintf("%s=%s", key, valStr)) + } + sort.Strings(parts) + result := strings.Join(parts, " ") + if len(result) > 60 { + return result[:57] + "..." + } + return result +} diff --git a/internal/agent/loop_test.go b/internal/agent/loop_test.go new file mode 100644 index 0000000..760a05d --- /dev/null +++ b/internal/agent/loop_test.go @@ -0,0 +1,73 @@ +package agent + +import ( + "strings" + "testing" +) + +func TestFormatToolArgs(t *testing.T) { + tests := []struct { + name string + args map[string]any + want string + contains []string + maxLen int + }{ + { + name: "empty map", + args: map[string]any{}, + want: "", + }, + { + name: "simple map", + args: map[string]any{"key": "value"}, + contains: []string{"key=", `"value"`}, + }, + { + name: "long args truncated at 60", + args: map[string]any{"data": strings.Repeat("a", 300)}, + maxLen: 60, + }, + { + name: "multiple args", + args: map[string]any{"path": "/tmp/test", "command": "ls"}, + contains: []string{"path=", "command="}, + }, + { + name: "numeric args", + args: map[string]any{"count": 42, "ratio": 3.14}, + contains: []string{"count=42", "ratio=3.14"}, + }, + { + name: "array args", + args: map[string]any{"items": []any{1, 2, 3}}, + contains: []string{"items=", "[3 items]"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatToolArgs(tt.args) + + if tt.want != "" && got != tt.want { + t.Errorf("FormatToolArgs() = %q, want %q", got, tt.want) + } + + for _, substr := range tt.contains { + if !strings.Contains(got, substr) { + t.Errorf("FormatToolArgs() = %q, missing %q", got, substr) + } + } + + if tt.maxLen > 0 { + if len(got) > tt.maxLen { + t.Errorf("FormatToolArgs() len = %d, want <= %d", len(got), tt.maxLen) + } + // Check for truncation indicator (either "..." in value or at end) + if !strings.Contains(got, "...") { + t.Errorf("FormatToolArgs() should contain '...' when truncated, got %q", got) + } + } + }) + } +} diff --git a/internal/agent/memory.go b/internal/agent/memory.go new file mode 100644 index 0000000..b4ce231 --- /dev/null +++ b/internal/agent/memory.go @@ -0,0 +1,171 @@ +package agent + +import ( + "fmt" + "strings" + + "ai-agent/internal/llm" + "ai-agent/internal/memory" +) + +func (a *Agent) memoryBuiltinToolDefs() []llm.ToolDef { + return memory.BuiltinToolDefs() +} + +func (a *Agent) isMemoryTool(name string) bool { + return memory.IsBuiltinTool(name) +} + +func (a *Agent) handleMemoryTool(tc llm.ToolCall) (string, bool) { + switch tc.Name { + case "memory_save": + return a.handleMemorySave(tc.Arguments) + case "memory_recall": + return a.handleMemoryRecall(tc.Arguments) + case "memory_delete": + return a.handleMemoryDelete(tc.Arguments) + case "memory_update": + return a.handleMemoryUpdate(tc.Arguments) + case "memory_list": + return a.handleMemoryList(tc.Arguments) + default: + return fmt.Sprintf("unknown memory tool: %s", tc.Name), true + } +} + +func (a *Agent) handleMemorySave(args map[string]any) (string, bool) { + content, _ := args["content"].(string) + if content == "" { + return "error: content is required", true + } + var tags []string + if rawTags, ok := args["tags"]; ok { + switch v := rawTags.(type) { + case []any: + for _, t := range v { + if s, ok := t.(string); ok { + tags = append(tags, s) + } + } + case []string: + tags = v + } + } + id, err := a.memoryStore.Save(content, tags) + if err != nil { + return fmt.Sprintf("error saving memory: %v", err), true + } + return fmt.Sprintf("Memory saved (id: %d)", id), false +} + +func (a *Agent) handleMemoryRecall(args map[string]any) (string, bool) { + query, _ := args["query"].(string) + if query == "" { + return "error: query is required", true + } + memories := a.memoryStore.Recall(query, 5) + if len(memories) == 0 { + return "No matching memories found.", false + } + var b strings.Builder + fmt.Fprintf(&b, "Found %d matching memories:\n", len(memories)) + for _, mem := range memories { + fmt.Fprintf(&b, "- [%d] %s", mem.ID, mem.Content) + if len(mem.Tags) > 0 { + fmt.Fprintf(&b, " (tags: %s)", strings.Join(mem.Tags, ", ")) + } + b.WriteString("\n") + } + return b.String(), false +} + +func (a *Agent) handleMemoryDelete(args map[string]any) (string, bool) { + idVal, ok := args["id"] + if !ok { + return "error: id is required", true + } + var id int + switch v := idVal.(type) { + case float64: + id = int(v) + case int: + id = v + default: + return "error: id must be a number", true + } + deleted, err := a.memoryStore.Delete(id) + if err != nil { + return fmt.Sprintf("error deleting memory: %v", err), true + } + if !deleted { + return fmt.Sprintf("memory with id %d not found", id), true + } + return fmt.Sprintf("Memory %d deleted", id), false +} + +func (a *Agent) handleMemoryUpdate(args map[string]any) (string, bool) { + idVal, ok := args["id"] + if !ok { + return "error: id is required", true + } + var id int + switch v := idVal.(type) { + case float64: + id = int(v) + case int: + id = v + default: + return "error: id must be a number", true + } + content, _ := args["content"].(string) + var tags []string + if rawTags, ok := args["tags"]; ok { + switch v := rawTags.(type) { + case []any: + for _, t := range v { + if s, ok := t.(string); ok { + tags = append(tags, s) + } + } + case []string: + tags = v + } + } + if content == "" && len(tags) == 0 { + return "error: at least one of content or tags is required", true + } + updated, err := a.memoryStore.Update(id, content, tags) + if err != nil { + return fmt.Sprintf("error updating memory: %v", err), true + } + if !updated { + return fmt.Sprintf("memory with id %d not found", id), true + } + return fmt.Sprintf("Memory %d updated", id), false +} + +func (a *Agent) handleMemoryList(args map[string]any) (string, bool) { + limit := 20 + if rawLimit, ok := args["limit"]; ok { + switch v := rawLimit.(type) { + case float64: + limit = int(v) + case int: + limit = v + } + } + memories := a.memoryStore.Recent(limit) + if len(memories) == 0 { + return "No memories stored.", false + } + var b strings.Builder + fmt.Fprintf(&b, "Stored memories (%d total):\n", a.memoryStore.Count()) + for _, mem := range memories { + fmt.Fprintf(&b, "- [%d] %s", mem.ID, mem.Content) + if len(mem.Tags) > 0 { + fmt.Fprintf(&b, " (tags: %s)", strings.Join(mem.Tags, ", ")) + } + b.WriteString("\n") + } + return b.String(), false +} diff --git a/internal/agent/memory_test.go b/internal/agent/memory_test.go new file mode 100644 index 0000000..241d540 --- /dev/null +++ b/internal/agent/memory_test.go @@ -0,0 +1,159 @@ +package agent + +import ( + "path/filepath" + "strings" + "testing" + + "ai-agent/internal/llm" + "ai-agent/internal/mcp" + "ai-agent/internal/memory" +) + +func newTestAgentWithMemory(t *testing.T) *Agent { + t.Helper() + store := memory.NewStore(filepath.Join(t.TempDir(), "test-memories.json")) + return &Agent{memoryStore: store, registry: mcp.NewRegistry()} +} + +func TestHandleMemoryTool(t *testing.T) { + tests := []struct { + name string + toolCall llm.ToolCall + wantSubstr string + wantErr bool + }{ + { + name: "dispatch to save", + toolCall: llm.ToolCall{ + Name: "memory_save", + Arguments: map[string]any{"content": "test fact", "tags": []any{"tag1"}}, + }, + wantSubstr: "Memory saved (id:", + wantErr: false, + }, + { + name: "dispatch to recall", + toolCall: llm.ToolCall{ + Name: "memory_recall", + Arguments: map[string]any{"query": "test"}, + }, + wantSubstr: "No matching memories found.", + wantErr: false, + }, + { + name: "unknown tool", + toolCall: llm.ToolCall{ + Name: "unknown", + Arguments: map[string]any{}, + }, + wantSubstr: "unknown memory tool: unknown", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ag := newTestAgentWithMemory(t) + result, isErr := ag.handleMemoryTool(tt.toolCall) + if isErr != tt.wantErr { + t.Errorf("handleMemoryTool() isErr = %v, want %v", isErr, tt.wantErr) + } + if !strings.Contains(result, tt.wantSubstr) { + t.Errorf("handleMemoryTool() = %q, want substring %q", result, tt.wantSubstr) + } + }) + } +} + +func TestHandleMemorySave(t *testing.T) { + tests := []struct { + name string + args map[string]any + wantSubstr string + wantErr bool + }{ + { + name: "valid with tags as []any", + args: map[string]any{"content": "test fact", "tags": []any{"tag1", "tag2"}}, + wantSubstr: "Memory saved (id:", + wantErr: false, + }, + { + name: "valid without tags", + args: map[string]any{"content": "another fact"}, + wantSubstr: "Memory saved (id:", + wantErr: false, + }, + { + name: "missing content", + args: map[string]any{}, + wantSubstr: "error: content is required", + wantErr: true, + }, + { + name: "empty content", + args: map[string]any{"content": ""}, + wantSubstr: "error: content is required", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ag := newTestAgentWithMemory(t) + result, isErr := ag.handleMemorySave(tt.args) + if isErr != tt.wantErr { + t.Errorf("handleMemorySave() isErr = %v, want %v", isErr, tt.wantErr) + } + if !strings.Contains(result, tt.wantSubstr) { + t.Errorf("handleMemorySave() = %q, want substring %q", result, tt.wantSubstr) + } + }) + } +} + +func TestHandleMemoryRecall(t *testing.T) { + tests := []struct { + name string + setup func(ag *Agent) + args map[string]any + wantSubstr string + wantErr bool + }{ + { + name: "valid recall finds saved memory", + setup: func(ag *Agent) { + _, _ = ag.memoryStore.Save("user prefers Go", []string{"language"}) + }, + args: map[string]any{"query": "Go"}, + wantSubstr: "Found 1 matching memories:", + wantErr: false, + }, + { + name: "missing query", + setup: func(ag *Agent) {}, + args: map[string]any{}, + wantSubstr: "error: query is required", + wantErr: true, + }, + { + name: "no matches", + setup: func(ag *Agent) {}, + args: map[string]any{"query": "nonexistent"}, + wantSubstr: "No matching memories found.", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ag := newTestAgentWithMemory(t) + tt.setup(ag) + result, isErr := ag.handleMemoryRecall(tt.args) + if isErr != tt.wantErr { + t.Errorf("handleMemoryRecall() isErr = %v, want %v", isErr, tt.wantErr) + } + if !strings.Contains(result, tt.wantSubstr) { + t.Errorf("handleMemoryRecall() = %q, want substring %q", result, tt.wantSubstr) + } + }) + } +} diff --git a/internal/agent/output.go b/internal/agent/output.go new file mode 100644 index 0000000..a1f871c --- /dev/null +++ b/internal/agent/output.go @@ -0,0 +1,24 @@ +package agent + +import "time" + +// Output is the interface the agent uses to stream results to the UI. +type Output interface { + // StreamText sends incremental text content. + StreamText(text string) + + // StreamDone signals that the current response is complete. + StreamDone(evalCount, promptTokens int) + + // ToolCallStart signals the beginning of a tool invocation. + ToolCallStart(name string, args map[string]any) + + // ToolCallResult delivers the result of a tool invocation. + ToolCallResult(name string, result string, isError bool, duration time.Duration) + + // SystemMessage displays a system-level message to the user. + SystemMessage(msg string) + + // Error reports a non-fatal error to the user. + Error(msg string) +} diff --git a/internal/agent/system.go b/internal/agent/system.go new file mode 100644 index 0000000..15a59e1 --- /dev/null +++ b/internal/agent/system.go @@ -0,0 +1,272 @@ +package agent + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "ai-agent/internal/llm" + "ai-agent/internal/memory" +) + +const systemTemplate = `You are a helpful personal assistant running locally on the user's machine. +You have access to tools via MCP servers. You MUST use tools to accomplish tasks — do not guess or make up answers when a tool can provide the real information. +%s +Current date: %s +%s%s +%s%s%s +## Available Tools +%s +## Guidelines +- **ALWAYS use your tools** when the user asks you to read, explore, search, or modify files. You have filesystem tools — use them. +- When the user says "read this codebase" or similar, use list/read tools starting from the working directory shown above. +- Be concise and direct in your responses. +- When a tool call fails, explain what happened and suggest alternatives. +- For multi-step tasks, explain your plan briefly before executing. +- Format responses in markdown when it improves readability. +- If you're unsure about something, say so rather than guessing. +- Never fabricate tool results — always call the actual tool. +- Do NOT claim you cannot access files or the filesystem. You have tools for that — use them. +%s` + +const smallModelTemplate = `You are a local AI assistant. Use tools to read/write files and run commands. +%sDate: %s +%s%s +%s +## Tools +%s +Guidelines: +- Be concise and direct +- Use tools when needed to complete tasks +- If a tool fails, continue with available information +- Don't guess - use tools to verify +- You can complete tasks even if some tools fail +%s` + +func isSmallModel(modelName string) bool { + lower := strings.ToLower(modelName) + if strings.Contains(lower, "0.8b") || strings.Contains(lower, "1b") || strings.Contains(lower, "2b") { + return true + } + return false +} + +func buildSystemPrompt(modePrefix string, tools []llm.ToolDef, skillContent, loadedContext string, memStore *memory.Store, iceContext, workDir, ignoreContent string) string { + return buildSystemPromptForModel(modePrefix, tools, skillContent, loadedContext, memStore, iceContext, workDir, ignoreContent, "") +} + +func buildSystemPromptForModel(modePrefix string, tools []llm.ToolDef, skillContent, loadedContext string, memStore *memory.Store, iceContext, workDir, ignoreContent string, modelName string) string { + useSmallModel := isSmallModel(modelName) + var toolList string + if len(tools) == 0 { + toolList = "No tools currently available.\n" + } else if useSmallModel { + toolList = simplifyToolsForSmallModel(tools) + } else { + var b strings.Builder + for _, t := range tools { + fmt.Fprintf(&b, "- **%s**: %s\n", t.Name, t.Description) + } + toolList = b.String() + } + envSection := buildEnvironmentSection(workDir) + var skillSection string + if skillContent != "" { + skillSection = fmt.Sprintf("\n## Active Skills\n%s\n", skillContent) + } + var ctxSection string + if loadedContext != "" { + ctxSection = fmt.Sprintf("\n## Loaded Context\n%s\n", loadedContext) + } + var memorySection string + if iceContext != "" { + memorySection = iceContext + } else if memStore != nil { + memorySection = buildMemorySection(memStore) + } + var memoryGuidelines string + if memStore != nil { + memoryGuidelines = ` +## Memory Guidelines +- You have access to persistent memory via memory_save and memory_recall tools. +- Proactively save important user preferences, project facts, and key decisions. +- When the user shares personal information (name, preferences, etc.), save it. +- Use memory_recall to look up previously saved information when relevant. +- Don't save trivial or session-specific information. +` + } + var ignoreSection string + if ignoreContent != "" { + ignoreSection = fmt.Sprintf("\n## Ignored Paths\nThe following paths/patterns should be excluded from file operations:\n%s\n", ignoreContent) + } + var modePrefixSection string + if modePrefix != "" { + modePrefixSection = "\n" + modePrefix + "\n" + } + dateStr := time.Now().Format("Monday, January 2, 2006") + if useSmallModel { + return fmt.Sprintf(smallModelTemplate, + modePrefixSection, + dateStr, + envSection, + ignoreSection, + skillSection, + toolList, + memoryGuidelines, + ) + } + return fmt.Sprintf(systemTemplate, + modePrefixSection, + dateStr, + envSection, + ignoreSection, + skillSection, + ctxSection, + memorySection, + toolList, + memoryGuidelines, + ) +} + +func simplifyToolsForSmallModel(tools []llm.ToolDef) string { + var b strings.Builder + for _, t := range tools { + desc := t.Description + if len(desc) > 50 { + desc = desc[:47] + "..." + } + fmt.Fprintf(&b, "- %s: %s\n", t.Name, desc) + } + return b.String() +} + +func buildEnvironmentSection(workDir string) string { + if workDir == "" { + return "" + } + var b strings.Builder + b.WriteString("\n## Environment\n") + b.WriteString(fmt.Sprintf("Working directory: %s\n", workDir)) + if info := detectProjectInfo(workDir); info != "" { + b.WriteString(info) + } + if gitInfo := detectGitInfo(workDir); gitInfo != "" { + b.WriteString(gitInfo) + } + return b.String() +} + +func detectProjectInfo(workDir string) string { + markers := []struct { + file string + desc string + }{ + {"go.mod", "Go module"}, + {"package.json", "Node.js/JavaScript"}, + {"Cargo.toml", "Rust"}, + {"pyproject.toml", "Python"}, + {"setup.py", "Python"}, + {"Makefile", ""}, + {"Taskfile.yml", ""}, + } + var found []string + for _, m := range markers { + if _, err := os.Stat(filepath.Join(workDir, m.file)); err == nil { + if m.desc != "" { + found = append(found, fmt.Sprintf("%s (%s)", m.file, m.desc)) + } else { + found = append(found, m.file) + } + } + } + if len(found) == 0 { + return "" + } + return fmt.Sprintf("Project markers: %s\n", strings.Join(found, ", ")) +} + +func detectGitInfo(workDir string) string { + gitDir := filepath.Join(workDir, ".git") + if _, err := os.Stat(gitDir); err != nil { + return "" + } + var b strings.Builder + branch := runGitCommand(workDir, "rev-parse", "--abbrev-ref", "HEAD") + if branch != "" { + b.WriteString(fmt.Sprintf("Git branch: %s\n", branch)) + } + status := runGitCommand(workDir, "status", "--porcelain") + if status != "" { + lines := strings.Split(strings.TrimSpace(status), "\n") + var modified, added, deleted int + for _, line := range lines { + if len(line) >= 2 { + switch line[0] { + case 'M', 'm': + modified++ + case 'A': + added++ + case 'D': + deleted++ + } + } + } + if modified > 0 || added > 0 || deleted > 0 { + statusParts := []string{} + if modified > 0 { + statusParts = append(statusParts, fmt.Sprintf("%d modified", modified)) + } + if added > 0 { + statusParts = append(statusParts, fmt.Sprintf("%d added", added)) + } + if deleted > 0 { + statusParts = append(statusParts, fmt.Sprintf("%d deleted", deleted)) + } + b.WriteString(fmt.Sprintf("Git status: %s\n", strings.Join(statusParts, ", "))) + } + } + recentLog := runGitCommand(workDir, "log", "-3", "--oneline") + if recentLog != "" { + b.WriteString(fmt.Sprintf("Recent commits:\n")) + for _, line := range strings.Split(strings.TrimSpace(recentLog), "\n") { + b.WriteString(fmt.Sprintf(" - %s\n", line)) + } + } + if b.Len() == 0 { + return "" + } + return b.String() +} + +func runGitCommand(dir string, args ...string) string { + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.Output() + if err != nil { + return "" + } + return strings.TrimSpace(string(out)) +} + +func buildMemorySection(store *memory.Store) string { + if store.Count() == 0 { + return "" + } + recent := store.Recent(10) + if len(recent) == 0 { + return "" + } + var b strings.Builder + b.WriteString("\n## Remembered Facts\n") + for _, mem := range recent { + b.WriteString(fmt.Sprintf("- %s", mem.Content)) + if len(mem.Tags) > 0 { + b.WriteString(fmt.Sprintf(" [tags: %s]", strings.Join(mem.Tags, ", "))) + } + b.WriteString("\n") + } + return b.String() +} diff --git a/internal/agent/system_test.go b/internal/agent/system_test.go new file mode 100644 index 0000000..12990dd --- /dev/null +++ b/internal/agent/system_test.go @@ -0,0 +1,186 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "ai-agent/internal/llm" + "ai-agent/internal/memory" +) + +func TestBuildSystemPrompt(t *testing.T) { + tests := []struct { + name string + tools []llm.ToolDef + skillContent string + loadedCtx string + memStore *memory.Store + iceContext string + contains []string + notContains []string + }{ + { + name: "no optional sections", + contains: []string{"No tools currently available.", "Current date:"}, + notContains: []string{"Active Skills", "Loaded Context", "Remembered Facts"}, + }, + { + name: "with tools", + tools: []llm.ToolDef{ + {Name: "test_tool", Description: "does stuff"}, + }, + contains: []string{"test_tool", "does stuff"}, + notContains: []string{"No tools currently available."}, + }, + { + name: "with skills", + skillContent: "skill content here", + contains: []string{"Active Skills", "skill content here"}, + }, + { + name: "with loaded context", + loadedCtx: "some loaded context", + contains: []string{"Loaded Context", "some loaded context"}, + }, + { + name: "ICE overrides memory", + iceContext: "ice assembled context", + contains: []string{"ice assembled context"}, + notContains: []string{"Remembered Facts"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildSystemPrompt("", tt.tools, tt.skillContent, tt.loadedCtx, tt.memStore, tt.iceContext, "", "") + for _, want := range tt.contains { + if !strings.Contains(result, want) { + t.Errorf("buildSystemPrompt() missing %q", want) + } + } + for _, notWant := range tt.notContains { + if strings.Contains(result, notWant) { + t.Errorf("buildSystemPrompt() should not contain %q", notWant) + } + } + }) + } + t.Run("with memory store entries", func(t *testing.T) { + store := memory.NewStore(filepath.Join(t.TempDir(), "test-memories.json")) + _, _ = store.Save("user prefers dark mode", []string{"preference"}) + result := buildSystemPrompt("", nil, "", "", store, "", "", "") + if !strings.Contains(result, "Remembered Facts") { + t.Error("expected Remembered Facts section") + } + if !strings.Contains(result, "user prefers dark mode") { + t.Error("expected memory content in prompt") + } + }) +} + +func TestBuildSystemPrompt_WithWorkDir(t *testing.T) { + result := buildSystemPrompt("", nil, "", "", nil, "", "/home/user/myproject", "") + if !strings.Contains(result, "Working directory: /home/user/myproject") { + t.Error("expected working directory in prompt") + } + if !strings.Contains(result, "Environment") { + t.Error("expected Environment section header") + } +} + +func TestBuildSystemPrompt_EmptyWorkDir(t *testing.T) { + result := buildSystemPrompt("", nil, "", "", nil, "", "", "") + if strings.Contains(result, "Working directory") { + t.Error("should not include working directory when empty") + } +} + +func TestBuildSystemPrompt_WithIgnoreContent(t *testing.T) { + ignoreContent := "- node_modules\n- *.log\n- build/" + result := buildSystemPrompt("", nil, "", "", nil, "", "", ignoreContent) + if !strings.Contains(result, "Ignored Paths") { + t.Error("expected Ignored Paths section header") + } + if !strings.Contains(result, "node_modules") { + t.Error("expected node_modules in ignore section") + } + if !strings.Contains(result, "*.log") { + t.Error("expected *.log in ignore section") + } +} + +func TestBuildSystemPrompt_EmptyIgnoreContent(t *testing.T) { + result := buildSystemPrompt("", nil, "", "", nil, "", "", "") + if strings.Contains(result, "Ignored Paths") { + t.Error("should not include Ignored Paths when content is empty") + } +} + +func TestDetectProjectInfo_GoProject(t *testing.T) { + dir := t.TempDir() + _ = os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test"), 0o644) + + info := detectProjectInfo(dir) + if !strings.Contains(info, "go.mod") { + t.Errorf("expected go.mod in project info, got %q", info) + } + if !strings.Contains(info, "Go module") { + t.Errorf("expected 'Go module' in project info, got %q", info) + } +} + +func TestDetectProjectInfo_EmptyDir(t *testing.T) { + dir := t.TempDir() + info := detectProjectInfo(dir) + if info != "" { + t.Errorf("expected empty for dir with no markers, got %q", info) + } +} + +func TestBuildMemorySection(t *testing.T) { + tests := []struct { + name string + setup func(s *memory.Store) + contains []string + wantEmpty bool + }{ + { + name: "empty store", + setup: func(s *memory.Store) {}, + wantEmpty: true, + }, + { + name: "store with tagged entry", + setup: func(s *memory.Store) { + _, _ = s.Save("likes Go", []string{"lang", "preference"}) + }, + contains: []string{"Remembered Facts", "likes Go", "[tags: lang, preference]"}, + }, + { + name: "store with untagged entry", + setup: func(s *memory.Store) { + _, _ = s.Save("project uses modules", nil) + }, + contains: []string{"Remembered Facts", "project uses modules"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := memory.NewStore(filepath.Join(t.TempDir(), "mem.json")) + tt.setup(store) + result := buildMemorySection(store) + if tt.wantEmpty { + if result != "" { + t.Errorf("expected empty string, got %q", result) + } + return + } + for _, want := range tt.contains { + if !strings.Contains(result, want) { + t.Errorf("buildMemorySection() missing %q in:\n%s", want, result) + } + } + }) + } +} diff --git a/internal/agent/tools.go b/internal/agent/tools.go new file mode 100644 index 0000000..3a33949 --- /dev/null +++ b/internal/agent/tools.go @@ -0,0 +1,688 @@ +package agent + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "ai-agent/internal/llm" + "ai-agent/internal/tools" +) + +const ( + maxTimeout = 120 * time.Second +) + +func (a *Agent) toolsBuiltinToolDefs() []llm.ToolDef { + return tools.AllToolDefs() +} + +func (a *Agent) isToolsTool(name string) bool { + return tools.IsBuiltinTool(name) +} + +func (a *Agent) handleToolsTool(tc llm.ToolCall) (string, bool) { + switch tc.Name { + case "grep": + return a.handleGrep(tc.Arguments) + case "read": + return a.handleRead(tc.Arguments) + case "write": + return a.handleWrite(tc.Arguments) + case "glob": + return a.handleGlob(tc.Arguments) + case "bash": + return a.handleBash(tc.Arguments) + case "ls": + return a.handleLs(tc.Arguments) + case "find": + return a.handleFind(tc.Arguments) + case "diff": + return a.handleDiff(tc.Arguments) + case "edit": + return a.handleEdit(tc.Arguments) + case "mkdir": + return a.handleMkdir(tc.Arguments) + case "remove": + return a.handleRemove(tc.Arguments) + case "copy": + return a.handleCopy(tc.Arguments) + case "move": + return a.handleMove(tc.Arguments) + case "exists": + return a.handleExists(tc.Arguments) + default: + return fmt.Sprintf("unknown tool: %s", tc.Name), true + } +} + +func (a *Agent) handleGrep(args map[string]any) (string, bool) { + pattern, _ := args["pattern"].(string) + if pattern == "" { + return "error: pattern is required", true + } + path := a.getArgString(args, "path", a.workDir) + include := a.getArgString(args, "include", "") + context := a.getArgInt(args, "context", 3) + maxResults := a.MaxGrepResults() + if _, err := os.Stat(path); err != nil { + return fmt.Sprintf("error: path does not exist: %s", path), true + } + re, err := regexp.Compile(pattern) + if err != nil { + return fmt.Sprintf("error: invalid regex pattern: %v", err), true + } + var results []string + err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + if shouldSkipDir(info.Name()) { + return filepath.SkipDir + } + return nil + } + if include != "" { + matched, err := filepath.Match(include, info.Name()) + if err != nil || !matched { + return nil + } + } + if strings.HasPrefix(info.Name(), ".") { + return nil + } + content, err := os.ReadFile(filePath) + if err != nil { + return nil + } + lines := strings.Split(string(content), "\n") + for i, line := range lines { + if re.MatchString(line) { + relPath, _ := filepath.Rel(path, filePath) + ctxStart := i - context + if ctxStart < 0 { + ctxStart = 0 + } + ctxEnd := i + context + 1 + if ctxEnd > len(lines) { + ctxEnd = len(lines) + } + results = append(results, fmt.Sprintf("%s:%d: %s", relPath, i+1, line)) + if context > 0 && ctxStart < i { + for j := ctxStart; j < i; j++ { + if len(results) < maxResults { + results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j])) + } + } + } + if context > 0 && i+1 < ctxEnd { + for j := i + 1; j < ctxEnd; j++ { + if len(results) < maxResults { + results = append(results, fmt.Sprintf(" %d: %s", j+1, lines[j])) + } + } + } + if len(results) >= maxResults { + results = append(results, fmt.Sprintf("\n... (truncated, max %d results)", maxResults)) + return filepath.SkipAll + } + } + } + return nil + }) + if err != nil { + return fmt.Sprintf("error walking directory: %v", err), true + } + if len(results) == 0 { + return fmt.Sprintf("No matches found for pattern: %s", pattern), false + } + return strings.Join(results, "\n"), false +} + +func (a *Agent) handleRead(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + if path == "" { + return "error: path is required", true + } + path = a.resolvePath(path) + data, err := os.ReadFile(path) + if err != nil { + return fmt.Sprintf("error reading file: %v", err), true + } + lines := strings.Split(string(data), "\n") + offset := a.getArgInt(args, "offset", 1) + limit := a.getArgInt(args, "limit", 0) + if offset > len(lines) { + return "error: offset beyond file length", true + } + if offset > 1 { + lines = lines[offset-1:] + } + if limit > 0 && len(lines) > limit { + lines = lines[:limit] + content := strings.Join(lines, "\n") + content += fmt.Sprintf("\n\n... (%d more lines)", len(lines)-limit) + return content, false + } + return strings.Join(lines, "\n"), false +} + +func (a *Agent) handleWrite(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + content, _ := args["content"].(string) + if path == "" { + return "error: path is required", true + } + path = a.resolvePath(path) + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Sprintf("error creating directory: %v", err), true + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + return fmt.Sprintf("error writing file: %v", err), true + } + return fmt.Sprintf("Written to %s (%d bytes)", path, len(content)), false +} + +func (a *Agent) handleGlob(args map[string]any) (string, bool) { + pattern, _ := args["pattern"].(string) + if pattern == "" { + return "error: pattern is required", true + } + path := a.getArgString(args, "path", a.workDir) + if _, err := os.Stat(path); err != nil { + return fmt.Sprintf("error: path does not exist: %s", path), true + } + basePattern := filepath.Join(path, pattern) + matches, err := filepath.Glob(basePattern) + if err != nil { + return fmt.Sprintf("error: invalid pattern: %v", err), true + } + if len(matches) == 0 { + return fmt.Sprintf("No files match pattern: %s", pattern), false + } + relMatches := make([]string, 0, len(matches)) + for _, m := range matches { + rel, err := filepath.Rel(path, m) + if err != nil { + continue + } + relMatches = append(relMatches, rel) + } + return strings.Join(relMatches, "\n"), false +} + +func (a *Agent) handleBash(args map[string]any) (string, bool) { + command, _ := args["command"].(string) + if command == "" { + return "error: command is required", true + } + timeout := a.getArgInt(args, "timeout", int(a.ToolTimeout().Seconds())) + maxTimeoutSecs := int(a.ToolTimeout().Seconds()) + if maxTimeoutSecs > 120 { + maxTimeoutSecs = 120 + } + if timeout > maxTimeoutSecs { + timeout = maxTimeoutSecs + } + if timeout < 1 { + timeout = 1 + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, "sh", "-c", command) + cmd.Dir = a.workDir + cmd.Env = os.Environ() + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + output := stdout.String() + if stderr.Len() > 0 { + if output != "" { + output += "\n" + } + output += "STDERR:\n" + stderr.String() + } + if ctx.Err() == context.DeadlineExceeded { + return fmt.Sprintf("error: command timed out after %d seconds", timeout), true + } + if err != nil { + if output == "" { + return fmt.Sprintf("error: %v", err), true + } + return fmt.Sprintf("Command exited with error:\n%s", output), true + } + if output == "" { + return "Command completed successfully (no output)", false + } + return output, false +} + +func (a *Agent) handleLs(args map[string]any) (string, bool) { + path := a.getArgString(args, "path", a.workDir) + path = a.resolvePath(path) + entries, err := os.ReadDir(path) + if err != nil { + return fmt.Sprintf("error reading directory: %v", err), true + } + if len(entries) == 0 { + return "Directory is empty", false + } + var dirs []string + var files []string + for _, e := range entries { + name := e.Name() + if e.IsDir() { + dirs = append(dirs, name+"/") + } else { + files = append(files, name) + } + } + var result strings.Builder + for _, d := range dirs { + result.WriteString(d + "\n") + } + for _, f := range files { + result.WriteString(f + "\n") + } + return result.String(), false +} + +func (a *Agent) handleFind(args map[string]any) (string, bool) { + name, _ := args["name"].(string) + if name == "" { + return "error: name is required", true + } + path := a.getArgString(args, "path", a.workDir) + fileType := a.getArgString(args, "type", "") + if _, err := os.Stat(path); err != nil { + return fmt.Sprintf("error: path does not exist: %s", path), true + } + re, err := regexp.Compile("^" + strings.ReplaceAll(name, "*", ".*") + "$") + if err != nil { + return fmt.Sprintf("error: invalid name pattern: %v", err), true + } + var results []string + err = filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if shouldSkipDir(info.Name()) && filePath != path { + if info.IsDir() { + return filepath.SkipDir + } + return nil + } + isDir := info.IsDir() + if fileType == "f" && isDir { + return nil + } + if fileType == "d" && !isDir { + return nil + } + if re.MatchString(info.Name()) { + relPath, _ := filepath.Rel(path, filePath) + if relPath != "." { + if isDir { + results = append(results, relPath+"/") + } else { + results = append(results, relPath) + } + } + } + return nil + }) + if err != nil { + return fmt.Sprintf("error walking directory: %v", err), true + } + if len(results) == 0 { + return fmt.Sprintf("No files/directories found matching: %s", name), false + } + return strings.Join(results, "\n"), false +} + +func (a *Agent) getArgString(args map[string]any, key, defaultValue string) string { + if v, ok := args[key].(string); ok && v != "" { + return v + } + return defaultValue +} + +func (a *Agent) getArgInt(args map[string]any, key string, defaultValue int) int { + if v, ok := args[key]; ok { + switch n := v.(type) { + case float64: + return int(n) + case int: + return n + case string: + if n == "" { + return defaultValue + } + if i, err := strconv.Atoi(n); err == nil { + return i + } + } + } + return defaultValue +} + +func (a *Agent) resolvePath(path string) string { + if filepath.IsAbs(path) { + return path + } + return filepath.Join(a.workDir, path) +} + +func shouldSkipDir(name string) bool { + switch name { + case "node_modules", ".git", "__pycache__", ".venv", "venv", + "dist", "build", "target", ".cache", ".npm", + ".svn", "CVS", ".hg", ".bzr": + return true + } + return strings.HasPrefix(name, ".") +} + +func (a *Agent) handleDiff(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + newContent, _ := args["new_content"].(string) + if path == "" { + return "error: path is required", true + } + if newContent == "" { + return "error: new_content is required", true + } + path = a.resolvePath(path) + oldContent, err := os.ReadFile(path) + if err != nil { + return fmt.Sprintf("error reading file: %v", err), true + } + oldLines := strings.Split(string(oldContent), "\n") + newLines := strings.Split(newContent, "\n") + diff := computeDiff(oldLines, newLines) + if diff == "" { + return "No changes (files are identical)", false + } + return diff, false +} + +func computeDiff(oldLines, newLines []string) string { + var result strings.Builder + oldLen := len(oldLines) + newLen := len(newLines) + lcs := longestCommonSubsequence(oldLines, newLines) + oldIdx := 0 + newIdx := 0 + lcsIdx := 0 + for oldIdx < oldLen || newIdx < newLen { + if lcsIdx < len(lcs) { + for oldIdx < oldLen && oldLines[oldIdx] != lcs[lcsIdx] { + result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx])) + oldIdx++ + } + for newIdx < newLen && newLines[newIdx] != lcs[lcsIdx] { + result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx])) + newIdx++ + } + if oldIdx < oldLen && newIdx < newLen { + result.WriteString(fmt.Sprintf(" %s\n", lcs[lcsIdx])) + oldIdx++ + newIdx++ + lcsIdx++ + } + } else { + for oldIdx < oldLen { + result.WriteString(fmt.Sprintf("-%s\n", oldLines[oldIdx])) + oldIdx++ + } + for newIdx < newLen { + result.WriteString(fmt.Sprintf("+%s\n", newLines[newIdx])) + newIdx++ + } + } + } + return result.String() +} + +func longestCommonSubsequence(a, b []string) []string { + m, n := len(a), len(b) + dp := make([][]int, m+1) + for i := range dp { + dp[i] = make([]int, n+1) + } + for i := 1; i <= m; i++ { + for j := 1; j <= n; j++ { + if a[i-1] == b[j-1] { + dp[i][j] = dp[i-1][j-1] + 1 + } else { + if dp[i-1][j] > dp[i][j-1] { + dp[i][j] = dp[i-1][j] + } else { + dp[i][j] = dp[i][j-1] + } + } + } + } + var lcs []string + i, j := m, n + for i > 0 && j > 0 { + if a[i-1] == b[j-1] { + lcs = append([]string{a[i-1]}, lcs...) + i-- + j-- + } else if dp[i-1][j] > dp[i][j-1] { + i-- + } else { + j-- + } + } + return lcs +} + +func (a *Agent) handleEdit(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + patch, _ := args["patch"].(string) + if path == "" { + return "error: path is required", true + } + if patch == "" { + return "error: patch is required", true + } + path = a.resolvePath(path) + oldContent, err := os.ReadFile(path) + if err != nil { + return fmt.Sprintf("error reading file: %v", err), true + } + newContent, err := applyPatch(string(oldContent), patch) + if err != nil { + return fmt.Sprintf("error applying patch: %v", err), true + } + if err := os.WriteFile(path, []byte(newContent), 0644); err != nil { + return fmt.Sprintf("error writing file: %v", err), true + } + return fmt.Sprintf("Applied patch to %s (%d bytes)", path, len(newContent)), false +} + +func (a *Agent) handleMkdir(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + if path == "" { + return "error: path is required", true + } + path = a.resolvePath(path) + if err := os.MkdirAll(path, 0755); err != nil { + return fmt.Sprintf("error creating directory: %v", err), true + } + return fmt.Sprintf("Created directory: %s", path), false +} + +func (a *Agent) handleRemove(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + if path == "" { + return "error: path is required", true + } + path = a.resolvePath(path) + recursive := a.getArgBool(args, "recursive", false) + force := a.getArgBool(args, "force", false) + info, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + if force { + return "Removed (ignored nonexistent)", false + } + return fmt.Sprintf("error: path does not exist: %s", path), true + } + return fmt.Sprintf("error: %v", err), true + } + if info.IsDir() { + if recursive { + err = os.RemoveAll(path) + } else { + err = os.Remove(path) + } + } else { + err = os.Remove(path) + } + if err != nil { + return fmt.Sprintf("error removing: %v", err), true + } + return fmt.Sprintf("Removed: %s", path), false +} + +func (a *Agent) handleCopy(args map[string]any) (string, bool) { + source, _ := args["source"].(string) + destination, _ := args["destination"].(string) + if source == "" || destination == "" { + return "error: source and destination are required", true + } + source = a.resolvePath(source) + destination = a.resolvePath(destination) + info, err := os.Stat(source) + if err != nil { + return fmt.Sprintf("error: %v", err), true + } + if info.IsDir() { + return "error: copying directories not supported (use bash with cp -r)", true + } + srcData, err := os.ReadFile(source) + if err != nil { + return fmt.Sprintf("error reading source: %v", err), true + } + dir := filepath.Dir(destination) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Sprintf("error creating destination directory: %v", err), true + } + err = os.WriteFile(destination, srcData, info.Mode()) + if err != nil { + return fmt.Sprintf("error writing destination: %v", err), true + } + return fmt.Sprintf("Copied: %s -> %s", source, destination), false +} + +func (a *Agent) handleMove(args map[string]any) (string, bool) { + source, _ := args["source"].(string) + destination, _ := args["destination"].(string) + if source == "" || destination == "" { + return "error: source and destination are required", true + } + source = a.resolvePath(source) + destination = a.resolvePath(destination) + dir := filepath.Dir(destination) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Sprintf("error creating destination directory: %v", err), true + } + err := os.Rename(source, destination) + if err != nil { + return fmt.Sprintf("error moving: %v", err), true + } + return fmt.Sprintf("Moved: %s -> %s", source, destination), false +} + +func (a *Agent) handleExists(args map[string]any) (string, bool) { + path, _ := args["path"].(string) + if path == "" { + return "error: path is required", true + } + path = a.resolvePath(path) + info, err := os.Stat(path) + if os.IsNotExist(err) { + return fmt.Sprintf("false: %s does not exist", path), false + } + if err != nil { + return fmt.Sprintf("error: %v", err), true + } + if info.IsDir() { + return fmt.Sprintf("true: %s (directory)", path), false + } + return fmt.Sprintf("true: %s (file, %d bytes)", path, info.Size()), false +} + +func (a *Agent) getArgBool(args map[string]any, key string, defaultValue bool) bool { + if v, ok := args[key]; ok { + if b, ok := v.(bool); ok { + return b + } + } + return defaultValue +} + +func applyPatch(content, patch string) (string, error) { + lines := strings.Split(content, "\n") + patchLines := strings.Split(patch, "\n") + var result []string + i := 0 + for i < len(patchLines) { + line := patchLines[i] + if strings.HasPrefix(line, "@@") { + parts := strings.Fields(line) + if len(parts) < 4 { + return "", fmt.Errorf("invalid hunk header: %s", line) + } + oldSpec := strings.TrimPrefix(parts[1], "-") + oldParts := strings.Split(oldSpec, ",") + oldStart, _ := strconv.Atoi(oldParts[0]) + newSpec := strings.TrimPrefix(parts[2], "+") + newParts := strings.Split(newSpec, ",") + newStart, _ := strconv.Atoi(newParts[0]) + oldIdx := oldStart - 1 + newIdx := newStart - 1 + i++ + for i < len(patchLines) && !strings.HasPrefix(patchLines[i], "@@") { + patchLine := patchLines[i] + if strings.HasPrefix(patchLine, "-") { + if oldIdx < len(lines) { + _ = lines[oldIdx] + oldIdx++ + } + } else if strings.HasPrefix(patchLine, "+") { + content := strings.TrimPrefix(patchLine, "+") + result = append(result, content) + newIdx++ + } else if strings.HasPrefix(patchLine, " ") || patchLine == "" { + if oldIdx < len(lines) { + result = append(result, lines[oldIdx]) + oldIdx++ + } + } else { + result = append(result, patchLine) + } + i++ + } + continue + } + i++ + } + if len(result) == 0 { + return content, nil + } + return strings.Join(result, "\n"), nil +} diff --git a/internal/command/commands.go b/internal/command/commands.go new file mode 100644 index 0000000..bf305d9 --- /dev/null +++ b/internal/command/commands.go @@ -0,0 +1,387 @@ +package command + +import ( + "fmt" + "os" + "strings" +) + +const maxContextFileSize = 32 * 1024 // 32KB + +// RegisterBuiltins adds all built-in slash commands to the registry. +func RegisterBuiltins(r *Registry) { + r.Register(&Command{ + Name: "help", + Aliases: []string{"h", "?"}, + Description: "Show help overlay with shortcuts and commands", + Handler: func(_ *Context, _ []string) Result { + return Result{Action: ActionShowHelp} + }, + }) + + r.Register(&Command{ + Name: "clear", + Description: "Clear conversation history", + Handler: func(_ *Context, _ []string) Result { + return Result{ + Text: "Conversation cleared.", + Action: ActionClear, + } + }, + }) + + r.Register(&Command{ + Name: "new", + Description: "Start a fresh conversation", + Handler: func(_ *Context, _ []string) Result { + return Result{ + Text: "New conversation started.", + Action: ActionClear, + } + }, + }) + + r.Register(&Command{ + Name: "model", + Aliases: []string{"m"}, + Description: "Show, switch, or list models", + Usage: "/model [name|list|fast|smart]", + Handler: func(ctx *Context, args []string) Result { + if len(args) == 0 { + return Result{Action: ActionShowModelPicker} + } + + switch args[0] { + case "list", "ls": + var b strings.Builder + b.WriteString("Available models:\n") + for _, m := range ctx.ModelList { + marker := " " + if m == ctx.Model { + marker = "* " + } + fmt.Fprintf(&b, " %s%s\n", marker, m) + } + b.WriteString("\n* = current") + return Result{Text: b.String()} + + case "fast": + if len(ctx.ModelList) > 0 { + return Result{ + Text: fmt.Sprintf("Switching to fastest model: %s", ctx.ModelList[0]), + Action: ActionSwitchModel, + Data: ctx.ModelList[0], + } + } + return Result{Error: "No models available"} + + case "smart": + if len(ctx.ModelList) > 0 { + smartModel := ctx.ModelList[len(ctx.ModelList)-1] + return Result{ + Text: fmt.Sprintf("Switching to smartest model: %s", smartModel), + Action: ActionSwitchModel, + Data: smartModel, + } + } + return Result{Error: "No models available"} + + default: + for _, m := range ctx.ModelList { + if m == args[0] { + return Result{ + Text: fmt.Sprintf("Switching to model: %s", m), + Action: ActionSwitchModel, + Data: m, + } + } + } + return Result{Error: fmt.Sprintf("Unknown model: %s (use /model list to see available)", args[0])} + } + }, + }) + + r.Register(&Command{ + Name: "models", + Aliases: []string{"ml"}, + Description: "Open model picker", + Handler: func(_ *Context, _ []string) Result { + return Result{Action: ActionShowModelPicker} + }, + }) + + r.Register(&Command{ + Name: "agent", + Aliases: []string{"a"}, + Description: "Show or switch agent profile", + Usage: "/agent [name|list]", + Handler: func(ctx *Context, args []string) Result { + if len(args) == 0 || args[0] == "list" { + var b strings.Builder + if len(ctx.AgentList) == 0 { + b.WriteString("No agent profiles found in ~/.agents/agents/") + return Result{Text: b.String()} + } + b.WriteString("Available agent profiles:\n") + for _, a := range ctx.AgentList { + marker := " " + if a == ctx.AgentProfile { + marker = "* " + } + fmt.Fprintf(&b, " %s%s\n", marker, a) + } + b.WriteString("\n* = current") + return Result{Text: b.String()} + } + + for _, a := range ctx.AgentList { + if a == args[0] { + return Result{ + Text: fmt.Sprintf("Switching to agent: %s", a), + Action: ActionSwitchAgent, + Data: a, + } + } + } + return Result{Error: fmt.Sprintf("Unknown agent: %s (use /agent list to see available)", args[0])} + }, + }) + + r.Register(&Command{ + Name: "load", + Aliases: []string{"l"}, + Description: "Load a markdown file as context", + Usage: "/load ", + Handler: func(_ *Context, args []string) Result { + if len(args) == 0 { + return Result{Error: "Usage: /load "} + } + path := strings.Join(args, " ") + + // Expand ~ to home directory. + if strings.HasPrefix(path, "~/") { + if home, err := os.UserHomeDir(); err == nil { + path = home + path[1:] + } + } + + info, err := os.Stat(path) + if err != nil { + return Result{Error: fmt.Sprintf("Cannot access %s: %v", path, err)} + } + if info.Size() > maxContextFileSize { + return Result{Error: fmt.Sprintf("File too large (%d bytes, max %d)", info.Size(), maxContextFileSize)} + } + + data, err := os.ReadFile(path) + if err != nil { + return Result{Error: fmt.Sprintf("Cannot read %s: %v", path, err)} + } + + return Result{ + Text: fmt.Sprintf("Loaded context: %s (%d bytes)", path, len(data)), + Action: ActionLoadContext, + Data: path + "\x00" + string(data), // path\0content + } + }, + }) + + r.Register(&Command{ + Name: "unload", + Description: "Remove loaded context file", + Handler: func(ctx *Context, _ []string) Result { + if ctx.LoadedFile == "" { + return Result{Text: "No context file loaded."} + } + return Result{ + Text: "Context unloaded.", + Action: ActionUnloadContext, + } + }, + }) + + r.Register(&Command{ + Name: "skill", + Aliases: []string{"sk"}, + Description: "Manage skills (list, activate, deactivate)", + Usage: "/skill [list|activate|deactivate] [name]", + Handler: func(ctx *Context, args []string) Result { + if len(args) == 0 || args[0] == "list" { + return skillList(ctx) + } + if len(args) < 2 { + return Result{Error: "Usage: /skill [list|activate|deactivate] "} + } + switch args[0] { + case "activate", "on": + return Result{ + Text: fmt.Sprintf("Activated skill: %s", args[1]), + Action: ActionActivateSkill, + Data: args[1], + } + case "deactivate", "off": + return Result{ + Text: fmt.Sprintf("Deactivated skill: %s", args[1]), + Action: ActionDeactivateSkill, + Data: args[1], + } + default: + return Result{Error: fmt.Sprintf("Unknown skill action: %s (use list, activate, or deactivate)", args[0])} + } + }, + }) + + r.Register(&Command{ + Name: "servers", + Description: "List connected MCP servers", + Handler: func(ctx *Context, _ []string) Result { + if len(ctx.ServerNames) == 0 { + return Result{Text: "No MCP servers connected."} + } + var b strings.Builder + b.WriteString(fmt.Sprintf("Connected servers (%d):\n", len(ctx.ServerNames))) + for _, name := range ctx.ServerNames { + fmt.Fprintf(&b, " - %s\n", name) + } + b.WriteString(fmt.Sprintf("\nTotal tools: %d", ctx.ToolCount)) + return Result{Text: b.String()} + }, + }) + + r.Register(&Command{ + Name: "ice", + Description: "Show Infinite Context Engine status", + Handler: func(ctx *Context, _ []string) Result { + if !ctx.ICEEnabled { + return Result{Text: "ICE is not enabled. Add `ice: {enabled: true}` to your config.yaml"} + } + var b strings.Builder + b.WriteString("Infinite Context Engine (ICE)\n") + fmt.Fprintf(&b, " Status: enabled\n") + fmt.Fprintf(&b, " Conversations: %d stored\n", ctx.ICEConversations) + fmt.Fprintf(&b, " Session ID: %s\n", ctx.ICESessionID) + fmt.Fprintf(&b, " Embed model: nomic-embed-text\n") + return Result{Text: b.String()} + }, + }) + + r.Register(&Command{ + Name: "sessions", + Aliases: []string{"ss"}, + Description: "Browse and restore saved sessions", + Handler: func(_ *Context, _ []string) Result { + return Result{Action: ActionShowSessions} + }, + }) + + r.Register(&Command{ + Name: "changes", + Description: "List files modified by the agent this session", + Handler: func(ctx *Context, _ []string) Result { + if len(ctx.FileChanges) == 0 { + return Result{Text: "No files modified this session."} + } + var b strings.Builder + fmt.Fprintf(&b, "Files modified (%d):\n", len(ctx.FileChanges)) + for path, count := range ctx.FileChanges { + if count > 1 { + fmt.Fprintf(&b, " ✎ %s (%dx)\n", path, count) + } else { + fmt.Fprintf(&b, " ✎ %s\n", path) + } + } + return Result{Text: b.String()} + }, + }) + + r.Register(&Command{ + Name: "commit", + Aliases: []string{"ci"}, + Description: "Generate commit message from staged changes and commit", + Handler: func(_ *Context, args []string) Result { + return Result{Action: ActionCommit, Data: strings.Join(args, " ")} + }, + }) + + r.Register(&Command{ + Name: "stats", + Description: "Show token usage statistics for this session", + Handler: func(ctx *Context, _ []string) Result { + if ctx.SessionTurnCount == 0 { + return Result{Text: "No token usage recorded yet."} + } + var b strings.Builder + b.WriteString("Session Token Stats\n") + fmt.Fprintf(&b, " Model: %s\n", ctx.CurrentModel) + fmt.Fprintf(&b, " Turns: %d\n", ctx.SessionTurnCount) + fmt.Fprintf(&b, " Output tokens: %d\n", ctx.SessionEvalTotal) + fmt.Fprintf(&b, " Prompt tokens: %d (last turn)\n", ctx.SessionPromptTotal) + if ctx.NumCtx > 0 { + fmt.Fprintf(&b, " Context window: %d\n", ctx.NumCtx) + pct := ctx.SessionPromptTotal * 100 / ctx.NumCtx + fmt.Fprintf(&b, " Context used: %d%%\n", pct) + } + avgOut := ctx.SessionEvalTotal / ctx.SessionTurnCount + fmt.Fprintf(&b, " Avg out/turn: %d\n", avgOut) + return Result{Text: b.String()} + }, + }) + + r.Register(&Command{ + Name: "export", + Description: "Export conversation to a markdown file", + Usage: "/export [path]", + Handler: func(_ *Context, args []string) Result { + if len(args) < 1 || args[0] == "" { + return Result{Error: "usage: /export "} + } + return Result{ + Text: fmt.Sprintf("Exporting conversation to: %s", args[0]), + Action: ActionExport, + Data: args[0], + } + }, + }) + + r.Register(&Command{ + Name: "import", + Description: "Import conversation from a markdown file", + Usage: "/import [path]", + Handler: func(_ *Context, args []string) Result { + if len(args) < 1 || args[0] == "" { + return Result{Error: "usage: /import "} + } + return Result{ + Text: fmt.Sprintf("Importing conversation from: %s", args[0]), + Action: ActionImport, + Data: args[0], + } + }, + }) + + r.Register(&Command{ + Name: "exit", + Aliases: []string{"quit", "q"}, + Description: "Quit ai-agent", + Handler: func(_ *Context, _ []string) Result { + return Result{Action: ActionQuit} + }, + }) +} + +func skillList(ctx *Context) Result { + if len(ctx.Skills) == 0 { + return Result{Text: "No skills found. Add .md files to ~/.config/ai-agent/skills/"} + } + var b strings.Builder + b.WriteString(fmt.Sprintf("Skills (%d):\n", len(ctx.Skills))) + for _, s := range ctx.Skills { + status := " " + if s.Active { + status = "* " + } + fmt.Fprintf(&b, " %s%s — %s\n", status, s.Name, s.Description) + } + b.WriteString("\n* = active") + return Result{Text: b.String()} +} diff --git a/internal/command/commands_test.go b/internal/command/commands_test.go new file mode 100644 index 0000000..6adba91 --- /dev/null +++ b/internal/command/commands_test.go @@ -0,0 +1,380 @@ +package command + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func newTestRegistry() *Registry { + r := NewRegistry() + RegisterBuiltins(r) + return r +} + +func TestBuiltin_Help(t *testing.T) { + r := newTestRegistry() + result := r.Execute(&Context{}, "help", nil) + if result.Action != ActionShowHelp { + t.Errorf("help action = %d, want %d (ActionShowHelp)", result.Action, ActionShowHelp) + } +} + +func TestBuiltin_Clear(t *testing.T) { + r := newTestRegistry() + result := r.Execute(&Context{}, "clear", nil) + if result.Action != ActionClear { + t.Errorf("clear action = %d, want %d (ActionClear)", result.Action, ActionClear) + } + if result.Text == "" { + t.Error("clear should have text") + } +} + +func TestBuiltin_New(t *testing.T) { + r := newTestRegistry() + result := r.Execute(&Context{}, "new", nil) + if result.Action != ActionClear { + t.Errorf("new action = %d, want %d (ActionClear)", result.Action, ActionClear) + } + if result.Text == "" { + t.Error("new should have text") + } +} + +func TestBuiltin_Model(t *testing.T) { + r := newTestRegistry() + ctx := &Context{ + Model: "qwen3.5:0.8b", + ModelList: []string{"qwen3.5:0.8b", "qwen3.5:2b", "qwen3.5:4b", "qwen3.5:9b"}, + } + + tests := []struct { + name string + args []string + wantAction Action + wantData string + wantErr bool + checkText string + }{ + { + name: "no args opens model picker", + args: nil, + wantAction: ActionShowModelPicker, + }, + { + name: "list shows models", + args: []string{"list"}, + checkText: "Available models", + }, + { + name: "fast switches to first", + args: []string{"fast"}, + wantAction: ActionSwitchModel, + wantData: "qwen3.5:0.8b", + }, + { + name: "smart switches to last", + args: []string{"smart"}, + wantAction: ActionSwitchModel, + wantData: "qwen3.5:9b", + }, + { + name: "valid name switches", + args: []string{"qwen3.5:2b"}, + wantAction: ActionSwitchModel, + wantData: "qwen3.5:2b", + }, + { + name: "invalid name errors", + args: []string{"nonexistent"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Execute(ctx, "model", tt.args) + if tt.wantErr { + if result.Error == "" { + t.Error("expected error") + } + return + } + if result.Error != "" { + t.Errorf("unexpected error: %s", result.Error) + return + } + if tt.wantAction != ActionNone && result.Action != tt.wantAction { + t.Errorf("action = %d, want %d", result.Action, tt.wantAction) + } + if tt.wantData != "" && result.Data != tt.wantData { + t.Errorf("data = %q, want %q", result.Data, tt.wantData) + } + if tt.checkText != "" && !strings.Contains(result.Text, tt.checkText) { + t.Errorf("text %q does not contain %q", result.Text, tt.checkText) + } + }) + } +} + +func TestBuiltin_Models(t *testing.T) { + r := newTestRegistry() + ctx := &Context{ + Model: "qwen3.5:0.8b", + ModelList: []string{"qwen3.5:0.8b", "qwen3.5:2b"}, + } + result := r.Execute(ctx, "models", nil) + if result.Action != ActionShowModelPicker { + t.Errorf("expected ActionShowModelPicker, got %d", result.Action) + } +} + +func TestBuiltin_Agent(t *testing.T) { + r := newTestRegistry() + + t.Run("no args lists agents", func(t *testing.T) { + ctx := &Context{AgentList: []string{"coder", "reviewer"}, AgentProfile: "coder"} + result := r.Execute(ctx, "agent", nil) + if !strings.Contains(result.Text, "Available agent profiles") { + t.Errorf("expected agent list, got %q", result.Text) + } + }) + + t.Run("list subcommand", func(t *testing.T) { + ctx := &Context{AgentList: []string{"coder"}} + result := r.Execute(ctx, "agent", []string{"list"}) + if !strings.Contains(result.Text, "Available agent profiles") { + t.Errorf("expected agent list, got %q", result.Text) + } + }) + + t.Run("valid switch", func(t *testing.T) { + ctx := &Context{AgentList: []string{"coder", "reviewer"}} + result := r.Execute(ctx, "agent", []string{"reviewer"}) + if result.Action != ActionSwitchAgent { + t.Errorf("action = %d, want %d", result.Action, ActionSwitchAgent) + } + if result.Data != "reviewer" { + t.Errorf("data = %q, want %q", result.Data, "reviewer") + } + }) + + t.Run("invalid errors", func(t *testing.T) { + ctx := &Context{AgentList: []string{"coder"}} + result := r.Execute(ctx, "agent", []string{"unknown"}) + if result.Error == "" { + t.Error("expected error for unknown agent") + } + }) + + t.Run("no agents", func(t *testing.T) { + ctx := &Context{AgentList: []string{}} + result := r.Execute(ctx, "agent", nil) + if !strings.Contains(result.Text, "No agent profiles") { + t.Errorf("expected no agents message, got %q", result.Text) + } + }) +} + +func TestBuiltin_Load(t *testing.T) { + r := newTestRegistry() + + t.Run("no args errors", func(t *testing.T) { + result := r.Execute(&Context{}, "load", nil) + if result.Error == "" { + t.Error("expected error for no args") + } + }) + + t.Run("valid file loads", func(t *testing.T) { + tmp := t.TempDir() + path := filepath.Join(tmp, "test.md") + if err := os.WriteFile(path, []byte("# Hello"), 0644); err != nil { + t.Fatal(err) + } + result := r.Execute(&Context{}, "load", []string{path}) + if result.Error != "" { + t.Errorf("unexpected error: %s", result.Error) + } + if result.Action != ActionLoadContext { + t.Errorf("action = %d, want %d", result.Action, ActionLoadContext) + } + // Data should be path\0content + parts := strings.SplitN(result.Data, "\x00", 2) + if len(parts) != 2 { + t.Fatalf("expected path\\0content, got %q", result.Data) + } + if parts[0] != path { + t.Errorf("data path = %q, want %q", parts[0], path) + } + if parts[1] != "# Hello" { + t.Errorf("data content = %q, want %q", parts[1], "# Hello") + } + }) + + t.Run("too large errors", func(t *testing.T) { + tmp := t.TempDir() + path := filepath.Join(tmp, "big.md") + data := make([]byte, 33*1024) // > 32KB + if err := os.WriteFile(path, data, 0644); err != nil { + t.Fatal(err) + } + result := r.Execute(&Context{}, "load", []string{path}) + if result.Error == "" { + t.Error("expected error for oversized file") + } + if !strings.Contains(result.Error, "too large") { + t.Errorf("error = %q, want containing 'too large'", result.Error) + } + }) + + t.Run("nonexistent errors", func(t *testing.T) { + result := r.Execute(&Context{}, "load", []string{"/nonexistent/file.md"}) + if result.Error == "" { + t.Error("expected error for nonexistent file") + } + }) +} + +func TestBuiltin_Unload(t *testing.T) { + r := newTestRegistry() + + t.Run("no loaded file", func(t *testing.T) { + result := r.Execute(&Context{LoadedFile: ""}, "unload", nil) + if !strings.Contains(result.Text, "No context") { + t.Errorf("expected 'No context' message, got %q", result.Text) + } + }) + + t.Run("loaded file unloads", func(t *testing.T) { + result := r.Execute(&Context{LoadedFile: "something.md"}, "unload", nil) + if result.Action != ActionUnloadContext { + t.Errorf("action = %d, want %d", result.Action, ActionUnloadContext) + } + }) +} + +func TestBuiltin_Skill(t *testing.T) { + r := newTestRegistry() + ctx := &Context{ + Skills: []SkillInfo{ + {Name: "coder", Description: "Code generation", Active: true}, + {Name: "reviewer", Description: "Code review", Active: false}, + }, + } + + t.Run("no args lists skills", func(t *testing.T) { + result := r.Execute(ctx, "skill", nil) + if !strings.Contains(result.Text, "Skills") { + t.Errorf("expected skills list, got %q", result.Text) + } + }) + + t.Run("list subcommand", func(t *testing.T) { + result := r.Execute(ctx, "skill", []string{"list"}) + if !strings.Contains(result.Text, "Skills") { + t.Errorf("expected skills list, got %q", result.Text) + } + }) + + t.Run("activate", func(t *testing.T) { + result := r.Execute(ctx, "skill", []string{"activate", "reviewer"}) + if result.Action != ActionActivateSkill { + t.Errorf("action = %d, want %d", result.Action, ActionActivateSkill) + } + if result.Data != "reviewer" { + t.Errorf("data = %q, want %q", result.Data, "reviewer") + } + }) + + t.Run("deactivate", func(t *testing.T) { + result := r.Execute(ctx, "skill", []string{"deactivate", "coder"}) + if result.Action != ActionDeactivateSkill { + t.Errorf("action = %d, want %d", result.Action, ActionDeactivateSkill) + } + if result.Data != "coder" { + t.Errorf("data = %q, want %q", result.Data, "coder") + } + }) + + t.Run("unknown action errors", func(t *testing.T) { + result := r.Execute(ctx, "skill", []string{"unknown", "foo"}) + if result.Error == "" { + t.Error("expected error for unknown skill action") + } + }) + + t.Run("missing name errors", func(t *testing.T) { + result := r.Execute(ctx, "skill", []string{"activate"}) + if result.Error == "" { + t.Error("expected error for missing skill name") + } + }) +} + +func TestBuiltin_Servers(t *testing.T) { + r := newTestRegistry() + + t.Run("no servers", func(t *testing.T) { + result := r.Execute(&Context{ServerNames: nil}, "servers", nil) + if !strings.Contains(result.Text, "No MCP servers") { + t.Errorf("expected no servers message, got %q", result.Text) + } + }) + + t.Run("with servers", func(t *testing.T) { + ctx := &Context{ + ServerNames: []string{"server-a", "server-b"}, + ToolCount: 10, + } + result := r.Execute(ctx, "servers", nil) + if !strings.Contains(result.Text, "server-a") { + t.Errorf("expected server-a in output, got %q", result.Text) + } + if !strings.Contains(result.Text, "server-b") { + t.Errorf("expected server-b in output, got %q", result.Text) + } + if !strings.Contains(result.Text, "10") { + t.Errorf("expected tool count in output, got %q", result.Text) + } + }) +} + +func TestBuiltin_ICE(t *testing.T) { + r := newTestRegistry() + + t.Run("disabled", func(t *testing.T) { + result := r.Execute(&Context{ICEEnabled: false}, "ice", nil) + if !strings.Contains(result.Text, "not enabled") { + t.Errorf("expected disabled message, got %q", result.Text) + } + }) + + t.Run("enabled shows status", func(t *testing.T) { + ctx := &Context{ + ICEEnabled: true, + ICEConversations: 5, + ICESessionID: "abc-123", + } + result := r.Execute(ctx, "ice", nil) + if !strings.Contains(result.Text, "enabled") { + t.Errorf("expected enabled status, got %q", result.Text) + } + if !strings.Contains(result.Text, "5") { + t.Errorf("expected conversation count, got %q", result.Text) + } + if !strings.Contains(result.Text, "abc-123") { + t.Errorf("expected session ID, got %q", result.Text) + } + }) +} + +func TestBuiltin_Exit(t *testing.T) { + r := newTestRegistry() + result := r.Execute(&Context{}, "exit", nil) + if result.Action != ActionQuit { + t.Errorf("exit action = %d, want %d (ActionQuit)", result.Action, ActionQuit) + } +} diff --git a/internal/command/custom.go b/internal/command/custom.go new file mode 100644 index 0000000..52c1557 --- /dev/null +++ b/internal/command/custom.go @@ -0,0 +1,116 @@ +package command + +import ( + "os" + "path/filepath" + "strings" +) + +// CustomCommand represents a user-defined command loaded from a markdown file. +type CustomCommand struct { + Name string + Description string + Template string // prompt template with {{input}} placeholder +} + +// LoadCustomCommands reads .md files from the commands directory and returns +// parsed custom commands. Each file should have YAML-like frontmatter: +// +// --- +// name: review +// description: Code review prompt +// --- +// Review this code: {{input}} +func LoadCustomCommands(dir string) []CustomCommand { + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + + var cmds []CustomCommand + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") { + continue + } + data, err := os.ReadFile(filepath.Join(dir, entry.Name())) + if err != nil { + continue + } + if cmd, ok := parseCustomCommand(string(data)); ok { + cmds = append(cmds, cmd) + } + } + return cmds +} + +// parseCustomCommand parses a markdown file with YAML frontmatter. +func parseCustomCommand(content string) (CustomCommand, bool) { + content = strings.TrimSpace(content) + if !strings.HasPrefix(content, "---") { + return CustomCommand{}, false + } + + // Find end of frontmatter. + rest := content[3:] + idx := strings.Index(rest, "---") + if idx < 0 { + return CustomCommand{}, false + } + + frontmatter := rest[:idx] + body := strings.TrimSpace(rest[idx+3:]) + + cmd := CustomCommand{Template: body} + + // Parse simple key: value pairs from frontmatter. + for _, line := range strings.Split(frontmatter, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(parts[0]) + val := strings.TrimSpace(parts[1]) + switch key { + case "name": + cmd.Name = val + case "description": + cmd.Description = val + } + } + + if cmd.Name == "" || cmd.Template == "" { + return CustomCommand{}, false + } + + return cmd, true +} + +// RegisterCustomCommands loads and registers custom commands from the given directory. +func RegisterCustomCommands(r *Registry, dir string) { + cmds := LoadCustomCommands(dir) + for _, cc := range cmds { + // Capture for closure. + tmpl := cc.Template + desc := cc.Description + if desc == "" { + desc = "Custom command" + } + + r.Register(&Command{ + Name: cc.Name, + Description: desc, + Handler: func(_ *Context, args []string) Result { + input := strings.Join(args, " ") + prompt := strings.ReplaceAll(tmpl, "{{input}}", input) + return Result{ + Action: ActionSendPrompt, + Data: prompt, + } + }, + }) + } +} diff --git a/internal/command/custom_test.go b/internal/command/custom_test.go new file mode 100644 index 0000000..c77e7de --- /dev/null +++ b/internal/command/custom_test.go @@ -0,0 +1,148 @@ +package command + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseCustomCommand(t *testing.T) { + tests := []struct { + name string + content string + wantOK bool + wantCmd CustomCommand + }{ + { + name: "valid command", + content: `--- +name: review +description: Code review prompt +--- +Review this code: {{input}}`, + wantOK: true, + wantCmd: CustomCommand{ + Name: "review", + Description: "Code review prompt", + Template: "Review this code: {{input}}", + }, + }, + { + name: "no description", + content: `--- +name: explain +--- +Explain this: {{input}}`, + wantOK: true, + wantCmd: CustomCommand{ + Name: "explain", + Template: "Explain this: {{input}}", + }, + }, + { + name: "no frontmatter", + content: "just some text", + wantOK: false, + }, + { + name: "no name", + content: `--- +description: something +--- +body`, + wantOK: false, + }, + { + name: "no body", + content: `--- +name: empty +---`, + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, ok := parseCustomCommand(tt.content) + if ok != tt.wantOK { + t.Fatalf("parseCustomCommand() ok = %v, want %v", ok, tt.wantOK) + } + if !ok { + return + } + if cmd.Name != tt.wantCmd.Name { + t.Errorf("Name = %q, want %q", cmd.Name, tt.wantCmd.Name) + } + if cmd.Description != tt.wantCmd.Description { + t.Errorf("Description = %q, want %q", cmd.Description, tt.wantCmd.Description) + } + if cmd.Template != tt.wantCmd.Template { + t.Errorf("Template = %q, want %q", cmd.Template, tt.wantCmd.Template) + } + }) + } +} + +func TestLoadCustomCommands(t *testing.T) { + dir := t.TempDir() + + // Write a valid command file. + err := os.WriteFile(filepath.Join(dir, "review.md"), []byte(`--- +name: review +description: Review code +--- +Review: {{input}}`), 0o644) + if err != nil { + t.Fatal(err) + } + + // Write an invalid file (no frontmatter). + err = os.WriteFile(filepath.Join(dir, "invalid.md"), []byte("just text"), 0o644) + if err != nil { + t.Fatal(err) + } + + // Write a non-md file (should be ignored). + err = os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("not a command"), 0o644) + if err != nil { + t.Fatal(err) + } + + cmds := LoadCustomCommands(dir) + if len(cmds) != 1 { + t.Fatalf("LoadCustomCommands() returned %d commands, want 1", len(cmds)) + } + if cmds[0].Name != "review" { + t.Errorf("Name = %q, want %q", cmds[0].Name, "review") + } +} + +func TestLoadCustomCommands_MissingDir(t *testing.T) { + cmds := LoadCustomCommands("/nonexistent/path") + if len(cmds) != 0 { + t.Errorf("expected empty result for missing dir, got %d", len(cmds)) + } +} + +func TestRegisterCustomCommands(t *testing.T) { + dir := t.TempDir() + err := os.WriteFile(filepath.Join(dir, "test.md"), []byte(`--- +name: testcmd +description: A test command +--- +Do this: {{input}}`), 0o644) + if err != nil { + t.Fatal(err) + } + + reg := NewRegistry() + RegisterCustomCommands(reg, dir) + + result := reg.Execute(&Context{}, "testcmd", []string{"hello", "world"}) + if result.Action != ActionSendPrompt { + t.Errorf("Action = %v, want ActionSendPrompt", result.Action) + } + if result.Data != "Do this: hello world" { + t.Errorf("Data = %q, want %q", result.Data, "Do this: hello world") + } +} diff --git a/internal/command/registry.go b/internal/command/registry.go new file mode 100644 index 0000000..ee956a3 --- /dev/null +++ b/internal/command/registry.go @@ -0,0 +1,129 @@ +package command + +import ( + "fmt" + "sort" + "strings" +) + +// Command represents a slash command. +type Command struct { + Name string + Aliases []string + Description string + Usage string + Handler func(ctx *Context, args []string) Result +} + +// Context provides commands with read access to application state. +type Context struct { + Model string + ModelList []string + AgentProfile string + AgentList []string + ToolCount int + ServerCount int + ServerNames []string + Skills []SkillInfo + LoadedFile string + ICEEnabled bool + ICEConversations int + ICESessionID string + // Token stats + SessionEvalTotal int + SessionPromptTotal int + SessionTurnCount int + NumCtx int + CurrentModel string + // File changes + FileChanges map[string]int // path → modification count +} + +// SkillInfo is a read-only view of a skill for command display. +type SkillInfo struct { + Name string + Description string + Active bool +} + +// Result is returned by command handlers to describe what to do. +type Result struct { + Text string // Display text (shown as system message) + Action Action // Side effect for the TUI to execute + Data string // Optional payload (e.g. file path, model name) + Error string // Error text (takes priority over Text) +} + +// Action describes a side effect the TUI should perform. +type Action int + +const ( + ActionNone Action = iota + ActionShowHelp // Show help overlay + ActionClear // Clear conversation history + ActionQuit // Exit the application + ActionLoadContext // Load markdown context (Data = path) + ActionUnloadContext // Remove loaded context + ActionActivateSkill // Activate skill (Data = name) + ActionDeactivateSkill // Deactivate skill (Data = name) + ActionSwitchModel // Switch model (Data = model name) + ActionSwitchAgent // Switch agent profile (Data = agent name) + ActionShowSessions // Open sessions picker + ActionShowModelPicker // Open model picker overlay + ActionCommit // Generate commit message and commit + ActionSendPrompt // Send Data as a message to the agent + ActionExport // Export conversation (Data = path) + ActionImport // Import conversation (Data = path) +) + +// Registry holds all registered slash commands. +type Registry struct { + commands map[string]*Command // name/alias → command + all []*Command // ordered list +} + +// NewRegistry creates an empty command registry. +func NewRegistry() *Registry { + return &Registry{ + commands: make(map[string]*Command), + } +} + +// Register adds a command to the registry. +func (r *Registry) Register(cmd *Command) { + r.all = append(r.all, cmd) + r.commands[cmd.Name] = cmd + for _, alias := range cmd.Aliases { + r.commands[alias] = cmd + } +} + +// Execute dispatches a slash command by name and returns the result. +func (r *Registry) Execute(ctx *Context, name string, args []string) Result { + cmd, ok := r.commands[name] + if !ok { + return Result{Error: fmt.Sprintf("unknown command: /%s — type /help for available commands", name)} + } + return cmd.Handler(ctx, args) +} + +// All returns all registered commands in registration order. +func (r *Registry) All() []*Command { + return r.all +} + +// Match returns commands whose name starts with the given prefix. +func (r *Registry) Match(prefix string) []*Command { + var matches []*Command + seen := make(map[string]bool) + for _, cmd := range r.all { + if strings.HasPrefix(cmd.Name, prefix) && !seen[cmd.Name] { + matches = append(matches, cmd) + seen[cmd.Name] = true + } + } + sort.Slice(matches, func(i, j int) bool { + return matches[i].Name < matches[j].Name + }) + return matches +} diff --git a/internal/command/registry_test.go b/internal/command/registry_test.go new file mode 100644 index 0000000..e7d81c3 --- /dev/null +++ b/internal/command/registry_test.go @@ -0,0 +1,164 @@ +package command + +import "testing" + +func TestRegistry_Register(t *testing.T) { + r := NewRegistry() + cmd := &Command{ + Name: "test", + Description: "A test command", + Handler: func(_ *Context, _ []string) Result { + return Result{Text: "ok"} + }, + } + r.Register(cmd) + + all := r.All() + if len(all) != 1 { + t.Fatalf("expected 1 command, got %d", len(all)) + } + if all[0].Name != "test" { + t.Errorf("command name = %q, want %q", all[0].Name, "test") + } + + // Execute to verify it was registered correctly + result := r.Execute(&Context{}, "test", nil) + if result.Text != "ok" { + t.Errorf("result text = %q, want %q", result.Text, "ok") + } +} + +func TestRegistry_Execute(t *testing.T) { + r := NewRegistry() + called := false + r.Register(&Command{ + Name: "run", + Handler: func(_ *Context, _ []string) Result { + called = true + return Result{Text: "executed"} + }, + }) + + t.Run("found command executes handler", func(t *testing.T) { + result := r.Execute(&Context{}, "run", nil) + if !called { + t.Error("handler was not called") + } + if result.Text != "executed" { + t.Errorf("result text = %q, want %q", result.Text, "executed") + } + }) + + t.Run("not found returns error", func(t *testing.T) { + result := r.Execute(&Context{}, "nonexistent", nil) + if result.Error == "" { + t.Error("expected error for unknown command") + } + }) +} + +func TestRegistry_ExecuteByAlias(t *testing.T) { + r := NewRegistry() + r.Register(&Command{ + Name: "mycommand", + Aliases: []string{"mc", "m"}, + Handler: func(_ *Context, _ []string) Result { + return Result{Text: "alias works"} + }, + }) + + tests := []struct { + name string + cmdName string + wantOk bool + }{ + {name: "by name", cmdName: "mycommand", wantOk: true}, + {name: "by alias mc", cmdName: "mc", wantOk: true}, + {name: "by alias m", cmdName: "m", wantOk: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := r.Execute(&Context{}, tt.cmdName, nil) + if tt.wantOk && result.Error != "" { + t.Errorf("unexpected error: %s", result.Error) + } + if tt.wantOk && result.Text != "alias works" { + t.Errorf("result text = %q, want %q", result.Text, "alias works") + } + }) + } +} + +func TestRegistry_All(t *testing.T) { + r := NewRegistry() + names := []string{"alpha", "beta", "gamma"} + for _, name := range names { + n := name // capture + r.Register(&Command{ + Name: n, + Handler: func(_ *Context, _ []string) Result { return Result{} }, + }) + } + + all := r.All() + if len(all) != len(names) { + t.Fatalf("expected %d commands, got %d", len(names), len(all)) + } + for i, cmd := range all { + if cmd.Name != names[i] { + t.Errorf("All()[%d].Name = %q, want %q", i, cmd.Name, names[i]) + } + } +} + +func TestRegistry_Match(t *testing.T) { + r := NewRegistry() + r.Register(&Command{ + Name: "model", + Aliases: []string{"m"}, + Handler: func(_ *Context, _ []string) Result { return Result{} }, + }) + r.Register(&Command{ + Name: "models", + Aliases: []string{"ml"}, + Handler: func(_ *Context, _ []string) Result { return Result{} }, + }) + r.Register(&Command{ + Name: "help", + Handler: func(_ *Context, _ []string) Result { return Result{} }, + }) + + tests := []struct { + name string + prefix string + want int + }{ + {name: "prefix mo matches model and models", prefix: "mo", want: 2}, + {name: "prefix model matches model and models", prefix: "model", want: 2}, + {name: "prefix models matches only models", prefix: "models", want: 1}, + {name: "prefix h matches help", prefix: "h", want: 1}, + {name: "no match", prefix: "z", want: 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := r.Match(tt.prefix) + if len(matches) != tt.want { + t.Errorf("Match(%q) returned %d results, want %d", tt.prefix, len(matches), tt.want) + } + }) + } + + // Verify no duplicates from aliases + t.Run("aliases dont create dupes", func(t *testing.T) { + matches := r.Match("model") + seen := make(map[string]bool) + for _, m := range matches { + if seen[m.Name] { + t.Errorf("duplicate match for %q", m.Name) + } + seen[m.Name] = true + } + }) +} diff --git a/internal/config/agents.go b/internal/config/agents.go new file mode 100644 index 0000000..74e7434 --- /dev/null +++ b/internal/config/agents.go @@ -0,0 +1,366 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +type AgentsDir struct { + Path string + Agents map[string]AgentProfile + MCPServers []ServerConfig + GlobalInstructions string + Skills []SkillDef +} + +type AgentProfile struct { + Name string `yaml:"name" json:"name"` + Description string `yaml:"description" json:"description"` + Model string `yaml:"model" json:"model"` + Skills []string `yaml:"skills" json:"skills"` + MCPServers []string `yaml:"mcp_servers" json:"mcp_servers"` + SystemPrompt string `yaml:"system_prompt" json:"system_prompt"` + UseCases []string `yaml:"use_cases" json:"use_cases"` +} + +type SkillDef struct { + Name string `yaml:"name" json:"name"` + Description string `yaml:"description" json:"description"` + Path string `yaml:"path" json:"path"` +} + +type MCPConfig struct { + Servers []ServerConfig `json:"servers,omitempty"` +} + +type ModelsConfig struct { + Models []Model `yaml:"models,omitempty"` + DefaultModel string `yaml:"default_model,omitempty"` + FallbackChain []string `yaml:"fallback_chain,omitempty"` + AutoSelect bool `yaml:"auto_select,omitempty"` + EmbedModel string `yaml:"embed_model,omitempty"` +} + +func FindAgentsDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + + candidates := []string{ + filepath.Join(home, ".agents"), + filepath.Join(home, ".config", "agents"), + } + + for _, dir := range candidates { + if _, err := os.Stat(dir); err == nil { + return dir + } + } + + return "" +} + +func FindAgentsDirWithCreate() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("get home dir: %w", err) + } + + dirs := []string{ + filepath.Join(home, ".agents"), + filepath.Join(home, ".config", "agents"), + } + + for _, dir := range dirs { + if _, err := os.Stat(dir); err == nil { + return dir, nil + } + } + + if err := os.MkdirAll(dirs[0], 0755); err != nil { + return "", fmt.Errorf("create agents dir: %w", err) + } + + return dirs[0], nil +} + +func LoadAgentsDir(path string) (*AgentsDir, error) { + if path == "" { + path = FindAgentsDir() + if path == "" { + return &AgentsDir{ + Path: "", + Agents: make(map[string]AgentProfile), + }, nil + } + } + + dir := &AgentsDir{ + Path: path, + Agents: make(map[string]AgentProfile), + } + + if err := dir.loadAgents(path); err != nil { + return nil, fmt.Errorf("load agents: %w", err) + } + + if err := dir.loadMCP(path); err != nil { + return nil, fmt.Errorf("load MCP: %w", err) + } + + if err := dir.loadGlobalInstructions(path); err != nil { + return nil, fmt.Errorf("load instructions: %w", err) + } + + if err := dir.loadSkills(path); err != nil { + return nil, fmt.Errorf("load skills: %w", err) + } + + return dir, nil +} + +func (d *AgentsDir) loadAgents(path string) error { + agentsDir := filepath.Join(path, "agents") + entries, err := os.ReadDir(agentsDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + for _, entry := range entries { + if entry.IsDir() { + agentPath := filepath.Join(agentsDir, entry.Name(), "agent.yaml") + if _, err := os.Stat(agentPath); err != nil { + agentPath = filepath.Join(agentsDir, entry.Name(), "agent.md") + } + if _, err := os.Stat(agentPath); err != nil { + continue + } + + data, err := os.ReadFile(agentPath) + if err != nil { + continue + } + + var profile AgentProfile + if err := yaml.Unmarshal(data, &profile); err != nil { + continue + } + + if profile.Name == "" { + profile.Name = entry.Name() + } + + d.Agents[profile.Name] = profile + } + } + + return nil +} + +func (d *AgentsDir) loadMCP(path string) error { + mcpPath := filepath.Join(path, "mcp.json") + data, err := os.ReadFile(mcpPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var mcpCfg MCPConfig + if err := json.Unmarshal(data, &mcpCfg); err != nil { + return fmt.Errorf("parse mcp.json: %w", err) + } + + d.MCPServers = mcpCfg.Servers + return nil +} + +func (d *AgentsDir) loadGlobalInstructions(path string) error { + paths := []string{ + filepath.Join(path, "agents.md"), + filepath.Join(path, "instructions.md"), + } + + for _, p := range paths { + data, err := os.ReadFile(p) + if err == nil { + d.GlobalInstructions = string(data) + return nil + } + } + + return nil +} + +func (d *AgentsDir) loadSkills(path string) error { + skillsDir := filepath.Join(path, "skills") + entries, err := os.ReadDir(skillsDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + skillDir := filepath.Join(skillsDir, entry.Name()) + + // Try both SKILL.md and skill.md (case insensitive check) + skillPath := "" + for _, name := range []string{"SKILL.md", "skill.md"} { + path := filepath.Join(skillDir, name) + if info, err := os.Stat(path); err == nil && !info.IsDir() { + skillPath = path + break + } + } + + if skillPath == "" { + continue + } + + data, err := os.ReadFile(skillPath) + if err != nil { + continue + } + + d.Skills = append(d.Skills, SkillDef{ + Name: entry.Name(), + Description: extractDescription(string(data)), + Path: skillPath, + }) + } + + return nil +} + +func extractDescription(content string) string { + for _, line := range splitLines(content) { + line = trimWhitespace(line) + if line == "" || startsWith(line, "#") { + continue + } + return line + } + return "" +} + +func splitLines(s string) []string { + var lines []string + start := 0 + for i, r := range s { + if r == '\n' { + lines = append(lines, s[start:i]) + start = i + 1 + } + } + lines = append(lines, s[start:]) + return lines +} + +func trimWhitespace(s string) string { + start := 0 + end := len(s) + for start < end && (s[start] == ' ' || s[start] == '\t') { + start++ + } + for end > start && (s[end-1] == ' ' || s[end-1] == '\t') { + end-- + } + return s[start:end] +} + +func startsWith(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +func (d *AgentsDir) GetAgent(name string) *AgentProfile { + if agent, ok := d.Agents[name]; ok { + return &agent + } + return nil +} + +func (d *AgentsDir) ListAgents() []AgentProfile { + agents := make([]AgentProfile, 0, len(d.Agents)) + for _, agent := range d.Agents { + agents = append(agents, agent) + } + return agents +} + +func (d *AgentsDir) GetSkills() []SkillDef { + return d.Skills +} + +func (d *AgentsDir) HasMCP() bool { + return len(d.MCPServers) > 0 +} + +func (d *AgentsDir) GetMCPServers() []ServerConfig { + return d.MCPServers +} + +func (d *AgentsDir) GetGlobalInstructions() string { + return d.GlobalInstructions +} + +func CreateDefaultAgentsDir() error { + dir, err := FindAgentsDirWithCreate() + if err != nil { + return err + } + + subdirs := []string{"agents", "skills", "tasks", "memories"} + for _, sub := range subdirs { + path := filepath.Join(dir, sub) + if _, err := os.Stat(path); err != nil { + if err := os.MkdirAll(path, 0755); err != nil { + return fmt.Errorf("create %s: %w", sub, err) + } + } + } + + mcpPath := filepath.Join(dir, "mcp.json") + if _, err := os.Stat(mcpPath); err != nil { + defaultMCP := MCPConfig{ + Servers: []ServerConfig{}, + } + data, _ := json.MarshalIndent(defaultMCP, "", " ") + if err := os.WriteFile(mcpPath, data, 0644); err != nil { + return fmt.Errorf("write mcp.json: %w", err) + } + } + + agentsPath := filepath.Join(dir, "agents.md") + if _, err := os.Stat(agentsPath); err != nil { + defaultContent := `# Global Agent Instructions + +You are a helpful local AI coding assistant. + +## Guidelines +- Be concise and direct +- Explain your reasoning +- Ask for clarification when needed +- Never fabricate information +` + if err := os.WriteFile(agentsPath, []byte(defaultContent), 0644); err != nil { + return fmt.Errorf("write agents.md: %w", err) + } + } + + return nil +} diff --git a/internal/config/agents_test.go b/internal/config/agents_test.go new file mode 100644 index 0000000..7d0f11a --- /dev/null +++ b/internal/config/agents_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestExtractDescription(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "first non-header non-empty line", + content: "# Title\n\nThis is the description.\nMore text.", + want: "This is the description.", + }, + { + name: "header only content", + content: "# Title\n## Subtitle\n### Another", + want: "", + }, + { + name: "empty content", + content: "", + want: "", + }, + { + name: "whitespace around description", + content: "# Title\n\n Indented description \n", + want: "Indented description", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractDescription(tt.content) + if got != tt.want { + t.Errorf("extractDescription() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestSplitLines(t *testing.T) { + tests := []struct { + name string + s string + want int // expected number of lines + }{ + {name: "normal lines", s: "a\nb\nc", want: 3}, + {name: "empty string", s: "", want: 1}, + {name: "trailing newline", s: "a\nb\n", want: 3}, + {name: "single line", s: "hello", want: 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := splitLines(tt.s) + if len(got) != tt.want { + t.Errorf("splitLines(%q) returned %d lines, want %d (lines: %v)", tt.s, len(got), tt.want, got) + } + }) + } +} + +func TestTrimWhitespace(t *testing.T) { + tests := []struct { + name string + s string + want string + }{ + {name: "tabs", s: "\thello\t", want: "hello"}, + {name: "spaces", s: " hello ", want: "hello"}, + {name: "mixed", s: "\t hello \t", want: "hello"}, + {name: "already trimmed", s: "hello", want: "hello"}, + {name: "empty", s: "", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := trimWhitespace(tt.s) + if got != tt.want { + t.Errorf("trimWhitespace(%q) = %q, want %q", tt.s, got, tt.want) + } + }) + } +} + +func TestStartsWith(t *testing.T) { + tests := []struct { + name string + s string + prefix string + want bool + }{ + {name: "match", s: "hello world", prefix: "hello", want: true}, + {name: "no match", s: "hello world", prefix: "world", want: false}, + {name: "empty prefix", s: "hello", prefix: "", want: true}, + {name: "longer prefix", s: "hi", prefix: "hello", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := startsWith(tt.s, tt.prefix) + if got != tt.want { + t.Errorf("startsWith(%q, %q) = %v, want %v", tt.s, tt.prefix, got, tt.want) + } + }) + } +} + +func TestLoadAgentsDir(t *testing.T) { + t.Run("valid temp structure with agent", func(t *testing.T) { + tmp := t.TempDir() + + // Create agents/test-agent/agent.yaml + agentDir := filepath.Join(tmp, "agents", "test-agent") + if err := os.MkdirAll(agentDir, 0755); err != nil { + t.Fatal(err) + } + agentYAML := `name: test-agent +description: A test agent +model: qwen3.5:0.8b +` + if err := os.WriteFile(filepath.Join(agentDir, "agent.yaml"), []byte(agentYAML), 0644); err != nil { + t.Fatal(err) + } + + dir, err := LoadAgentsDir(tmp) + if err != nil { + t.Fatalf("LoadAgentsDir() error: %v", err) + } + if dir.Path != tmp { + t.Errorf("Path = %q, want %q", dir.Path, tmp) + } + if len(dir.Agents) != 1 { + t.Errorf("expected 1 agent, got %d", len(dir.Agents)) + } + agent, ok := dir.Agents["test-agent"] + if !ok { + t.Fatal("expected agent 'test-agent' to exist") + } + if agent.Description != "A test agent" { + t.Errorf("agent description = %q, want %q", agent.Description, "A test agent") + } + }) + + t.Run("empty path uses FindAgentsDir", func(t *testing.T) { + dir, err := LoadAgentsDir("") + if err != nil { + t.Fatalf("LoadAgentsDir('') error: %v", err) + } + // Should return a valid AgentsDir (possibly with no agents) + if dir == nil { + t.Fatal("expected non-nil AgentsDir") + } + if dir.Agents == nil { + t.Error("expected Agents map to be initialized") + } + }) + + t.Run("nonexistent subdirs dont error", func(t *testing.T) { + tmp := t.TempDir() + // Empty temp dir — no agents/, skills/, mcp.json, etc. + dir, err := LoadAgentsDir(tmp) + if err != nil { + t.Fatalf("LoadAgentsDir() error: %v", err) + } + if len(dir.Agents) != 0 { + t.Errorf("expected 0 agents, got %d", len(dir.Agents)) + } + }) +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..558bbc0 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,186 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + + "gopkg.in/yaml.v3" +) + +type Config struct { + Ollama OllamaConfig `yaml:"ollama"` + Model ModelConfig `yaml:"model,omitempty"` + Agents AgentsConfig `yaml:"agents,omitempty"` + Servers []ServerConfig `yaml:"servers,omitempty"` + SkillsDir string `yaml:"skills_dir,omitempty"` + ICE ICEConfig `yaml:"ice,omitempty"` + AgentProfile string `yaml:"agent_profile,omitempty"` + Tools ToolsConfig `yaml:"tools,omitempty"` +} + +type AgentsConfig struct { + Dir string `yaml:"dir,omitempty"` + AutoLoad bool `yaml:"auto_load"` +} + +type ToolsConfig struct { + Timeout string `yaml:"timeout,omitempty"` // e.g., "30s", "2m" + MaxGrepResults int `yaml:"max_grep_results,omitempty"` + MaxIterations int `yaml:"max_iterations,omitempty"` +} + +type ICEConfig struct { + Enabled bool `yaml:"enabled"` + EmbedModel string `yaml:"embed_model,omitempty"` + StorePath string `yaml:"store_path,omitempty"` +} + +type OllamaConfig struct { + Model string `yaml:"model"` + BaseURL string `yaml:"base_url"` + NumCtx int `yaml:"num_ctx"` +} + +type ServerConfig struct { + Name string `yaml:"name"` + Command string `yaml:"command,omitempty"` + Args []string `yaml:"args,omitempty"` + Env []string `yaml:"env,omitempty"` + Transport string `yaml:"transport,omitempty"` + URL string `yaml:"url,omitempty"` +} + +func defaults() Config { + modelCfg := DefaultModelConfig() + return Config{ + Ollama: OllamaConfig{ + Model: "qwen3.5:2b", + BaseURL: "http://localhost:11434", + NumCtx: 262144, + }, + Model: modelCfg, + Agents: AgentsConfig{ + Dir: "", + AutoLoad: true, + }, + Tools: ToolsConfig{ + Timeout: "30s", + MaxGrepResults: 500, + MaxIterations: 10, + }, + } +} + +func Load() (*Config, error) { + cfg := defaults() + + localPath := findConfigFile() + if localPath != "" { + data, err := os.ReadFile(localPath) + if err != nil { + return nil, fmt.Errorf("read config %s: %w", localPath, err) + } + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config %s: %w", localPath, err) + } + } + + agentsDir := cfg.Agents.Dir + if agentsDir == "" { + agentsDir = FindAgentsDir() + } + + var agentsData *AgentsDir + if agentsDir != "" && cfg.Agents.AutoLoad { + var err error + agentsData, err = LoadAgentsDir(agentsDir) + if err != nil { + fmt.Fprintf(os.Stderr, "warning: failed to load .agents directory: %v\n", err) + } else { + if agentsData != nil { + if cfg.Ollama.Model == "" { + cfg.Ollama.Model = cfg.Model.DefaultModel + } + + if len(cfg.Servers) == 0 && agentsData.HasMCP() { + cfg.Servers = agentsData.GetMCPServers() + } + } + } + } + + applyEnvOverrides(&cfg) + return &cfg, nil +} + +func LoadWithAgentsDir() (*Config, *AgentsDir, error) { + cfg, err := Load() + if err != nil { + return nil, nil, err + } + + agentsDir := cfg.Agents.Dir + if agentsDir == "" { + agentsDir = FindAgentsDir() + } + var agents *AgentsDir + if agentsDir != "" && cfg.Agents.AutoLoad { + agents, _ = LoadAgentsDir(agentsDir) + } + + return cfg, agents, nil +} + +func findConfigFile() string { + candidates := []string{ + "ai-agent.yaml", + "ai-agent.yml", + "config.yaml", + "config.yml", + } + if home, err := os.UserHomeDir(); err == nil { + candidates = append(candidates, + filepath.Join(home, ".config", "ai-agent", "config.yaml"), + filepath.Join(home, ".config", "ai-agent", "config.yml"), + ) + } + for _, path := range candidates { + if _, err := os.Stat(path); err == nil { + return path + } + } + return "" +} + +func applyEnvOverrides(cfg *Config) { + if v := os.Getenv("OLLAMA_HOST"); v != "" { + cfg.Ollama.BaseURL = v + } + if v := os.Getenv("LOCAL_AGENT_MODEL"); v != "" { + cfg.Ollama.Model = v + } + if v := os.Getenv("LOCAL_AGENT_AGENTS_DIR"); v != "" { + cfg.Agents.Dir = v + } + if v := os.Getenv("LOCAL_AGENT_TOOLS_TIMEOUT"); v != "" { + cfg.Tools.Timeout = v + } + if v := os.Getenv("LOCAL_AGENT_TOOLS_MAX_GREP"); v != "" { + cfg.Tools.MaxGrepResults = parseEnvInt(v, cfg.Tools.MaxGrepResults) + } + if v := os.Getenv("LOCAL_AGENT_TOOLS_MAX_ITER"); v != "" { + cfg.Tools.MaxIterations = parseEnvInt(v, cfg.Tools.MaxIterations) + } + if v := os.Getenv("LOCAL_AGENT_ICE_EMBED_MODEL"); v != "" { + cfg.ICE.EmbedModel = v + } +} + +func parseEnvInt(v string, defaultVal int) int { + if i, err := strconv.Atoi(v); err == nil { + return i + } + return defaultVal +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..5af3e75 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,82 @@ +package config + +import "testing" + +func TestDefaults(t *testing.T) { + cfg := defaults() + + tests := []struct { + name string + got string + want string + }{ + {name: "Ollama.Model", got: cfg.Ollama.Model, want: "qwen3.5:2b"}, + {name: "Ollama.BaseURL", got: cfg.Ollama.BaseURL, want: "http://localhost:11434"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.want { + t.Errorf("%s = %q, want %q", tt.name, tt.got, tt.want) + } + }) + } + + if cfg.Ollama.NumCtx != 262144 { + t.Errorf("Ollama.NumCtx = %d, want %d", cfg.Ollama.NumCtx, 262144) + } + + if !cfg.Model.AutoSelect { + t.Error("Model.AutoSelect should be true by default") + } +} + +func TestApplyEnvOverrides(t *testing.T) { + tests := []struct { + name string + envKey string + envVal string + checkFn func(cfg *Config) string + want string + }{ + { + name: "OLLAMA_HOST overrides BaseURL", + envKey: "OLLAMA_HOST", + envVal: "http://custom:1234", + checkFn: func(cfg *Config) string { + return cfg.Ollama.BaseURL + }, + want: "http://custom:1234", + }, + { + name: "LOCAL_AGENT_MODEL overrides Model", + envKey: "LOCAL_AGENT_MODEL", + envVal: "custom-model", + checkFn: func(cfg *Config) string { + return cfg.Ollama.Model + }, + want: "custom-model", + }, + { + name: "LOCAL_AGENT_AGENTS_DIR overrides AgentsDir", + envKey: "LOCAL_AGENT_AGENTS_DIR", + envVal: "/custom/agents", + checkFn: func(cfg *Config) string { + return cfg.Agents.Dir + }, + want: "/custom/agents", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(tt.envKey, tt.envVal) + cfg := defaults() + applyEnvOverrides(&cfg) + got := tt.checkFn(&cfg) + if got != tt.want { + t.Errorf("after setting %s=%q, got %q, want %q", tt.envKey, tt.envVal, got, tt.want) + } + }) + } +} diff --git a/internal/config/ignore.go b/internal/config/ignore.go new file mode 100644 index 0000000..5017390 --- /dev/null +++ b/internal/config/ignore.go @@ -0,0 +1,103 @@ +package config + +import ( + "bufio" + "os" + "path/filepath" + "strings" +) + +// IgnorePatterns holds parsed .agentignore patterns. +type IgnorePatterns struct { + patterns []string + raw string // original file content for injection into system prompt +} + +// LoadIgnoreFile reads and parses an .agentignore file from the given directory. +// Returns nil if no .agentignore file exists (not an error). +func LoadIgnoreFile(dir string) *IgnorePatterns { + path := filepath.Join(dir, ".agentignore") + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + + var patterns []string + var rawLines []string + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + rawLines = append(rawLines, line) + + trimmed := strings.TrimSpace(line) + // Skip empty lines and comments. + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + patterns = append(patterns, trimmed) + } + + return &IgnorePatterns{ + patterns: patterns, + raw: strings.Join(rawLines, "\n"), + } +} + +// Match returns true if the given path should be ignored. +// Returns false if the receiver is nil. +func (ip *IgnorePatterns) Match(path string) bool { + if ip == nil || len(ip.patterns) == 0 { + return false + } + + // Normalise the path separators and remove trailing slashes for comparison. + path = filepath.ToSlash(path) + cleanPath := strings.TrimSuffix(path, "/") + + for _, pattern := range ip.patterns { + pat := strings.TrimSuffix(pattern, "/") + + // Check each component of the path against the pattern. + // e.g. "node_modules" should match "node_modules", "node_modules/foo", + // and "src/node_modules/bar". + parts := strings.Split(cleanPath, "/") + for _, part := range parts { + if matched, _ := filepath.Match(pat, part); matched { + return true + } + } + + // Also try matching the full path with the pattern (for glob patterns + // that include path separators like "build/output"). + if matched, _ := filepath.Match(pat, cleanPath); matched { + return true + } + + // Prefix match: path starts with the pattern directory. + if strings.HasPrefix(cleanPath, pat+"/") || cleanPath == pat { + return true + } + } + + return false +} + +// Raw returns the raw file content for system prompt injection. +// Returns an empty string if the receiver is nil. +func (ip *IgnorePatterns) Raw() string { + if ip == nil { + return "" + } + return ip.raw +} + +// Patterns returns the list of patterns. +// Returns nil if the receiver is nil. +func (ip *IgnorePatterns) Patterns() []string { + if ip == nil { + return nil + } + return ip.patterns +} diff --git a/internal/config/ignore_test.go b/internal/config/ignore_test.go new file mode 100644 index 0000000..c707c30 --- /dev/null +++ b/internal/config/ignore_test.go @@ -0,0 +1,183 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadIgnoreFile_Valid(t *testing.T) { + dir := t.TempDir() + content := `# Build artifacts +node_modules +*.log +.git +build/ +dist/ +vendor/ +` + if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + ip := LoadIgnoreFile(dir) + if ip == nil { + t.Fatal("expected non-nil IgnorePatterns") + } + + wantPatterns := []string{"node_modules", "*.log", ".git", "build/", "dist/", "vendor/"} + if len(ip.Patterns()) != len(wantPatterns) { + t.Fatalf("got %d patterns, want %d", len(ip.Patterns()), len(wantPatterns)) + } + for i, p := range ip.Patterns() { + if p != wantPatterns[i] { + t.Errorf("pattern[%d] = %q, want %q", i, p, wantPatterns[i]) + } + } + + if ip.Raw() != content[:len(content)-1] { // raw joins lines without trailing newline from Join + // Just check it contains the comment and patterns + if ip.Raw() == "" { + t.Error("Raw() should not be empty") + } + } +} + +func TestLoadIgnoreFile_Missing(t *testing.T) { + dir := t.TempDir() + ip := LoadIgnoreFile(dir) + if ip != nil { + t.Error("expected nil for missing .agentignore") + } +} + +func TestLoadIgnoreFile_Empty(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(""), 0o644); err != nil { + t.Fatal(err) + } + + ip := LoadIgnoreFile(dir) + if ip == nil { + t.Fatal("expected non-nil IgnorePatterns for empty file") + } + if len(ip.Patterns()) != 0 { + t.Errorf("expected 0 patterns, got %d", len(ip.Patterns())) + } +} + +func TestLoadIgnoreFile_CommentsOnly(t *testing.T) { + dir := t.TempDir() + content := "# This is a comment\n# Another comment\n\n" + if err := os.WriteFile(filepath.Join(dir, ".agentignore"), []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + ip := LoadIgnoreFile(dir) + if ip == nil { + t.Fatal("expected non-nil IgnorePatterns") + } + if len(ip.Patterns()) != 0 { + t.Errorf("expected 0 patterns for comments-only file, got %d", len(ip.Patterns())) + } +} + +func TestIgnorePatterns_Match_Exact(t *testing.T) { + ip := &IgnorePatterns{ + patterns: []string{"node_modules", ".git", "vendor"}, + } + + tests := []struct { + path string + want bool + }{ + {"node_modules", true}, + {"node_modules/package/index.js", true}, + {".git", true}, + {".git/config", true}, + {"vendor", true}, + {"vendor/lib/foo.go", true}, + {"src/main.go", false}, + {"README.md", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + if got := ip.Match(tt.path); got != tt.want { + t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestIgnorePatterns_Match_Glob(t *testing.T) { + ip := &IgnorePatterns{ + patterns: []string{"*.log", "*.tmp"}, + } + + tests := []struct { + path string + want bool + }{ + {"app.log", true}, + {"debug.log", true}, + {"temp.tmp", true}, + {"logs/app.log", true}, + {"main.go", false}, + {"log.txt", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + if got := ip.Match(tt.path); got != tt.want { + t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestIgnorePatterns_Match_DirectoryPattern(t *testing.T) { + ip := &IgnorePatterns{ + patterns: []string{"build/", "dist/"}, + } + + tests := []struct { + path string + want bool + }{ + {"build", true}, + {"build/output.js", true}, + {"dist", true}, + {"dist/bundle.js", true}, + {"src/build.go", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + if got := ip.Match(tt.path); got != tt.want { + t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestIgnorePatterns_Match_NilReceiver(t *testing.T) { + var ip *IgnorePatterns + if ip.Match("anything") { + t.Error("nil IgnorePatterns should not match anything") + } +} + +func TestIgnorePatterns_Raw_NilReceiver(t *testing.T) { + var ip *IgnorePatterns + if ip.Raw() != "" { + t.Error("nil IgnorePatterns Raw() should return empty string") + } +} + +func TestIgnorePatterns_Patterns_NilReceiver(t *testing.T) { + var ip *IgnorePatterns + if ip.Patterns() != nil { + t.Error("nil IgnorePatterns Patterns() should return nil") + } +} diff --git a/internal/config/models.go b/internal/config/models.go new file mode 100644 index 0000000..aa54569 --- /dev/null +++ b/internal/config/models.go @@ -0,0 +1,167 @@ +package config + +import "fmt" + +type ModelFamily string + +const ( + FamilyQwen3 ModelFamily = "qwen3" + FamilyQwen35 ModelFamily = "qwen3.5" + FamilyLlama ModelFamily = "llama" + FamilyMistral ModelFamily = "mistral" +) + +type ModelCapability int + +const ( + CapabilitySimple ModelCapability = iota + CapabilityMedium + CapabilityComplex + CapabilityAdvanced +) + +type Model struct { + Name string `yaml:"name"` + Family ModelFamily `yaml:"family"` + DisplayName string `yaml:"display_name"` + Size string `yaml:"size"` + Parameters string `yaml:"parameters"` + ContextSize int `yaml:"context_size"` + Capability ModelCapability `yaml:"capability"` + Speed float64 `yaml:"speed"` // 1.0 = baseline + UseCases []string `yaml:"use_cases"` + Description string `yaml:"description"` + Default bool `yaml:"default,omitempty"` +} + +type ModelConfig struct { + Models []Model `yaml:"models"` + DefaultModel string `yaml:"default_model"` + FallbackChain []string `yaml:"fallback_chain"` + AutoSelect bool `yaml:"auto_select"` + EmbedModel string `yaml:"embed_model,omitempty"` +} + +func DefaultModels() []Model { + return []Model{ + { + Name: "qwen3.5:0.8b", + Family: FamilyQwen35, + DisplayName: "Qwen 3.5 0.8B", + Size: "0.8B", + Parameters: "0.8 billion", + ContextSize: 262144, + Capability: CapabilitySimple, + Speed: 4.0, + UseCases: []string{"quick_answers", "simple_tools", "single_file_edits"}, + Description: "Fast, lightweight model for simple tasks and quick answers", + Default: false, + }, + { + Name: "qwen3.5:2b", + Family: FamilyQwen35, + DisplayName: "Qwen 3.5 2B", + Size: "2B", + Parameters: "2 billion", + ContextSize: 262144, + Capability: CapabilityMedium, + Speed: 2.5, + UseCases: []string{"code_completion", "simple_refactoring", "explanations"}, + Description: "Balanced model for medium complexity tasks", + Default: true, + }, + { + Name: "qwen3.5:4b", + Family: FamilyQwen35, + DisplayName: "Qwen 3.5 4B", + Size: "4B", + Parameters: "4 billion", + ContextSize: 262144, + Capability: CapabilityComplex, + Speed: 1.5, + UseCases: []string{"multi_step_reasoning", "code_review", "debugging", "refactoring"}, + Description: "Capable model for complex reasoning and code analysis", + Default: false, + }, + { + Name: "qwen3.5:9b", + Family: FamilyQwen35, + DisplayName: "Qwen 3.5 9B", + Size: "9B", + Parameters: "9 billion", + ContextSize: 262144, + Capability: CapabilityAdvanced, + Speed: 1.0, + UseCases: []string{"complex_reasoning", "architecture", "full_stack", "advanced_debugging"}, + Description: "Full capability model for advanced tasks", + Default: false, + }, + } +} + +func DefaultModelConfig() ModelConfig { + models := DefaultModels() + return ModelConfig{ + Models: models, + DefaultModel: "qwen3.5:2b", + FallbackChain: []string{"qwen3.5:2b", "qwen3.5:0.8b", "qwen3.5:4b", "qwen3.5:9b"}, + AutoSelect: true, + EmbedModel: "nomic-embed-text", + } +} + +func (m *Model) IsSimpleTask() bool { + return m.Capability <= CapabilityMedium +} + +func (m *Model) IsComplexTask() bool { + return m.Capability >= CapabilityComplex +} + +func (mc *ModelConfig) GetModel(name string) (*Model, error) { + for _, m := range mc.Models { + if m.Name == name { + return &m, nil + } + } + return nil, fmt.Errorf("model not found: %s", name) +} + +func (mc *ModelConfig) GetDefaultModel() *Model { + for _, m := range mc.Models { + if m.Default { + return &m + } + } + if len(mc.Models) > 0 { + return &mc.Models[len(mc.Models)-1] + } + return nil +} + +func (mc *ModelConfig) SelectModelForTask(taskComplexity string) string { + if !mc.AutoSelect { + return mc.DefaultModel + } + + switch taskComplexity { + case "simple": + return mc.Models[0].Name + case "medium": + for _, m := range mc.Models { + if m.Capability == CapabilityMedium { + return m.Name + } + } + case "complex": + for _, m := range mc.Models { + if m.Capability == CapabilityComplex { + return m.Name + } + } + case "advanced": + return mc.DefaultModel + } + + return mc.DefaultModel +} diff --git a/internal/config/models_test.go b/internal/config/models_test.go new file mode 100644 index 0000000..23ede07 --- /dev/null +++ b/internal/config/models_test.go @@ -0,0 +1,159 @@ +package config + +import "testing" + +func TestModel_IsSimpleTask(t *testing.T) { + tests := []struct { + name string + capability ModelCapability + want bool + }{ + {name: "CapabilitySimple is simple", capability: CapabilitySimple, want: true}, + {name: "CapabilityMedium is simple", capability: CapabilityMedium, want: true}, + {name: "CapabilityComplex is not simple", capability: CapabilityComplex, want: false}, + {name: "CapabilityAdvanced is not simple", capability: CapabilityAdvanced, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &Model{Capability: tt.capability} + if got := m.IsSimpleTask(); got != tt.want { + t.Errorf("Model{Capability: %d}.IsSimpleTask() = %v, want %v", tt.capability, got, tt.want) + } + }) + } +} + +func TestModel_IsComplexTask(t *testing.T) { + tests := []struct { + name string + capability ModelCapability + want bool + }{ + {name: "CapabilitySimple is not complex", capability: CapabilitySimple, want: false}, + {name: "CapabilityMedium is not complex", capability: CapabilityMedium, want: false}, + {name: "CapabilityComplex is complex", capability: CapabilityComplex, want: true}, + {name: "CapabilityAdvanced is complex", capability: CapabilityAdvanced, want: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &Model{Capability: tt.capability} + if got := m.IsComplexTask(); got != tt.want { + t.Errorf("Model{Capability: %d}.IsComplexTask() = %v, want %v", tt.capability, got, tt.want) + } + }) + } +} + +func TestModelConfig_GetModel(t *testing.T) { + cfg := DefaultModelConfig() + + tests := []struct { + name string + model string + wantErr bool + }{ + {name: "found model", model: "qwen3.5:0.8b", wantErr: false}, + {name: "not found", model: "nonexistent", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := cfg.GetModel(tt.model) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if got.Name != tt.model { + t.Errorf("GetModel(%q).Name = %q, want %q", tt.model, got.Name, tt.model) + } + } + }) + } +} + +func TestModelConfig_GetDefaultModel(t *testing.T) { + tests := []struct { + name string + cfg ModelConfig + want string // empty means nil expected + }{ + { + name: "model with Default=true", + cfg: ModelConfig{ + Models: []Model{ + {Name: "a", Default: false}, + {Name: "b", Default: true}, + {Name: "c", Default: false}, + }, + }, + want: "b", + }, + { + name: "no default returns last", + cfg: ModelConfig{ + Models: []Model{ + {Name: "a", Default: false}, + {Name: "b", Default: false}, + }, + }, + want: "b", + }, + { + name: "empty slice returns nil", + cfg: ModelConfig{Models: []Model{}}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cfg.GetDefaultModel() + if tt.want == "" { + if got != nil { + t.Errorf("expected nil, got %+v", got) + } + } else { + if got == nil { + t.Fatal("expected non-nil model, got nil") + } + if got.Name != tt.want { + t.Errorf("GetDefaultModel().Name = %q, want %q", got.Name, tt.want) + } + } + }) + } +} + +func TestModelConfig_SelectModelForTask(t *testing.T) { + cfg := DefaultModelConfig() + + tests := []struct { + name string + complexity string + autoSelect bool + want string + }{ + {name: "auto simple", complexity: "simple", autoSelect: true, want: "qwen3.5:0.8b"}, + {name: "auto medium", complexity: "medium", autoSelect: true, want: "qwen3.5:2b"}, + {name: "auto complex", complexity: "complex", autoSelect: true, want: "qwen3.5:4b"}, + {name: "auto advanced", complexity: "advanced", autoSelect: true, want: cfg.DefaultModel}, + {name: "no autoselect simple", complexity: "simple", autoSelect: false, want: cfg.DefaultModel}, + {name: "no autoselect complex", complexity: "complex", autoSelect: false, want: cfg.DefaultModel}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg.AutoSelect = tt.autoSelect + got := cfg.SelectModelForTask(tt.complexity) + if got != tt.want { + t.Errorf("SelectModelForTask(%q) = %q, want %q", tt.complexity, got, tt.want) + } + }) + } +} diff --git a/internal/config/qwen_router.go b/internal/config/qwen_router.go new file mode 100644 index 0000000..7271a36 --- /dev/null +++ b/internal/config/qwen_router.go @@ -0,0 +1,364 @@ +package config + +import ( + "context" + "strings" + "sync" + "time" +) + +type QwenModelRouter struct { + config *ModelConfig + overrideLog []ModelOverride + modeContext ModeContext + mu sync.RWMutex +} + +type ModeContext int + +const ( + ModeAskContext ModeContext = iota + ModePlanContext + ModeBuildContext +) + +type QwenComplexity string + +const ( + QwenTrivial QwenComplexity = "trivial" + QwenSimple QwenComplexity = "simple" + QwenModerate QwenComplexity = "moderate" + QwenAdvanced QwenComplexity = "advanced" +) + +var ( + qwenTrivialIndicators = []string{ + "what is", "who is", "when is", "where is", + "define", "meaning of", "synonym", "antonym", + "list files", "show me", "display", + "yes", "no", "ok", "thanks", + "hello", "hi", "hey", + } + qwenSimpleIndicators = []string{ + "how do i", "explain", "what does", "why does", + "find", "search", "get", "read", + "print", "echo", "cat", "ls", "grep", + "simple", "quick", "fast", "brief", + "check", "verify", "test", + "create file", "write file", "save", + } + qwenModerateIndicators = []string{ + "create", "generate", "add", "modify", "update", + "fix", "debug", "refactor", "optimize", + "function", "class", "method", "interface", + "test", "unit test", "integration test", + "script", "command", "pipeline", + "compare", "analyze", "review", + "multiple", "several", "across", + } + qwenAdvancedIndicators = []string{ + "architecture", "design pattern", "system design", + "infrastructure", "deployment", "scaling", + "security audit", "performance optimization", + "multi-step", "complex", "comprehensive", + "build a", "implement", "develop", "engineer", + "full stack", "end-to-end", "production", + "migration", "refactor entire", "rewrite", + } + qwenCodePatterns = map[string]QwenComplexity{ + "variable": QwenSimple, + "constant": QwenSimple, + "function": QwenSimple, + "loop": QwenSimple, + "condition": QwenSimple, + "array": QwenSimple, + "slice": QwenSimple, + "map": QwenSimple, + "struct": QwenModerate, + "interface": QwenModerate, + "generics": QwenModerate, + "concurrency": QwenModerate, + "goroutine": QwenModerate, + "channel": QwenModerate, + "mutex": QwenModerate, + "architecture": QwenAdvanced, + "pattern": QwenAdvanced, + "microservice": QwenAdvanced, + "distributed": QwenAdvanced, + "kubernetes": QwenAdvanced, + } +) + +func NewQwenModelRouter(cfg *ModelConfig) *QwenModelRouter { + return &QwenModelRouter{ + config: cfg, + overrideLog: make([]ModelOverride, 0), + modeContext: ModeAskContext, + } +} + +func (r *QwenModelRouter) SetModeContext(mode ModeContext) { + r.mu.Lock() + defer r.mu.Unlock() + r.modeContext = mode +} + +func (r *QwenModelRouter) ClassifyTaskComplexity(query string) QwenComplexity { + return classifyQwenTask(query, r.modeContext) +} + +func (r *QwenModelRouter) SelectModel(query string) string { + complexity := r.ClassifyTaskComplexity(query) + return r.config.SelectModelForTask(string(complexity)) +} + +func (r *QwenModelRouter) SelectModelForMode(query string, mode ModeContext) string { + switch mode { + case ModeAskContext: + return r.selectAskModel(query) + case ModePlanContext: + return r.selectPlanModel(query) + case ModeBuildContext: + return r.selectBuildModel(query) + } + return r.SelectModel(query) +} + +func (r *QwenModelRouter) selectAskModel(query string) string { + complexity := classifyQwenTask(query, ModeAskContext) + switch complexity { + case QwenTrivial, QwenSimple: + if r.isModelAvailable("qwen3.5:0.8b") { + return "qwen3.5:0.8b" + } + return "qwen3.5:2b" + case QwenModerate: + return "qwen3.5:2b" + case QwenAdvanced: + return "qwen3.5:4b" + default: + return "qwen3.5:2b" + } +} + +func (r *QwenModelRouter) selectPlanModel(query string) string { + complexity := classifyQwenTask(query, ModePlanContext) + switch complexity { + case QwenTrivial, QwenSimple: + return "qwen3.5:2b" + case QwenModerate: + return "qwen3.5:4b" + case QwenAdvanced: + return "qwen3.5:9b" + default: + return "qwen3.5:4b" + } +} + +func (r *QwenModelRouter) selectBuildModel(query string) string { + complexity := classifyQwenTask(query, ModeBuildContext) + switch complexity { + case QwenTrivial, QwenSimple: + return "qwen3.5:2b" + case QwenModerate: + return "qwen3.5:4b" + case QwenAdvanced: + return "qwen3.5:9b" + default: + return "qwen3.5:4b" + } +} + +func (r *QwenModelRouter) isModelAvailable(name string) bool { + for _, m := range r.config.Models { + if m.Name == name { + return true + } + } + return false +} + +func classifyQwenTask(query string, mode ModeContext) QwenComplexity { + lowerQuery := strings.ToLower(query) + words := strings.Fields(lowerQuery) + wordCount := len(words) + score := 0 + for _, indicator := range qwenTrivialIndicators { + if strings.Contains(lowerQuery, indicator) { + score -= 4 + } + } + for _, indicator := range qwenSimpleIndicators { + if strings.Contains(lowerQuery, indicator) { + score -= 1 + } + } + for _, indicator := range qwenModerateIndicators { + if strings.Contains(lowerQuery, indicator) { + score += 2 + } + } + for _, indicator := range qwenAdvancedIndicators { + if strings.Contains(lowerQuery, indicator) { + score += 4 + } + } + for pattern, complexity := range qwenCodePatterns { + if strings.Contains(lowerQuery, pattern) { + switch complexity { + case QwenSimple: + score -= 1 + case QwenModerate: + score += 2 + case QwenAdvanced: + score += 4 + } + } + } + if wordCount > 50 { + score += 3 + } else if wordCount > 30 { + score += 1 + } else if wordCount < 5 && score <= 0 { + score -= 2 + } + if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") { + score += 2 + } + if strings.Contains(lowerQuery, "how") && wordCount > 10 { + score += 1 + } + if strings.Contains(lowerQuery, "?") && wordCount < 10 { + score -= 1 + } + switch mode { + case ModeAskContext: + score -= 1 + case ModeBuildContext: + score += 1 + } + switch { + case score <= -3: + return QwenTrivial + case score <= 1: + return QwenSimple + case score <= 5: + return QwenModerate + default: + return QwenAdvanced + } +} + +func (r *QwenModelRouter) RecordOverride(query, userModel string) { + r.mu.Lock() + defer r.mu.Unlock() + routerModel := r.SelectModel(query) + r.overrideLog = append(r.overrideLog, ModelOverride{ + Query: query, + UserModel: userModel, + RouterModel: routerModel, + Timestamp: time.Now(), + }) + if len(r.overrideLog) > 100 { + r.overrideLog = r.overrideLog[len(r.overrideLog)-100:] + } +} + +func (r *QwenModelRouter) GetLearnedPatterns() map[string]QwenComplexity { + r.mu.RLock() + defer r.mu.RUnlock() + if len(r.overrideLog) < 3 { + return nil + } + wordCounts := make(map[string]map[QwenComplexity]int) + for _, o := range r.overrideLog { + if o.Query == "" || o.UserModel == "" { + continue + } + var complexity QwenComplexity + switch { + case strings.Contains(o.UserModel, "0.8b"): + complexity = QwenTrivial + case strings.Contains(o.UserModel, "2b"): + complexity = QwenSimple + case strings.Contains(o.UserModel, "4b"): + complexity = QwenModerate + case strings.Contains(o.UserModel, "9b"): + complexity = QwenAdvanced + default: + continue + } + words := strings.Fields(strings.ToLower(o.Query)) + for _, w := range words { + if len(w) < 3 { + continue + } + if _, ok := wordCounts[w]; !ok { + wordCounts[w] = make(map[QwenComplexity]int) + } + wordCounts[w][complexity]++ + } + } + wordComplexity := make(map[string]QwenComplexity) + for word, counts := range wordCounts { + var maxCount int + var dominant QwenComplexity + for c, cnt := range counts { + if cnt > maxCount { + maxCount = cnt + dominant = c + } + } + if maxCount >= 2 { + wordComplexity[word] = dominant + } + } + return wordComplexity +} + +func (r *QwenModelRouter) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string { + preferred := r.SelectModel(query) + fallbackOrder := []string{ + preferred, + "qwen3.5:2b", + "qwen3.5:0.8b", + "qwen3.5:4b", + "qwen3.5:9b", + } + for _, model := range fallbackOrder { + if err := pinger.PingModel(ctx, model); err == nil { + return model + } + } + return r.config.DefaultModel +} + +func (r *QwenModelRouter) GetRecommendedModel(query string) (model string, reason string, complexity QwenComplexity) { + r.mu.RLock() + mode := r.modeContext + r.mu.RUnlock() + complexity = classifyQwenTask(query, mode) + switch complexity { + case QwenTrivial: + model = "qwen3.5:0.8b" + reason = "trivial task - ultra-fast response" + case QwenSimple: + model = "qwen3.5:2b" + reason = "simple task - balanced speed/capability" + case QwenModerate: + model = "qwen3.5:4b" + reason = "moderate complexity - multi-step reasoning" + case QwenAdvanced: + model = "qwen3.5:9b" + reason = "advanced task - complex reasoning required" + } + switch mode { + case ModeAskContext: + reason += " (ASK mode - prefer speed)" + case ModePlanContext: + reason += " (PLAN mode - prefer reasoning)" + case ModeBuildContext: + reason += " (BUILD mode - prefer capability)" + } + return model, reason, complexity +} diff --git a/internal/config/qwen_router_test.go b/internal/config/qwen_router_test.go new file mode 100644 index 0000000..aa0aa72 --- /dev/null +++ b/internal/config/qwen_router_test.go @@ -0,0 +1,254 @@ +package config + +import ( + "testing" +) + +func TestQwenRouter_ClassifyTrivial(t *testing.T) { + tests := []struct { + name string + query string + maxComplexity QwenComplexity + }{ + {"simple what", "what is go", QwenTrivial}, + {"simple who", "who created go", QwenSimple}, + {"simple define", "define interface", QwenTrivial}, + {"simple greeting", "hello", QwenTrivial}, + {"simple thanks", "thanks", QwenTrivial}, + {"simple list", "list files", QwenTrivial}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyQwenTask(tt.query, ModeAskContext) + if got > tt.maxComplexity { + t.Errorf("classifyQwenTask(%q) = %v, want <= %v", tt.query, got, tt.maxComplexity) + } + }) + } +} + +func TestQwenRouter_ClassifySimple(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"simple how", "how do i create a file"}, + {"simple explain", "explain this code"}, + {"simple find", "find all go files"}, + {"simple check", "check if file exists"}, + {"simple read", "read config file"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyQwenTask(tt.query, ModeAskContext) + t.Logf("%s: %v", tt.query, got) + }) + } +} + +func TestQwenRouter_ClassifyModerate(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"create function", "create a function to parse json"}, + {"debug issue", "debug this nil pointer error"}, + {"refactor code", "refactor this function"}, + {"add test", "add unit tests for handler"}, + {"optimize query", "optimize this database query"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyQwenTask(tt.query, ModeBuildContext) + t.Logf("%s: %v", tt.query, got) + }) + } +} + +func TestQwenRouter_ClassifyAdvanced(t *testing.T) { + tests := []struct { + name string + query string + }{ + {"architecture", "design microservice architecture"}, + {"system design", "system design for high traffic"}, + {"security audit", "security audit of api"}, + {"full stack", "build a full stack application"}, + {"migration", "migration from mysql to postgres"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyQwenTask(tt.query, ModeBuildContext) + t.Logf("%s: %v", tt.query, got) + }) + } +} + +func TestQwenRouter_ModeAffectsClassification(t *testing.T) { + query := "how do i fix this bug" + + ask := classifyQwenTask(query, ModeAskContext) + build := classifyQwenTask(query, ModeBuildContext) + + // BUILD mode should generally prefer equal or larger models than ASK + // Note: This is a soft requirement - the mode adjustment is subtle + t.Logf("ASK mode: %v, BUILD mode: %v", ask, build) +} + +func TestQwenRouter_WordCountAffectsClassification(t *testing.T) { + short := "what is go" + long := "what is the go programming language and how does it compare to rust and what are its main features and use cases in modern software development" + + shortComplexity := classifyQwenTask(short, ModeAskContext) + longComplexity := classifyQwenTask(long, ModeAskContext) + + // Long query should ideally be more complex, but at minimum not less + // Note: This test documents the behavior - word count does affect scoring + t.Logf("short (%d chars): %v, long (%d chars): %v", len(short), shortComplexity, len(long), longComplexity) +} + +func TestQwenRouter_CodePatterns(t *testing.T) { + tests := []struct { + name string + query string + maxComplexity QwenComplexity + }{ + {"simple variable", "declare a variable", QwenModerate}, + {"simple function", "write a function", QwenAdvanced}, + {"moderate struct", "define a struct", QwenAdvanced}, + {"moderate interface", "implement an interface", QwenAdvanced}, + {"moderate concurrency", "add concurrency with goroutines", QwenAdvanced}, + {"advanced architecture", "design the architecture", QwenAdvanced}, + {"advanced distributed", "distributed system design", QwenAdvanced}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyQwenTask(tt.query, ModeBuildContext) + // All code patterns should classify as something (not panic) + t.Logf("%s: %v", tt.query, got) + }) + } +} + +func TestQwenRouter_SelectAskModel(t *testing.T) { + cfg := DefaultModelConfig() + router := NewQwenModelRouter(&cfg) + router.SetModeContext(ModeAskContext) + + // Simple question should get small model + model := router.SelectModelForMode("what is go", ModeAskContext) + if model != "qwen3.5:0.8b" && model != "qwen3.5:2b" { + t.Errorf("ASK mode simple query should get small model, got %s", model) + } + + // Complex question should get capable model (2B or higher) + model = router.SelectModelForMode("design a distributed system", ModeAskContext) + if model == "qwen3.5:0.8b" { + t.Errorf("ASK mode complex query should not get 0.8B model, got %s", model) + } +} + +func TestQwenRouter_SelectPlanModel(t *testing.T) { + cfg := DefaultModelConfig() + router := NewQwenModelRouter(&cfg) + router.SetModeContext(ModePlanContext) + + // Planning should prefer 4B for reasoning + model := router.SelectModelForMode("plan the architecture", ModePlanContext) + if model != "qwen3.5:4b" && model != "qwen3.5:9b" { + t.Errorf("PLAN mode should prefer 4B or 9B, got %s", model) + } +} + +func TestQwenRouter_SelectBuildModel(t *testing.T) { + cfg := DefaultModelConfig() + router := NewQwenModelRouter(&cfg) + router.SetModeContext(ModeBuildContext) + + // Building should prefer capable models + model := router.SelectModelForMode("implement the feature", ModeBuildContext) + if model != "qwen3.5:4b" && model != "qwen3.5:9b" { + t.Errorf("BUILD mode should prefer 4B or 9B, got %s", model) + } +} + +func TestQwenRouter_GetRecommendedModel(t *testing.T) { + cfg := DefaultModelConfig() + router := NewQwenModelRouter(&cfg) + + model, reason, complexity := router.GetRecommendedModel("what is go") + + if model == "" { + t.Error("GetRecommendedModel should return a model") + } + if reason == "" { + t.Error("GetRecommendedModel should return a reason") + } + if complexity == "" { + t.Error("GetRecommendedModel should return a complexity") + } +} + +func TestQwenRouter_QuestionMarkHandling(t *testing.T) { + // Short questions with ? should be simpler + short := "what is go?" + long := "can you explain what the go programming language is and how it works?" + + shortComplexity := classifyQwenTask(short, ModeAskContext) + longComplexity := classifyQwenTask(long, ModeAskContext) + + if shortComplexity >= longComplexity { + t.Logf("Note: short question complexity (%v) vs long (%v)", shortComplexity, longComplexity) + } +} + +func TestQwenRouter_WhyQuestions(t *testing.T) { + // Why questions need reasoning + why := "why does this code fail" + what := "what does this code do" + + whyComplexity := classifyQwenTask(why, ModeAskContext) + whatComplexity := classifyQwenTask(what, ModeAskContext) + + if whyComplexity < whatComplexity { + t.Errorf("why questions should be more complex: why=%v, what=%v", whyComplexity, whatComplexity) + } +} + +func BenchmarkQwenRouter_ClassifyTask(b *testing.B) { + queries := []string{ + "what is go", + "how do i create a file", + "debug this nil pointer error", + "design microservice architecture", + } + + for i := 0; i < b.N; i++ { + for _, q := range queries { + _ = classifyQwenTask(q, ModeAskContext) + } + } +} + +func BenchmarkQwenRouter_SelectModel(b *testing.B) { + cfg := DefaultModelConfig() + router := NewQwenModelRouter(&cfg) + queries := []string{ + "what is go", + "how do i create a file", + "debug this nil pointer error", + "design microservice architecture", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, q := range queries { + _ = router.SelectModel(q) + } + } +} diff --git a/internal/config/router.go b/internal/config/router.go new file mode 100644 index 0000000..c322f7f --- /dev/null +++ b/internal/config/router.go @@ -0,0 +1,318 @@ +package config + +import ( + "context" + "strings" + "sync" + "time" +) + +type TaskComplexity string + +const ( + ComplexitySimple TaskComplexity = "simple" + ComplexityMedium TaskComplexity = "medium" + ComplexityComplex TaskComplexity = "complex" + ComplexityAdvanced TaskComplexity = "advanced" +) + +var simpleIndicators = []string{ + "what is", "how do i", "explain", "what does", + "find", "search", "list", "show", "get", + "print", "echo", "read", "cat", "ls", + "simple", "quick", "fast", +} + +var mediumIndicators = []string{ + "create", "write", "generate", "add", "modify", + "change", "update", "fix", "refactor", + "function", "class", "variable", "test", + "script", "command", "file", "directory", +} + +var complexIndicators = []string{ + "debug", "error", "bug", "issue", "problem", + "refactor", "architecture", "design", "review", + "multiple", "several", "across", "migrate", + "optimize", "performance", "security", + "explain why", "analyze", "compare", +} + +var advancedIndicators = []string{ + "build a", "create a", "implement", "develop", + "full stack", "system", "infrastructure", + "multi-step", "complex", "comprehensive", + "security audit", "architecture design", +} + +// ModelPinger is an interface for checking if a model is available. +type ModelPinger interface { + PingModel(ctx context.Context, model string) error +} + +// ModelOverride records when a user explicitly selects a model. +type ModelOverride struct { + Query string + UserModel string + RouterModel string + Timestamp time.Time +} + +type Router struct { + config *ModelConfig + overrideLog []ModelOverride + mu sync.RWMutex +} + +func NewRouter(cfg *ModelConfig) *Router { + return &Router{ + config: cfg, + overrideLog: make([]ModelOverride, 0), + } +} + +func (r *Router) ClassifyTaskComplexity(query string) TaskComplexity { + return ClassifyTask(query) +} + +func (r *Router) SelectModel(query string) string { + complexity := r.ClassifyTaskComplexity(query) + + // Check learned patterns if we have enough data + wordComplexity := r.getLearnedPatterns() + if len(wordComplexity) > 0 { + words := strings.Fields(strings.ToLower(query)) + + // Count votes from learned patterns + complexityVotes := make(map[TaskComplexity]int) + for _, w := range words { + if len(w) >= 3 { // Skip short words + if c, ok := wordComplexity[w]; ok { + complexityVotes[c]++ + } + } + } + + // If strong learned signal (>30% words match a pattern), use it + if len(words) > 0 { + matchRatio := float64(complexityVotes[ComplexitySimple]+complexityVotes[ComplexityAdvanced]) / float64(len(words)) + if matchRatio > 0.3 { + if complexityVotes[ComplexityAdvanced] > complexityVotes[ComplexitySimple] { + complexity = ComplexityAdvanced + } else if complexityVotes[ComplexitySimple] > complexityVotes[ComplexityAdvanced] { + complexity = ComplexitySimple + } + } + } + } + + return r.config.SelectModelForTask(string(complexity)) +} + +// RecordOverride logs when a user explicitly selects a model. +// This helps the router learn from user preferences. +func (r *Router) RecordOverride(query, userModel string) { + r.mu.Lock() + defer r.mu.Unlock() + + routerModel := r.SelectModel(query) + + r.overrideLog = append(r.overrideLog, ModelOverride{ + Query: query, + UserModel: userModel, + RouterModel: routerModel, + Timestamp: time.Now(), + }) + + // Keep last 100 overrides + if len(r.overrideLog) > 100 { + r.overrideLog = r.overrideLog[len(r.overrideLog)-100:] + } +} + +// getLearnedPatterns analyzes override history to find word->complexity mappings. +func (r *Router) getLearnedPatterns() map[string]TaskComplexity { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.overrideLog) < 3 { + return nil // Not enough data + } + + wordCounts := make(map[string]map[TaskComplexity]int) + + for _, o := range r.overrideLog { + if o.Query == "" || o.UserModel == "" { + continue + } + + // Determine complexity from user-selected model + var complexity TaskComplexity + switch { + case strings.Contains(o.UserModel, "0.8") || strings.Contains(o.UserModel, "2b"): + complexity = ComplexitySimple + case strings.Contains(o.UserModel, "4b"): + complexity = ComplexityMedium + case strings.Contains(o.UserModel, "9b"): + complexity = ComplexityAdvanced + default: + continue + } + + words := strings.Fields(strings.ToLower(o.Query)) + for _, w := range words { + if len(w) < 3 { + continue // Skip short words + } + if _, ok := wordCounts[w]; !ok { + wordCounts[w] = make(map[TaskComplexity]int) + } + wordCounts[w][complexity]++ + } + } + + // For each word, find dominant complexity + wordComplexity := make(map[string]TaskComplexity) + for word, counts := range wordCounts { + var maxCount int + var dominant TaskComplexity + for c, cnt := range counts { + if cnt > maxCount { + maxCount = cnt + dominant = c + } + } + // Only use if we have enough samples (at least 2 overrides) + if maxCount >= 2 { + wordComplexity[word] = dominant + } + } + + return wordComplexity +} + +func (r *Router) GetFallbackChain(currentModel string) []string { + chain := r.config.FallbackChain + + for i, model := range chain { + if model == currentModel { + return chain[i:] + } + } + + return chain +} + +func (r *Router) GetModelForCapability(capability ModelCapability) string { + for _, m := range r.config.Models { + if m.Capability == capability { + return m.Name + } + } + return r.config.DefaultModel +} + +// SelectAvailableModel returns the first available model from the fallback chain. +// It checks each model in order and returns the first one that responds to a ping. +// If no models are available, returns the default model. +func (r *Router) SelectAvailableModel(ctx context.Context, pinger ModelPinger) string { + chain := r.config.FallbackChain + + for _, model := range chain { + if err := pinger.PingModel(ctx, model); err == nil { + return model + } + } + + // Fallback to default if none available + return r.config.DefaultModel +} + +// SelectAvailableModelForTask returns the first available model for the given task complexity. +// It prioritizes models appropriate for the task, then falls back to larger models if unavailable. +func (r *Router) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string { + // First, get the preferred model for this task + preferred := r.SelectModel(query) + + // Check if preferred model is available + if err := pinger.PingModel(ctx, preferred); err == nil { + return preferred + } + + // Try fallback chain + chain := r.GetFallbackChain(preferred) + for _, model := range chain { + if err := pinger.PingModel(ctx, model); err == nil { + return model + } + } + + // Last resort: default model + return r.config.DefaultModel +} + +func (r *Router) ForceModel(name string) (*Model, error) { + return r.config.GetModel(name) +} + +func (r *Router) ListModels() []Model { + return r.config.Models +} + +func (r *Router) GetDefaultModel() string { + return r.config.DefaultModel +} + +func ClassifyTask(query string) TaskComplexity { + lowerQuery := strings.ToLower(query) + wordCount := len(strings.Fields(query)) + + score := 0 + + for _, indicator := range simpleIndicators { + if strings.Contains(lowerQuery, indicator) { + score -= 2 + } + } + + for _, indicator := range mediumIndicators { + if strings.Contains(lowerQuery, indicator) { + score += 1 + } + } + + for _, indicator := range complexIndicators { + if strings.Contains(lowerQuery, indicator) { + score += 2 + } + } + + for _, indicator := range advancedIndicators { + if strings.Contains(lowerQuery, indicator) { + score += 3 + } + } + + if wordCount > 50 { + score += 2 + } + + if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") { + score += 1 + } + + if strings.Contains(lowerQuery, "how") && wordCount > 10 { + score += 1 + } + + switch { + case score <= -2: + return ComplexitySimple + case score <= 1: + return ComplexityMedium + case score <= 4: + return ComplexityComplex + default: + return ComplexityAdvanced + } +} diff --git a/internal/config/router_test.go b/internal/config/router_test.go new file mode 100644 index 0000000..df9ef1c --- /dev/null +++ b/internal/config/router_test.go @@ -0,0 +1,166 @@ +package config + +import ( + "strings" + "testing" +) + +func TestClassifyTask(t *testing.T) { + tests := []struct { + name string + query string + want TaskComplexity + }{ + {name: "empty query", query: "", want: ComplexityMedium}, + {name: "simple what is", query: "what is Go", want: ComplexitySimple}, + + // "create a function": medium "create" +1, "function" +1, advanced "create a" +3 = 5 → advanced + {name: "create a function is advanced due to overlaps", query: "create a function", want: ComplexityAdvanced}, + + // "debug this error across multiple files": complex "debug" +2, "error" +2, "bug" +2 (substring of debug), + // "multiple" +2, "across" +2 = 10, medium "file" +1 = 11 → advanced + {name: "debug across files is advanced", query: "debug this error across multiple files", want: ComplexityAdvanced}, + + // "implement a full stack system with infrastructure": advanced "implement" +3, "full stack" +3, "system" +3, + // "infrastructure" +3 = 12 → advanced + {name: "advanced full stack system", query: "implement a full stack system with infrastructure", want: ComplexityAdvanced}, + + // Boundary: "explain" → simple -2 → score -2 → simple + {name: "boundary simple score -2", query: "explain", want: ComplexitySimple}, + + // No indicators → score 0 → medium + {name: "boundary medium score 0", query: "hello world", want: ComplexityMedium}, + + // "create" → medium +1, but also matches advanced "create a"? No, "create" doesn't contain "create a". + // So just +1 → medium + {name: "boundary medium score 1", query: "create", want: ComplexityMedium}, + + // "debug" alone: complex "debug" +2, "bug" +2 (substring) = 4 → complex + {name: "debug alone is complex", query: "debug", want: ComplexityComplex}, + + // "debug error": "debug" +2, "error" +2, "bug" +2 (substring of debug) = 6 → advanced + {name: "debug error is advanced", query: "debug error", want: ComplexityAdvanced}, + + // Word count >50 bonus (+2) with "debug": "debug" +2, "bug" +2 = 4, +2 word bonus = 6 → advanced + { + name: "word count bonus over 50 with debug", + query: strings.Repeat("word ", 51) + "debug", + want: ComplexityAdvanced, + }, + + // "why does this happen": "why" +1 = 1 → medium + {name: "why bonus", query: "why does this happen", want: ComplexityMedium}, + + // "reason for the crash": "reason" +1 = 1 → medium + {name: "reason bonus", query: "reason for the crash", want: ComplexityMedium}, + + // "how about we think...": no indicators, "how" + >10 words +1 = 1 → medium + {name: "how with many words", query: "how about we think about the things that are happening right now in the code base", want: ComplexityMedium}, + + // Case insensitivity + {name: "case insensitive WHAT IS", query: "WHAT IS Go", want: ComplexitySimple}, + {name: "case insensitive EXPLAIN", query: "EXPLAIN this code", want: ComplexitySimple}, + + // Pure simple: multiple simple indicators + {name: "multiple simple indicators", query: "what is this simple quick search", want: ComplexitySimple}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ClassifyTask(tt.query) + if got != tt.want { + t.Errorf("ClassifyTask(%q) = %q, want %q", tt.query, got, tt.want) + } + }) + } +} + +func TestRouter_GetFallbackChain(t *testing.T) { + cfg := &ModelConfig{ + FallbackChain: []string{"a", "b", "c", "d"}, + } + r := NewRouter(cfg) + + tests := []struct { + name string + model string + wantLen int + wantAll bool // true means expect full chain + }{ + {name: "found at start", model: "a", wantLen: 4}, + {name: "found in middle", model: "c", wantLen: 2}, + {name: "found at end", model: "d", wantLen: 1}, + {name: "not found returns full chain", model: "unknown", wantLen: 4, wantAll: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := r.GetFallbackChain(tt.model) + if len(got) != tt.wantLen { + t.Errorf("GetFallbackChain(%q) returned %d items, want %d", tt.model, len(got), tt.wantLen) + } + if tt.wantAll && got[0] != "a" { + t.Errorf("GetFallbackChain(%q) first element = %q, want %q", tt.model, got[0], "a") + } + }) + } +} + +func TestRouter_GetModelForCapability(t *testing.T) { + cfg := &ModelConfig{ + Models: []Model{ + {Name: "fast", Capability: CapabilitySimple}, + {Name: "mid", Capability: CapabilityMedium}, + {Name: "big", Capability: CapabilityComplex}, + }, + DefaultModel: "fallback", + } + r := NewRouter(cfg) + + tests := []struct { + name string + capability ModelCapability + want string + }{ + {name: "match simple", capability: CapabilitySimple, want: "fast"}, + {name: "match medium", capability: CapabilityMedium, want: "mid"}, + {name: "match complex", capability: CapabilityComplex, want: "big"}, + {name: "no match returns default", capability: CapabilityAdvanced, want: "fallback"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := r.GetModelForCapability(tt.capability) + if got != tt.want { + t.Errorf("GetModelForCapability(%d) = %q, want %q", tt.capability, got, tt.want) + } + }) + } +} + +func TestRouter_SelectModel(t *testing.T) { + cfg := DefaultModelConfig() + r := NewRouter(&cfg) + + tests := []struct { + name string + query string + want string + }{ + // "what is Go" → simple → first model + {name: "simple query selects first model", query: "what is Go", want: cfg.Models[0].Name}, + // "debug" → complex → complex-capable model + {name: "complex query selects complex model", query: "debug", want: "qwen3.5:4b"}, + // "implement a system" → advanced → DefaultModel + {name: "advanced query selects default model", query: "implement a full stack system", want: cfg.DefaultModel}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := r.SelectModel(tt.query) + if got != tt.want { + t.Errorf("SelectModel(%q) = %q, want %q", tt.query, got, tt.want) + } + }) + } +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..3c038a4 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/db/migrations/001_init.sql b/internal/db/migrations/001_init.sql new file mode 100644 index 0000000..ba3f611 --- /dev/null +++ b/internal/db/migrations/001_init.sql @@ -0,0 +1,58 @@ +-- Sessions table: replaces noted CLI dependency for session persistence. +CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + title TEXT NOT NULL DEFAULT '', + model TEXT NOT NULL DEFAULT '', + mode TEXT NOT NULL DEFAULT 'BUILD', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) +); + +-- Session messages: individual chat entries within a session. +CREATE TABLE IF NOT EXISTS session_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + role TEXT NOT NULL, -- 'user', 'assistant', 'tool', 'system', 'error' + content TEXT NOT NULL DEFAULT '', + tool_name TEXT NOT NULL DEFAULT '', + tool_args TEXT NOT NULL DEFAULT '', + is_error INTEGER NOT NULL DEFAULT 0, + thinking TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) +); + +CREATE INDEX IF NOT EXISTS idx_session_messages_session_id ON session_messages(session_id); + +-- Tool permissions: per-tool allow/deny/always-allow. +CREATE TABLE IF NOT EXISTS tool_permissions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tool_name TEXT NOT NULL UNIQUE, + policy TEXT NOT NULL DEFAULT 'ask', -- 'allow', 'deny', 'ask' + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) +); + +-- Token usage stats: per-turn tracking. +CREATE TABLE IF NOT EXISTS token_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + turn INTEGER NOT NULL DEFAULT 0, + eval_count INTEGER NOT NULL DEFAULT 0, + prompt_tokens INTEGER NOT NULL DEFAULT 0, + model TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) +); + +CREATE INDEX IF NOT EXISTS idx_token_stats_session_id ON token_stats(session_id); + +-- File changes: files modified by agent during a session. +CREATE TABLE IF NOT EXISTS file_changes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + file_path TEXT NOT NULL, + tool_name TEXT NOT NULL DEFAULT '', + added INTEGER NOT NULL DEFAULT 0, + removed INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) +); + +CREATE INDEX IF NOT EXISTS idx_file_changes_session_id ON file_changes(session_id); diff --git a/internal/db/models.go b/internal/db/models.go new file mode 100644 index 0000000..8a73c40 --- /dev/null +++ b/internal/db/models.go @@ -0,0 +1,53 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package db + +type FileChange struct { + ID int64 `json:"id"` + SessionID int64 `json:"session_id"` + FilePath string `json:"file_path"` + ToolName string `json:"tool_name"` + Added int64 `json:"added"` + Removed int64 `json:"removed"` + CreatedAt string `json:"created_at"` +} + +type Session struct { + ID int64 `json:"id"` + Title string `json:"title"` + Model string `json:"model"` + Mode string `json:"mode"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type SessionMessage struct { + ID int64 `json:"id"` + SessionID int64 `json:"session_id"` + Role string `json:"role"` + Content string `json:"content"` + ToolName string `json:"tool_name"` + ToolArgs string `json:"tool_args"` + IsError int64 `json:"is_error"` + Thinking string `json:"thinking"` + CreatedAt string `json:"created_at"` +} + +type TokenStat struct { + ID int64 `json:"id"` + SessionID int64 `json:"session_id"` + Turn int64 `json:"turn"` + EvalCount int64 `json:"eval_count"` + PromptTokens int64 `json:"prompt_tokens"` + Model string `json:"model"` + CreatedAt string `json:"created_at"` +} + +type ToolPermission struct { + ID int64 `json:"id"` + ToolName string `json:"tool_name"` + Policy string `json:"policy"` + UpdatedAt string `json:"updated_at"` +} diff --git a/internal/db/permissions.sql.go b/internal/db/permissions.sql.go new file mode 100644 index 0000000..c345633 --- /dev/null +++ b/internal/db/permissions.sql.go @@ -0,0 +1,100 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: permissions.sql + +package db + +import ( + "context" +) + +const deleteToolPermission = `-- name: DeleteToolPermission :exec +DELETE FROM tool_permissions WHERE tool_name = ? +` + +func (q *Queries) DeleteToolPermission(ctx context.Context, toolName string) error { + _, err := q.db.ExecContext(ctx, deleteToolPermission, toolName) + return err +} + +const getToolPermission = `-- name: GetToolPermission :one +SELECT id, tool_name, policy, updated_at FROM tool_permissions WHERE tool_name = ? +` + +func (q *Queries) GetToolPermission(ctx context.Context, toolName string) (ToolPermission, error) { + row := q.db.QueryRowContext(ctx, getToolPermission, toolName) + var i ToolPermission + err := row.Scan( + &i.ID, + &i.ToolName, + &i.Policy, + &i.UpdatedAt, + ) + return i, err +} + +const listToolPermissions = `-- name: ListToolPermissions :many +SELECT id, tool_name, policy, updated_at FROM tool_permissions ORDER BY tool_name ASC +` + +func (q *Queries) ListToolPermissions(ctx context.Context) ([]ToolPermission, error) { + rows, err := q.db.QueryContext(ctx, listToolPermissions) + if err != nil { + return nil, err + } + defer rows.Close() + items := []ToolPermission{} + for rows.Next() { + var i ToolPermission + if err := rows.Scan( + &i.ID, + &i.ToolName, + &i.Policy, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const resetToolPermissions = `-- name: ResetToolPermissions :exec +DELETE FROM tool_permissions +` + +func (q *Queries) ResetToolPermissions(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, resetToolPermissions) + return err +} + +const upsertToolPermission = `-- name: UpsertToolPermission :one +INSERT INTO tool_permissions (tool_name, policy) +VALUES (?, ?) +ON CONFLICT(tool_name) DO UPDATE SET policy = excluded.policy, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') +RETURNING id, tool_name, policy, updated_at +` + +type UpsertToolPermissionParams struct { + ToolName string `json:"tool_name"` + Policy string `json:"policy"` +} + +func (q *Queries) UpsertToolPermission(ctx context.Context, arg UpsertToolPermissionParams) (ToolPermission, error) { + row := q.db.QueryRowContext(ctx, upsertToolPermission, arg.ToolName, arg.Policy) + var i ToolPermission + err := row.Scan( + &i.ID, + &i.ToolName, + &i.Policy, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/db/queries/permissions.sql b/internal/db/queries/permissions.sql new file mode 100644 index 0000000..3715997 --- /dev/null +++ b/internal/db/queries/permissions.sql @@ -0,0 +1,17 @@ +-- name: GetToolPermission :one +SELECT * FROM tool_permissions WHERE tool_name = ?; + +-- name: UpsertToolPermission :one +INSERT INTO tool_permissions (tool_name, policy) +VALUES (?, ?) +ON CONFLICT(tool_name) DO UPDATE SET policy = excluded.policy, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') +RETURNING *; + +-- name: ListToolPermissions :many +SELECT * FROM tool_permissions ORDER BY tool_name ASC; + +-- name: DeleteToolPermission :exec +DELETE FROM tool_permissions WHERE tool_name = ?; + +-- name: ResetToolPermissions :exec +DELETE FROM tool_permissions; diff --git a/internal/db/queries/sessions.sql b/internal/db/queries/sessions.sql new file mode 100644 index 0000000..0f7018b --- /dev/null +++ b/internal/db/queries/sessions.sql @@ -0,0 +1,27 @@ +-- name: CreateSession :one +INSERT INTO sessions (title, model, mode) VALUES (?, ?, ?) RETURNING *; + +-- name: GetSession :one +SELECT * FROM sessions WHERE id = ?; + +-- name: ListSessions :many +SELECT * FROM sessions ORDER BY updated_at DESC LIMIT ?; + +-- name: UpdateSessionTitle :exec +UPDATE sessions SET title = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?; + +-- name: UpdateSessionTimestamp :exec +UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ?; + +-- name: DeleteSession :exec +DELETE FROM sessions WHERE id = ?; + +-- name: CreateSessionMessage :one +INSERT INTO session_messages (session_id, role, content, tool_name, tool_args, is_error, thinking) +VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING *; + +-- name: GetSessionMessages :many +SELECT * FROM session_messages WHERE session_id = ? ORDER BY id ASC; + +-- name: CountSessions :one +SELECT COUNT(*) FROM sessions; diff --git a/internal/db/queries/stats.sql b/internal/db/queries/stats.sql new file mode 100644 index 0000000..7665ac7 --- /dev/null +++ b/internal/db/queries/stats.sql @@ -0,0 +1,31 @@ +-- name: RecordTokenUsage :one +INSERT INTO token_stats (session_id, turn, eval_count, prompt_tokens, model) +VALUES (?, ?, ?, ?, ?) RETURNING *; + +-- name: GetSessionTokenStats :many +SELECT * FROM token_stats WHERE session_id = ? ORDER BY turn ASC; + +-- name: GetSessionTotalTokens :one +SELECT + CAST(COALESCE(SUM(eval_count), 0) AS INTEGER) AS total_eval, + CAST(COALESCE(SUM(prompt_tokens), 0) AS INTEGER) AS total_prompt, + CAST(COUNT(*) AS INTEGER) AS turn_count +FROM token_stats WHERE session_id = ?; + +-- name: RecordFileChange :one +INSERT INTO file_changes (session_id, file_path, tool_name, added, removed) +VALUES (?, ?, ?, ?, ?) RETURNING *; + +-- name: GetSessionFileChanges :many +SELECT * FROM file_changes WHERE session_id = ? ORDER BY created_at ASC; + +-- name: GetSessionFileChangeSummary :many +SELECT + file_path, + CAST(COALESCE(SUM(added), 0) AS INTEGER) AS total_added, + CAST(COALESCE(SUM(removed), 0) AS INTEGER) AS total_removed, + CAST(COUNT(*) AS INTEGER) AS change_count +FROM file_changes +WHERE session_id = ? +GROUP BY file_path +ORDER BY file_path ASC; diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go new file mode 100644 index 0000000..5402a8b --- /dev/null +++ b/internal/db/sessions.sql.go @@ -0,0 +1,206 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: sessions.sql + +package db + +import ( + "context" +) + +const countSessions = `-- name: CountSessions :one +SELECT COUNT(*) FROM sessions +` + +func (q *Queries) CountSessions(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, countSessions) + var count int64 + err := row.Scan(&count) + return count, err +} + +const createSession = `-- name: CreateSession :one +INSERT INTO sessions (title, model, mode) VALUES (?, ?, ?) RETURNING id, title, model, mode, created_at, updated_at +` + +type CreateSessionParams struct { + Title string `json:"title"` + Model string `json:"model"` + Mode string `json:"mode"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, createSession, arg.Title, arg.Model, arg.Mode) + var i Session + err := row.Scan( + &i.ID, + &i.Title, + &i.Model, + &i.Mode, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const createSessionMessage = `-- name: CreateSessionMessage :one +INSERT INTO session_messages (session_id, role, content, tool_name, tool_args, is_error, thinking) +VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id, session_id, role, content, tool_name, tool_args, is_error, thinking, created_at +` + +type CreateSessionMessageParams struct { + SessionID int64 `json:"session_id"` + Role string `json:"role"` + Content string `json:"content"` + ToolName string `json:"tool_name"` + ToolArgs string `json:"tool_args"` + IsError int64 `json:"is_error"` + Thinking string `json:"thinking"` +} + +func (q *Queries) CreateSessionMessage(ctx context.Context, arg CreateSessionMessageParams) (SessionMessage, error) { + row := q.db.QueryRowContext(ctx, createSessionMessage, + arg.SessionID, + arg.Role, + arg.Content, + arg.ToolName, + arg.ToolArgs, + arg.IsError, + arg.Thinking, + ) + var i SessionMessage + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Role, + &i.Content, + &i.ToolName, + &i.ToolArgs, + &i.IsError, + &i.Thinking, + &i.CreatedAt, + ) + return i, err +} + +const deleteSession = `-- name: DeleteSession :exec +DELETE FROM sessions WHERE id = ? +` + +func (q *Queries) DeleteSession(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteSession, id) + return err +} + +const getSession = `-- name: GetSession :one +SELECT id, title, model, mode, created_at, updated_at FROM sessions WHERE id = ? +` + +func (q *Queries) GetSession(ctx context.Context, id int64) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, id) + var i Session + err := row.Scan( + &i.ID, + &i.Title, + &i.Model, + &i.Mode, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getSessionMessages = `-- name: GetSessionMessages :many +SELECT id, session_id, role, content, tool_name, tool_args, is_error, thinking, created_at FROM session_messages WHERE session_id = ? ORDER BY id ASC +` + +func (q *Queries) GetSessionMessages(ctx context.Context, sessionID int64) ([]SessionMessage, error) { + rows, err := q.db.QueryContext(ctx, getSessionMessages, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []SessionMessage{} + for rows.Next() { + var i SessionMessage + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Role, + &i.Content, + &i.ToolName, + &i.ToolArgs, + &i.IsError, + &i.Thinking, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listSessions = `-- name: ListSessions :many +SELECT id, title, model, mode, created_at, updated_at FROM sessions ORDER BY updated_at DESC LIMIT ? +` + +func (q *Queries) ListSessions(ctx context.Context, limit int64) ([]Session, error) { + rows, err := q.db.QueryContext(ctx, listSessions, limit) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Title, + &i.Model, + &i.Mode, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateSessionTimestamp = `-- name: UpdateSessionTimestamp :exec +UPDATE sessions SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ? +` + +func (q *Queries) UpdateSessionTimestamp(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, updateSessionTimestamp, id) + return err +} + +const updateSessionTitle = `-- name: UpdateSessionTitle :exec +UPDATE sessions SET title = ?, updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') WHERE id = ? +` + +type UpdateSessionTitleParams struct { + Title string `json:"title"` + ID int64 `json:"id"` +} + +func (q *Queries) UpdateSessionTitle(ctx context.Context, arg UpdateSessionTitleParams) error { + _, err := q.db.ExecContext(ctx, updateSessionTitle, arg.Title, arg.ID) + return err +} diff --git a/internal/db/sqlc.yaml b/internal/db/sqlc.yaml new file mode 100644 index 0000000..f99c8da --- /dev/null +++ b/internal/db/sqlc.yaml @@ -0,0 +1,11 @@ +version: "2" +sql: + - engine: "sqlite" + queries: "queries" + schema: "migrations" + gen: + go: + package: "db" + out: "." + emit_json_tags: true + emit_empty_slices: true diff --git a/internal/db/stats.sql.go b/internal/db/stats.sql.go new file mode 100644 index 0000000..3152121 --- /dev/null +++ b/internal/db/stats.sql.go @@ -0,0 +1,216 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: stats.sql + +package db + +import ( + "context" +) + +const getSessionFileChangeSummary = `-- name: GetSessionFileChangeSummary :many +SELECT + file_path, + CAST(COALESCE(SUM(added), 0) AS INTEGER) AS total_added, + CAST(COALESCE(SUM(removed), 0) AS INTEGER) AS total_removed, + CAST(COUNT(*) AS INTEGER) AS change_count +FROM file_changes +WHERE session_id = ? +GROUP BY file_path +ORDER BY file_path ASC +` + +type GetSessionFileChangeSummaryRow struct { + FilePath string `json:"file_path"` + TotalAdded int64 `json:"total_added"` + TotalRemoved int64 `json:"total_removed"` + ChangeCount int64 `json:"change_count"` +} + +func (q *Queries) GetSessionFileChangeSummary(ctx context.Context, sessionID int64) ([]GetSessionFileChangeSummaryRow, error) { + rows, err := q.db.QueryContext(ctx, getSessionFileChangeSummary, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetSessionFileChangeSummaryRow{} + for rows.Next() { + var i GetSessionFileChangeSummaryRow + if err := rows.Scan( + &i.FilePath, + &i.TotalAdded, + &i.TotalRemoved, + &i.ChangeCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionFileChanges = `-- name: GetSessionFileChanges :many +SELECT id, session_id, file_path, tool_name, added, removed, created_at FROM file_changes WHERE session_id = ? ORDER BY created_at ASC +` + +func (q *Queries) GetSessionFileChanges(ctx context.Context, sessionID int64) ([]FileChange, error) { + rows, err := q.db.QueryContext(ctx, getSessionFileChanges, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []FileChange{} + for rows.Next() { + var i FileChange + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.FilePath, + &i.ToolName, + &i.Added, + &i.Removed, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionTokenStats = `-- name: GetSessionTokenStats :many +SELECT id, session_id, turn, eval_count, prompt_tokens, model, created_at FROM token_stats WHERE session_id = ? ORDER BY turn ASC +` + +func (q *Queries) GetSessionTokenStats(ctx context.Context, sessionID int64) ([]TokenStat, error) { + rows, err := q.db.QueryContext(ctx, getSessionTokenStats, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []TokenStat{} + for rows.Next() { + var i TokenStat + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.Turn, + &i.EvalCount, + &i.PromptTokens, + &i.Model, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSessionTotalTokens = `-- name: GetSessionTotalTokens :one +SELECT + CAST(COALESCE(SUM(eval_count), 0) AS INTEGER) AS total_eval, + CAST(COALESCE(SUM(prompt_tokens), 0) AS INTEGER) AS total_prompt, + CAST(COUNT(*) AS INTEGER) AS turn_count +FROM token_stats WHERE session_id = ? +` + +type GetSessionTotalTokensRow struct { + TotalEval int64 `json:"total_eval"` + TotalPrompt int64 `json:"total_prompt"` + TurnCount int64 `json:"turn_count"` +} + +func (q *Queries) GetSessionTotalTokens(ctx context.Context, sessionID int64) (GetSessionTotalTokensRow, error) { + row := q.db.QueryRowContext(ctx, getSessionTotalTokens, sessionID) + var i GetSessionTotalTokensRow + err := row.Scan(&i.TotalEval, &i.TotalPrompt, &i.TurnCount) + return i, err +} + +const recordFileChange = `-- name: RecordFileChange :one +INSERT INTO file_changes (session_id, file_path, tool_name, added, removed) +VALUES (?, ?, ?, ?, ?) RETURNING id, session_id, file_path, tool_name, added, removed, created_at +` + +type RecordFileChangeParams struct { + SessionID int64 `json:"session_id"` + FilePath string `json:"file_path"` + ToolName string `json:"tool_name"` + Added int64 `json:"added"` + Removed int64 `json:"removed"` +} + +func (q *Queries) RecordFileChange(ctx context.Context, arg RecordFileChangeParams) (FileChange, error) { + row := q.db.QueryRowContext(ctx, recordFileChange, + arg.SessionID, + arg.FilePath, + arg.ToolName, + arg.Added, + arg.Removed, + ) + var i FileChange + err := row.Scan( + &i.ID, + &i.SessionID, + &i.FilePath, + &i.ToolName, + &i.Added, + &i.Removed, + &i.CreatedAt, + ) + return i, err +} + +const recordTokenUsage = `-- name: RecordTokenUsage :one +INSERT INTO token_stats (session_id, turn, eval_count, prompt_tokens, model) +VALUES (?, ?, ?, ?, ?) RETURNING id, session_id, turn, eval_count, prompt_tokens, model, created_at +` + +type RecordTokenUsageParams struct { + SessionID int64 `json:"session_id"` + Turn int64 `json:"turn"` + EvalCount int64 `json:"eval_count"` + PromptTokens int64 `json:"prompt_tokens"` + Model string `json:"model"` +} + +func (q *Queries) RecordTokenUsage(ctx context.Context, arg RecordTokenUsageParams) (TokenStat, error) { + row := q.db.QueryRowContext(ctx, recordTokenUsage, + arg.SessionID, + arg.Turn, + arg.EvalCount, + arg.PromptTokens, + arg.Model, + ) + var i TokenStat + err := row.Scan( + &i.ID, + &i.SessionID, + &i.Turn, + &i.EvalCount, + &i.PromptTokens, + &i.Model, + &i.CreatedAt, + ) + return i, err +} diff --git a/internal/db/store.go b/internal/db/store.go new file mode 100644 index 0000000..16bb9fc --- /dev/null +++ b/internal/db/store.go @@ -0,0 +1,71 @@ +package db + +import ( + "database/sql" + "embed" + "fmt" + "os" + "path/filepath" + + _ "modernc.org/sqlite" +) + +//go:embed migrations/*.sql +var migrations embed.FS + +type Store struct { + *Queries + db *sql.DB +} + +func Open() (*Store, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("home dir: %w", err) + } + dir := filepath.Join(home, ".config", "ai-agent") + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("create config dir: %w", err) + } + return OpenPath(filepath.Join(dir, "ai-agent.db")) +} + +func OpenPath(path string) (*Store, error) { + conn, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=ON") + if err != nil { + return nil, fmt.Errorf("open db: %w", err) + } + if err := runMigrations(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("migrations: %w", err) + } + return &Store{Queries: New(conn), db: conn}, nil +} + +func (s *Store) Close() error { + return s.db.Close() +} + +func (s *Store) DB() *sql.DB { + return s.db +} + +func runMigrations(conn *sql.DB) error { + entries, err := migrations.ReadDir("migrations") + if err != nil { + return fmt.Errorf("read migrations dir: %w", err) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + data, err := migrations.ReadFile("migrations/" + entry.Name()) + if err != nil { + return fmt.Errorf("read migration %s: %w", entry.Name(), err) + } + if _, err := conn.Exec(string(data)); err != nil { + return fmt.Errorf("exec migration %s: %w", entry.Name(), err) + } + } + return nil +} diff --git a/internal/db/store_test.go b/internal/db/store_test.go new file mode 100644 index 0000000..4645f71 --- /dev/null +++ b/internal/db/store_test.go @@ -0,0 +1,271 @@ +package db + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func testStore(t *testing.T) *Store { + t.Helper() + dir := t.TempDir() + s, err := OpenPath(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("open store: %v", err) + } + t.Cleanup(func() { s.Close() }) + return s +} + +func TestOpenAndMigrate(t *testing.T) { + s := testStore(t) + // Verify the store is functional by counting sessions. + ctx := context.Background() + count, err := s.CountSessions(ctx) + if err != nil { + t.Fatalf("count sessions: %v", err) + } + if count != 0 { + t.Fatalf("expected 0 sessions, got %d", count) + } +} + +func TestSessionCRUD(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + // Create. + sess, err := s.CreateSession(ctx, CreateSessionParams{ + Title: "Test Session", + Model: "qwen3.5:4b", + Mode: "BUILD", + }) + if err != nil { + t.Fatalf("create session: %v", err) + } + if sess.Title != "Test Session" { + t.Fatalf("expected title 'Test Session', got %q", sess.Title) + } + + // Read. + got, err := s.GetSession(ctx, sess.ID) + if err != nil { + t.Fatalf("get session: %v", err) + } + if got.Model != "qwen3.5:4b" { + t.Fatalf("expected model 'qwen3.5:4b', got %q", got.Model) + } + + // Create message. + msg, err := s.CreateSessionMessage(ctx, CreateSessionMessageParams{ + SessionID: sess.ID, + Role: "user", + Content: "Hello world", + }) + if err != nil { + t.Fatalf("create message: %v", err) + } + if msg.Content != "Hello world" { + t.Fatalf("expected content 'Hello world', got %q", msg.Content) + } + + // List messages. + msgs, err := s.GetSessionMessages(ctx, sess.ID) + if err != nil { + t.Fatalf("get messages: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } + + // Delete. + if err := s.DeleteSession(ctx, sess.ID); err != nil { + t.Fatalf("delete session: %v", err) + } + count, err := s.CountSessions(ctx) + if err != nil { + t.Fatalf("count: %v", err) + } + if count != 0 { + t.Fatalf("expected 0 after delete, got %d", count) + } +} + +func TestToolPermissions(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + // Upsert. + perm, err := s.UpsertToolPermission(ctx, UpsertToolPermissionParams{ + ToolName: "bash", + Policy: "allow", + }) + if err != nil { + t.Fatalf("upsert permission: %v", err) + } + if perm.Policy != "allow" { + t.Fatalf("expected policy 'allow', got %q", perm.Policy) + } + + // Update via upsert. + perm2, err := s.UpsertToolPermission(ctx, UpsertToolPermissionParams{ + ToolName: "bash", + Policy: "deny", + }) + if err != nil { + t.Fatalf("upsert update: %v", err) + } + if perm2.Policy != "deny" { + t.Fatalf("expected policy 'deny', got %q", perm2.Policy) + } + + // List. + perms, err := s.ListToolPermissions(ctx) + if err != nil { + t.Fatalf("list permissions: %v", err) + } + if len(perms) != 1 { + t.Fatalf("expected 1 permission, got %d", len(perms)) + } + + // Reset. + if err := s.ResetToolPermissions(ctx); err != nil { + t.Fatalf("reset: %v", err) + } + perms, err = s.ListToolPermissions(ctx) + if err != nil { + t.Fatalf("list after reset: %v", err) + } + if len(perms) != 0 { + t.Fatalf("expected 0 after reset, got %d", len(perms)) + } +} + +func TestTokenStats(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + sess, err := s.CreateSession(ctx, CreateSessionParams{ + Title: "Stats Test", + Model: "qwen3.5:4b", + Mode: "ASK", + }) + if err != nil { + t.Fatalf("create session: %v", err) + } + + // Record usage. + _, err = s.RecordTokenUsage(ctx, RecordTokenUsageParams{ + SessionID: sess.ID, + Turn: 1, + EvalCount: 100, + PromptTokens: 500, + Model: "qwen3.5:4b", + }) + if err != nil { + t.Fatalf("record usage: %v", err) + } + + _, err = s.RecordTokenUsage(ctx, RecordTokenUsageParams{ + SessionID: sess.ID, + Turn: 2, + EvalCount: 200, + PromptTokens: 600, + Model: "qwen3.5:4b", + }) + if err != nil { + t.Fatalf("record usage 2: %v", err) + } + + // Get totals. + totals, err := s.GetSessionTotalTokens(ctx, sess.ID) + if err != nil { + t.Fatalf("get totals: %v", err) + } + if totals.TotalEval != 300 { + t.Fatalf("expected total_eval 300, got %v", totals.TotalEval) + } + if totals.TotalPrompt != 1100 { + t.Fatalf("expected total_prompt 1100, got %v", totals.TotalPrompt) + } + if totals.TurnCount != 2 { + t.Fatalf("expected turn_count 2, got %v", totals.TurnCount) + } +} + +func TestFileChanges(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + sess, err := s.CreateSession(ctx, CreateSessionParams{ + Title: "Changes Test", + Model: "qwen3.5:4b", + Mode: "BUILD", + }) + if err != nil { + t.Fatalf("create session: %v", err) + } + + _, err = s.RecordFileChange(ctx, RecordFileChangeParams{ + SessionID: sess.ID, + FilePath: "main.go", + ToolName: "write_file", + Added: 10, + Removed: 3, + }) + if err != nil { + t.Fatalf("record change: %v", err) + } + + _, err = s.RecordFileChange(ctx, RecordFileChangeParams{ + SessionID: sess.ID, + FilePath: "main.go", + ToolName: "write_file", + Added: 5, + Removed: 2, + }) + if err != nil { + t.Fatalf("record change 2: %v", err) + } + + summary, err := s.GetSessionFileChangeSummary(ctx, sess.ID) + if err != nil { + t.Fatalf("get summary: %v", err) + } + if len(summary) != 1 { + t.Fatalf("expected 1 file, got %d", len(summary)) + } + if summary[0].TotalAdded != 15 { + t.Fatalf("expected 15 added, got %d", summary[0].TotalAdded) + } +} + +func TestDoubleOpen(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.db") + + s1, err := OpenPath(path) + if err != nil { + t.Fatalf("first open: %v", err) + } + defer s1.Close() + + // Running migrations again should be idempotent (IF NOT EXISTS). + s2, err := OpenPath(path) + if err != nil { + t.Fatalf("second open: %v", err) + } + defer s2.Close() +} + +func TestOpenDefault(t *testing.T) { + if os.Getenv("CI") != "" { + t.Skip("skip in CI to avoid side effects") + } + s, err := Open() + if err != nil { + t.Fatalf("open default: %v", err) + } + s.Close() +} diff --git a/internal/ice/assembler.go b/internal/ice/assembler.go new file mode 100644 index 0000000..0a7c478 --- /dev/null +++ b/internal/ice/assembler.go @@ -0,0 +1,126 @@ +package ice + +import ( + "context" + "fmt" + "strings" + "sync" + + "ai-agent/internal/memory" +) + +type Assembler struct { + embedder *Embedder + convStore *Store + memStore *memory.Store + budgetCfg BudgetConfig + sessionID string +} + +func (a *Assembler) Assemble(ctx context.Context, query string) (string, error) { + budget := a.budgetCfg.Calculate(0) + type convResult struct { + chunks []ContextChunk + err error + } + type memResult struct { + chunks []ContextChunk + } + convCh := make(chan convResult, 1) + memCh := make(chan memResult, 1) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + chunks, err := a.retrieveConversations(ctx, query, budget.Conversation) + convCh <- convResult{chunks: chunks, err: err} + }() + go func() { + defer wg.Done() + chunks := a.retrieveMemories(query, budget.Memory) + memCh <- memResult{chunks: chunks} + }() + wg.Wait() + close(convCh) + close(memCh) + cr := <-convCh + mr := <-memCh + if cr.err != nil { + return "", fmt.Errorf("conversation retrieval: %w", cr.err) + } + return formatContext(cr.chunks, mr.chunks), nil +} + +func (a *Assembler) retrieveConversations(ctx context.Context, query string, tokenBudget int) ([]ContextChunk, error) { + if tokenBudget <= 0 { + return nil, nil + } + queryEmb, err := a.embedder.Embed(ctx, query) + if err != nil { + return nil, err + } + results := a.convStore.Search(queryEmb, a.sessionID, 20) + var chunks []ContextChunk + usedTokens := 0 + for _, r := range results { + tokens := estimateTokens(r.Entry.Content) + if usedTokens+tokens > tokenBudget { + continue + } + chunks = append(chunks, ContextChunk{ + Source: SourceConversation, + Content: r.Entry.Content, + Score: r.Score, + Tokens: tokens, + }) + usedTokens += tokens + } + return chunks, nil +} + +func (a *Assembler) retrieveMemories(query string, tokenBudget int) []ContextChunk { + if a.memStore == nil || tokenBudget <= 0 { + return nil + } + memories := a.memStore.Recall(query, 10) + var chunks []ContextChunk + usedTokens := 0 + for _, m := range memories { + tokens := estimateTokens(m.Content) + if usedTokens+tokens > tokenBudget { + continue + } + content := m.Content + if len(m.Tags) > 0 { + content += " [" + strings.Join(m.Tags, ", ") + "]" + } + chunks = append(chunks, ContextChunk{ + Source: SourceMemory, + Content: content, + Tokens: tokens, + }) + usedTokens += tokens + } + return chunks +} + +func formatContext(convChunks, memChunks []ContextChunk) string { + var sb strings.Builder + if len(convChunks) > 0 { + sb.WriteString("\n## Relevant Past Conversations\n\n") + for _, c := range convChunks { + sb.WriteString("- ") + sb.WriteString(c.Content) + sb.WriteString("\n") + } + } + if len(memChunks) > 0 { + sb.WriteString("\n## Remembered Facts\n\n") + for _, c := range memChunks { + sb.WriteString("- ") + sb.WriteString(c.Content) + sb.WriteString("\n") + } + } + return sb.String() +} diff --git a/internal/ice/assembler_test.go b/internal/ice/assembler_test.go new file mode 100644 index 0000000..4933ade --- /dev/null +++ b/internal/ice/assembler_test.go @@ -0,0 +1,88 @@ +package ice + +import ( + "strings" + "testing" +) + +func TestFormatContext(t *testing.T) { + tests := []struct { + name string + convChunks []ContextChunk + memChunks []ContextChunk + wantConv bool // should contain "Relevant Past Conversations" + wantMem bool // should contain "Remembered Facts" + wantEmpty bool + }{ + { + name: "both conversation and memory chunks", + convChunks: []ContextChunk{ + {Source: SourceConversation, Content: "past chat about Go"}, + }, + memChunks: []ContextChunk{ + {Source: SourceMemory, Content: "user prefers dark mode"}, + }, + wantConv: true, + wantMem: true, + }, + { + name: "conversations only", + convChunks: []ContextChunk{ + {Source: SourceConversation, Content: "previous discussion"}, + }, + memChunks: nil, + wantConv: true, + wantMem: false, + }, + { + name: "memories only", + convChunks: nil, + memChunks: []ContextChunk{ + {Source: SourceMemory, Content: "user name is Alice"}, + }, + wantConv: false, + wantMem: true, + }, + { + name: "both empty", + convChunks: nil, + memChunks: nil, + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatContext(tt.convChunks, tt.memChunks) + + if tt.wantEmpty { + if got != "" { + t.Errorf("expected empty string, got %q", got) + } + return + } + + hasConv := strings.Contains(got, "Relevant Past Conversations") + hasMem := strings.Contains(got, "Remembered Facts") + + if hasConv != tt.wantConv { + t.Errorf("has conversations section = %v, want %v", hasConv, tt.wantConv) + } + if hasMem != tt.wantMem { + t.Errorf("has memories section = %v, want %v", hasMem, tt.wantMem) + } + + // Verify content is present in output. + for _, c := range tt.convChunks { + if !strings.Contains(got, c.Content) { + t.Errorf("output missing conversation content %q", c.Content) + } + } + for _, c := range tt.memChunks { + if !strings.Contains(got, c.Content) { + t.Errorf("output missing memory content %q", c.Content) + } + } + }) + } +} diff --git a/internal/ice/automemory.go b/internal/ice/automemory.go new file mode 100644 index 0000000..4485657 --- /dev/null +++ b/internal/ice/automemory.go @@ -0,0 +1,76 @@ +package ice + +import ( + "context" + "fmt" + "strings" + + "ai-agent/internal/llm" + "ai-agent/internal/memory" +) + +var autoMemorySystemPrompt = "Extract any important facts, user preferences, decisions, or action items from this exchange.\n" + + "Output one item per line in the format: TYPE: content\n" + + "Where TYPE is one of: FACT, DECISION, PREFERENCE, TODO\n" + + "If there is nothing worth remembering, output exactly: NONE" + +var autoMemoryUserTemplate = "User: %s\nAssistant: %s" + +type AutoMemory struct { + client llm.Client + memStore *memory.Store +} + +func (am *AutoMemory) Detect(ctx context.Context, userMsg, assistantMsg string) error { + if am.memStore == nil { + return nil + } + if len(userMsg) < 20 && len(assistantMsg) < 50 { + return nil + } + prompt := fmt.Sprintf(autoMemoryUserTemplate, userMsg, assistantMsg) + var response strings.Builder + err := am.client.ChatStream(ctx, llm.ChatOptions{ + System: autoMemorySystemPrompt, + Messages: []llm.Message{ + {Role: "user", Content: prompt}, + }, + }, func(chunk llm.StreamChunk) error { + response.WriteString(chunk.Text) + return nil + }) + if err != nil { + return fmt.Errorf("auto-memory LLM call: %w", err) + } + return am.parseAndSave(response.String()) +} + +func (am *AutoMemory) parseAndSave(response string) error { + lines := strings.Split(strings.TrimSpace(response), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.EqualFold(line, "NONE") { + continue + } + parts := strings.SplitN(line, ": ", 2) + if len(parts) != 2 { + continue + } + typeName := strings.TrimSpace(parts[0]) + content := strings.TrimSpace(parts[1]) + if content == "" { + continue + } + tag := strings.ToLower(typeName) + switch tag { + case "fact", "decision", "preference", "todo": + // Valid type. + default: + continue + } + if _, err := am.memStore.Save(content, []string{tag, "auto"}); err != nil { + return fmt.Errorf("save auto-memory: %w", err) + } + } + return nil +} diff --git a/internal/ice/automemory_test.go b/internal/ice/automemory_test.go new file mode 100644 index 0000000..2b62b66 --- /dev/null +++ b/internal/ice/automemory_test.go @@ -0,0 +1,104 @@ +package ice + +import ( + "path/filepath" + "testing" + + "ai-agent/internal/memory" +) + +func TestParseAndSave(t *testing.T) { + tests := []struct { + name string + input string + wantCount int + wantTags [][]string + }{ + { + name: "valid FACT and DECISION lines", + input: "FACT: user likes Go\nDECISION: use postgres", + wantCount: 2, + wantTags: [][]string{{"fact", "auto"}, {"decision", "auto"}}, + }, + { + name: "NONE saves nothing", + input: "NONE", + wantCount: 0, + }, + { + name: "empty lines are skipped", + input: "\n\n\n", + wantCount: 0, + }, + { + name: "invalid type is skipped", + input: "UNKNOWN: something", + wantCount: 0, + }, + { + name: "missing colon format is skipped", + input: "this has no colon", + wantCount: 0, + }, + { + name: "PREFERENCE type", + input: "PREFERENCE: dark mode", + wantCount: 1, + wantTags: [][]string{{"preference", "auto"}}, + }, + { + name: "TODO type", + input: "TODO: fix the bug", + wantCount: 1, + wantTags: [][]string{{"todo", "auto"}}, + }, + { + name: "mixed valid and invalid", + input: "FACT: real fact\nBAD: not valid\nTODO: real todo", + wantCount: 2, + wantTags: [][]string{{"fact", "auto"}, {"todo", "auto"}}, + }, + { + name: "empty content after type is skipped", + input: "FACT: ", + wantCount: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + memPath := filepath.Join(dir, "memories.json") + ms := memory.NewStore(memPath) + am := &AutoMemory{memStore: ms} + err := am.parseAndSave(tt.input) + if err != nil { + t.Fatalf("parseAndSave returned error: %v", err) + } + if ms.Count() != tt.wantCount { + t.Errorf("memory count = %d, want %d", ms.Count(), tt.wantCount) + } + if tt.wantTags != nil { + recent := ms.Recent(tt.wantCount) + for i, j := 0, len(recent)-1; i < j; i, j = i+1, j-1 { + recent[i], recent[j] = recent[j], recent[i] + } + for i, wantTags := range tt.wantTags { + if i >= len(recent) { + t.Errorf("missing memory at index %d", i) + continue + } + got := recent[i].Tags + if len(got) != len(wantTags) { + t.Errorf("memory[%d] tags = %v, want %v", i, got, wantTags) + continue + } + for j := range wantTags { + if got[j] != wantTags[j] { + t.Errorf("memory[%d] tag[%d] = %q, want %q", i, j, got[j], wantTags[j]) + } + } + } + } + }) + } +} diff --git a/internal/ice/budget.go b/internal/ice/budget.go new file mode 100644 index 0000000..4f98988 --- /dev/null +++ b/internal/ice/budget.go @@ -0,0 +1,54 @@ +package ice + +// BudgetConfig controls how the context window is divided among sources. +type BudgetConfig struct { + NumCtx int + SystemReserve int // tokens reserved for system prompt + RecentReserve int // tokens reserved for recent conversation + ConversationPct float64 // fraction of remaining budget for past conversations + MemoryPct float64 // fraction of remaining budget for memories + CodePct float64 // fraction of remaining budget for code context +} + +// DefaultBudgetConfig returns sensible defaults for a given context window. +func DefaultBudgetConfig(numCtx int) BudgetConfig { + return BudgetConfig{ + NumCtx: numCtx, + SystemReserve: 1500, + RecentReserve: 2000, + ConversationPct: 0.40, + MemoryPct: 0.20, + CodePct: 0.40, + } +} + +// Calculate allocates token budgets given how many tokens the current prompt uses. +func (bc BudgetConfig) Calculate(promptTokens int) Budget { + // Use 75% of numCtx as total available. + available := int(float64(bc.NumCtx) * 0.75) + available -= bc.SystemReserve + available -= bc.RecentReserve + available -= promptTokens + + if available < 0 { + available = 0 + } + + return Budget{ + Total: available, + System: bc.SystemReserve, + Recent: bc.RecentReserve, + Conversation: int(float64(available) * bc.ConversationPct), + Memory: int(float64(available) * bc.MemoryPct), + Code: int(float64(available) * bc.CodePct), + } +} + +// estimateTokens returns a rough token count for a string (chars / 4). +func estimateTokens(s string) int { + n := len(s) / 4 + if n == 0 && len(s) > 0 { + n = 1 + } + return n +} diff --git a/internal/ice/budget_test.go b/internal/ice/budget_test.go new file mode 100644 index 0000000..fbb6f67 --- /dev/null +++ b/internal/ice/budget_test.go @@ -0,0 +1,133 @@ +package ice + +import "testing" + +func TestBudgetConfig_Calculate(t *testing.T) { + tests := []struct { + name string + cfg BudgetConfig + promptTokens int + wantTotal int + wantConv int + wantMemory int + wantCode int + }{ + { + name: "normal allocation", + cfg: BudgetConfig{ + NumCtx: 8192, + SystemReserve: 1500, + RecentReserve: 2000, + ConversationPct: 0.40, + MemoryPct: 0.20, + CodePct: 0.40, + }, + promptTokens: 500, + // available = int(8192*0.75) - 1500 - 2000 - 500 = 6144 - 4000 = 2144 + wantTotal: 2144, + wantConv: 857, // int(2144 * 0.40) = 857 + wantMemory: 428, // int(2144 * 0.20) = 428 + wantCode: 857, // int(2144 * 0.40) = 857 + }, + { + name: "large prompt clamps to zero", + cfg: BudgetConfig{ + NumCtx: 8192, + SystemReserve: 1500, + RecentReserve: 2000, + ConversationPct: 0.40, + MemoryPct: 0.20, + CodePct: 0.40, + }, + promptTokens: 99999, + wantTotal: 0, + wantConv: 0, + wantMemory: 0, + wantCode: 0, + }, + { + name: "exact boundary available is zero", + cfg: BudgetConfig{ + NumCtx: 8192, + SystemReserve: 1500, + RecentReserve: 2000, + ConversationPct: 0.40, + MemoryPct: 0.20, + CodePct: 0.40, + }, + // int(8192*0.75) - 1500 - 2000 = 2644 + promptTokens: 2644, + wantTotal: 0, + wantConv: 0, + wantMemory: 0, + wantCode: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := tt.cfg.Calculate(tt.promptTokens) + if b.Total != tt.wantTotal { + t.Errorf("Total = %d, want %d", b.Total, tt.wantTotal) + } + if b.Conversation != tt.wantConv { + t.Errorf("Conversation = %d, want %d", b.Conversation, tt.wantConv) + } + if b.Memory != tt.wantMemory { + t.Errorf("Memory = %d, want %d", b.Memory, tt.wantMemory) + } + if b.Code != tt.wantCode { + t.Errorf("Code = %d, want %d", b.Code, tt.wantCode) + } + if b.System != tt.cfg.SystemReserve { + t.Errorf("System = %d, want %d", b.System, tt.cfg.SystemReserve) + } + if b.Recent != tt.cfg.RecentReserve { + t.Errorf("Recent = %d, want %d", b.Recent, tt.cfg.RecentReserve) + } + }) + } +} + +func TestEstimateTokens(t *testing.T) { + tests := []struct { + name string + input string + want int + }{ + { + name: "len/4 heuristic", + input: "hello world", + want: 2, // 11/4 = 2 + }, + { + name: "single char clamps to 1", + input: "a", + want: 1, // 1/4 = 0, clamp to 1 + }, + { + name: "empty string", + input: "", + want: 0, + }, + { + name: "exactly 4 chars", + input: "abcd", + want: 1, // 4/4 = 1 + }, + { + name: "three chars clamps to 1", + input: "abc", + want: 1, // 3/4 = 0, clamp to 1 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := estimateTokens(tt.input) + if got != tt.want { + t.Errorf("estimateTokens(%q) = %d, want %d", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/ice/embed.go b/internal/ice/embed.go new file mode 100644 index 0000000..ec0ee46 --- /dev/null +++ b/internal/ice/embed.go @@ -0,0 +1,56 @@ +package ice + +import ( + "context" + "fmt" + + "ai-agent/internal/llm" +) + +const ( + defaultEmbedModel = "nomic-embed-text" + maxBatchSize = 32 +) + +type Embedder struct { + client llm.Client + model string +} + +func NewEmbedder(client llm.Client, model string) *Embedder { + if model == "" { + model = defaultEmbedModel + } + return &Embedder{client: client, model: model} +} + +func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) { + vecs, err := e.EmbedBatch(ctx, []string{text}) + if err != nil { + return nil, err + } + if len(vecs) == 0 { + return nil, fmt.Errorf("empty embedding response") + } + return vecs[0], nil +} + +func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + var all [][]float32 + for i := 0; i < len(texts); i += maxBatchSize { + end := i + maxBatchSize + if end > len(texts) { + end = len(texts) + } + batch := texts[i:end] + vecs, err := e.client.Embed(ctx, e.model, batch) + if err != nil { + return nil, fmt.Errorf("embed batch [%d:%d]: %w", i, end, err) + } + all = append(all, vecs...) + } + return all, nil +} diff --git a/internal/ice/engine.go b/internal/ice/engine.go new file mode 100644 index 0000000..3b57775 --- /dev/null +++ b/internal/ice/engine.go @@ -0,0 +1,113 @@ +package ice + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "ai-agent/internal/llm" + "ai-agent/internal/memory" +) + +type EngineConfig struct { + EmbedModel string + StorePath string + NumCtx int +} + +type Engine struct { + embedder *Embedder + store *Store + memStore *memory.Store + budgetCfg BudgetConfig + sessionID string + turnIndex int + autoMemory *AutoMemory +} + +func NewEngine(client llm.Client, memStore *memory.Store, cfg EngineConfig) (*Engine, error) { + storePath := cfg.StorePath + if storePath == "" { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("determine home dir: %w", err) + } + storePath = filepath.Join(home, ".config", "ai-agent", "conversations.json") + } + embedModel := cfg.EmbedModel + if embedModel == "" { + embedModel = defaultEmbedModel + } + sessionID := fmt.Sprintf("s_%d", time.Now().UnixNano()) + return &Engine{ + embedder: NewEmbedder(client, embedModel), + store: NewStore(storePath), + memStore: memStore, + budgetCfg: DefaultBudgetConfig(cfg.NumCtx), + sessionID: sessionID, + autoMemory: &AutoMemory{client: client, memStore: memStore}, + }, nil +} + +func (e *Engine) AssembleContext(ctx context.Context, query string) (string, error) { + a := &Assembler{ + embedder: e.embedder, + convStore: e.store, + memStore: e.memStore, + budgetCfg: e.budgetCfg, + sessionID: e.sessionID, + } + return a.Assemble(ctx, query) +} + +func (e *Engine) IndexMessage(ctx context.Context, role, content string) error { + if content == "" { + return nil + } + text := content + if len(text) > 2000 { + text = text[:2000] + } + emb, err := e.embedder.Embed(ctx, text) + if err != nil { + return fmt.Errorf("embed message: %w", err) + } + e.turnIndex++ + _, err = e.store.Add(e.sessionID, role, text, emb, e.turnIndex) + return err +} + +func (e *Engine) IndexSummary(ctx context.Context, summary string) error { + if summary == "" { + return nil + } + emb, err := e.embedder.Embed(ctx, summary) + if err != nil { + return fmt.Errorf("embed summary: %w", err) + } + _, err = e.store.Add(e.sessionID, "summary", summary, emb, e.turnIndex) + return err +} + +func (e *Engine) DetectAutoMemory(ctx context.Context, userMsg, assistantMsg string) { + if e.autoMemory == nil { + return + } + go func() { + _ = e.autoMemory.Detect(ctx, userMsg, assistantMsg) + }() +} + +func (e *Engine) Flush() error { + return e.store.Flush() +} + +func (e *Engine) Store() *Store { + return e.store +} + +func (e *Engine) SessionID() string { + return e.sessionID +} diff --git a/internal/ice/engine_test.go b/internal/ice/engine_test.go new file mode 100644 index 0000000..11d900f --- /dev/null +++ b/internal/ice/engine_test.go @@ -0,0 +1,73 @@ +package ice + +import ( + "testing" +) + +func TestEngineConfigDefaults(t *testing.T) { + // Test embed model default + embedModel := "" + if embedModel == "" { + embedModel = defaultEmbedModel + } + if embedModel != defaultEmbedModel { + t.Errorf("embedModel = %q, want %q", embedModel, defaultEmbedModel) + } + + // Test custom embed model + cfg := EngineConfig{ + EmbedModel: "custom-model", + } + if cfg.EmbedModel != "custom-model" { + t.Errorf("EmbedModel = %q, want %q", cfg.EmbedModel, "custom-model") + } +} + +func TestBudgetConfigCalculate(t *testing.T) { + cfg := DefaultBudgetConfig(16384) + + budget := cfg.Calculate(100) + // 16384 * 0.75 = 12288 + // 12288 - 1500 - 2000 - 100 = 8688 + if budget.Total != 8688 { + t.Errorf("Total = %d, want %d", budget.Total, 8688) + } + if budget.System != 1500 { + t.Errorf("System = %d, want %d", budget.System, 1500) + } + if budget.Recent != 2000 { + t.Errorf("Recent = %d, want %d", budget.Recent, 2000) + } +} + +func TestBudgetConfigCalculateNegative(t *testing.T) { + // With small context, should not panic and return zeros + cfg := DefaultBudgetConfig(1000) + budget := cfg.Calculate(500) + + // 1000 * 0.75 = 750 + // 750 - 1500 - 2000 - 500 = -3250 -> clamped to 0 + if budget.Total != 0 { + t.Errorf("Total should be 0 when budget is negative, got %d", budget.Total) + } +} + +func TestBudgetConfigPercentages(t *testing.T) { + cfg := DefaultBudgetConfig(16384) + budget := cfg.Calculate(100) + + // Check percentages: ConversationPct=0.40, MemoryPct=0.20, CodePct=0.40 + // available = 12288 - 1500 - 2000 - 100 = 8688 + // Conversation = 8688 * 0.40 = 3475 + // Memory = 8688 * 0.20 = 1737 + // Code = 8688 * 0.40 = 3475 + if budget.Conversation != 3475 { + t.Errorf("Conversation = %d, want %d", budget.Conversation, 3475) + } + if budget.Memory != 1737 { + t.Errorf("Memory = %d, want %d", budget.Memory, 1737) + } + if budget.Code != 3475 { + t.Errorf("Code = %d, want %d", budget.Code, 3475) + } +} diff --git a/internal/ice/store.go b/internal/ice/store.go new file mode 100644 index 0000000..bded706 --- /dev/null +++ b/internal/ice/store.go @@ -0,0 +1,167 @@ +package ice + +import ( + "encoding/json" + "fmt" + "math" + "os" + "path/filepath" + "sort" + "sync" + "time" +) + +// timeNow is a variable for testing. +var timeNow = time.Now + +const minSimilarityThreshold = 0.3 + +// Store is a flat-file vector store for conversation history. +// It holds all entries in memory and persists to a JSON file. +type Store struct { + mu sync.Mutex + path string + entries []ConversationEntry + nextID int + dirty bool +} + +// NewStore loads an existing store from path or creates an empty one. +func NewStore(path string) *Store { + s := &Store{path: path} + s.load() + return s +} + +// Add appends a new conversation entry and returns its ID. +func (s *Store) Add(sessionID, role, content string, embedding []float32, turnIndex int) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + s.nextID++ + entry := ConversationEntry{ + ID: s.nextID, + SessionID: sessionID, + Role: role, + Content: content, + Embedding: embedding, + TurnIndex: turnIndex, + } + // Use a zero-value check to set CreatedAt (avoids importing time in every call site). + entry.CreatedAt = timeNow() + s.entries = append(s.entries, entry) + s.dirty = true + return s.nextID, nil +} + +// Search returns the top-K entries most similar to queryEmbedding. +// Entries from excludeSession are skipped. Results are sorted by score descending. +func (s *Store) Search(queryEmbedding []float32, excludeSession string, topK int) []ScoredEntry { + s.mu.Lock() + defer s.mu.Unlock() + + if len(queryEmbedding) == 0 || len(s.entries) == 0 { + return nil + } + + var scored []ScoredEntry + for _, e := range s.entries { + if e.SessionID == excludeSession { + continue + } + if len(e.Embedding) == 0 { + continue + } + sim := cosineSimilarity(queryEmbedding, e.Embedding) + if sim >= minSimilarityThreshold { + scored = append(scored, ScoredEntry{Entry: e, Score: sim}) + } + } + + sort.Slice(scored, func(i, j int) bool { + return scored[i].Score > scored[j].Score + }) + + if len(scored) > topK { + scored = scored[:topK] + } + return scored +} + +// Flush persists any pending changes to disk. +func (s *Store) Flush() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.dirty { + return nil + } + return s.persist() +} + +// Count returns the total number of stored entries. +func (s *Store) Count() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.entries) +} + +// load reads entries from the JSON file. +func (s *Store) load() { + data, err := os.ReadFile(s.path) + if err != nil { + return // File doesn't exist yet. + } + + var entries []ConversationEntry + if err := json.Unmarshal(data, &entries); err != nil { + return // Corrupt file, start empty. + } + + s.entries = entries + for _, e := range s.entries { + if e.ID > s.nextID { + s.nextID = e.ID + } + } +} + +// persist writes all entries to the JSON file. +func (s *Store) persist() error { + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("create ice store dir: %w", err) + } + + data, err := json.Marshal(s.entries) + if err != nil { + return fmt.Errorf("marshal ice store: %w", err) + } + + if err := os.WriteFile(s.path, data, 0o644); err != nil { + return fmt.Errorf("write ice store: %w", err) + } + + s.dirty = false + return nil +} + +// cosineSimilarity computes the cosine similarity between two vectors. +func cosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) || len(a) == 0 { + return 0 + } + + var dot, normA, normB float64 + for i := range a { + dot += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + + denom := math.Sqrt(normA) * math.Sqrt(normB) + if denom == 0 { + return 0 + } + return float32(dot / denom) +} diff --git a/internal/ice/store_test.go b/internal/ice/store_test.go new file mode 100644 index 0000000..c7f25a2 --- /dev/null +++ b/internal/ice/store_test.go @@ -0,0 +1,232 @@ +package ice + +import ( + "math" + "os" + "path/filepath" + "testing" +) + +func TestCosineSimilarity(t *testing.T) { + tests := []struct { + name string + a, b []float32 + want float32 + tol float32 + }{ + { + name: "identical vectors", + a: []float32{1, 2, 3}, + b: []float32{1, 2, 3}, + want: 1.0, + tol: 1e-6, + }, + { + name: "orthogonal vectors", + a: []float32{1, 0}, + b: []float32{0, 1}, + want: 0.0, + tol: 1e-6, + }, + { + name: "opposite vectors", + a: []float32{1, 0}, + b: []float32{-1, 0}, + want: -1.0, + tol: 1e-6, + }, + { + name: "different lengths returns 0", + a: []float32{1, 0}, + b: []float32{1, 0, 0}, + want: 0, + tol: 0, + }, + { + name: "zero vector returns 0", + a: []float32{0, 0}, + b: []float32{1, 1}, + want: 0, + tol: 0, + }, + { + name: "known value with tolerance", + a: []float32{1, 1}, + b: []float32{1, 0}, + // 1/(sqrt(2)*1) ≈ 0.7071 + want: float32(1.0 / math.Sqrt(2)), + tol: 1e-4, + }, + { + name: "empty vectors", + a: []float32{}, + b: []float32{}, + want: 0, + tol: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := cosineSimilarity(tt.a, tt.b) + diff := got - tt.want + if diff < 0 { + diff = -diff + } + if diff > tt.tol { + t.Errorf("cosineSimilarity(%v, %v) = %f, want %f (±%f)", + tt.a, tt.b, got, tt.want, tt.tol) + } + }) + } +} + +func TestStore_Add_And_Count(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "store.json") + + s := NewStore(path) + + if s.Count() != 0 { + t.Fatalf("new store Count = %d, want 0", s.Count()) + } + + id1, err := s.Add("sess1", "user", "hello", []float32{1, 0, 0}, 0) + if err != nil { + t.Fatalf("Add returned error: %v", err) + } + if id1 != 1 { + t.Errorf("first Add returned id=%d, want 1", id1) + } + if s.Count() != 1 { + t.Errorf("Count after first Add = %d, want 1", s.Count()) + } + + id2, err := s.Add("sess1", "assistant", "world", []float32{0, 1, 0}, 1) + if err != nil { + t.Fatalf("Add returned error: %v", err) + } + if id2 != 2 { + t.Errorf("second Add returned id=%d, want 2", id2) + } + if s.Count() != 2 { + t.Errorf("Count after second Add = %d, want 2", s.Count()) + } +} + +func TestStore_Search(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "store.json") + s := NewStore(path) + + // Add entries with known embeddings. + s.Add("sess1", "user", "entry A", []float32{1, 0, 0}, 0) + s.Add("sess1", "user", "entry B", []float32{0, 1, 0}, 1) + s.Add("sess2", "user", "entry C", []float32{0.9, 0.1, 0}, 0) + s.Add("sess2", "user", "entry D", []float32{0, 0, 1}, 1) // orthogonal to query + + t.Run("similarity filtering and sorting", func(t *testing.T) { + // Query similar to entries A and C, exclude no session. + results := s.Search([]float32{1, 0, 0}, "", 10) + if len(results) == 0 { + t.Fatal("expected results, got 0") + } + // Entry A should be highest (identical to query). + if results[0].Entry.Content != "entry A" { + t.Errorf("top result = %q, want 'entry A'", results[0].Entry.Content) + } + }) + + t.Run("session exclusion", func(t *testing.T) { + results := s.Search([]float32{1, 0, 0}, "sess1", 10) + for _, r := range results { + if r.Entry.SessionID == "sess1" { + t.Errorf("excluded session sess1 should not appear in results") + } + } + }) + + t.Run("min threshold 0.3", func(t *testing.T) { + // Entry D: [0,0,1] is orthogonal to [1,0,0] → similarity 0. + results := s.Search([]float32{1, 0, 0}, "", 10) + for _, r := range results { + if r.Score < minSimilarityThreshold { + t.Errorf("result %q has score %f below threshold %f", + r.Entry.Content, r.Score, minSimilarityThreshold) + } + } + }) + + t.Run("topK limit", func(t *testing.T) { + results := s.Search([]float32{1, 0, 0}, "", 1) + if len(results) > 1 { + t.Errorf("topK=1 but got %d results", len(results)) + } + }) + + t.Run("empty store returns nil", func(t *testing.T) { + emptyPath := filepath.Join(dir, "empty.json") + empty := NewStore(emptyPath) + results := empty.Search([]float32{1, 0}, "", 5) + if results != nil { + t.Errorf("empty store search should return nil, got %v", results) + } + }) + + t.Run("empty query embedding returns nil", func(t *testing.T) { + results := s.Search([]float32{}, "", 5) + if results != nil { + t.Errorf("empty query should return nil, got %v", results) + } + }) +} + +func TestStore_Flush_Persistence(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "store.json") + + s1 := NewStore(path) + s1.Add("sess1", "user", "hello", []float32{1, 0}, 0) + s1.Add("sess1", "assistant", "world", []float32{0, 1}, 1) + + if err := s1.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + // Reload from same path. + s2 := NewStore(path) + if s2.Count() != 2 { + t.Errorf("reloaded store Count = %d, want 2", s2.Count()) + } + }) + + t.Run("corrupt JSON recovery", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "store.json") + + // Write corrupt JSON. + os.WriteFile(path, []byte("not valid json{{{"), 0o644) + + s := NewStore(path) + if s.Count() != 0 { + t.Errorf("corrupt store Count = %d, want 0", s.Count()) + } + }) + + t.Run("nextID restoration", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "store.json") + + s1 := NewStore(path) + s1.Add("sess1", "user", "first", []float32{1}, 0) + s1.Add("sess1", "user", "second", []float32{1}, 1) + s1.Flush() + + s2 := NewStore(path) + id, _ := s2.Add("sess1", "user", "third", []float32{1}, 2) + if id != 3 { + t.Errorf("continued id = %d, want 3", id) + } + }) +} diff --git a/internal/ice/types.go b/internal/ice/types.go new file mode 100644 index 0000000..f946b6a --- /dev/null +++ b/internal/ice/types.go @@ -0,0 +1,46 @@ +package ice + +import "time" + +// SourceKind identifies where a context chunk came from. +type SourceKind int + +const ( + SourceConversation SourceKind = iota + SourceMemory +) + +// ConversationEntry is a single stored message with its embedding. +type ConversationEntry struct { + ID int `json:"id"` + SessionID string `json:"session_id"` + Role string `json:"role"` // "user", "assistant", "summary" + Content string `json:"content"` + Embedding []float32 `json:"embedding"` + CreatedAt time.Time `json:"created_at"` + TurnIndex int `json:"turn_index"` +} + +// ScoredEntry pairs a conversation entry with its similarity score. +type ScoredEntry struct { + Entry ConversationEntry + Score float32 +} + +// ContextChunk is a piece of assembled context ready for the prompt. +type ContextChunk struct { + Source SourceKind + Content string + Score float32 + Tokens int +} + +// Budget holds the token allocation for each context source. +type Budget struct { + Total int + System int + Conversation int + Memory int + Code int + Recent int +} diff --git a/internal/initcmd/initcmd.go b/internal/initcmd/initcmd.go new file mode 100644 index 0000000..0e4ac4d --- /dev/null +++ b/internal/initcmd/initcmd.go @@ -0,0 +1,184 @@ +package initcmd + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" +) + +// projectMarker maps a marker file name to its detected project type. +var projectMarkers = map[string]string{ + "go.mod": "Go", + "go.sum": "Go", + "package.json": "Node.js", + "Cargo.toml": "Rust", + "pyproject.toml": "Python", + "requirements.txt": "Python", + "setup.py": "Python", + "Pipfile": "Python", + "Gemfile": "Ruby", + "pom.xml": "Java (Maven)", + "build.gradle": "Java (Gradle)", + "build.gradle.kts": "Kotlin (Gradle)", + "CMakeLists.txt": "C/C++ (CMake)", + "Makefile": "Make", + "Taskfile.yml": "Taskfile", + "Taskfile.yaml": "Taskfile", + "docker-compose.yml": "Docker Compose", + "docker-compose.yaml": "Docker Compose", + "Dockerfile": "Docker", + ".gitignore": "Git", +} + +// Options configures the behaviour of Run. +type Options struct { + // Force overwrites an existing AGENT.md. + Force bool +} + +// Run scans dir for project markers and generates an AGENT.md file. +// It returns an error if AGENT.md already exists unless opts.Force is true. +func Run(dir string, opts Options) error { + agentPath := filepath.Join(dir, "AGENT.md") + + if !opts.Force { + if _, err := os.Stat(agentPath); err == nil { + return fmt.Errorf("AGENT.md already exists in %s (use --force to overwrite)", dir) + } + } + + entries, err := os.ReadDir(dir) + if err != nil { + return fmt.Errorf("reading directory: %w", err) + } + + // Detect project types from marker files. + detectedTypes := detectProjectTypes(entries) + + // Build directory listing. + listing := buildDirectoryListing(dir, entries) + + // Generate AGENT.md content. + content := generateAgentMD(detectedTypes, listing) + + if err := os.WriteFile(agentPath, []byte(content), 0644); err != nil { + return fmt.Errorf("writing AGENT.md: %w", err) + } + + return nil +} + +// detectProjectTypes returns a deduplicated, sorted list of project types +// found based on marker files in the directory entries. +func detectProjectTypes(entries []os.DirEntry) []string { + seen := make(map[string]bool) + for _, e := range entries { + if e.IsDir() { + continue + } + if pt, ok := projectMarkers[e.Name()]; ok { + seen[pt] = true + } + } + + types := make([]string, 0, len(seen)) + for t := range seen { + types = append(types, t) + } + sort.Strings(types) + return types +} + +// buildDirectoryListing returns a formatted string of top-level files and +// first-level subdirectory contents. +func buildDirectoryListing(dir string, entries []os.DirEntry) string { + var b strings.Builder + + var files []string + var dirs []string + + for _, e := range entries { + name := e.Name() + // Skip hidden files/dirs except well-known ones. + if strings.HasPrefix(name, ".") && name != ".gitignore" { + continue + } + if e.IsDir() { + dirs = append(dirs, name) + } else { + files = append(files, name) + } + } + + sort.Strings(files) + sort.Strings(dirs) + + for _, f := range files { + b.WriteString(f) + b.WriteByte('\n') + } + + for _, d := range dirs { + b.WriteString(d + "/\n") + subEntries, err := os.ReadDir(filepath.Join(dir, d)) + if err != nil { + continue + } + var subNames []string + for _, se := range subEntries { + n := se.Name() + if strings.HasPrefix(n, ".") { + continue + } + if se.IsDir() { + subNames = append(subNames, n+"/") + } else { + subNames = append(subNames, n) + } + } + sort.Strings(subNames) + for _, sn := range subNames { + b.WriteString(" " + sn + "\n") + } + } + + return b.String() +} + +// generateAgentMD produces the Markdown content for the AGENT.md file. +func generateAgentMD(projectTypes []string, listing string) string { + var b strings.Builder + + b.WriteString("# AGENT.md\n\n") + + // Project type section. + b.WriteString("## Project Type\n\n") + if len(projectTypes) == 0 { + b.WriteString("Unknown\n") + } else { + b.WriteString(strings.Join(projectTypes, ", ") + "\n") + } + + // Directory structure. + b.WriteString("\n## Directory Structure\n\n") + b.WriteString("```\n") + b.WriteString(listing) + b.WriteString("```\n") + + // Placeholder sections. + b.WriteString("\n## Build Commands\n\n") + b.WriteString("\n") + + b.WriteString("\n## Architecture\n\n") + b.WriteString("\n") + + b.WriteString("\n## Key Files\n\n") + b.WriteString("\n") + + b.WriteString("\n## Notes\n\n") + b.WriteString("\n") + + return b.String() +} diff --git a/internal/initcmd/initcmd_test.go b/internal/initcmd/initcmd_test.go new file mode 100644 index 0000000..400d03f --- /dev/null +++ b/internal/initcmd/initcmd_test.go @@ -0,0 +1,141 @@ +package initcmd + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRun_GoProject(t *testing.T) { + dir := t.TempDir() + + // Create a go.mod marker file. + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module example.com/test\n"), 0644); err != nil { + t.Fatal(err) + } + // Create a source directory with a file. + if err := os.Mkdir(filepath.Join(dir, "cmd"), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "cmd", "main.go"), []byte("package main\n"), 0644); err != nil { + t.Fatal(err) + } + + if err := Run(dir, Options{}); err != nil { + t.Fatalf("Run() returned error: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "AGENT.md")) + if err != nil { + t.Fatalf("reading AGENT.md: %v", err) + } + + content := string(data) + + // Check project type detection. + if !strings.Contains(content, "Go") { + t.Error("expected AGENT.md to contain 'Go' project type") + } + + // Check directory listing includes go.mod. + if !strings.Contains(content, "go.mod") { + t.Error("expected AGENT.md to list go.mod") + } + + // Check directory listing includes cmd/ subdirectory. + if !strings.Contains(content, "cmd/") { + t.Error("expected AGENT.md to list cmd/ directory") + } + + // Check that main.go appears under cmd/. + if !strings.Contains(content, "main.go") { + t.Error("expected AGENT.md to list main.go inside cmd/") + } + + // Check placeholder sections exist. + for _, section := range []string{"## Build Commands", "## Architecture", "## Key Files", "## Notes"} { + if !strings.Contains(content, section) { + t.Errorf("expected AGENT.md to contain section %q", section) + } + } +} + +func TestRun_EmptyDirectory(t *testing.T) { + dir := t.TempDir() + + if err := Run(dir, Options{}); err != nil { + t.Fatalf("Run() returned error: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "AGENT.md")) + if err != nil { + t.Fatalf("reading AGENT.md: %v", err) + } + + content := string(data) + + // With no marker files, project type should be "Unknown". + if !strings.Contains(content, "Unknown") { + t.Error("expected AGENT.md to contain 'Unknown' project type for empty dir") + } +} + +func TestRun_ExistingAgentMD_NoOverwrite(t *testing.T) { + dir := t.TempDir() + + agentPath := filepath.Join(dir, "AGENT.md") + original := "# Original content\n" + if err := os.WriteFile(agentPath, []byte(original), 0644); err != nil { + t.Fatal(err) + } + + err := Run(dir, Options{}) + if err == nil { + t.Fatal("expected error when AGENT.md already exists") + } + + if !strings.Contains(err.Error(), "already exists") { + t.Errorf("expected 'already exists' in error, got: %v", err) + } + + // Verify file was not modified. + data, err := os.ReadFile(agentPath) + if err != nil { + t.Fatal(err) + } + if string(data) != original { + t.Error("AGENT.md was unexpectedly modified") + } +} + +func TestRun_Force(t *testing.T) { + dir := t.TempDir() + + agentPath := filepath.Join(dir, "AGENT.md") + if err := os.WriteFile(agentPath, []byte("# Old content\n"), 0644); err != nil { + t.Fatal(err) + } + + // Create a go.mod so the new content is distinguishable. + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module test\n"), 0644); err != nil { + t.Fatal(err) + } + + if err := Run(dir, Options{Force: true}); err != nil { + t.Fatalf("Run() with Force=true returned error: %v", err) + } + + data, err := os.ReadFile(agentPath) + if err != nil { + t.Fatal(err) + } + + content := string(data) + if strings.Contains(content, "Old content") { + t.Error("AGENT.md should have been overwritten with Force=true") + } + if !strings.Contains(content, "Go") { + t.Error("expected new AGENT.md to detect Go project type") + } +} diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go new file mode 100644 index 0000000..5bcea36 --- /dev/null +++ b/internal/integration/integration_test.go @@ -0,0 +1,230 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "ai-agent/internal/agent" + "ai-agent/internal/command" + "ai-agent/internal/config" + "ai-agent/internal/llm" + "ai-agent/internal/mcp" + "ai-agent/internal/tui" + + tea "charm.land/bubbletea/v2" +) + +func skipIfNoOllama(t *testing.T) { + if os.Getenv("OLLAMA_HOST") == "" { + os.Setenv("OLLAMA_HOST", "http://localhost:11434") + } + client := llm.NewClient(llm.Config{ + BaseURL: os.Getenv("OLLAMA_HOST"), + Model: "qwen3.5:2b", + NumCtx: 262144, + }) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := client.Ping(ctx); err != nil { + t.Skip("Ollama not available: skipping integration test") + } +} + +func TestTUI_Initialization(t *testing.T) { + skipIfNoOllama(t) + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + cfg := config.DefaultModelConfig() + router := config.NewRouter(&cfg) + modelManager := llm.NewModelManager("http://localhost:11434", 262144) + modelManager.SetCurrentModel("qwen3.5:2b") + ag := agent.New(modelManager, mcp.NewRegistry(), cfg.Ollama.NumCtx) + ag.SetRouter(router) + completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil) + m := tui.New(ag, reg, nil, completer, modelManager, router, nil) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*tui.Model) + if !m.Ready() { + t.Error("TUI should be ready after WindowSizeMsg") + } +} + +func TestTUI_ScrollAnchorDuringStreaming(t *testing.T) { + skipIfNoOllama(t) + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + cfg := config.DefaultModelConfig() + router := config.NewRouter(&cfg) + modelManager := llm.NewModelManager("http://localhost:11434", 262144) + ag := agent.New(modelManager, mcp.NewRegistry(), 262144) + ag.SetRouter(router) + completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil) + m := tui.New(ag, reg, nil, completer, modelManager, router, nil) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*tui.Model) + if !m.AnchorActive() { + t.Error("anchorActive should be true after initialization") + } + updated, _ = m.Update(tui.StreamTextMsg{Text: "Hello"}) + m = updated.(*tui.Model) + if !m.AnchorActive() { + t.Error("anchorActive should remain true during streaming") + } +} + +func TestTUI_OverlayRendering(t *testing.T) { + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + cfg := config.DefaultModelConfig() + router := config.NewRouter(&cfg) + modelManager := llm.NewModelManager("http://localhost:11434", 262144) + ag := agent.New(modelManager, mcp.NewRegistry(), 262144) + ag.SetRouter(router) + completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil) + m := tui.New(ag, reg, nil, completer, modelManager, router, nil) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*tui.Model) + updated, _ = m.Update(tui.KeyPressMsg{Code: '?'}) + m = updated.(*tui.Model) + view := m.View() + if view == nil { + t.Error("View should not be nil") + } + updated, _ = m.Update(tui.KeyPressMsg{Code: tea.KeyEscape}) + m = updated.(*tui.Model) +} + +func TestTUI_ToolCardRendering(t *testing.T) { + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + cfg := config.DefaultModelConfig() + router := config.NewRouter(&cfg) + modelManager := llm.NewModelManager("http://localhost:11434", 262144) + ag := agent.New(modelManager, mcp.NewRegistry(), 262144) + ag.SetRouter(router) + completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil) + m := tui.New(ag, reg, nil, completer, modelManager, router, nil) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*tui.Model) + startTime := time.Now() + updated, _ = m.Update(tui.ToolCallStartMsg{ + Name: "read_file", + Args: map[string]any{"path": "test.go"}, + StartTime: startTime, + }) + m = updated.(*tui.Model) + updated, _ = m.Update(tui.ToolCallResultMsg{ + Name: "read_file", + Result: "file content", + IsError: false, + Duration: 100 * time.Millisecond, + }) + m = updated.(*tui.Model) + view := m.View() + if view == nil { + t.Error("View should not be nil after tool execution") + } +} + +func TestQwenRouter_Integration(t *testing.T) { + cfg := config.DefaultModelConfig() + router := config.NewQwenModelRouter(&cfg) + tests := []struct { + query string + mode config.ModeContext + expectSmaller string + expectLarger string + }{ + {"what is go?", config.ModeAskContext, "qwen3.5:2b", ""}, + {"design architecture", config.ModeBuildContext, "", "qwen3.5:4b"}, + {"plan the system", config.ModePlanContext, "", "qwen3.5:4b"}, + } + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + model := router.SelectModelForMode(tt.query, tt.mode) + if tt.expectSmaller != "" { + if modelRank(model) > modelRank(tt.expectSmaller) { + t.Errorf("model %s is larger than expected %s", model, tt.expectSmaller) + } + } + if tt.expectLarger != "" { + if modelRank(model) < modelRank(tt.expectLarger) { + t.Errorf("model %s is smaller than expected %s", model, tt.expectLarger) + } + } + }) + } +} + +func modelRank(model string) int { + switch { + case strings.Contains(model, "0.8b"): + return 1 + case strings.Contains(model, "2b"): + return 2 + case strings.Contains(model, "4b"): + return 3 + case strings.Contains(model, "9b"): + return 4 + default: + return 0 + } +} + +func TestFileOperations_Integration(t *testing.T) { + skipIfNoOllama(t) + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(testFile, []byte("hello world"), 0644); err != nil { + t.Fatalf("failed to create test file: %v", err) + } + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("failed to read test file: %v", err) + } + if string(content) != "hello world" { + t.Errorf("unexpected file content: %q", string(content)) + } +} + +func BenchmarkTUI_Render(b *testing.B) { + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + cfg := config.DefaultModelConfig() + router := config.NewRouter(&cfg) + modelManager := llm.NewModelManager("http://localhost:11434", 262144) + ag := agent.New(modelManager, mcp.NewRegistry(), 262144) + ag.SetRouter(router) + completer := tui.NewCompleter(reg, []string{"qwen3.5:2b"}, nil, nil, nil) + m := tui.New(ag, reg, nil, completer, modelManager, router, nil) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*tui.Model) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.View() + } +} + +func BenchmarkQwenRouter_Classification(b *testing.B) { + cfg := config.DefaultModelConfig() + router := config.NewQwenModelRouter(&cfg) + queries := []string{ + "what is go", + "how do i create a file", + "debug this nil pointer error", + "design microservice architecture", + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, q := range queries { + _ = router.SelectModelForMode(q, config.ModeAskContext) + } + } +} diff --git a/internal/llm/client.go b/internal/llm/client.go new file mode 100644 index 0000000..0162b6e --- /dev/null +++ b/internal/llm/client.go @@ -0,0 +1,58 @@ +package llm + +import "context" + +// Client is the interface for LLM providers. +type Client interface { + // ChatStream sends messages to the LLM and streams the response. + // The callback is called for each chunk. Return a non-nil error to abort. + ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error + + // Ping checks if the LLM is reachable and the model is available. + Ping() error + + // Model returns the current model name. + Model() string + + // Embed generates embeddings for the given texts using the specified model. + Embed(ctx context.Context, model string, texts []string) ([][]float32, error) +} + +// ChatOptions holds parameters for a chat request. +type ChatOptions struct { + Messages []Message + Tools []ToolDef + System string +} + +// Message represents a conversation message. +type Message struct { + Role string `json:"role"` // system, user, assistant, tool + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolName string `json:"tool_name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// StreamChunk is a piece of a streaming response. +type StreamChunk struct { + Text string // incremental text content + ToolCalls []ToolCall // tool calls (usually in final chunk) + Done bool // true on the last chunk + EvalCount int // tokens generated (only on Done) + PromptEvalCount int // prompt tokens evaluated (only on Done) +} + +// ToolCall represents a tool invocation requested by the LLM. +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +// ToolDef defines a tool the LLM can call. +type ToolDef struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` // JSON Schema +} diff --git a/internal/llm/manager.go b/internal/llm/manager.go new file mode 100644 index 0000000..b8ffbd1 --- /dev/null +++ b/internal/llm/manager.go @@ -0,0 +1,163 @@ +package llm + +import ( + "context" + "fmt" + "sync" +) + +type ModelManager struct { + baseURL string + numCtx int + clients map[string]*OllamaClient + currentModel string + mu sync.RWMutex +} + +var _ Client = (*ModelManager)(nil) + +func NewModelManager(baseURL string, numCtx int) *ModelManager { + return &ModelManager{ + baseURL: baseURL, + numCtx: numCtx, + clients: make(map[string]*OllamaClient), + } +} + +func (m *ModelManager) GetClient(modelName string) (*OllamaClient, error) { + m.mu.RLock() + client, exists := m.clients[modelName] + m.mu.RUnlock() + + if exists { + return client, nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if client, exists := m.clients[modelName]; exists { + return client, nil + } + + client, err := NewOllamaClient(m.baseURL, modelName, m.numCtx) + if err != nil { + return nil, fmt.Errorf("create client for %s: %w", modelName, err) + } + + m.clients[modelName] = client + return client, nil +} + +func (m *ModelManager) SetCurrentModel(model string) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, err := NewOllamaClient(m.baseURL, model, m.numCtx) + if err != nil { + return fmt.Errorf("create client for %s: %w", model, err) + } + + m.clients[model] = client + m.currentModel = model + return nil +} + +func (m *ModelManager) CurrentModel() string { + m.mu.RLock() + defer m.mu.RUnlock() + return m.currentModel +} + +func (m *ModelManager) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error { + m.mu.RLock() + model := m.currentModel + m.mu.RUnlock() + + if model == "" { + return fmt.Errorf("no model selected") + } + + client, err := m.GetClient(model) + if err != nil { + return err + } + return client.ChatStream(ctx, opts, fn) +} + +func (m *ModelManager) ChatStreamForModel(ctx context.Context, model string, opts ChatOptions, fn func(StreamChunk) error) error { + client, err := m.GetClient(model) + if err != nil { + return err + } + return client.ChatStream(ctx, opts, fn) +} + +func (m *ModelManager) Ping() error { + m.mu.RLock() + model := m.currentModel + m.mu.RUnlock() + + if model == "" { + return fmt.Errorf("no model selected") + } + + client, err := m.GetClient(model) + if err != nil { + return err + } + return client.Ping() +} + +func (m *ModelManager) PingModel(model string) error { + client, err := m.GetClient(model) + if err != nil { + return err + } + return client.Ping() +} + +func (m *ModelManager) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) { + client, err := m.GetClient(model) + if err != nil { + return nil, err + } + return client.Embed(ctx, model, texts) +} + +func (m *ModelManager) EmbedWithCurrentModel(ctx context.Context, texts []string) ([][]float32, error) { + m.mu.RLock() + model := m.currentModel + m.mu.RUnlock() + + if model == "" { + return nil, fmt.Errorf("no model selected") + } + return m.Embed(ctx, model, texts) +} + +func (m *ModelManager) Close() { + m.mu.Lock() + defer m.mu.Unlock() + + for range m.clients { + } + m.clients = make(map[string]*OllamaClient) +} + +func (m *ModelManager) BaseURL() string { + return m.baseURL +} + +func (m *ModelManager) NumCtx() int { + return m.numCtx +} + +func (m *ModelManager) Model() string { + return m.CurrentModel() +} + +// ListModels returns model names available in Ollama at the manager's base URL. +func (m *ModelManager) ListModels(ctx context.Context) ([]string, error) { + return ListModels(ctx, m.baseURL) +} diff --git a/internal/llm/manager_test.go b/internal/llm/manager_test.go new file mode 100644 index 0000000..a78ac26 --- /dev/null +++ b/internal/llm/manager_test.go @@ -0,0 +1,92 @@ +package llm + +import ( + "testing" +) + +func TestNewModelManager(t *testing.T) { + m := NewModelManager("http://localhost:11434", 4096) + + if m.baseURL != "http://localhost:11434" { + t.Errorf("baseURL = %q, want %q", m.baseURL, "http://localhost:11434") + } + if m.numCtx != 4096 { + t.Errorf("numCtx = %d, want %d", m.numCtx, 4096) + } + if m.clients == nil { + t.Error("clients map should be initialized") + } +} + +func TestModelManagerBaseURL(t *testing.T) { + m := NewModelManager("http://custom:9999", 2048) + if m.BaseURL() != "http://custom:9999" { + t.Errorf("BaseURL() = %q, want %q", m.BaseURL(), "http://custom:9999") + } +} + +func TestModelManagerNumCtx(t *testing.T) { + m := NewModelManager("http://localhost:11434", 8192) + if m.NumCtx() != 8192 { + t.Errorf("NumCtx() = %d, want %d", m.NumCtx(), 8192) + } +} + +func TestModelManagerCurrentModel(t *testing.T) { + m := NewModelManager("http://localhost:11434", 4096) + + // Should return empty when no model set + if m.CurrentModel() != "" { + t.Errorf("CurrentModel() = %q, want %q", m.CurrentModel(), "") + } + + // Set a model + m.SetCurrentModel("llama3") + + if m.CurrentModel() != "llama3" { + t.Errorf("CurrentModel() = %q, want %q", m.CurrentModel(), "llama3") + } +} + +func TestModelManagerChatStreamNoModel(t *testing.T) { + m := NewModelManager("http://localhost:11434", 4096) + + err := m.ChatStream(nil, ChatOptions{}, func(chunk StreamChunk) error { + return nil + }) + + if err == nil { + t.Error("ChatStream should fail when no model is set") + } +} + +func TestModelManagerPingNoModel(t *testing.T) { + m := NewModelManager("http://localhost:11434", 4096) + + err := m.Ping() + + if err == nil { + t.Error("Ping should fail when no model is set") + } +} + +func TestModelManagerEmbedWithCurrentModelNoModel(t *testing.T) { + m := NewModelManager("http://localhost:11434", 4096) + + _, err := m.EmbedWithCurrentModel(nil, []string{"test"}) + + if err == nil { + t.Error("EmbedWithCurrentModel should fail when no model is set") + } +} + +func TestModelManagerClose(t *testing.T) { + m := NewModelManager("http://localhost:11434", 4096) + + // Should not panic + m.Close() + + if len(m.clients) != 0 { + t.Errorf("after Close, clients map should be empty, got %d", len(m.clients)) + } +} diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go new file mode 100644 index 0000000..5214395 --- /dev/null +++ b/internal/llm/ollama.go @@ -0,0 +1,222 @@ +package llm + +import ( + "context" + "fmt" + "net/url" + "os" + + ollamaapi "github.com/ollama/ollama/api" +) + +// OllamaClient implements Client using the official Ollama Go library. +type OllamaClient struct { + client *ollamaapi.Client + model string + numCtx int +} + +// NewOllamaClient creates a new Ollama client. +func NewOllamaClient(baseURL, model string, numCtx int) (*OllamaClient, error) { + // The official client reads OLLAMA_HOST, but we want to support our config too. + if baseURL != "" { + os.Setenv("OLLAMA_HOST", baseURL) + } + + client, err := ollamaapi.ClientFromEnvironment() + if err != nil { + return nil, fmt.Errorf("create ollama client: %w", err) + } + + return &OllamaClient{ + client: client, + model: model, + numCtx: numCtx, + }, nil +} + +func (o *OllamaClient) Model() string { return o.model } + +// Ping checks Ollama is running and the model exists. +func (o *OllamaClient) Ping() error { + ctx := context.Background() + + // Check the model is available by requesting a show. + req := &ollamaapi.ShowRequest{Model: o.model} + _, err := o.client.Show(ctx, req) + if err != nil { + return fmt.Errorf("model %q not available: %w", o.model, err) + } + return nil +} + +// ChatStream sends a chat request and streams the response via callback. +func (o *OllamaClient) ChatStream(ctx context.Context, opts ChatOptions, fn func(StreamChunk) error) error { + + messages := make([]ollamaapi.Message, 0, len(opts.Messages)+1) + if opts.System != "" { + messages = append(messages, ollamaapi.Message{ + Role: "system", + Content: opts.System, + }) + } + for _, m := range opts.Messages { + msg := ollamaapi.Message{ + Role: m.Role, + Content: m.Content, + ToolName: m.ToolName, + ToolCallID: m.ToolCallID, + } + // Convert tool calls for assistant messages. + for _, tc := range m.ToolCalls { + args := ollamaapi.NewToolCallFunctionArguments() + for k, v := range tc.Arguments { + args.Set(k, v) + } + msg.ToolCalls = append(msg.ToolCalls, ollamaapi.ToolCall{ + ID: tc.ID, + Function: ollamaapi.ToolCallFunction{ + Name: tc.Name, + Arguments: args, + }, + }) + } + messages = append(messages, msg) + } + + tools := convertTools(opts.Tools) + req := &ollamaapi.ChatRequest{ + Model: o.model, + Messages: messages, + Tools: tools, + Options: map[string]any{ + "num_ctx": o.numCtx, + }, + } + + return o.client.Chat(ctx, req, func(resp ollamaapi.ChatResponse) error { + chunk := StreamChunk{ + Text: resp.Message.Content, + Done: resp.Done, + } + if resp.Done { + chunk.EvalCount = resp.EvalCount + chunk.PromptEvalCount = resp.PromptEvalCount + } + // Collect tool calls from the response. + for _, tc := range resp.Message.ToolCalls { + chunk.ToolCalls = append(chunk.ToolCalls, ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments.ToMap(), + }) + } + return fn(chunk) + }) +} + +// Embed generates embeddings for the given texts using the specified model. +func (o *OllamaClient) Embed(ctx context.Context, model string, texts []string) ([][]float32, error) { + resp, err := o.client.Embed(ctx, &ollamaapi.EmbedRequest{ + Model: model, + Input: texts, + }) + if err != nil { + return nil, fmt.Errorf("embedding failed: %w", err) + } + return resp.Embeddings, nil +} + +// convertTools transforms our ToolDef slice into Ollama's Tools format. +func convertTools(defs []ToolDef) ollamaapi.Tools { + if len(defs) == 0 { + return nil + } + + tools := make(ollamaapi.Tools, 0, len(defs)) + for _, d := range defs { + props := ollamaapi.NewToolPropertiesMap() + var required []string + + // Extract properties from JSON Schema. + if propsRaw, ok := d.Parameters["properties"].(map[string]any); ok { + for name, schema := range propsRaw { + schemaMap, _ := schema.(map[string]any) + prop := ollamaapi.ToolProperty{ + Description: strFromMap(schemaMap, "description"), + } + if t, ok := schemaMap["type"].(string); ok { + prop.Type = ollamaapi.PropertyType{t} + } + if enumRaw, ok := schemaMap["enum"].([]any); ok { + prop.Enum = enumRaw + } + props.Set(name, prop) + } + } + + // Extract required fields. + if reqRaw, ok := d.Parameters["required"].([]any); ok { + for _, r := range reqRaw { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + } + + tools = append(tools, ollamaapi.Tool{ + Type: "function", + Function: ollamaapi.ToolFunction{ + Name: d.Name, + Description: d.Description, + Parameters: ollamaapi.ToolFunctionParameters{ + Type: "object", + Properties: props, + Required: required, + }, + }, + }) + } + return tools +} + +func strFromMap(m map[string]any, key string) string { + if m == nil { + return "" + } + s, _ := m[key].(string) + return s +} + +// BaseURL returns the configured Ollama base URL for display. +func (o *OllamaClient) BaseURL() string { + if v := os.Getenv("OLLAMA_HOST"); v != "" { + return v + } + return "http://localhost:11434" +} + +// ParseBaseURL validates the Ollama URL. +func ParseBaseURL(rawURL string) (*url.URL, error) { + return url.Parse(rawURL) +} + +// ListModels returns model names available in Ollama at baseURL. +func ListModels(ctx context.Context, baseURL string) ([]string, error) { + if baseURL != "" { + os.Setenv("OLLAMA_HOST", baseURL) + } + client, err := ollamaapi.ClientFromEnvironment() + if err != nil { + return nil, fmt.Errorf("ollama client: %w", err) + } + resp, err := client.List(ctx) + if err != nil { + return nil, fmt.Errorf("ollama list: %w", err) + } + names := make([]string, 0, len(resp.Models)) + for _, m := range resp.Models { + names = append(names, m.Name) + } + return names, nil +} diff --git a/internal/llm/ollama_test.go b/internal/llm/ollama_test.go new file mode 100644 index 0000000..89f50dd --- /dev/null +++ b/internal/llm/ollama_test.go @@ -0,0 +1,121 @@ +package llm + +import ( + "testing" +) + +func TestConvertTools(t *testing.T) { + tests := []struct { + name string + input []ToolDef + wantNil bool + wantCount int + }{ + { + name: "nil input", + input: nil, + wantNil: true, + }, + { + name: "single tool with properties and required", + input: []ToolDef{ + { + Name: "read_file", + Description: "Read a file", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "file path", + }, + }, + "required": []any{"path"}, + }, + }, + }, + wantCount: 1, + }, + { + name: "tool without properties in parameters", + input: []ToolDef{ + { + Name: "noop", + Description: "Does nothing", + Parameters: map[string]any{"type": "object"}, + }, + }, + wantCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertTools(tt.input) + if tt.wantNil { + if result != nil { + t.Errorf("convertTools() = %v, want nil", result) + } + return + } + if len(result) != tt.wantCount { + t.Errorf("convertTools() returned %d tools, want %d", len(result), tt.wantCount) + } + if tt.wantCount > 0 { + tool := result[0] + if tool.Function.Name != tt.input[0].Name { + t.Errorf("tool name = %q, want %q", tool.Function.Name, tt.input[0].Name) + } + if tool.Function.Description != tt.input[0].Description { + t.Errorf("tool description = %q, want %q", tool.Function.Description, tt.input[0].Description) + } + if tool.Type != "function" { + t.Errorf("tool type = %q, want %q", tool.Type, "function") + } + } + }) + } +} + +func TestStrFromMap(t *testing.T) { + tests := []struct { + name string + m map[string]any + key string + want string + }{ + { + name: "key present", + m: map[string]any{"description": "a desc"}, + key: "description", + want: "a desc", + }, + { + name: "key missing", + m: map[string]any{"other": "value"}, + key: "description", + want: "", + }, + { + name: "nil map", + m: nil, + key: "description", + want: "", + }, + { + name: "non-string value", + m: map[string]any{"count": 42}, + key: "count", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := strFromMap(tt.m, tt.key) + if got != tt.want { + t.Errorf("strFromMap() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/logging/logger.go b/internal/logging/logger.go new file mode 100644 index 0000000..49cadeb --- /dev/null +++ b/internal/logging/logger.go @@ -0,0 +1,33 @@ +package logging + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/charmbracelet/log" +) + +func NewSessionLogger() (*log.Logger, *os.File, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, nil, fmt.Errorf("home dir: %w", err) + } + logDir := filepath.Join(home, ".config", "ai-agent", "logs") + if err := os.MkdirAll(logDir, 0o755); err != nil { + return nil, nil, fmt.Errorf("create log dir: %w", err) + } + filename := time.Now().Format("2006-01-02_15-04-05") + ".log" + f, err := os.Create(filepath.Join(logDir, filename)) + if err != nil { + return nil, nil, fmt.Errorf("create log file: %w", err) + } + logger := log.NewWithOptions(f, log.Options{ + ReportTimestamp: true, + TimeFormat: time.RFC3339, + Prefix: "ai-agent", + Level: log.DebugLevel, + }) + return logger, f, nil +} diff --git a/internal/logging/logger_test.go b/internal/logging/logger_test.go new file mode 100644 index 0000000..1e1d326 --- /dev/null +++ b/internal/logging/logger_test.go @@ -0,0 +1,42 @@ +package logging + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestNewSessionLogger(t *testing.T) { + logger, f, err := NewSessionLogger() + if err != nil { + t.Fatalf("NewSessionLogger() error: %v", err) + } + if f != nil { + defer f.Close() + defer os.Remove(f.Name()) + } + if logger == nil { + t.Fatal("logger should not be nil") + } + if f == nil { + t.Fatal("file should not be nil") + } + dir := filepath.Dir(f.Name()) + if !strings.Contains(dir, filepath.Join(".config", "ai-agent", "logs")) { + t.Errorf("log file should be in ~/.config/ai-agent/logs/, got %q", dir) + } + logger.Info("test message", "key", "value") +} + +func TestNilLoggerNoPanic(t *testing.T) { + var called bool + logger, f, err := NewSessionLogger() + if err == nil && f != nil { + defer f.Close() + defer os.Remove(f.Name()) + logger.Info("test") + called = true + } + _ = called +} diff --git a/internal/logging/reader.go b/internal/logging/reader.go new file mode 100644 index 0000000..0eff8af --- /dev/null +++ b/internal/logging/reader.go @@ -0,0 +1,92 @@ +package logging + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "sort" + "time" +) + +type LogEntry struct { + Path string + ModTime time.Time + Size int64 +} + +func LogDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".config", "ai-agent", "logs") +} + +func ListLogs(n int) ([]LogEntry, error) { + return listLogsIn(LogDir(), n) +} + +func listLogsIn(dir string, n int) ([]LogEntry, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("read log dir: %w", err) + } + var logs []LogEntry + for _, e := range entries { + if e.IsDir() { + continue + } + info, err := e.Info() + if err != nil { + continue + } + logs = append(logs, LogEntry{ + Path: filepath.Join(dir, e.Name()), + ModTime: info.ModTime(), + Size: info.Size(), + }) + } + sort.Slice(logs, func(i, j int) bool { + return logs[i].ModTime.After(logs[j].ModTime) + }) + if n > 0 && n < len(logs) { + logs = logs[:n] + } + return logs, nil +} + +func LatestLogPath() (string, error) { + return latestLogPathIn(LogDir()) +} + +func latestLogPathIn(dir string) (string, error) { + logs, err := listLogsIn(dir, 1) + if err != nil { + return "", err + } + if len(logs) == 0 { + return "", fmt.Errorf("no log files found in %s", dir) + } + return logs[0].Path, nil +} + +func TailLog(path string, n int) ([]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open log: %w", err) + } + defer f.Close() + var lines []string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read log: %w", err) + } + if n > 0 && n < len(lines) { + lines = lines[len(lines)-n:] + } + return lines, nil +} diff --git a/internal/logging/reader_test.go b/internal/logging/reader_test.go new file mode 100644 index 0000000..36d38da --- /dev/null +++ b/internal/logging/reader_test.go @@ -0,0 +1,157 @@ +package logging + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// helper creates n temp log files in dir with distinct mod times. +func createFakeLogs(t *testing.T, dir string, n int) []string { + t.Helper() + var paths []string + for i := range n { + name := filepath.Join(dir, "2025-01-01_00-00-0"+string(rune('0'+i))+".log") + if err := os.WriteFile(name, []byte("line "+string(rune('0'+i))+"\n"), 0o644); err != nil { + t.Fatal(err) + } + // Stagger mod times so ordering is deterministic. + ts := time.Now().Add(time.Duration(i) * time.Second) + if err := os.Chtimes(name, ts, ts); err != nil { + t.Fatal(err) + } + paths = append(paths, name) + } + return paths +} + +func TestListLogs(t *testing.T) { + dir := t.TempDir() + createFakeLogs(t, dir, 5) + + logs, err := listLogsIn(dir, 3) + if err != nil { + t.Fatalf("listLogsIn error: %v", err) + } + if len(logs) != 3 { + t.Fatalf("expected 3 entries, got %d", len(logs)) + } + + // Verify newest-first ordering. + for i := 1; i < len(logs); i++ { + if logs[i].ModTime.After(logs[i-1].ModTime) { + t.Errorf("entry %d (%v) is newer than entry %d (%v)", i, logs[i].ModTime, i-1, logs[i-1].ModTime) + } + } +} + +func TestListLogs_All(t *testing.T) { + dir := t.TempDir() + createFakeLogs(t, dir, 4) + + logs, err := listLogsIn(dir, 0) + if err != nil { + t.Fatalf("listLogsIn error: %v", err) + } + if len(logs) != 4 { + t.Fatalf("expected 4 entries, got %d", len(logs)) + } +} + +func TestListLogs_EmptyDir(t *testing.T) { + dir := t.TempDir() + logs, err := listLogsIn(dir, 5) + if err != nil { + t.Fatalf("listLogsIn error: %v", err) + } + if len(logs) != 0 { + t.Fatalf("expected 0 entries, got %d", len(logs)) + } +} + +func TestListLogs_MissingDir(t *testing.T) { + _, err := listLogsIn("/tmp/nonexistent-log-dir-test-xyz", 5) + if err == nil { + t.Fatal("expected error for missing dir") + } +} + +func TestTailLog(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.log") + content := "line1\nline2\nline3\nline4\nline5\n" + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + lines, err := TailLog(path, 3) + if err != nil { + t.Fatalf("TailLog error: %v", err) + } + if len(lines) != 3 { + t.Fatalf("expected 3 lines, got %d", len(lines)) + } + if lines[0] != "line3" { + t.Errorf("expected 'line3', got %q", lines[0]) + } + if lines[2] != "line5" { + t.Errorf("expected 'line5', got %q", lines[2]) + } +} + +func TestTailLog_FewerLines(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "short.log") + if err := os.WriteFile(path, []byte("only\n"), 0o644); err != nil { + t.Fatal(err) + } + + lines, err := TailLog(path, 100) + if err != nil { + t.Fatalf("TailLog error: %v", err) + } + if len(lines) != 1 { + t.Fatalf("expected 1 line, got %d", len(lines)) + } +} + +func TestTailLog_MissingFile(t *testing.T) { + _, err := TailLog("/tmp/nonexistent-file-test-xyz.log", 10) + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestLatestLogPath(t *testing.T) { + dir := t.TempDir() + paths := createFakeLogs(t, dir, 3) + + latest, err := latestLogPathIn(dir) + if err != nil { + t.Fatalf("latestLogPathIn error: %v", err) + } + // The last created file has the newest mod time. + expected := paths[len(paths)-1] + if latest != expected { + t.Errorf("expected %q, got %q", expected, latest) + } +} + +func TestLatestLogPath_EmptyDir(t *testing.T) { + dir := t.TempDir() + _, err := latestLogPathIn(dir) + if err == nil { + t.Fatal("expected error for empty dir") + } +} + +func TestLogDir(t *testing.T) { + dir := LogDir() + if dir == "" { + t.Fatal("LogDir should not be empty") + } + if filepath.Base(dir) != "logs" { + t.Errorf("expected dir to end in 'logs', got %q", dir) + } +} diff --git a/internal/mcp/client.go b/internal/mcp/client.go new file mode 100644 index 0000000..ec7393b --- /dev/null +++ b/internal/mcp/client.go @@ -0,0 +1,110 @@ +package mcp + +import ( + "context" + "fmt" + "os/exec" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type MCPClient struct { + name string + client *sdkmcp.Client + session *sdkmcp.ClientSession + cmd *exec.Cmd +} + +func Connect(ctx context.Context, name, command string, args []string, env []string, transport, url string) (*MCPClient, error) { + client := sdkmcp.NewClient( + &sdkmcp.Implementation{Name: "ai-agent", Version: "0.2.0"}, + nil, + ) + var t sdkmcp.Transport + switch transport { + case "sse": + if url == "" { + return nil, fmt.Errorf("sse transport requires url for %s", name) + } + t = &sdkmcp.SSEClientTransport{Endpoint: url} + case "streamable-http": + if url == "" { + return nil, fmt.Errorf("streamable-http transport requires url for %s", name) + } + t = &sdkmcp.StreamableClientTransport{Endpoint: url} + default: + if command == "" { + return nil, fmt.Errorf("stdio transport requires command for %s", name) + } + cmd := exec.Command(command, args...) + if len(env) > 0 { + cmd.Env = append(cmd.Environ(), env...) + } + t = &sdkmcp.CommandTransport{Command: cmd} + } + session, err := client.Connect(ctx, t, nil) + if err != nil { + return nil, fmt.Errorf("connect to %s: %w", name, err) + } + return &MCPClient{ + name: name, + client: client, + session: session, + }, nil +} + +func (c *MCPClient) Name() string { return c.name } + +func (c *MCPClient) ListTools(ctx context.Context) ([]*sdkmcp.Tool, error) { + caps := c.session.InitializeResult() + if caps == nil || caps.Capabilities.Tools == nil { + return nil, nil + } + var tools []*sdkmcp.Tool + for tool, err := range c.session.Tools(ctx, nil) { + if err != nil { + return tools, fmt.Errorf("list tools from %s: %w", c.name, err) + } + tools = append(tools, tool) + } + return tools, nil +} + +func (c *MCPClient) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) { + result, err := c.session.CallTool(ctx, &sdkmcp.CallToolParams{ + Name: name, + Arguments: args, + }) + if err != nil { + return nil, fmt.Errorf("call tool %s on %s: %w", name, c.name, err) + } + var text string + for _, ct := range result.Content { + if tc, ok := ct.(*sdkmcp.TextContent); ok { + if text != "" { + text += "\n" + } + text += tc.Text + } + } + return &ToolResult{Content: text, IsError: result.IsError}, nil +} + +func (c *MCPClient) Close() error { + if c.session != nil { + return c.session.Close() + } + return nil +} + +func (c *MCPClient) IsConnected() bool { + return c.session != nil +} + +func (c *MCPClient) Ping(ctx context.Context) error { + if c.session == nil { + return fmt.Errorf("no session") + } + _, err := c.ListTools(ctx) + return err +} diff --git a/internal/mcp/registry.go b/internal/mcp/registry.go new file mode 100644 index 0000000..2f849ca --- /dev/null +++ b/internal/mcp/registry.go @@ -0,0 +1,241 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "ai-agent/internal/config" + "ai-agent/internal/llm" +) + +type FailedServer struct { + Name string + Reason string +} + +type ServerStatus struct { + Name string + Connected bool + LastError string + LastPing time.Time +} + +type Registry struct { + mu sync.RWMutex + clients []*MCPClient + toolMap map[string]*MCPClient + toolDefs []llm.ToolDef + failedServers []FailedServer + serverConfigs map[string]config.ServerConfig +} + +func NewRegistry() *Registry { + return &Registry{toolMap: make(map[string]*MCPClient), serverConfigs: make(map[string]config.ServerConfig)} +} + +const connectTimeout = 5 * time.Second + +func (r *Registry) ConnectServer(ctx context.Context, srv config.ServerConfig) (int, error) { + connCtx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + client, err := Connect(connCtx, srv.Name, srv.Command, srv.Args, srv.Env, srv.Transport, srv.URL) + if err != nil { + r.mu.Lock() + r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()}) + r.mu.Unlock() + return 0, fmt.Errorf("connect to %s: %w", srv.Name, err) + } + tools, err := client.ListTools(connCtx) + if err != nil { + client.Close() + r.mu.Lock() + r.failedServers = append(r.failedServers, FailedServer{Name: srv.Name, Reason: err.Error()}) + r.mu.Unlock() + return 0, fmt.Errorf("%s tools: %w", srv.Name, err) + } + r.mu.Lock() + r.clients = append(r.clients, client) + for _, tool := range tools { + r.toolMap[tool.Name] = client + r.toolDefs = append(r.toolDefs, ToLLMToolDef(tool.Name, tool.Description, tool.InputSchema)) + } + r.serverConfigs[srv.Name] = srv + r.mu.Unlock() + return len(tools), nil +} + +func (r *Registry) ConnectAll(ctx context.Context, servers []config.ServerConfig, logFn func(string)) { + for _, srv := range servers { + toolCount, err := r.ConnectServer(ctx, srv) + if err != nil { + logFn(fmt.Sprintf("skip %s: %v", srv.Name, err)) + continue + } + logFn(fmt.Sprintf("connected %s (%d tools)", srv.Name, toolCount)) + } +} + +func (r *Registry) Tools() []llm.ToolDef { + r.mu.RLock() + defer r.mu.RUnlock() + return r.toolDefs +} + +func (r *Registry) ToolCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.toolDefs) +} + +func (r *Registry) ServerCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.clients) +} + +func (r *Registry) ServerNames() []string { + r.mu.RLock() + defer r.mu.RUnlock() + names := make([]string, len(r.clients)) + for i, c := range r.clients { + names[i] = c.Name() + } + return names +} + +func (r *Registry) FailedServers() []FailedServer { + r.mu.RLock() + defer r.mu.RUnlock() + return r.failedServers +} + +func (r *Registry) CallTool(ctx context.Context, name string, args map[string]any) (*ToolResult, error) { + r.mu.RLock() + client, ok := r.toolMap[name] + r.mu.RUnlock() + if !ok { + return &ToolResult{ + Content: fmt.Sprintf("unknown tool: %s", name), + IsError: true, + }, nil + } + return client.CallTool(ctx, name, args) +} + +func (r *Registry) Close() { + r.mu.Lock() + defer r.mu.Unlock() + for _, c := range r.clients { + c.Close() + } + r.clients = nil + r.toolMap = make(map[string]*MCPClient) + r.toolDefs = nil +} + +func (r *Registry) HealthCheck(ctx context.Context) []ServerStatus { + r.mu.RLock() + defer r.mu.RUnlock() + var results []ServerStatus + for _, client := range r.clients { + status := ServerStatus{Name: client.Name()} + if client.IsConnected() { + pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + err := client.Ping(pingCtx) + cancel() + status.Connected = err == nil + if err != nil { + status.LastError = err.Error() + } + status.LastPing = time.Now() + } + results = append(results, status) + } + for _, failed := range r.failedServers { + results = append(results, ServerStatus{ + Name: failed.Name, + Connected: false, + LastError: failed.Reason, + }) + } + return results +} + +func (r *Registry) ReconnectServer(ctx context.Context, name string) (int, error) { + r.mu.RLock() + srv, ok := r.serverConfigs[name] + r.mu.RUnlock() + if !ok { + return 0, fmt.Errorf("no config found for server: %s", name) + } + r.mu.Lock() + var remainingFailed []FailedServer + for _, f := range r.failedServers { + if f.Name != name { + remainingFailed = append(remainingFailed, f) + } + } + r.failedServers = remainingFailed + r.mu.Unlock() + return r.ConnectServer(ctx, srv) +} + +type MonitorConfig struct { + Interval time.Duration + MaxRetries int + BackoffBase time.Duration +} + +var defaultMonitorConfig = MonitorConfig{ + Interval: 30 * time.Second, + MaxRetries: 3, + BackoffBase: 5 * time.Second, +} + +func (r *Registry) StartHealthMonitor(ctx context.Context, cfg MonitorConfig, logFn func(string)) context.CancelFunc { + if cfg.Interval == 0 { + cfg = defaultMonitorConfig + } + monitorCtx, cancel := context.WithCancel(ctx) + go func() { + ticker := time.NewTicker(cfg.Interval) + defer ticker.Stop() + for { + select { + case <-monitorCtx.Done(): + return + case <-ticker.C: + r.healthCheckRound(monitorCtx, cfg, logFn) + } + } + }() + return cancel +} + +func (r *Registry) healthCheckRound(ctx context.Context, cfg MonitorConfig, logFn func(string)) { + statuses := r.HealthCheck(ctx) + for _, status := range statuses { + if status.Connected { + continue + } + logFn(fmt.Sprintf("server %s unhealthy, attempting reconnect...", status.Name)) + for attempt := 1; attempt <= cfg.MaxRetries; attempt++ { + backoff := cfg.BackoffBase * time.Duration(attempt) + select { + case <-ctx.Done(): + return + case <-time.After(backoff): + } + _, err := r.ReconnectServer(ctx, status.Name) + if err == nil { + logFn(fmt.Sprintf("server %s reconnected", status.Name)) + break + } + if attempt == cfg.MaxRetries { + logFn(fmt.Sprintf("server %s reconnection failed after %d attempts: %v", status.Name, cfg.MaxRetries, err)) + } + } + } +} diff --git a/internal/mcp/registry_test.go b/internal/mcp/registry_test.go new file mode 100644 index 0000000..8dbc60d --- /dev/null +++ b/internal/mcp/registry_test.go @@ -0,0 +1,73 @@ +package mcp + +import ( + "context" + "strings" + "testing" +) + +func TestNewRegistry(t *testing.T) { + r := NewRegistry() + + if r.ToolCount() != 0 { + t.Errorf("ToolCount() = %d, want 0", r.ToolCount()) + } + if r.ServerCount() != 0 { + t.Errorf("ServerCount() = %d, want 0", r.ServerCount()) + } + if tools := r.Tools(); len(tools) != 0 { + t.Errorf("Tools() = %v, want empty", tools) + } +} + +func TestRegistry_CallTool_Unknown(t *testing.T) { + r := NewRegistry() + + result, err := r.CallTool(context.Background(), "nonexistent_tool", nil) + if err != nil { + t.Fatalf("CallTool() unexpected error: %v", err) + } + if !result.IsError { + t.Error("CallTool() IsError = false, want true for unknown tool") + } + if !strings.Contains(result.Content, "unknown tool") { + t.Errorf("CallTool() Content = %q, want to contain 'unknown tool'", result.Content) + } +} + +func TestRegistry_HealthCheck_Empty(t *testing.T) { + r := NewRegistry() + + statuses := r.HealthCheck(context.Background()) + if len(statuses) != 0 { + t.Errorf("HealthCheck() returned %d statuses, want 0", len(statuses)) + } +} + +func TestRegistry_HealthCheck_TracksFailedServers(t *testing.T) { + r := NewRegistry() + + // Simulate a failed server by directly adding to failedServers + r.mu.Lock() + r.failedServers = append(r.failedServers, FailedServer{ + Name: "failed-server", + Reason: "connection refused", + }) + r.mu.Unlock() + + statuses := r.HealthCheck(context.Background()) + if len(statuses) != 1 { + t.Fatalf("HealthCheck() returned %d statuses, want 1", len(statuses)) + } + + status := statuses[0] + if status.Name != "failed-server" { + t.Errorf("status.Name = %q, want 'failed-server'", status.Name) + } + if status.Connected { + t.Error("status.Connected = true, want false") + } + if status.LastError != "connection refused" { + t.Errorf("status.LastError = %q, want 'connection refused'", status.LastError) + } +} diff --git a/internal/mcp/types.go b/internal/mcp/types.go new file mode 100644 index 0000000..477e16f --- /dev/null +++ b/internal/mcp/types.go @@ -0,0 +1,27 @@ +package mcp + +import ( + "ai-agent/internal/llm" +) + +type ServerInfo struct { + Name string + ToolCount int +} + +type ToolResult struct { + Content string + IsError bool +} + +func ToLLMToolDef(name, description string, inputSchema any) llm.ToolDef { + params, _ := inputSchema.(map[string]any) + if params == nil { + params = map[string]any{"type": "object", "properties": map[string]any{}} + } + return llm.ToolDef{ + Name: name, + Description: description, + Parameters: params, + } +} diff --git a/internal/mcp/types_test.go b/internal/mcp/types_test.go new file mode 100644 index 0000000..3f4df67 --- /dev/null +++ b/internal/mcp/types_test.go @@ -0,0 +1,71 @@ +package mcp + +import ( + "testing" +) + +func TestToLLMToolDef(t *testing.T) { + tests := []struct { + name string + toolName string + description string + inputSchema any + wantName string + wantDesc string + wantParams bool // true = should have non-nil params + }{ + { + name: "normal with valid schema", + toolName: "read_file", + description: "Read a file", + inputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + wantName: "read_file", + wantDesc: "Read a file", + wantParams: true, + }, + { + name: "nil schema uses default", + toolName: "noop", + description: "No-op tool", + inputSchema: nil, + wantName: "noop", + wantDesc: "No-op tool", + wantParams: true, + }, + { + name: "non-map schema uses default", + toolName: "bad_schema", + description: "Bad schema tool", + inputSchema: "not a map", + wantName: "bad_schema", + wantDesc: "Bad schema tool", + wantParams: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToLLMToolDef(tt.toolName, tt.description, tt.inputSchema) + if result.Name != tt.wantName { + t.Errorf("Name = %q, want %q", result.Name, tt.wantName) + } + if result.Description != tt.wantDesc { + t.Errorf("Description = %q, want %q", result.Description, tt.wantDesc) + } + if tt.wantParams && result.Parameters == nil { + t.Error("Parameters should not be nil") + } + // Nil and non-map schemas should get the default object schema. + if tt.inputSchema == nil || func() bool { _, ok := tt.inputSchema.(map[string]any); return !ok }() { + if result.Parameters["type"] != "object" { + t.Errorf("default schema type = %v, want 'object'", result.Parameters["type"]) + } + } + }) + } +} diff --git a/internal/memory/store.go b/internal/memory/store.go new file mode 100644 index 0000000..5c92301 --- /dev/null +++ b/internal/memory/store.go @@ -0,0 +1,259 @@ +package memory + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" +) + +type Memory struct { + ID int `json:"id"` + Content string `json:"content"` + Tags []string `json:"tags,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastUsed time.Time `json:"last_used"` +} + +type Store struct { + mu sync.Mutex + path string + memories []Memory + nextID int +} + +func NewStore(path string) *Store { + if path == "" { + home, err := os.UserHomeDir() + if err != nil { + home = "." + } + path = filepath.Join(home, ".config", "ai-agent", "memories.json") + } + s := &Store{path: path} + s.load() + return s +} + +func (s *Store) Save(content string, tags []string) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.nextID++ + mem := Memory{ + ID: s.nextID, + Content: content, + Tags: tags, + CreatedAt: time.Now(), + LastUsed: time.Now(), + } + s.memories = append(s.memories, mem) + if err := s.persist(); err != nil { + return 0, err + } + return mem.ID, nil +} + +func (s *Store) Recall(query string, maxResults int) []Memory { + s.mu.Lock() + defer s.mu.Unlock() + if maxResults <= 0 { + maxResults = 5 + } + queryLower := strings.ToLower(query) + words := strings.Fields(queryLower) + type scored struct { + mem Memory + score int + } + var results []scored + for i := range s.memories { + mem := s.memories[i] + score := 0 + contentLower := strings.ToLower(mem.Content) + for _, w := range words { + if strings.Contains(contentLower, w) { + score += 2 + } + } + for _, tag := range mem.Tags { + tagLower := strings.ToLower(tag) + for _, w := range words { + if strings.Contains(tagLower, w) { + score += 3 + } + } + } + if score > 0 { + results = append(results, scored{mem: mem, score: score}) + } + } + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].mem.LastUsed.After(results[j].mem.LastUsed) + }) + if len(results) > maxResults { + results = results[:maxResults] + } + now := time.Now() + out := make([]Memory, len(results)) + for i, r := range results { + out[i] = r.mem + for j := range s.memories { + if s.memories[j].ID == r.mem.ID { + s.memories[j].LastUsed = now + break + } + } + } + _ = s.persist() + return out +} + +func (s *Store) Recent(n int) []Memory { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.memories) == 0 { + return nil + } + sorted := make([]Memory, len(s.memories)) + copy(sorted, s.memories) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].LastUsed.After(sorted[j].LastUsed) + }) + if n > len(sorted) { + n = len(sorted) + } + return sorted[:n] +} + +func (s *Store) Count() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.memories) +} + +func (s *Store) Delete(id int) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for i, mem := range s.memories { + if mem.ID == id { + s.memories = append(s.memories[:i], s.memories[i+1:]...) + return true, s.persist() + } + } + return false, nil +} + +func (s *Store) DeleteByTag(tag string) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + tagLower := strings.ToLower(tag) + var remaining []Memory + deleted := 0 + for _, mem := range s.memories { + found := false + for _, t := range mem.Tags { + if strings.ToLower(t) == tagLower { + found = true + break + } + } + if found { + deleted++ + } else { + remaining = append(remaining, mem) + } + } + s.memories = remaining + if deleted > 0 { + return deleted, s.persist() + } + return deleted, nil +} + +func (s *Store) Update(id int, content string, tags []string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for i, mem := range s.memories { + if mem.ID == id { + if content != "" { + s.memories[i].Content = content + } + if tags != nil { + s.memories[i].Tags = tags + } + s.memories[i].LastUsed = time.Now() + return true, s.persist() + } + } + return false, nil +} + +func (s *Store) Prune(olderThan time.Duration) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + cutoff := time.Now().Add(-olderThan) + var remaining []Memory + deleted := 0 + for _, mem := range s.memories { + if mem.CreatedAt.Before(cutoff) { + deleted++ + } else { + remaining = append(remaining, mem) + } + } + s.memories = remaining + if deleted > 0 { + return deleted, s.persist() + } + return deleted, nil +} + +func (s *Store) Get(id int) (Memory, bool) { + s.mu.Lock() + defer s.mu.Unlock() + for _, mem := range s.memories { + if mem.ID == id { + return mem, true + } + } + return Memory{}, false +} + +func (s *Store) load() { + data, err := os.ReadFile(s.path) + if err != nil { + return + } + var memories []Memory + if err := json.Unmarshal(data, &memories); err != nil { + return + } + s.memories = memories + for _, m := range s.memories { + if m.ID > s.nextID { + s.nextID = m.ID + } + } +} + +func (s *Store) persist() error { + dir := filepath.Dir(s.path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("create memory dir: %w", err) + } + data, err := json.MarshalIndent(s.memories, "", " ") + if err != nil { + return fmt.Errorf("marshal memories: %w", err) + } + if err := os.WriteFile(s.path, data, 0o644); err != nil { + return fmt.Errorf("write memories: %w", err) + } + return nil +} diff --git a/internal/memory/store_test.go b/internal/memory/store_test.go new file mode 100644 index 0000000..fc78526 --- /dev/null +++ b/internal/memory/store_test.go @@ -0,0 +1,381 @@ +package memory + +import ( + "path/filepath" + "testing" + "time" +) + +func TestStore_Save_And_Count(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + + s := NewStore(path) + if s.Count() != 0 { + t.Fatalf("new store Count = %d, want 0", s.Count()) + } + + id1, err := s.Save("first memory", []string{"tag1"}) + if err != nil { + t.Fatalf("Save returned error: %v", err) + } + if id1 != 1 { + t.Errorf("first Save id = %d, want 1", id1) + } + if s.Count() != 1 { + t.Errorf("Count after first Save = %d, want 1", s.Count()) + } + + id2, err := s.Save("second memory", []string{"tag2"}) + if err != nil { + t.Fatalf("Save returned error: %v", err) + } + if id2 != 2 { + t.Errorf("second Save id = %d, want 2", id2) + } + if s.Count() != 2 { + t.Errorf("Count after second Save = %d, want 2", s.Count()) + } +} + +func TestStore_Recall(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + s.Save("the user prefers Go language", []string{"preference", "golang"}) + s.Save("project uses PostgreSQL database", []string{"tech", "database"}) + s.Save("user name is Alice", []string{"name"}) + + t.Run("content match", func(t *testing.T) { + results := s.Recall("Go", 10) + if len(results) == 0 { + t.Fatal("expected results for 'Go' query") + } + found := false + for _, r := range results { + if r.Content == "the user prefers Go language" { + found = true + } + } + if !found { + t.Error("expected to find 'the user prefers Go language'") + } + }) + + t.Run("tag match", func(t *testing.T) { + results := s.Recall("golang", 10) + if len(results) == 0 { + t.Fatal("expected results for 'golang' tag query") + } + if results[0].Content != "the user prefers Go language" { + t.Errorf("top result = %q, want 'the user prefers Go language'", results[0].Content) + } + }) + + t.Run("combined scoring", func(t *testing.T) { + // "database" matches both content and tag for PostgreSQL entry. + results := s.Recall("database", 10) + if len(results) == 0 { + t.Fatal("expected results for 'database' query") + } + if results[0].Content != "project uses PostgreSQL database" { + t.Errorf("top result = %q, want 'project uses PostgreSQL database'", + results[0].Content) + } + }) + + t.Run("maxResults limit", func(t *testing.T) { + results := s.Recall("user", 1) + if len(results) > 1 { + t.Errorf("maxResults=1 but got %d results", len(results)) + } + }) + + t.Run("default maxResults 5 when 0", func(t *testing.T) { + // With 3 memories, should return all 3 (default limit is 5). + results := s.Recall("user", 0) + if len(results) > 5 { + t.Errorf("default maxResults should be 5, got %d results", len(results)) + } + }) + + t.Run("case insensitive", func(t *testing.T) { + results := s.Recall("ALICE", 10) + if len(results) == 0 { + t.Fatal("expected case-insensitive match for 'ALICE'") + } + if results[0].Content != "user name is Alice" { + t.Errorf("result = %q, want 'user name is Alice'", results[0].Content) + } + }) + + t.Run("no matches", func(t *testing.T) { + results := s.Recall("xyzzyzxyz", 10) + if len(results) != 0 { + t.Errorf("expected no results for nonsense query, got %d", len(results)) + } + }) +} + +func TestStore_Recall_TieBreakByRecency(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + // Save two memories with the same scoring potential. + s.Save("alpha topic info", []string{"info"}) + // Small delay so LastUsed differs. + time.Sleep(10 * time.Millisecond) + s.Save("beta topic info", []string{"info"}) + + results := s.Recall("info", 10) + if len(results) < 2 { + t.Fatalf("expected at least 2 results, got %d", len(results)) + } + // Both match tag "info" equally (+3), so more recent (beta) should come first. + if results[0].Content != "beta topic info" { + t.Errorf("expected more recent 'beta topic info' first, got %q", results[0].Content) + } +} + +func TestStore_Recent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + s.Save("old memory", nil) + time.Sleep(10 * time.Millisecond) + s.Save("new memory", nil) + + t.Run("ordering by LastUsed", func(t *testing.T) { + recent := s.Recent(2) + if len(recent) != 2 { + t.Fatalf("Recent(2) returned %d, want 2", len(recent)) + } + if recent[0].Content != "new memory" { + t.Errorf("first recent = %q, want 'new memory'", recent[0].Content) + } + if recent[1].Content != "old memory" { + t.Errorf("second recent = %q, want 'old memory'", recent[1].Content) + } + }) + + t.Run("limit exceeds count returns all", func(t *testing.T) { + recent := s.Recent(100) + if len(recent) != 2 { + t.Errorf("Recent(100) returned %d, want 2", len(recent)) + } + }) + + t.Run("empty store", func(t *testing.T) { + emptyPath := filepath.Join(dir, "empty.json") + empty := NewStore(emptyPath) + recent := empty.Recent(5) + if recent != nil { + t.Errorf("empty Recent should return nil, got %v", recent) + } + }) +} + +func TestStore_Persistence_RoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + + s1 := NewStore(path) + s1.Save("persistent memory", []string{"test"}) + s1.Save("another memory", []string{"test2"}) + + // Create new store from same path. + s2 := NewStore(path) + if s2.Count() != 2 { + t.Errorf("reloaded Count = %d, want 2", s2.Count()) + } + + // Verify data is intact. + recent := s2.Recent(2) + contents := map[string]bool{} + for _, m := range recent { + contents[m.Content] = true + } + if !contents["persistent memory"] { + t.Error("missing 'persistent memory' after reload") + } + if !contents["another memory"] { + t.Error("missing 'another memory' after reload") + } + + // Verify IDs continue. + id, err := s2.Save("third", nil) + if err != nil { + t.Fatalf("Save after reload: %v", err) + } + if id != 3 { + t.Errorf("continued id = %d, want 3", id) + } +} + +func TestStore_Delete(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + id, _ := s.Save("to be deleted", []string{"temp"}) + if s.Count() != 1 { + t.Fatalf("expected 1 memory, got %d", s.Count()) + } + + deleted, err := s.Delete(id) + if err != nil { + t.Fatalf("Delete returned error: %v", err) + } + if !deleted { + t.Error("Delete returned false for existing memory") + } + if s.Count() != 0 { + t.Errorf("Count after delete = %d, want 0", s.Count()) + } + + // Try deleting non-existent. + deleted, err = s.Delete(999) + if err != nil { + t.Fatalf("Delete returned error: %v", err) + } + if deleted { + t.Error("Delete should return false for non-existent memory") + } +} + +func TestStore_Update(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + id, _ := s.Save("original content", []string{"original"}) + + updated, err := s.Update(id, "updated content", []string{"updated"}) + if err != nil { + t.Fatalf("Update returned error: %v", err) + } + if !updated { + t.Error("Update returned false for existing memory") + } + + // Verify update. + mem, found := s.Get(id) + if !found { + t.Fatal("memory not found after update") + } + if mem.Content != "updated content" { + t.Errorf("Content = %q, want 'updated content'", mem.Content) + } + if len(mem.Tags) != 1 || mem.Tags[0] != "updated" { + t.Errorf("Tags = %v, want ['updated']", mem.Tags) + } + + // Try updating non-existent. + updated, err = s.Update(999, "test", nil) + if err != nil { + t.Fatalf("Update returned error: %v", err) + } + if updated { + t.Error("Update should return false for non-existent memory") + } +} + +func TestStore_DeleteByTag(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + s.Save("keep this 1", []string{"keep"}) + s.Save("delete this", []string{"temp"}) + s.Save("keep this 2", []string{"keep"}) + s.Save("delete this too", []string{"temp"}) + s.Save("also keep", []string{"permanent"}) + + deleted, err := s.DeleteByTag("temp") + if err != nil { + t.Fatalf("DeleteByTag returned error: %v", err) + } + if deleted != 2 { + t.Errorf("DeleteByTag deleted = %d, want 2", deleted) + } + if s.Count() != 3 { + t.Errorf("Count after delete = %d, want 3", s.Count()) + } + + // Verify only temp memories are gone. + results := s.Recall("keep", 10) + if len(results) != 3 { + t.Errorf("Recall returned %d, want 3", len(results)) + } +} + +func TestStore_Get(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + id, _ := s.Save("test memory", []string{"tag"}) + + mem, found := s.Get(id) + if !found { + t.Fatal("Get returned false for existing memory") + } + if mem.Content != "test memory" { + t.Errorf("Content = %q, want 'test memory'", mem.Content) + } + if len(mem.Tags) != 1 || mem.Tags[0] != "tag" { + t.Errorf("Tags = %v, want ['tag']", mem.Tags) + } + + // Try getting non-existent. + _, found = s.Get(999) + if found { + t.Error("Get should return false for non-existent memory") + } +} + +func TestStore_UpdatePartial(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "memories.json") + s := NewStore(path) + + id, _ := s.Save("original content", []string{"original", "tags"}) + + // Update only content, keep tags. + updated, err := s.Update(id, "new content", nil) + if err != nil { + t.Fatalf("Update returned error: %v", err) + } + if !updated { + t.Error("Update returned false") + } + + mem, _ := s.Get(id) + if mem.Content != "new content" { + t.Errorf("Content = %q, want 'new content'", mem.Content) + } + // Tags should remain unchanged when nil is passed. + if len(mem.Tags) != 2 { + t.Errorf("Tags = %v, want 2 tags", mem.Tags) + } + + // Update only tags, keep content. + updated, err = s.Update(id, "", []string{"only", "tags"}) + if err != nil { + t.Fatalf("Update returned error: %v", err) + } + if !updated { + t.Error("Update returned false") + } + + mem, _ = s.Get(id) + if mem.Content != "new content" { + t.Errorf("Content changed unexpectedly to %q", mem.Content) + } + if len(mem.Tags) != 2 || mem.Tags[0] != "only" || mem.Tags[1] != "tags" { + t.Errorf("Tags = %v, want ['only', 'tags']", mem.Tags) + } +} diff --git a/internal/memory/tools.go b/internal/memory/tools.go new file mode 100644 index 0000000..8cb866f --- /dev/null +++ b/internal/memory/tools.go @@ -0,0 +1,102 @@ +package memory + +import ( + "ai-agent/internal/llm" +) + +func BuiltinToolDefs() []llm.ToolDef { + return []llm.ToolDef{ + { + Name: "memory_save", + Description: "Save an important fact, user preference, or piece of context to persistent memory. Use this proactively when the user shares information worth remembering across sessions.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "content": map[string]any{ + "type": "string", + "description": "The fact or information to remember.", + }, + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + "description": "Optional tags for categorization (e.g., 'preference', 'project', 'name').", + }, + }, + "required": []string{"content"}, + }, + }, + { + Name: "memory_recall", + Description: "Search persistent memory for previously saved facts. Use this when you need to recall user preferences, project details, or other saved context.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query to find relevant memories.", + }, + }, + "required": []string{"query"}, + }, + }, + { + Name: "memory_delete", + Description: "Delete a memory by its ID. Use memory_recall or memory_list first to find the ID of the memory you want to delete.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "integer", + "description": "The ID of the memory to delete (use memory_list or memory_recall to find IDs).", + }, + }, + "required": []string{"id"}, + }, + }, + { + Name: "memory_update", + Description: "Update an existing memory's content or tags. Use memory_recall or memory_list first to find the ID.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{ + "type": "integer", + "description": "The ID of the memory to update (use memory_list or memory_recall to find IDs).", + }, + "content": map[string]any{ + "type": "string", + "description": "New content for the memory.", + }, + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + "description": "New tags for the memory.", + }, + }, + "required": []string{"id"}, + }, + }, + { + Name: "memory_list", + Description: "List all stored memories with their IDs, content, and tags. Use this to see what has been saved.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of memories to return (default: 20).", + }, + }, + }, + }, + } +} + +func IsBuiltinTool(name string) bool { + switch name { + case "memory_save", "memory_recall", "memory_delete", "memory_update", "memory_list": + return true + default: + return false + } +} diff --git a/internal/memory/tools_test.go b/internal/memory/tools_test.go new file mode 100644 index 0000000..9d30579 --- /dev/null +++ b/internal/memory/tools_test.go @@ -0,0 +1,48 @@ +package memory + +import "testing" + +func TestBuiltinToolDefs(t *testing.T) { + defs := BuiltinToolDefs() + if len(defs) != 5 { + t.Fatalf("BuiltinToolDefs() returned %d defs, want 5", len(defs)) + } + + names := map[string]bool{} + for _, d := range defs { + names[d.Name] = true + } + + expected := []string{"memory_save", "memory_recall", "memory_delete", "memory_update", "memory_list"} + for _, name := range expected { + if !names[name] { + t.Errorf("missing %s tool definition", name) + } + } +} + +func TestIsBuiltinTool(t *testing.T) { + tests := []struct { + name string + tool string + want bool + }{ + {name: "memory_save", tool: "memory_save", want: true}, + {name: "memory_recall", tool: "memory_recall", want: true}, + {name: "memory_delete", tool: "memory_delete", want: true}, + {name: "memory_update", tool: "memory_update", want: true}, + {name: "memory_list", tool: "memory_list", want: true}, + {name: "unknown tool", tool: "unknown", want: false}, + {name: "empty string", tool: "", want: false}, + {name: "partial match", tool: "memory_", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsBuiltinTool(tt.tool) + if got != tt.want { + t.Errorf("IsBuiltinTool(%q) = %v, want %v", tt.tool, got, tt.want) + } + }) + } +} diff --git a/internal/permission/checker.go b/internal/permission/checker.go new file mode 100644 index 0000000..5963b76 --- /dev/null +++ b/internal/permission/checker.go @@ -0,0 +1,158 @@ +package permission + +import ( + "context" + "sync" + + "ai-agent/internal/db" +) + +type Policy string + +const ( + PolicyAllow Policy = "allow" + PolicyDeny Policy = "deny" + PolicyAsk Policy = "ask" +) + +type Checker struct { + store *db.Store + cache map[string]Policy + mu sync.RWMutex + yolo bool +} + +func NewChecker(store *db.Store, yolo bool) *Checker { + c := &Checker{ + store: store, + cache: make(map[string]Policy), + yolo: yolo, + } + if store != nil { + c.loadFromDB() + } + return c +} + +func (c *Checker) Check(toolName string) Policy { + if c.yolo { + return PolicyAllow + } + c.mu.RLock() + defer c.mu.RUnlock() + if p, ok := c.cache[toolName]; ok { + return p + } + return PolicyAsk +} + +func (c *Checker) SetPolicy(toolName string, policy Policy) { + c.mu.Lock() + c.cache[toolName] = policy + c.mu.Unlock() + if c.store != nil { + c.store.UpsertToolPermission(context.Background(), db.UpsertToolPermissionParams{ + ToolName: toolName, + Policy: string(policy), + }) + } +} + +func (c *Checker) IsYolo() bool { + return c.yolo +} + +func (c *Checker) AllPolicies() map[string]Policy { + c.mu.RLock() + defer c.mu.RUnlock() + result := make(map[string]Policy, len(c.cache)) + for k, v := range c.cache { + result[k] = v + } + return result +} + +func (c *Checker) Reset() { + c.mu.Lock() + c.cache = make(map[string]Policy) + c.mu.Unlock() + if c.store != nil { + c.store.ResetToolPermissions(context.Background()) + } +} + +func (c *Checker) loadFromDB() { + perms, err := c.store.ListToolPermissions(context.Background()) + if err != nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + for _, p := range perms { + switch Policy(p.Policy) { + case PolicyAllow, PolicyDeny, PolicyAsk: + c.cache[p.ToolName] = Policy(p.Policy) + } + } +} + +type ApprovalRequest struct { + ToolName string + Args map[string]any + Response chan ApprovalResponse +} + +type ApprovalResponse struct { + Allowed bool + Always bool +} + +func RequestApproval(toolName string, args map[string]any, callback func(ApprovalRequest)) (bool, bool) { + if callback == nil { + return true, false + } + ch := make(chan ApprovalResponse, 1) + callback(ApprovalRequest{ + ToolName: toolName, + Args: args, + Response: ch, + }) + resp := <-ch + return resp.Allowed, resp.Always +} + +type CheckResult int + +const ( + CheckAllow CheckResult = iota + CheckDeny + CheckAsk +) + +func (c *Checker) ToCheckResult(toolName string) CheckResult { + if c == nil || c.yolo { + return CheckAllow + } + switch c.Check(toolName) { + case PolicyAllow: + return CheckAllow + case PolicyDeny: + return CheckDeny + default: + return CheckAsk + } +} + +func NilSafe(store *db.Store, yolo bool) *Checker { + return NewChecker(store, yolo) +} + +var AlwaysAllow = func(_ ApprovalRequest) {} + +type ErrDenied struct { + ToolName string +} + +func (e *ErrDenied) Error() string { + return "tool call denied by permission policy: " + e.ToolName +} diff --git a/internal/permission/checker_test.go b/internal/permission/checker_test.go new file mode 100644 index 0000000..93e7889 --- /dev/null +++ b/internal/permission/checker_test.go @@ -0,0 +1,98 @@ +package permission + +import ( + "path/filepath" + "testing" + + "ai-agent/internal/db" +) + +func TestChecker_DefaultPolicy(t *testing.T) { + c := NewChecker(nil, false) + if got := c.Check("some_tool"); got != PolicyAsk { + t.Errorf("Check() = %q, want %q", got, PolicyAsk) + } +} + +func TestChecker_Yolo(t *testing.T) { + c := NewChecker(nil, true) + if got := c.Check("any_tool"); got != PolicyAllow { + t.Errorf("Check() = %q, want %q", got, PolicyAllow) + } + if !c.IsYolo() { + t.Error("expected IsYolo() = true") + } +} + +func TestChecker_SetPolicy(t *testing.T) { + c := NewChecker(nil, false) + c.SetPolicy("bash", PolicyAllow) + if got := c.Check("bash"); got != PolicyAllow { + t.Errorf("Check() = %q, want %q", got, PolicyAllow) + } + c.SetPolicy("bash", PolicyDeny) + if got := c.Check("bash"); got != PolicyDeny { + t.Errorf("Check() = %q, want %q", got, PolicyDeny) + } +} + +func TestChecker_WithDB(t *testing.T) { + store, err := db.OpenPath(filepath.Join(t.TempDir(), "test.db")) + if err != nil { + t.Fatal(err) + } + defer store.Close() + c := NewChecker(store, false) + c.SetPolicy("file_write", PolicyAllow) + c2 := NewChecker(store, false) + if got := c2.Check("file_write"); got != PolicyAllow { + t.Errorf("persisted Check() = %q, want %q", got, PolicyAllow) + } +} + +func TestChecker_Reset(t *testing.T) { + c := NewChecker(nil, false) + c.SetPolicy("tool1", PolicyAllow) + c.SetPolicy("tool2", PolicyDeny) + c.Reset() + if got := c.Check("tool1"); got != PolicyAsk { + t.Errorf("after reset Check() = %q, want %q", got, PolicyAsk) + } +} + +func TestChecker_AllPolicies(t *testing.T) { + c := NewChecker(nil, false) + c.SetPolicy("a", PolicyAllow) + c.SetPolicy("b", PolicyDeny) + + policies := c.AllPolicies() + if len(policies) != 2 { + t.Errorf("AllPolicies() len = %d, want 2", len(policies)) + } + if policies["a"] != PolicyAllow { + t.Errorf("policies[a] = %q, want %q", policies["a"], PolicyAllow) + } +} + +func TestToCheckResult(t *testing.T) { + c := NewChecker(nil, false) + c.SetPolicy("allowed", PolicyAllow) + c.SetPolicy("denied", PolicyDeny) + + if c.ToCheckResult("allowed") != CheckAllow { + t.Error("expected CheckAllow for allowed tool") + } + if c.ToCheckResult("denied") != CheckDeny { + t.Error("expected CheckDeny for denied tool") + } + if c.ToCheckResult("unknown") != CheckAsk { + t.Error("expected CheckAsk for unknown tool") + } +} + +func TestToCheckResult_Nil(t *testing.T) { + var c *Checker + if c.ToCheckResult("anything") != CheckAllow { + t.Error("nil checker should return CheckAllow") + } +} diff --git a/internal/skill/manager.go b/internal/skill/manager.go new file mode 100644 index 0000000..872811d --- /dev/null +++ b/internal/skill/manager.go @@ -0,0 +1,118 @@ +package skill + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +type Manager struct { + skills []*Skill + dirs []string +} + +func NewManager(dir string) *Manager { + dirs := []string{} + if dir != "" { + dirs = append(dirs, dir) + } else { + if home, err := os.UserHomeDir(); err == nil { + dirs = append(dirs, filepath.Join(home, ".config", "ai-agent", "skills")) + } + } + return &Manager{dirs: dirs} +} + +func (m *Manager) AddSearchPath(dir string) { + for _, d := range m.dirs { + if d == dir { + return + } + } + m.dirs = append(m.dirs, dir) +} + +func (m *Manager) Names() []string { + var names []string + for _, s := range m.skills { + names = append(names, s.Name) + } + return names +} + +func (m *Manager) LoadAll() error { + for _, dir := range m.dirs { + if err := m.loadFromDir(dir); err != nil { + return err + } + } + return nil +} + +func (m *Manager) loadFromDir(dir string) error { + if dir == "" { + return nil + } + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read skills dir: %w", err) + } + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") { + continue + } + path := filepath.Join(dir, entry.Name()) + data, err := os.ReadFile(path) + if err != nil { + continue + } + skill, err := parseFrontmatter(string(data)) + if err != nil { + continue + } + skill.Path = path + if skill.Name == "" { + skill.Name = strings.TrimSuffix(entry.Name(), ".md") + } + m.skills = append(m.skills, skill) + } + return nil +} + +func (m *Manager) All() []*Skill { + return m.skills +} + +func (m *Manager) Activate(name string) error { + for _, s := range m.skills { + if s.Name == name { + s.Active = true + return nil + } + } + return fmt.Errorf("skill not found: %s", name) +} + +func (m *Manager) Deactivate(name string) error { + for _, s := range m.skills { + if s.Name == name { + s.Active = false + return nil + } + } + return fmt.Errorf("skill not found: %s", name) +} + +func (m *Manager) ActiveContent() string { + var parts []string + for _, s := range m.skills { + if s.Active && s.Content != "" { + parts = append(parts, fmt.Sprintf("### %s\n%s", s.Name, s.Content)) + } + } + return strings.Join(parts, "\n\n") +} diff --git a/internal/skill/manager_test.go b/internal/skill/manager_test.go new file mode 100644 index 0000000..be35df6 --- /dev/null +++ b/internal/skill/manager_test.go @@ -0,0 +1,169 @@ +package skill + +import ( + "os" + "path/filepath" + "testing" +) + +func TestManager_LoadAll(t *testing.T) { + dir := t.TempDir() + + // Create valid skill files. + os.WriteFile(filepath.Join(dir, "greeting.md"), []byte("---\nname: greeting\ndescription: Say hello\n---\nHello!"), 0o644) + os.WriteFile(filepath.Join(dir, "farewell.md"), []byte("---\nname: farewell\ndescription: Say bye\n---\nGoodbye!"), 0o644) + + // Create a non-.md file (should be skipped). + os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("not a skill"), 0o644) + + // Create a subdirectory (should be skipped). + os.MkdirAll(filepath.Join(dir, "subdir"), 0o755) + + m := NewManager(dir) + if err := m.LoadAll(); err != nil { + t.Fatalf("LoadAll: %v", err) + } + + skills := m.All() + if len(skills) != 2 { + t.Fatalf("loaded %d skills, want 2", len(skills)) + } + + names := map[string]bool{} + for _, s := range skills { + names[s.Name] = true + } + if !names["greeting"] { + t.Error("missing 'greeting' skill") + } + if !names["farewell"] { + t.Error("missing 'farewell' skill") + } +} + +func TestManager_LoadAll_NoFrontmatter(t *testing.T) { + dir := t.TempDir() + + // File without frontmatter uses filename as name. + os.WriteFile(filepath.Join(dir, "plain.md"), []byte("Just content, no frontmatter"), 0o644) + + m := NewManager(dir) + if err := m.LoadAll(); err != nil { + t.Fatalf("LoadAll: %v", err) + } + + skills := m.All() + if len(skills) != 1 { + t.Fatalf("loaded %d skills, want 1", len(skills)) + } + if skills[0].Name != "plain" { + t.Errorf("Name = %q, want 'plain'", skills[0].Name) + } +} + +func TestManager_LoadAll_NonexistentDir(t *testing.T) { + m := NewManager("/nonexistent/path/that/does/not/exist") + if err := m.LoadAll(); err != nil { + t.Fatalf("LoadAll on nonexistent dir should not error, got: %v", err) + } + if len(m.All()) != 0 { + t.Errorf("expected 0 skills from nonexistent dir, got %d", len(m.All())) + } +} + +func TestManager_Activate_Deactivate(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "test.md"), []byte("---\nname: test\n---\nTest content"), 0o644) + + m := NewManager(dir) + m.LoadAll() + + t.Run("activate found", func(t *testing.T) { + err := m.Activate("test") + if err != nil { + t.Fatalf("Activate: %v", err) + } + skill := m.All()[0] + if !skill.Active { + t.Error("skill should be active after Activate") + } + }) + + t.Run("activate not found", func(t *testing.T) { + err := m.Activate("nonexistent") + if err == nil { + t.Error("expected error for nonexistent skill") + } + }) + + t.Run("deactivate found", func(t *testing.T) { + err := m.Deactivate("test") + if err != nil { + t.Fatalf("Deactivate: %v", err) + } + skill := m.All()[0] + if skill.Active { + t.Error("skill should be inactive after Deactivate") + } + }) + + t.Run("deactivate not found", func(t *testing.T) { + err := m.Deactivate("nonexistent") + if err == nil { + t.Error("expected error for nonexistent skill") + } + }) +} + +func TestManager_ActiveContent(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "alpha.md"), []byte("---\nname: alpha\n---\nAlpha content"), 0o644) + os.WriteFile(filepath.Join(dir, "beta.md"), []byte("---\nname: beta\n---\nBeta content"), 0o644) + + m := NewManager(dir) + m.LoadAll() + + t.Run("none active returns empty", func(t *testing.T) { + content := m.ActiveContent() + if content != "" { + t.Errorf("expected empty content, got %q", content) + } + }) + + t.Run("one active returns its content", func(t *testing.T) { + m.Activate("alpha") + content := m.ActiveContent() + if content == "" { + t.Fatal("expected non-empty content") + } + if !contains(content, "Alpha content") { + t.Errorf("content missing 'Alpha content': %q", content) + } + if contains(content, "Beta content") { + t.Errorf("content should not contain inactive 'Beta content': %q", content) + } + m.Deactivate("alpha") + }) + + t.Run("multiple active returns combined", func(t *testing.T) { + m.Activate("alpha") + m.Activate("beta") + content := m.ActiveContent() + if !contains(content, "Alpha content") || !contains(content, "Beta content") { + t.Errorf("combined content missing expected parts: %q", content) + } + }) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchString(s, substr) +} + +func searchString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/skill/types.go b/internal/skill/types.go new file mode 100644 index 0000000..1bddb80 --- /dev/null +++ b/internal/skill/types.go @@ -0,0 +1,65 @@ +package skill + +import ( + "bufio" + "strings" + + "gopkg.in/yaml.v3" +) + +// Skill represents a loadable skill definition. +type Skill struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Active bool `yaml:"-"` + Content string `yaml:"-"` // markdown body after frontmatter + Path string `yaml:"-"` // file path +} + +// parseFrontmatter extracts YAML frontmatter and markdown body from a skill file. +// Frontmatter is delimited by "---" on the first and closing lines. +func parseFrontmatter(data string) (*Skill, error) { + scanner := bufio.NewScanner(strings.NewReader(data)) + + // Check for opening "---". + if !scanner.Scan() || strings.TrimSpace(scanner.Text()) != "---" { + // No frontmatter — treat entire content as body. + return &Skill{Content: data}, nil + } + + // Read YAML lines until closing "---". + var yamlBuf strings.Builder + foundEnd := false + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "---" { + foundEnd = true + break + } + yamlBuf.WriteString(line) + yamlBuf.WriteString("\n") + } + + if !foundEnd { + // No closing delimiter — treat as body only. + return &Skill{Content: data}, nil + } + + // Parse YAML frontmatter. + s := &Skill{} + if err := yaml.Unmarshal([]byte(yamlBuf.String()), s); err != nil { + return nil, err + } + + // Remaining content is the markdown body. + var bodyBuf strings.Builder + for scanner.Scan() { + if bodyBuf.Len() > 0 { + bodyBuf.WriteString("\n") + } + bodyBuf.WriteString(scanner.Text()) + } + s.Content = strings.TrimSpace(bodyBuf.String()) + + return s, nil +} diff --git a/internal/skill/types_test.go b/internal/skill/types_test.go new file mode 100644 index 0000000..2db5372 --- /dev/null +++ b/internal/skill/types_test.go @@ -0,0 +1,78 @@ +package skill + +import "testing" + +func TestParseFrontmatter(t *testing.T) { + tests := []struct { + name string + input string + wantName string + wantDesc string + wantContent string + wantErr bool + }{ + { + name: "valid frontmatter", + input: "---\nname: test\ndescription: desc\n---\nBody content", + wantName: "test", + wantDesc: "desc", + wantContent: "Body content", + }, + { + name: "no frontmatter", + input: "Just body", + wantContent: "Just body", + }, + { + name: "missing closing delimiter", + input: "---\nname: test\nBody", + wantContent: "---\nname: test\nBody", + }, + { + name: "invalid YAML", + input: "---\n: :\n---\nbody", + wantErr: true, + }, + { + name: "empty body", + input: "---\nname: test\n---\n", + wantName: "test", + wantContent: "", + }, + { + name: "empty input", + input: "", + wantContent: "", + }, + { + name: "multiline body", + input: "---\nname: multi\n---\nline 1\nline 2\nline 3", + wantName: "multi", + wantContent: "line 1\nline 2\nline 3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + skill, err := parseFrontmatter(tt.input) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if skill.Name != tt.wantName { + t.Errorf("Name = %q, want %q", skill.Name, tt.wantName) + } + if skill.Description != tt.wantDesc { + t.Errorf("Description = %q, want %q", skill.Description, tt.wantDesc) + } + if skill.Content != tt.wantContent { + t.Errorf("Content = %q, want %q", skill.Content, tt.wantContent) + } + }) + } +} diff --git a/internal/tools/definitions.go b/internal/tools/definitions.go new file mode 100644 index 0000000..dd53d44 --- /dev/null +++ b/internal/tools/definitions.go @@ -0,0 +1,45 @@ +package tools + +import ( + "ai-agent/internal/llm" +) + +var builtinToolNames = map[string]bool{ + "grep": true, + "read": true, + "write": true, + "glob": true, + "bash": true, + "ls": true, + "find": true, + "diff": true, + "edit": true, + "mkdir": true, + "remove": true, + "copy": true, + "move": true, + "exists": true, +} + +func AllToolDefs() []llm.ToolDef { + return []llm.ToolDef{ + GrepToolDef(), + ReadToolDef(), + WriteToolDef(), + GlobToolDef(), + BashToolDef(), + LsToolDef(), + FindToolDef(), + DiffToolDef(), + EditToolDef(), + MkdirToolDef(), + RemoveToolDef(), + CopyToolDef(), + MoveToolDef(), + ExistsToolDef(), + } +} + +func IsBuiltinTool(name string) bool { + return builtinToolNames[name] +} diff --git a/internal/tools/tools.go b/internal/tools/tools.go new file mode 100644 index 0000000..65b5787 --- /dev/null +++ b/internal/tools/tools.go @@ -0,0 +1,306 @@ +package tools + +import ( + "ai-agent/internal/llm" +) + +func GrepToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "grep", + Description: "Search for a pattern in files. Use this to find code, text, or values across multiple files.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "pattern": map[string]any{ + "type": "string", + "description": "The regex pattern to search for.", + }, + "path": map[string]any{ + "type": "string", + "description": "Directory path to search in (defaults to current directory).", + }, + "include": map[string]any{ + "type": "string", + "description": "File pattern to include (e.g., '*.go', '*.ts').", + }, + "context": map[string]any{ + "type": "integer", + "description": "Number of lines of context to show around matches (default: 3).", + }, + }, + "required": []string{"pattern"}, + }, + } +} + +func ReadToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "read", + Description: "Read the contents of a file. Use this to view source code, configuration files, or any text file.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the file to read.", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of lines to read (optional).", + }, + "offset": map[string]any{ + "type": "integer", + "description": "Line number to start reading from (optional, 1-indexed).", + }, + }, + "required": []string{"path"}, + }, + } +} + +func WriteToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "write", + Description: "Write content to a file. Use this to create new files or overwrite existing ones. Creates parent directories if needed.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the file to write.", + }, + "content": map[string]any{ + "type": "string", + "description": "Content to write to the file.", + }, + }, + "required": []string{"path", "content"}, + }, + } +} + +func GlobToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "glob", + Description: "Find files matching a pattern. Use this to discover files by name patterns like '*.go', '**/*.ts', etc.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "pattern": map[string]any{ + "type": "string", + "description": "Glob pattern to match (e.g., '**/*.go', 'src/**/*.ts').", + }, + "path": map[string]any{ + "type": "string", + "description": "Directory to search in (defaults to current directory).", + }, + }, + "required": []string{"pattern"}, + }, + } +} + +func BashToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "bash", + Description: "Execute a shell command. Use this to run git, npm, go, or other command-line tools. Output is returned after completion.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "The shell command to execute.", + }, + "timeout": map[string]any{ + "type": "integer", + "description": "Timeout in seconds (default: 30, max: 120).", + }, + }, + "required": []string{"command"}, + }, + } +} + +func LsToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "ls", + Description: "List files and directories. Use this to see what's in a directory.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Directory path to list (defaults to current directory).", + }, + }, + }, + } +} + +func FindToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "find", + Description: "Find files or directories by name. Use this to locate specific files when you know all or part of the filename.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "Name or pattern to search for (supports * and ? wildcards).", + }, + "path": map[string]any{ + "type": "string", + "description": "Directory to search in (defaults to current directory).", + }, + "type": map[string]any{ + "type": "string", + "description": "Type to find: 'f' for files, 'd' for directories (default: both).", + }, + }, + "required": []string{"name"}, + }, + } +} + +func DiffToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "diff", + Description: "Show the differences between the current file content and new content. Use this to preview changes before writing.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the file to diff.", + }, + "new_content": map[string]any{ + "type": "string", + "description": "The new content to compare against the current file.", + }, + }, + "required": []string{"path", "new_content"}, + }, + } +} + +func EditToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "edit", + Description: "Apply a patch to a file. Use this to make targeted edits to specific lines without overwriting the entire file. The patch format is: @@ -start,count +new_start,new_count @@\nfollowed by lines starting with - (remove), + (add), or (context).", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the file to edit.", + }, + "patch": map[string]any{ + "type": "string", + "description": "Unified diff patch to apply. Format: @@ -start,count +new_start,new_count @@ followed by -line (remove), +line (add), or context line.", + }, + }, + "required": []string{"path", "patch"}, + }, + } +} + +func MkdirToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "mkdir", + Description: "Create one or more directories. Creates parent directories as needed.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the directory to create.", + }, + }, + "required": []string{"path"}, + }, + } +} + +func RemoveToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "remove", + Description: "Remove files or directories. Use with caution - this permanently deletes files.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to remove (file or directory).", + }, + "recursive": map[string]any{ + "type": "boolean", + "description": "Remove directories recursively (default: false).", + }, + "force": map[string]any{ + "type": "boolean", + "description": "Ignore nonexistent files (default: false).", + }, + }, + "required": []string{"path"}, + }, + } +} + +func CopyToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "copy", + Description: "Copy a file from source to destination.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "description": "Source path to copy from.", + }, + "destination": map[string]any{ + "type": "string", + "description": "Destination path to copy to.", + }, + }, + "required": []string{"source", "destination"}, + }, + } +} + +func MoveToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "move", + Description: "Move or rename a file or directory.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "description": "Source path to move from.", + }, + "destination": map[string]any{ + "type": "string", + "description": "Destination path to move to.", + }, + }, + "required": []string{"source", "destination"}, + }, + } +} + +func ExistsToolDef() llm.ToolDef { + return llm.ToolDef{ + Name: "exists", + Description: "Check if a file or directory exists and get information about it.", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to check.", + }, + }, + "required": []string{"path"}, + }, + } +} diff --git a/internal/tools/tools_test.go b/internal/tools/tools_test.go new file mode 100644 index 0000000..1531b0d --- /dev/null +++ b/internal/tools/tools_test.go @@ -0,0 +1,156 @@ +package tools + +import ( + "testing" +) + +func TestGrepToolDef(t *testing.T) { + tool := GrepToolDef() + + if tool.Name != "grep" { + t.Errorf("Name = %q, want %q", tool.Name, "grep") + } + if tool.Description == "" { + t.Error("Description should not be empty") + } + if tool.Parameters == nil { + t.Error("Parameters should not be nil") + } +} + +func TestReadToolDef(t *testing.T) { + tool := ReadToolDef() + + if tool.Name != "read" { + t.Errorf("Name = %q, want %q", tool.Name, "read") + } + props := tool.Parameters["properties"].(map[string]any) + if _, ok := props["path"]; !ok { + t.Error("should have path property") + } +} + +func TestWriteToolDef(t *testing.T) { + tool := WriteToolDef() + + if tool.Name != "write" { + t.Errorf("Name = %q, want %q", tool.Name, "write") + } + props := tool.Parameters["properties"].(map[string]any) + if _, ok := props["path"]; !ok { + t.Error("should have path property") + } + if _, ok := props["content"]; !ok { + t.Error("should have content property") + } +} + +func TestGlobToolDef(t *testing.T) { + tool := GlobToolDef() + + if tool.Name != "glob" { + t.Errorf("Name = %q, want %q", tool.Name, "glob") + } +} + +func TestBashToolDef(t *testing.T) { + tool := BashToolDef() + + if tool.Name != "bash" { + t.Errorf("Name = %q, want %q", tool.Name, "bash") + } + props := tool.Parameters["properties"].(map[string]any) + if _, ok := props["command"]; !ok { + t.Error("should have command property") + } +} + +func TestLsToolDef(t *testing.T) { + tool := LsToolDef() + + if tool.Name != "ls" { + t.Errorf("Name = %q, want %q", tool.Name, "ls") + } +} + +func TestFindToolDef(t *testing.T) { + tool := FindToolDef() + + if tool.Name != "find" { + t.Errorf("Name = %q, want %q", tool.Name, "find") + } + props := tool.Parameters["properties"].(map[string]any) + if _, ok := props["name"]; !ok { + t.Error("should have name property") + } +} + +func TestDiffToolDef(t *testing.T) { + tool := DiffToolDef() + + if tool.Name != "diff" { + t.Errorf("Name = %q, want %q", tool.Name, "diff") + } +} + +func TestEditToolDef(t *testing.T) { + tool := EditToolDef() + + if tool.Name != "edit" { + t.Errorf("Name = %q, want %q", tool.Name, "edit") + } +} + +func TestMkdirToolDef(t *testing.T) { + tool := MkdirToolDef() + + if tool.Name != "mkdir" { + t.Errorf("Name = %q, want %q", tool.Name, "mkdir") + } +} + +func TestRemoveToolDef(t *testing.T) { + tool := RemoveToolDef() + + if tool.Name != "remove" { + t.Errorf("Name = %q, want %q", tool.Name, "remove") + } + props := tool.Parameters["properties"].(map[string]any) + if _, ok := props["recursive"]; !ok { + t.Error("should have recursive property") + } + if _, ok := props["force"]; !ok { + t.Error("should have force property") + } +} + +func TestCopyToolDef(t *testing.T) { + tool := CopyToolDef() + + if tool.Name != "copy" { + t.Errorf("Name = %q, want %q", tool.Name, "copy") + } + props := tool.Parameters["properties"].(map[string]any) + if _, ok := props["source"]; !ok { + t.Error("should have source property") + } + if _, ok := props["destination"]; !ok { + t.Error("should have destination property") + } +} + +func TestMoveToolDef(t *testing.T) { + tool := MoveToolDef() + + if tool.Name != "move" { + t.Errorf("Name = %q, want %q", tool.Name, "move") + } +} + +func TestExistsToolDef(t *testing.T) { + tool := ExistsToolDef() + + if tool.Name != "exists" { + t.Errorf("Name = %q, want %q", tool.Name, "exists") + } +} diff --git a/internal/tui/RESPONSIVE_WIDTH.md b/internal/tui/RESPONSIVE_WIDTH.md new file mode 100644 index 0000000..5230879 --- /dev/null +++ b/internal/tui/RESPONSIVE_WIDTH.md @@ -0,0 +1,162 @@ +# Responsive Width Implementation + +## Overview +This document describes the responsive width calculations implemented to prevent horizontal scrolling in the TUI chat interface. + +## Width Calculation Hierarchy + +### 1. Viewport Width (Primary Constraint) +The viewport is the main container for chat content. All other widths derive from this. + +**Formula** (from `model.go:373-380`): +```go +viewportWidth := screenWidth - 1 +if sidePanel.IsVisible() { + viewportWidth = screenWidth - panelWidth - 2 +} +if viewportWidth < 20 { + viewportWidth = 20 // minimum width +} +``` + +**Breakdown**: +- `screenWidth - 1`: Full width minus right edge padding (when panel hidden) +- `screenWidth - panelWidth - 2`: Width minus panel and separator line (when panel visible) +- Minimum 20 characters to ensure readability + +### 2. Content Width (Text Wrapping) +Used for wrapping text in `renderEntries()`, `renderUserMsg()`, `renderAssistantMsg()`, etc. + +**Formula** (from `view.go:422-429`): +```go +contentW := screenWidth - 4 +if sidePanel.IsVisible() { + contentW = screenWidth - panelWidth - 5 +} +if contentW < 20 { + contentW = 20 +} +``` + +**Breakdown**: +- `screenWidth - 4`: Full width with 2-char padding on each side +- `screenWidth - panelWidth - 5`: Accounts for panel, separator, and padding +- Minimum 20 characters + +### 3. Markdown Width (Glamour Rendering) +Used for rendering markdown content via Glamour. + +**Formula** (from `model.go:382-386`): +```go +markdownWidth := viewportWidth - 3 +if markdownWidth < 20 { + markdownWidth = 20 +} +``` + +**Breakdown**: +- Derived from viewport width minus 3 chars for padding/indentation +- Minimum 20 characters + +### 4. Input Width +Matches viewport width exactly for unified appearance. + +**Formula** (from `model.go:431`): +```go +input.SetWidth(viewportWidth) +``` + +## Panel Width Calculation + +Panel width is dynamic based on screen size (from `model.go:365-371`): + +```go +panelWidth := 30 // default +if screenWidth < 100 { + panelWidth = 25 +} else if screenWidth > 160 { + panelWidth = 40 +} +``` + +## Layout Constraints + +### With Panel Visible +``` +┌─────────────────────────────────────────────────┐ +│ Panel (25-40) ││ Chat Viewport │ +│ ││ (screen - panel - 2) │ +│ ││ │ +│ ││ Content wrapped to: │ +│ ││ (screen - panel - 5) │ +└─────────────────────────────────────────────────┘ +``` + +### Without Panel +``` +┌─────────────────────────────────────────────────┐ +│ Chat Viewport (screen - 1) │ +│ │ +│ Content wrapped to: (screen - 4) │ +└─────────────────────────────────────────────────┘ +``` + +## Critical Invariants + +The following invariants are enforced to prevent horizontal scrolling: + +1. **viewportWidth ≤ screenWidth - 1** (or `screenWidth - panelWidth - 1` when panel visible) +2. **contentWidth ≤ viewportWidth** +3. **markdownWidth ≤ viewportWidth** +4. **All widths ≥ 20** (minimum readability) + +## Test Coverage + +Comprehensive tests in `width_test.go` verify: + +- `TestViewportWidthCalculation`: Validates width calculations for various screen sizes +- `TestResponsiveWidthToggle`: Ensures widths adjust correctly when panel is toggled +- `TestMinimumWidthConstraints`: Verifies minimum width enforcement on small screens +- `TestRenderedTextWidth`: Tests actual text wrapping behavior +- `TestLayoutConsistency`: Exhaustive testing across screen sizes 40-200 chars + +## Example Calculations + +### 120-char screen with panel (30 chars) +``` +Viewport: 120 - 30 - 2 = 88 chars +Content: 120 - 30 - 5 = 85 chars +Markdown: 88 - 3 = 85 chars +Input: 88 chars +Total: 30 (panel) + 1 (separator) + 88 (viewport) = 119 ✓ +``` + +### 80-char screen without panel +``` +Viewport: 80 - 1 = 79 chars +Content: 80 - 4 = 76 chars +Markdown: 79 - 3 = 76 chars +Input: 79 chars +Total: 79 chars ✓ +``` + +### 40-char screen with panel (25 chars) - Edge Case +``` +Viewport: 40 - 25 - 2 = 13 → 20 (minimum enforced) +Content: 40 - 25 - 5 = 10 → 20 (minimum enforced) +Markdown: 20 - 3 = 17 → 20 (minimum enforced) +Input: 20 chars +Total: 25 + 1 + 20 = 46 (exceeds screen, but minimum width takes priority) +``` + +**Note**: On very small screens (< 46 chars with panel), the minimum width constraints take precedence. Users should be advised to use larger terminal windows for optimal experience. + +## Responsive Behavior + +When the side panel is toggled: +1. Viewport width recalculates immediately +2. Content is re-wrapped to new width via `invalidateRenderedCache()` +3. Markdown renderer is recreated with new width +4. Input field resizes to match viewport + +This ensures seamless responsive behavior without horizontal scrolling. diff --git a/internal/tui/accessibility.go b/internal/tui/accessibility.go new file mode 100644 index 0000000..617c104 --- /dev/null +++ b/internal/tui/accessibility.go @@ -0,0 +1,240 @@ +package tui + +import ( + "fmt" + "strings" + + "charm.land/lipgloss/v2" +) + +// AccessibilityHelper provides accessibility features like screen reader support. +type AccessibilityHelper struct { + isDark bool + styles AccessibilityStyles + speakFunc func(string) // Function to speak text (for screen readers) + announceFunc func(string) // Function to announce changes +} + +// AccessibilityStyles holds styling. +type AccessibilityStyles struct { + Announce lipgloss.Style +} + +// DefaultAccessibilityStyles returns default styles. +func DefaultAccessibilityStyles(isDark bool) AccessibilityStyles { + return AccessibilityStyles{ + Announce: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")), + } +} + +// NewAccessibilityHelper creates a new accessibility helper. +func NewAccessibilityHelper(isDark bool) *AccessibilityHelper { + return &AccessibilityHelper{ + isDark: isDark, + styles: DefaultAccessibilityStyles(isDark), + } +} + +// SetDark updates theme. +func (ah *AccessibilityHelper) SetDark(isDark bool) { + ah.isDark = isDark + ah.styles = DefaultAccessibilityStyles(isDark) +} + +// SetSpeakFunc sets the function to speak text. +func (ah *AccessibilityHelper) SetSpeakFunc(f func(string)) { + ah.speakFunc = f +} + +// SetAnnounceFunc sets the function to announce changes. +func (ah *AccessibilityHelper) SetAnnounceFunc(f func(string)) { + ah.announceFunc = f +} + +// Announce announces a message to the user. +func (ah *AccessibilityHelper) Announce(format string, args ...string) { + if ah.announceFunc != nil { + msg := format + if len(args) > 0 { + msg = fmt.Sprintf(format, args) + } + ah.announceFunc(msg) + } +} + +// Speak speaks text directly. +func (ah *AccessibilityHelper) Speak(text string) { + if ah.speakFunc != nil { + ah.speakFunc(text) + } +} + +// DescribeEntry creates an accessibility description for a chat entry. +func (ah *AccessibilityHelper) DescribeEntry(entry ChatEntry, index int, toolCount int) string { + var desc strings.Builder + + switch entry.Kind { + case "user": + desc.WriteString("User message") + case "assistant": + desc.WriteString("Assistant response") + if entry.ThinkingContent != "" { + desc.WriteString(", has thinking") + } + case "tool_group": + desc.WriteString("Tool execution") + if index >= 0 && index < toolCount { + desc.WriteString(", tool result") + } + case "system": + desc.WriteString("System message") + case "error": + desc.WriteString("Error") + } + + // Add content preview + if entry.Content != "" { + preview := truncateStr(entry.Content, 50) + desc.WriteString(": ") + desc.WriteString(preview) + } + + return desc.String() +} + +// DescribeState creates an accessibility description of the current state. +func (ah *AccessibilityHelper) DescribeState(state State, model, mode string) string { + var desc string + + switch state { + case StateIdle: + desc = "Ready" + case StateWaiting: + desc = "Waiting for response" + case StateStreaming: + desc = "Receiving response" + } + + if model != "" { + desc += ", model: " + model + } + if mode != "" { + desc += ", mode: " + mode + } + + return desc +} + +// DescribeOverlay creates an accessibility description of the current overlay. +func (ah *AccessibilityHelper) DescribeOverlay(overlay OverlayKind) string { + switch overlay { + case OverlayNone: + return "" + case OverlayHelp: + return "Help overlay open" + case OverlayCompletion: + return "Completion menu open" + case OverlayModelPicker: + return "Model picker open" + case OverlayPlanForm: + return "Plan form open" + case OverlaySessionsPicker: + return "Sessions picker open" + default: + return "Overlay open" + } +} + +// DescribeTools creates an accessibility description of tool status. +func (ah *AccessibilityHelper) DescribeTools(pending, total int) string { + if pending == 0 && total == 0 { + return "No tools running" + } + if pending > 0 { + return fmt.Sprintf("%d tool running", pending) + } + return fmt.Sprintf("%d tools completed", total) +} + +// truncate truncates a string to maxLength. +func truncateStr(s string, maxLength int) string { + if len(s) <= maxLength { + return s + } + return s[:maxLength-3] + "..." +} + +// AccessibilityLabel returns an accessibility label for a view element. +func AccessibilityLabel(role, name string, props ...string) string { + var b strings.Builder + b.WriteString(role) + b.WriteString(": ") + b.WriteString(name) + + for _, p := range props { + b.WriteString(", ") + b.WriteString(p) + } + + return b.String() +} + +// FocusOrder represents the focus order for keyboard navigation. +type FocusOrder struct { + Current int + Items []Focusable +} + +// Focusable is an interface for focusable elements. +type Focusable interface { + Focus() error + Blur() error + IsFocused() bool +} + +// NewFocusOrder creates a new focus order. +func NewFocusOrder(items []Focusable) *FocusOrder { + return &FocusOrder{ + Current: 0, + Items: items, + } +} + +// Next moves focus to the next item. +func (fo *FocusOrder) Next() { + if len(fo.Items) == 0 { + return + } + fo.Current = (fo.Current + 1) % len(fo.Items) + fo.focusCurrent() +} + +// Prev moves focus to the previous item. +func (fo *FocusOrder) Prev() { + if len(fo.Items) == 0 { + return + } + fo.Current-- + if fo.Current < 0 { + fo.Current = len(fo.Items) - 1 + } + fo.focusCurrent() +} + +// Current returns the currently focused item. +func (fo *FocusOrder) CurrentItem() Focusable { + if fo.Current >= 0 && fo.Current < len(fo.Items) { + return fo.Items[fo.Current] + } + return nil +} + +func (fo *FocusOrder) focusCurrent() { + for i, item := range fo.Items { + if i == fo.Current { + item.Focus() + } else { + item.Blur() + } + } +} diff --git a/internal/tui/adapter.go b/internal/tui/adapter.go new file mode 100644 index 0000000..29fdba5 --- /dev/null +++ b/internal/tui/adapter.go @@ -0,0 +1,50 @@ +package tui + +import ( + "time" + + tea "charm.land/bubbletea/v2" +) + +// Adapter bridges the agent.Output interface to BubbleTea messages. +type Adapter struct { + program *tea.Program +} + +// NewAdapter creates an Adapter that sends messages to the given program. +func NewAdapter(p *tea.Program) *Adapter { + return &Adapter{program: p} +} + +func (a *Adapter) StreamText(text string) { + sendMsg(a.program, StreamTextMsg{Text: text}) +} + +func (a *Adapter) StreamDone(evalCount, promptTokens int) { + sendMsg(a.program, StreamDoneMsg{EvalCount: evalCount, PromptTokens: promptTokens}) +} + +func (a *Adapter) ToolCallStart(name string, args map[string]any) { + sendMsg(a.program, ToolCallStartMsg{Name: name, Args: args, StartTime: time.Now()}) +} + +func (a *Adapter) ToolCallResult(name string, result string, isError bool, duration time.Duration) { + sendMsg(a.program, ToolCallResultMsg{Name: name, Result: result, IsError: isError, Duration: duration}) +} + +func (a *Adapter) SystemMessage(msg string) { + sendMsg(a.program, SystemMessageMsg{Msg: msg}) +} + +func (a *Adapter) Error(msg string) { + // Log error for debugging + if len(msg) > 100 { + msg = msg[:97] + "..." + } + sendMsg(a.program, ErrorMsg{Msg: msg}) +} + +// Done sends the final completion message. +func (a *Adapter) Done() { + sendMsg(a.program, AgentDoneMsg{}) +} diff --git a/internal/tui/clipboard_test.go b/internal/tui/clipboard_test.go new file mode 100644 index 0000000..fd0727c --- /dev/null +++ b/internal/tui/clipboard_test.go @@ -0,0 +1,127 @@ +package tui + +import ( + "testing" +) + +func TestLastAssistantContent(t *testing.T) { + t.Run("found", func(t *testing.T) { + m := newTestModel(t) + m.entries = []ChatEntry{ + {Kind: "user", Content: "hello"}, + {Kind: "assistant", Content: "world"}, + } + got := m.lastAssistantContent() + if got != "world" { + t.Errorf("expected 'world', got %q", got) + } + }) + + t.Run("not_found", func(t *testing.T) { + m := newTestModel(t) + m.entries = []ChatEntry{ + {Kind: "user", Content: "hello"}, + {Kind: "system", Content: "info"}, + } + got := m.lastAssistantContent() + if got != "" { + t.Errorf("expected empty string, got %q", got) + } + }) + + t.Run("returns_last", func(t *testing.T) { + m := newTestModel(t) + m.entries = []ChatEntry{ + {Kind: "assistant", Content: "first"}, + {Kind: "user", Content: "question"}, + {Kind: "assistant", Content: "second"}, + } + got := m.lastAssistantContent() + if got != "second" { + t.Errorf("expected 'second', got %q", got) + } + }) + + t.Run("empty_entries", func(t *testing.T) { + m := newTestModel(t) + m.entries = nil + got := m.lastAssistantContent() + if got != "" { + t.Errorf("expected empty string, got %q", got) + } + }) +} + +func TestCopyLast_OnlyWhenIdleAndEmpty(t *testing.T) { + t.Run("idle_empty_with_assistant", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + m.entries = []ChatEntry{ + {Kind: "assistant", Content: "response text"}, + } + m.input.SetValue("") + + _, cmd := m.Update(ctrlKey('y')) + if cmd == nil { + t.Error("expected a command to be returned for copy") + } + }) + + t.Run("non_empty_input_no_trigger", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + m.entries = []ChatEntry{ + {Kind: "assistant", Content: "response text"}, + } + m.input.SetValue("some text") + + _, cmd := m.Update(ctrlKey('y')) + // When input is non-empty, ctrl+y should not trigger copy. + // The cmd may be non-nil (textarea update), but no copy should occur. + // Verify no system message about clipboard appears. + if cmd != nil { + msg := cmd() + if sysMsg, ok := msg.(SystemMessageMsg); ok { + if sysMsg.Msg == "Copied to clipboard." { + t.Error("should not trigger copy when input is non-empty") + } + } + } + }) + + t.Run("non_idle_no_trigger", func(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + m.entries = []ChatEntry{ + {Kind: "assistant", Content: "response text"}, + } + m.input.SetValue("") + + initialEntryCount := len(m.entries) + m.Update(ctrlKey('y')) + // Should not add any system message about clipboard + if len(m.entries) > initialEntryCount { + t.Error("should not trigger copy when not idle") + } + }) + + t.Run("no_assistant_entries", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + m.entries = []ChatEntry{ + {Kind: "user", Content: "hello"}, + } + m.input.SetValue("") + + _, cmd := m.Update(ctrlKey('y')) + // Should not return a copy command when there's no assistant content + if cmd != nil { + msg := cmd() + if sysMsg, ok := msg.(SystemMessageMsg); ok { + if sysMsg.Msg == "Copied to clipboard." { + t.Error("should not trigger copy when no assistant content") + } + } + } + }) +} diff --git a/internal/tui/commit.go b/internal/tui/commit.go new file mode 100644 index 0000000..bc49d8d --- /dev/null +++ b/internal/tui/commit.go @@ -0,0 +1,79 @@ +package tui + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + + "ai-agent/internal/llm" + + tea "charm.land/bubbletea/v2" +) + +func runCommit(client llm.Client, model string, extraMsg string) tea.Cmd { + return func() tea.Msg { + diff, err := gitDiff() + if err != nil { + return CommitResultMsg{Err: fmt.Errorf("git diff: %w", err)} + } + if strings.TrimSpace(diff) == "" { + return CommitResultMsg{Err: fmt.Errorf("no staged changes (use `git add` first)")} + } + if len(diff) > 8000 { + diff = diff[:8000] + "\n... (truncated)" + } + prompt := "Write a concise git commit message for the following staged diff. " + + "Return ONLY the commit message, no explanation or markdown. " + + "Use conventional commit style (e.g. feat:, fix:, refactor:). " + + "Keep the first line under 72 characters." + if extraMsg != "" { + prompt += "\n\nAdditional context: " + extraMsg + } + prompt += "\n\nDiff:\n" + diff + var msgBuf strings.Builder + err = client.ChatStream(context.Background(), llm.ChatOptions{ + Messages: []llm.Message{{Role: "user", Content: prompt}}, + System: "You are a helpful assistant that writes git commit messages.", + }, func(chunk llm.StreamChunk) error { + if chunk.Text != "" { + msgBuf.WriteString(chunk.Text) + } + return nil + }) + if err != nil { + return CommitResultMsg{Err: fmt.Errorf("LLM error: %w", err)} + } + commitMsg := strings.TrimSpace(msgBuf.String()) + if commitMsg == "" { + return CommitResultMsg{Err: fmt.Errorf("LLM returned empty commit message")} + } + commitMsg += fmt.Sprintf("\n\nAssisted-by: ai-agent (%s)", model) + if err := gitCommit(commitMsg); err != nil { + return CommitResultMsg{Err: fmt.Errorf("git commit: %w", err)} + } + return CommitResultMsg{Message: commitMsg} + } +} + +func gitDiff() (string, error) { + cmd := exec.Command("git", "diff", "--cached", "--stat") + stat, _ := cmd.Output() + cmd = exec.Command("git", "diff", "--cached") + out, err := cmd.Output() + if err != nil { + return "", err + } + return string(stat) + "\n" + string(out), nil +} + +func gitCommit(msg string) error { + cmd := exec.Command("git", "commit", "-m", msg) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("%s: %s", err, stderr.String()) + } + return nil +} diff --git a/internal/tui/complete.go b/internal/tui/complete.go new file mode 100644 index 0000000..0fd233e --- /dev/null +++ b/internal/tui/complete.go @@ -0,0 +1,339 @@ +package tui + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + + "ai-agent/internal/command" + "ai-agent/internal/config" + "ai-agent/internal/mcp" +) + +type Completion struct { + Label string + Insert string + Category string + Description string + Index int +} + +type Completer struct { + commands []*command.Command + models []string + skills []string + agents []string + workDir string + registry *mcp.Registry + ignorePatterns *config.IgnorePatterns +} + +func NewCompleter(cmdReg *command.Registry, models, skills, agents []string, registry *mcp.Registry) *Completer { + workDir, _ := os.Getwd() + return &Completer{ + commands: cmdReg.All(), + models: models, + skills: skills, + agents: agents, + workDir: workDir, + registry: registry, + } +} + +func (c *Completer) Complete(input string) []Completion { + var completions []Completion + + if strings.HasPrefix(input, "/") { + completions = c.completeCommand(input) + } else if strings.HasPrefix(input, "@") { + completions = c.completeAgentOrFile(input) + } else if strings.HasPrefix(input, "#") { + completions = c.completeSkill(input) + } + + return completions +} + +func (c *Completer) completeCommand(input string) []Completion { + var completions []Completion + input = strings.TrimPrefix(input, "/") + + for _, cmd := range c.commands { + if strings.HasPrefix(cmd.Name, input) { + comp := Completion{ + Label: "/" + cmd.Name, + Insert: "/" + cmd.Name + " ", + Category: "command", + } + if cmd.Usage != "" { + parts := strings.Fields(cmd.Usage) + if len(parts) > 1 { + comp.Label = "/" + cmd.Name + " " + parts[1] + } + } + completions = append(completions, comp) + } + + for _, alias := range cmd.Aliases { + if strings.HasPrefix(alias, input) { + completions = append(completions, Completion{ + Label: "/" + alias, + Insert: "/" + alias + " ", + Category: "command", + }) + } + } + } + + return completions +} + +func (c *Completer) completeAgentOrFile(input string) []Completion { + var completions []Completion + input = strings.TrimPrefix(input, "@") + + // Always show agents first + for _, agent := range c.agents { + if strings.HasPrefix(agent, input) { + completions = append(completions, Completion{ + Label: "@" + agent, + Insert: "@" + agent + " ", + Category: "agent", + }) + } + } + + // Always append file results (not just when no agents match) + completions = append(completions, c.completeFile(input)...) + + return completions +} + +func (c *Completer) completeFile(input string) []Completion { + var completions []Completion + + // Determine the directory to list + dir := c.workDir + if strings.Contains(input, "/") { + // User is typing a path + lastSlash := strings.LastIndex(input, "/") + dirPart := input[:lastSlash] + if !strings.HasPrefix(dirPart, "/") { + dirPart = filepath.Join(c.workDir, dirPart) + } + if info, err := os.Stat(dirPart); err == nil && info.IsDir() { + dir = dirPart + } + } + + // Read directory entries + entries, err := os.ReadDir(dir) + if err != nil { + return completions + } + + prefix := input + if strings.Contains(input, "/") { + prefix = input[strings.LastIndex(input, "/")+1:] + } + + for _, entry := range entries { + name := entry.Name() + // Skip hidden files unless user explicitly types . + if strings.HasPrefix(name, ".") && !strings.HasPrefix(prefix, ".") { + continue + } + // Skip entries matching ignore patterns. + if c.ignorePatterns.Match(name) { + continue + } + + if strings.HasPrefix(name, prefix) { + isDir := entry.IsDir() + displayName := name + insertName := name + + if isDir { + displayName += "/" + insertName += "/" + } + + // Build full path relative to input + if strings.Contains(input, "/") { + dirPath := input[:strings.LastIndex(input, "/")+1] + displayName = dirPath + displayName + insertName = dirPath + insertName + } else if dir != c.workDir { + relPath, _ := filepath.Rel(c.workDir, dir) + if relPath != "." { + displayName = relPath + "/" + name + if isDir { + displayName += "/" + } else { + insertName = relPath + "/" + insertName + } + } + } + + category := "file" + if isDir { + category = "folder" + } + + completions = append(completions, Completion{ + Label: "@" + displayName, + Insert: "@" + insertName + " ", + Category: category, + }) + } + } + + return completions +} + +// CompleteFilePath lists directory contents at a given relative path. +// Used for folder drill-down in the completion modal. +func (c *Completer) CompleteFilePath(relPath string) []Completion { + var completions []Completion + + dir := filepath.Join(c.workDir, relPath) + entries, err := os.ReadDir(dir) + if err != nil { + return completions + } + + for _, entry := range entries { + name := entry.Name() + if strings.HasPrefix(name, ".") { + continue + } + // Skip entries matching ignore patterns. + if c.ignorePatterns.Match(name) { + continue + } + + isDir := entry.IsDir() + displayName := name + insertPath := relPath + if insertPath != "" && !strings.HasSuffix(insertPath, "/") { + insertPath += "/" + } + insertPath += name + + if isDir { + displayName += "/" + } + + category := "file" + if isDir { + category = "folder" + } + + completions = append(completions, Completion{ + Label: displayName, + Insert: "@" + insertPath + " ", + Category: category, + }) + } + + return completions +} + +func (c *Completer) completeSkill(input string) []Completion { + var completions []Completion + input = strings.TrimPrefix(input, "#") + + for _, skill := range c.skills { + if strings.HasPrefix(skill, input) { + completions = append(completions, Completion{ + Label: "#" + skill, + Insert: "#" + skill + " ", + Category: "skill", + }) + } + } + + return completions +} + +// FilterCompletions filters completions by case-insensitive substring match on Label. +func FilterCompletions(items []Completion, query string) []Completion { + if query == "" { + return items + } + q := strings.ToLower(query) + var filtered []Completion + for _, item := range items { + if strings.Contains(strings.ToLower(item.Label), q) { + filtered = append(filtered, item) + } + } + return filtered +} + +// SearchFiles performs an async vecgrep search via the MCP registry. +func (c *Completer) SearchFiles(ctx context.Context, query string) []Completion { + if c.registry == nil || query == "" { + return nil + } + + result, err := c.registry.CallTool(ctx, "vecgrep_search", map[string]any{ + "query": query, + "limit": 10, + }) + if err != nil { + return nil + } + + var results []Completion + // Parse the result content as JSON array of file paths or objects + var searchResults []struct { + Path string `json:"path"` + Score float64 `json:"score"` + } + if err := json.Unmarshal([]byte(result.Content), &searchResults); err != nil { + // Try as simple string lines + for _, line := range strings.Split(result.Content, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + results = append(results, Completion{ + Label: "@" + line, + Insert: "@" + line + " ", + Category: "search_result", + Description: "vecgrep match", + }) + } + return results + } + + for _, sr := range searchResults { + results = append(results, Completion{ + Label: "@" + sr.Path, + Insert: "@" + sr.Path + " ", + Category: "search_result", + Description: "vecgrep match", + }) + } + return results +} + +func (c *Completer) UpdateModels(models []string) { + c.models = models +} + +func (c *Completer) UpdateSkills(skills []string) { + c.skills = skills +} + +func (c *Completer) UpdateAgents(agents []string) { + c.agents = agents +} + +// SetIgnorePatterns sets the ignore patterns used to filter file completions. +func (c *Completer) SetIgnorePatterns(patterns *config.IgnorePatterns) { + c.ignorePatterns = patterns +} diff --git a/internal/tui/complete_test.go b/internal/tui/complete_test.go new file mode 100644 index 0000000..76b6f86 --- /dev/null +++ b/internal/tui/complete_test.go @@ -0,0 +1,169 @@ +package tui + +import ( + "testing" + + "ai-agent/internal/command" +) + +func TestCompleter_Complete(t *testing.T) { + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + c := NewCompleter(reg, []string{"model-a"}, []string{"skill-a", "skill-b"}, []string{"agent-x"}, nil) + + t.Run("slash_dispatches_to_commands", func(t *testing.T) { + results := c.Complete("/h") + if len(results) == 0 { + t.Error("expected command completions for /h") + } + for _, r := range results { + if r.Category != "command" { + t.Errorf("expected category 'command', got %q", r.Category) + } + } + }) + + t.Run("at_dispatches_to_agents", func(t *testing.T) { + results := c.Complete("@agent") + found := false + for _, r := range results { + if r.Category == "agent" { + found = true + } + } + if !found { + t.Error("expected agent completions for @agent") + } + }) + + t.Run("hash_dispatches_to_skills", func(t *testing.T) { + results := c.Complete("#skill") + if len(results) == 0 { + t.Error("expected skill completions for #skill") + } + for _, r := range results { + if r.Category != "skill" { + t.Errorf("expected category 'skill', got %q", r.Category) + } + } + }) + + t.Run("plain_returns_nothing", func(t *testing.T) { + results := c.Complete("hello") + if len(results) != 0 { + t.Errorf("expected no completions for plain text, got %d", len(results)) + } + }) +} + +func TestCompleteCommand(t *testing.T) { + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + c := NewCompleter(reg, nil, nil, nil, nil) + + t.Run("prefix_matching", func(t *testing.T) { + results := c.Complete("/hel") + found := false + for _, r := range results { + if r.Insert == "/help " { + found = true + } + } + if !found { + t.Error("expected /help completion for prefix /hel") + } + }) + + t.Run("alias_matching", func(t *testing.T) { + // /h is an alias for /help + results := c.Complete("/h") + if len(results) == 0 { + t.Error("expected completions for /h (alias)") + } + }) + + t.Run("usage_suffix_in_label", func(t *testing.T) { + // /model has Usage: "/model [name|list|fast|smart]" + results := c.Complete("/model") + for _, r := range results { + if r.Insert == "/model " { + // The label should include usage args from the Usage field. + if r.Label == "/model" { + // Label should have usage suffix if Usage has args. + // Actually, let's check what the code does: + // The code checks if cmd.Usage has >1 field. + // "/model [name|list|fast|smart]" -> fields: ["/model", "[name|list|fast|smart]"] + // So label should be "/model [name|list|fast|smart]" + t.Error("label should include usage args") + } + } + } + }) + + t.Run("no_matches", func(t *testing.T) { + results := c.Complete("/zzzzz") + if len(results) != 0 { + t.Errorf("expected no completions for /zzzzz, got %d", len(results)) + } + }) +} + +func TestCompleteSkill(t *testing.T) { + reg := command.NewRegistry() + c := NewCompleter(reg, nil, []string{"coding", "writing", "debugging"}, nil, nil) + + t.Run("prefix_matching", func(t *testing.T) { + results := c.Complete("#cod") + if len(results) != 1 { + t.Fatalf("expected 1 match for #cod, got %d", len(results)) + } + if results[0].Label != "#coding" { + t.Errorf("expected '#coding', got %q", results[0].Label) + } + if results[0].Category != "skill" { + t.Errorf("expected category 'skill', got %q", results[0].Category) + } + }) + + t.Run("all_match_empty_prefix", func(t *testing.T) { + results := c.Complete("#") + if len(results) != 3 { + t.Errorf("expected 3 matches for #, got %d", len(results)) + } + }) + + t.Run("no_matches", func(t *testing.T) { + results := c.Complete("#zzz") + if len(results) != 0 { + t.Errorf("expected no matches for #zzz, got %d", len(results)) + } + }) +} + +func TestCompleterUpdateModels(t *testing.T) { + reg := command.NewRegistry() + c := NewCompleter(reg, []string{"old-model"}, nil, nil, nil) + + c.UpdateModels([]string{"new-model-a", "new-model-b"}) + + if len(c.models) != 2 { + t.Errorf("expected 2 models, got %d", len(c.models)) + } + if c.models[0] != "new-model-a" { + t.Errorf("expected 'new-model-a', got %q", c.models[0]) + } +} + +func TestCompleterUpdateAgents(t *testing.T) { + reg := command.NewRegistry() + c := NewCompleter(reg, nil, nil, []string{"old-agent"}, nil) + + c.UpdateAgents([]string{"new-agent"}) + + if len(c.agents) != 1 { + t.Errorf("expected 1 agent, got %d", len(c.agents)) + } + if c.agents[0] != "new-agent" { + t.Errorf("expected 'new-agent', got %q", c.agents[0]) + } +} diff --git a/internal/tui/contextmenu.go b/internal/tui/contextmenu.go new file mode 100644 index 0000000..ff66fc1 --- /dev/null +++ b/internal/tui/contextmenu.go @@ -0,0 +1,133 @@ +package tui + +import ( + "charm.land/lipgloss/v2" +) + +// ContextMenuItem represents an item in a context menu. +type ContextMenuItem struct { + Label string + Action string + Shortcut string +} + +// ContextMenuState holds the state for a context menu. +type ContextMenuState struct { + X, Y int + Items []ContextMenuItem + Selected int + Active bool + isDark bool + styles ContextMenuStyles +} + +// ContextMenuStyles holds styling for context menus. +type ContextMenuStyles struct { + Item lipgloss.Style + Selected lipgloss.Style + Shortcut lipgloss.Style + Border lipgloss.Style +} + +// DefaultContextMenuStyles returns default styles. +func DefaultContextMenuStyles(isDark bool) ContextMenuStyles { + if isDark { + return ContextMenuStyles{ + Item: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")).Background(lipgloss.Color("#3b4252")), + Shortcut: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Border: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + } + } + return ContextMenuStyles{ + Item: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#4f8f8f")).Background(lipgloss.Color("#e5e9f0")), + Shortcut: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + Border: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + } +} + +// NewContextMenuState creates a new context menu state. +func NewContextMenuState(items []ContextMenuItem, x, y int, isDark bool) *ContextMenuState { + return &ContextMenuState{ + X: x, + Y: y, + Items: items, + Selected: 0, + Active: true, + isDark: isDark, + styles: DefaultContextMenuStyles(isDark), + } +} + +// Activate shows the context menu at position. +func (cm *ContextMenuState) Activate(x, y int, items []ContextMenuItem) { + cm.X = x + cm.Y = y + cm.Items = items + cm.Selected = 0 + cm.Active = true +} + +// Deactivate hides the context menu. +func (cm *ContextMenuState) Deactivate() { + cm.Active = false +} + +// IsActive returns true if the menu is visible. +func (cm *ContextMenuState) IsActive() bool { + return cm.Active +} + +// SelectedAction returns the action of the selected item. +func (cm *ContextMenuState) SelectedAction() string { + if cm.Selected >= 0 && cm.Selected < len(cm.Items) { + return cm.Items[cm.Selected].Action + } + return "" +} + +// MoveUp selects the previous item. +func (cm *ContextMenuState) MoveUp() { + if cm.Selected > 0 { + cm.Selected-- + } +} + +// MoveDown selects the next item. +func (cm *ContextMenuState) MoveDown() { + if cm.Selected < len(cm.Items)-1 { + cm.Selected++ + } +} + +// Render returns the context menu view. +func (cm *ContextMenuState) Render(width int) string { + if !cm.Active { + return "" + } + + styles := DefaultContextMenuStyles(cm.isDark) + + var b string + for i, item := range cm.Items { + row := " " + item.Label + if item.Shortcut != "" { + row += " " + styles.Shortcut.Render(item.Shortcut) + } + + if i == cm.Selected { + b += styles.Selected.Render(row) + "\n" + } else { + b += styles.Item.Render(row) + "\n" + } + } + + // Wrap in border + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("#4c566a")). + Padding(0, 1) + + return box.Render(b) +} diff --git a/internal/tui/diff.go b/internal/tui/diff.go new file mode 100644 index 0000000..1490be6 --- /dev/null +++ b/internal/tui/diff.go @@ -0,0 +1,196 @@ +package tui + +import ( + "fmt" + "os" + "strings" +) + +// DiffLineKind represents the type of a diff line. +type DiffLineKind int + +const ( + DiffContext DiffLineKind = iota + DiffAdded + DiffRemoved +) + +// DiffLine is a single line in a unified diff. +type DiffLine struct { + Kind DiffLineKind + Content string +} + +// readFileForDiff extracts a file path from tool args and reads its content. +func readFileForDiff(rawArgs map[string]any) string { + for _, key := range []string{"path", "file_path", "filename", "file"} { + if p, ok := rawArgs[key].(string); ok { + data, err := os.ReadFile(p) + if err != nil { + return "" + } + return string(data) + } + } + return "" +} + +// computeDiff computes a line-level diff between before and after text. +// Returns nil if the texts are identical. +func computeDiff(before, after string) []DiffLine { + if before == after { + return nil + } + + beforeLines := splitLines(before) + afterLines := splitLines(after) + + lcs := lcsLines(beforeLines, afterLines) + + var all []DiffLine + bi, ai, li := 0, 0, 0 + + for li < len(lcs) { + for bi < len(beforeLines) && beforeLines[bi] != lcs[li] { + all = append(all, DiffLine{DiffRemoved, beforeLines[bi]}) + bi++ + } + for ai < len(afterLines) && afterLines[ai] != lcs[li] { + all = append(all, DiffLine{DiffAdded, afterLines[ai]}) + ai++ + } + all = append(all, DiffLine{DiffContext, lcs[li]}) + bi++ + ai++ + li++ + } + for bi < len(beforeLines) { + all = append(all, DiffLine{DiffRemoved, beforeLines[bi]}) + bi++ + } + for ai < len(afterLines) { + all = append(all, DiffLine{DiffAdded, afterLines[ai]}) + ai++ + } + + return filterContext(all, 3) +} + +// renderDiff renders diff lines with styles, capping output at maxLines. +func renderDiff(lines []DiffLine, styles Styles, maxLines int) string { + if len(lines) == 0 { + return "" + } + + var b strings.Builder + displayed := 0 + + for _, line := range lines { + if maxLines > 0 && displayed >= maxLines { + b.WriteString(styles.DiffHeader.Render(fmt.Sprintf(" ... %d more lines", len(lines)-displayed))) + b.WriteString("\n") + break + } + + switch line.Kind { + case DiffAdded: + b.WriteString(styles.DiffAdded.Render("+ " + line.Content)) + case DiffRemoved: + b.WriteString(styles.DiffRemoved.Render("- " + line.Content)) + case DiffContext: + b.WriteString(styles.DiffContext.Render(" " + line.Content)) + } + b.WriteString("\n") + displayed++ + } + + return b.String() +} + +// lcsLines computes the longest common subsequence of two string slices. +func lcsLines(a, b []string) []string { + m, n := len(a), len(b) + if m == 0 || n == 0 { + return nil + } + + dp := make([][]int, m+1) + for i := range dp { + dp[i] = make([]int, n+1) + } + + for i := 1; i <= m; i++ { + for j := 1; j <= n; j++ { + if a[i-1] == b[j-1] { + dp[i][j] = dp[i-1][j-1] + 1 + } else if dp[i-1][j] >= dp[i][j-1] { + dp[i][j] = dp[i-1][j] + } else { + dp[i][j] = dp[i][j-1] + } + } + } + + result := make([]string, dp[m][n]) + k := dp[m][n] - 1 + i, j := m, n + for i > 0 && j > 0 { + if a[i-1] == b[j-1] { + result[k] = a[i-1] + k-- + i-- + j-- + } else if dp[i-1][j] >= dp[i][j-1] { + i-- + } else { + j-- + } + } + + return result +} + +// filterContext keeps only diff lines near changes, with contextLines of context. +func filterContext(lines []DiffLine, contextLines int) []DiffLine { + if len(lines) == 0 { + return nil + } + + keep := make([]bool, len(lines)) + for i, line := range lines { + if line.Kind != DiffContext { + lo := i - contextLines + if lo < 0 { + lo = 0 + } + hi := i + contextLines + if hi >= len(lines) { + hi = len(lines) - 1 + } + for j := lo; j <= hi; j++ { + keep[j] = true + } + } + } + + var result []DiffLine + for i, line := range lines { + if keep[i] { + result = append(result, line) + } + } + + return result +} + +// splitLines splits text into lines, removing a trailing empty line from a trailing newline. +func splitLines(s string) []string { + if s == "" { + return nil + } + lines := strings.Split(s, "\n") + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + return lines +} diff --git a/internal/tui/diff_test.go b/internal/tui/diff_test.go new file mode 100644 index 0000000..c3f67c4 --- /dev/null +++ b/internal/tui/diff_test.go @@ -0,0 +1,187 @@ +package tui + +import ( + "strings" + "testing" +) + +func TestComputeDiff_Identical(t *testing.T) { + result := computeDiff("hello\nworld\n", "hello\nworld\n") + if result != nil { + t.Errorf("identical texts should return nil, got %d lines", len(result)) + } +} + +func TestComputeDiff_EmptyBefore(t *testing.T) { + result := computeDiff("", "line1\nline2\n") + if len(result) == 0 { + t.Fatal("expected diff lines for new file") + } + for _, line := range result { + if line.Kind != DiffAdded { + t.Errorf("new file should have only added lines, got kind %d", line.Kind) + } + } +} + +func TestComputeDiff_EmptyAfter(t *testing.T) { + result := computeDiff("line1\nline2\n", "") + if len(result) == 0 { + t.Fatal("expected diff lines for deleted file") + } + for _, line := range result { + if line.Kind != DiffRemoved { + t.Errorf("deleted file should have only removed lines, got kind %d", line.Kind) + } + } +} + +func TestComputeDiff_Modification(t *testing.T) { + before := "line1\nline2\nline3\n" + after := "line1\nline2-modified\nline3\n" + result := computeDiff(before, after) + + if len(result) == 0 { + t.Fatal("expected diff lines for modification") + } + + // Should contain removed and added lines. + var hasAdded, hasRemoved, hasContext bool + for _, line := range result { + switch line.Kind { + case DiffAdded: + hasAdded = true + if line.Content != "line2-modified" { + t.Errorf("added line should be 'line2-modified', got %q", line.Content) + } + case DiffRemoved: + hasRemoved = true + if line.Content != "line2" { + t.Errorf("removed line should be 'line2', got %q", line.Content) + } + case DiffContext: + hasContext = true + } + } + if !hasAdded || !hasRemoved { + t.Error("modification should produce both added and removed lines") + } + if !hasContext { + t.Error("modification should have context lines") + } +} + +func TestComputeDiff_ContextLimiting(t *testing.T) { + // Create a file with many lines and a change in the middle. + var before, after strings.Builder + for i := 0; i < 50; i++ { + before.WriteString("line" + strings.Repeat("x", i) + "\n") + after.WriteString("line" + strings.Repeat("x", i) + "\n") + } + // Change line 25 + beforeStr := strings.Replace(before.String(), "line"+strings.Repeat("x", 25), "CHANGED", 1) + afterStr := strings.Replace(after.String(), "line"+strings.Repeat("x", 25), "MODIFIED", 1) + + result := computeDiff(beforeStr, afterStr) + if len(result) == 0 { + t.Fatal("expected diff lines") + } + + // Should not include all 50 lines — context filtering should limit output. + if len(result) > 20 { + t.Errorf("context limiting should reduce output, got %d lines", len(result)) + } +} + +func TestFilterContext_EmptyInput(t *testing.T) { + result := filterContext(nil, 3) + if result != nil { + t.Errorf("empty input should return nil, got %d lines", len(result)) + } +} + +func TestFilterContext_AllChanges(t *testing.T) { + lines := []DiffLine{ + {DiffAdded, "a"}, + {DiffAdded, "b"}, + {DiffRemoved, "c"}, + } + result := filterContext(lines, 3) + if len(result) != 3 { + t.Errorf("all changes should be kept, got %d lines", len(result)) + } +} + +func TestSplitLines_Empty(t *testing.T) { + result := splitLines("") + if result != nil { + t.Errorf("empty string should return nil, got %v", result) + } +} + +func TestSplitLines_TrailingNewline(t *testing.T) { + result := splitLines("a\nb\n") + if len(result) != 2 { + t.Errorf("should have 2 lines, got %d: %v", len(result), result) + } +} + +func TestLcsLines_Empty(t *testing.T) { + result := lcsLines(nil, []string{"a"}) + if result != nil { + t.Errorf("LCS with empty input should be nil, got %v", result) + } +} + +func TestLcsLines_Basic(t *testing.T) { + a := []string{"a", "b", "c", "d"} + b := []string{"a", "c", "d", "e"} + lcs := lcsLines(a, b) + expected := []string{"a", "c", "d"} + if len(lcs) != len(expected) { + t.Fatalf("LCS length mismatch: got %v, want %v", lcs, expected) + } + for i, v := range lcs { + if v != expected[i] { + t.Errorf("LCS[%d] = %q, want %q", i, v, expected[i]) + } + } +} + +func TestRenderDiff_Empty(t *testing.T) { + s := NewStyles(true) + result := renderDiff(nil, s, 10) + if result != "" { + t.Errorf("empty diff should render empty, got %q", result) + } +} + +func TestRenderDiff_MaxLines(t *testing.T) { + lines := []DiffLine{ + {DiffAdded, "a"}, + {DiffAdded, "b"}, + {DiffAdded, "c"}, + {DiffAdded, "d"}, + {DiffAdded, "e"}, + } + s := NewStyles(true) + result := renderDiff(lines, s, 3) + // Should contain "more lines" indicator. + if !strings.Contains(result, "more lines") { + t.Error("should show 'more lines' when truncating") + } +} + +func TestReadFileForDiff_NoArgs(t *testing.T) { + result := readFileForDiff(nil) + if result != "" { + t.Errorf("nil args should return empty, got %q", result) + } +} + +func TestReadFileForDiff_NonexistentFile(t *testing.T) { + result := readFileForDiff(map[string]any{"path": "/nonexistent/file/path"}) + if result != "" { + t.Errorf("nonexistent file should return empty, got %q", result) + } +} diff --git a/internal/tui/help.go b/internal/tui/help.go new file mode 100644 index 0000000..1afaf1b --- /dev/null +++ b/internal/tui/help.go @@ -0,0 +1,204 @@ +package tui + +import ( + "fmt" + "strings" + + "charm.land/bubbles/v2/viewport" + "charm.land/lipgloss/v2" + + "ai-agent/internal/command" +) + +// helpContentWidth returns the inner width for the help modal content. +func (m *Model) helpContentWidth() int { + maxW := 60 + if m.width < maxW+8 { + maxW = m.width - 8 + } + if maxW < 30 { + maxW = 30 + } + return maxW +} + +// helpViewportHeight returns the viewport height for the help modal. +func (m *Model) helpViewportHeight() int { + // Leave room for border (2), padding (2), title (2), footer (1) + h := m.height - 10 + if h < 5 { + h = 5 + } + return h +} + +// buildHelpContent builds the raw help text (without border/viewport wrapper). +func (m *Model) buildHelpContent(innerW int) string { + var b strings.Builder + + loc := m.tr() + b.WriteString(m.styles.OverlayAccent.Render(loc.KeyboardShortcuts)) + b.WriteString("\n") + + shortcuts := []struct{ key, desc string }{ + {"enter", loc.SendMessage}, + {"shift+enter", loc.NewLineInInput}, + {"shift+tab", loc.CycleMode}, + {"F6", loc.QuickModelSwitch}, + {"esc", loc.CancelStreaming}, + {"ctrl+c / ctrl+q / F10", loc.QuitKeys}, + {"ctrl+l", loc.ClearScreen}, + {"ctrl+n", loc.NewConversation}, + {"?", loc.ToggleHelp}, + {"t", loc.ExpandTools}, + {"space", loc.ToggleToolDetails}, + {"ctrl+y", loc.CopyLastResponse}, + {"ctrl+t", loc.ToggleThinking}, + {"ctrl+k", loc.ToggleCompact}, + {"ctrl+e", loc.OpenInEditor}, + {"↑/↓", loc.BrowseHistory}, + {"pgup/pgdown", loc.ScrollViewport}, + {"ctrl+u/d", loc.HalfPageScroll}, + {"tab", loc.Autocomplete}, + {"F2", loc.LanguageF2}, + } + + for _, s := range shortcuts { + fmt.Fprintf(&b, " %s %s\n", + m.styles.FocusIndicator.Width(16).Render(s.key), + m.styles.OverlayDim.Render(s.desc), + ) + } + + b.WriteString("\n") + b.WriteString(m.styles.OverlayAccent.Render(loc.InputShortcuts)) + b.WriteString("\n") + + inputShortcuts := []struct{ key, desc string }{ + {"@file", loc.AttachFile}, + {"#skill", loc.ActivateSkill}, + {"/cmd", loc.RunSlashCommand}, + } + + for _, s := range inputShortcuts { + fmt.Fprintf(&b, " %s %s\n", + m.styles.FocusIndicator.Width(16).Render(s.key), + m.styles.OverlayDim.Render(s.desc), + ) + } + + b.WriteString("\n") + b.WriteString(m.styles.OverlayAccent.Render(loc.SlashCommands)) + b.WriteString("\n") + + // Slash commands. + if m.cmdRegistry != nil { + for _, cmd := range m.cmdRegistry.All() { + fmt.Fprintf(&b, " %s %s\n", + m.styles.FocusIndicator.Width(16).Render("/"+cmd.Name), + m.styles.OverlayDim.Render(cmd.Description), + ) + } + } + + return b.String() +} + +// initHelpViewport creates and populates the help viewport for scrolling. +func (m *Model) initHelpViewport() { + innerW := m.helpContentWidth() + vpH := m.helpViewportHeight() + + m.helpViewport = viewport.New( + viewport.WithWidth(innerW), + viewport.WithHeight(vpH), + ) + // Disable default arrow key bindings (we handle j/k/up/down ourselves via parent) + m.helpViewport.KeyMap.Up.SetEnabled(false) + m.helpViewport.KeyMap.Down.SetEnabled(false) + m.helpViewport.KeyMap.PageUp.SetEnabled(false) + m.helpViewport.KeyMap.PageDown.SetEnabled(false) + m.helpViewport.KeyMap.HalfPageUp.SetEnabled(false) + m.helpViewport.KeyMap.HalfPageDown.SetEnabled(false) + + content := m.buildHelpContent(innerW) + m.helpViewport.SetContent(content) +} + +// renderHelpOverlay builds a centered, scrollable help modal. +func (m *Model) renderHelpOverlay(contentWidth int) string { + innerW := m.helpContentWidth() + + var b strings.Builder + + loc := m.tr() + b.WriteString(m.styles.OverlayTitle.Render(loc.Help)) + b.WriteString("\n\n") + + // Viewport content (scrollable). + b.WriteString(m.helpViewport.View()) + b.WriteString("\n") + + pct := m.helpViewport.ScrollPercent() + var hint string + if pct <= 0 { + hint = loc.ScrollMore + } else if pct >= 1.0 { + hint = loc.ScrollClose + } else { + hint = fmt.Sprintf(loc.ScrollPct, pct*100) + } + b.WriteString(m.styles.OverlayDim.Render(hint)) + + // Wrap in a box. + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(m.styles.FocusIndicator.GetForeground()). + Padding(1, 2). + Width(innerW + 6) // +6 for padding (2*2) + border (2) + + return box.Render(b.String()) +} + +// overlayOnContent renders the overlay centered on the viewport area. +func (m *Model) overlayOnContent(base, overlay string) string { + baseLines := strings.Split(base, "\n") + overlayLines := strings.Split(overlay, "\n") + + // Center vertically. + startY := (len(baseLines) - len(overlayLines)) / 2 + if startY < 0 { + startY = 0 + } + + for i, ol := range overlayLines { + row := startY + i + if row >= len(baseLines) { + break + } + // Center horizontally. + olW := lipgloss.Width(ol) + padLeft := (m.width - olW) / 2 + if padLeft < 0 { + padLeft = 0 + } + baseLines[row] = strings.Repeat(" ", padLeft) + ol + } + + return strings.Join(baseLines, "\n") +} + +// commandHelpEntries extracts SkillInfo from commands for display. +func commandHelpEntries(reg *command.Registry) []struct{ Name, Desc string } { + var entries []struct{ Name, Desc string } + if reg == nil { + return entries + } + for _, cmd := range reg.All() { + entries = append(entries, struct{ Name, Desc string }{ + Name: "/" + cmd.Name, + Desc: cmd.Description, + }) + } + return entries +} diff --git a/internal/tui/helpers_test.go b/internal/tui/helpers_test.go new file mode 100644 index 0000000..855ad93 --- /dev/null +++ b/internal/tui/helpers_test.go @@ -0,0 +1,73 @@ +package tui + +import ( + "testing" + "time" + + "ai-agent/internal/agent" + "ai-agent/internal/command" + + tea "charm.land/bubbletea/v2" +) + +var ( + testTime = time.Now() + testDuration = 100 * time.Millisecond +) + +func newTestModel(t *testing.T) *Model { + t.Helper() + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + completer := NewCompleter(reg, []string{"model-a", "model-b"}, []string{"skill-a"}, []string{"agent-x"}, nil) + ag := agent.New(nil, nil, 0) + m := New(ag, reg, nil, completer, nil, nil, nil) + m.promptHistoryPath = "" + m.promptHistory = nil + m.lang = LangEn + m.initializing = false + updated, _ := m.Update(tea.WindowSizeMsg{Width: 80, Height: 24}) + return updated.(*Model) +} + +func escKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyEscape} +} + +func enterKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyEnter} +} + +func tabKey() tea.KeyPressMsg {return tea.KeyPressMsg{Code: tea.KeyTab} } + +func upKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyUp} +} + +func downKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyDown} +} + +func leftKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyLeft} +} + +func rightKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyRight} +} + +func spaceKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeySpace} +} + +func charKey(r rune) tea.KeyPressMsg { + return tea.KeyPressMsg{Code: r, Text: string(r)} +} + +func ctrlKey(r rune) tea.KeyPressMsg { + return tea.KeyPressMsg{Code: r, Mod: tea.ModCtrl} +} + +func shiftTabKey() tea.KeyPressMsg { + return tea.KeyPressMsg{Code: tea.KeyTab, Mod: tea.ModShift} +} diff --git a/internal/tui/history_test.go b/internal/tui/history_test.go new file mode 100644 index 0000000..910005b --- /dev/null +++ b/internal/tui/history_test.go @@ -0,0 +1,229 @@ +package tui + +import "testing" + +func TestPushHistory_Basic(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.pushHistory("world") + + if len(m.promptHistory) != 2 { + t.Fatalf("expected 2 history entries, got %d", len(m.promptHistory)) + } + if m.promptHistory[0] != "hello" || m.promptHistory[1] != "world" { + t.Errorf("unexpected history: %v", m.promptHistory) + } +} + +func TestPushHistory_Empty(t *testing.T) { + m := newTestModel(t) + m.pushHistory("") + if len(m.promptHistory) != 0 { + t.Error("empty string should not be added to history") + } +} + +func TestPushHistory_DedupConsecutive(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.pushHistory("hello") + m.pushHistory("hello") + + if len(m.promptHistory) != 1 { + t.Errorf("expected 1 entry after dedup, got %d", len(m.promptHistory)) + } +} + +func TestPushHistory_DedupNonConsecutive(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.pushHistory("world") + m.pushHistory("hello") + + if len(m.promptHistory) != 3 { + t.Errorf("non-consecutive duplicates should be kept, got %d", len(m.promptHistory)) + } +} + +func TestPushHistory_CapAt100(t *testing.T) { + m := newTestModel(t) + for i := 0; i < 110; i++ { + m.pushHistory(string(rune('a' + i%26)) + string(rune('0'+i/26))) + } + + if len(m.promptHistory) > 100 { + t.Errorf("history should be capped at 100, got %d", len(m.promptHistory)) + } +} + +func TestNavigateHistory_EmptyHistory(t *testing.T) { + m := newTestModel(t) + if m.navigateHistory(-1) { + t.Error("up on empty history should return false") + } + if m.navigateHistory(1) { + t.Error("down on empty history should return false") + } +} + +func TestNavigateHistory_UpDown(t *testing.T) { + m := newTestModel(t) + m.pushHistory("first") + m.pushHistory("second") + m.pushHistory("third") + + // Set current input + m.input.SetValue("current") + + // Press up: should go to "third" (most recent) + if !m.navigateHistory(-1) { + t.Fatal("up should succeed") + } + if m.input.Value() != "third" { + t.Errorf("expected 'third', got %q", m.input.Value()) + } + if m.historySaved != "current" { + t.Errorf("current input should be saved, got %q", m.historySaved) + } + + // Press up again: "second" + if !m.navigateHistory(-1) { + t.Fatal("up should succeed") + } + if m.input.Value() != "second" { + t.Errorf("expected 'second', got %q", m.input.Value()) + } + + // Press up again: "first" + if !m.navigateHistory(-1) { + t.Fatal("up should succeed") + } + if m.input.Value() != "first" { + t.Errorf("expected 'first', got %q", m.input.Value()) + } + + // Press up again: at oldest, should fail + if m.navigateHistory(-1) { + t.Error("up at oldest should return false") + } + + // Press down: "second" + if !m.navigateHistory(1) { + t.Fatal("down should succeed") + } + if m.input.Value() != "second" { + t.Errorf("expected 'second', got %q", m.input.Value()) + } + + // Press down: "third" + if !m.navigateHistory(1) { + t.Fatal("down should succeed") + } + if m.input.Value() != "third" { + t.Errorf("expected 'third', got %q", m.input.Value()) + } + + // Press down past newest: restore saved input + if !m.navigateHistory(1) { + t.Fatal("down past newest should succeed") + } + if m.input.Value() != "current" { + t.Errorf("expected restored 'current', got %q", m.input.Value()) + } + if m.historyIndex != -1 { + t.Errorf("historyIndex should be -1 after exiting history, got %d", m.historyIndex) + } +} + +func TestNavigateHistory_DownNotBrowsing(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + + // Down without first pressing up should return false + if m.navigateHistory(1) { + t.Error("down when not browsing should return false") + } +} + +func TestHistoryKey_OnlyWhenIdleAndEmpty(t *testing.T) { + t.Run("idle_empty_with_history", func(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.state = StateIdle + m.overlay = OverlayNone + + updated, _ := m.Update(upKey()) + m = updated.(*Model) + + if m.input.Value() != "hello" { + t.Errorf("up key should navigate history, got %q", m.input.Value()) + } + }) + + t.Run("idle_nonempty_no_history", func(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.state = StateIdle + m.overlay = OverlayNone + m.input.SetValue("typing something") + + updated, _ := m.Update(upKey()) + m = updated.(*Model) + + // Should NOT navigate history when input has content and not already browsing + if m.historyIndex != -1 { + t.Error("up key should not navigate history when input is non-empty") + } + }) + + t.Run("waiting_no_history", func(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.state = StateWaiting + + updated, _ := m.Update(upKey()) + m = updated.(*Model) + + if m.historyIndex != -1 { + t.Error("up key should not navigate history when not idle") + } + }) + + t.Run("overlay_no_history", func(t *testing.T) { + m := newTestModel(t) + m.pushHistory("hello") + m.state = StateIdle + m.overlay = OverlayHelp + + updated, _ := m.Update(upKey()) + m = updated.(*Model) + + if m.historyIndex != -1 { + t.Error("up key should not navigate history when overlay is open") + } + }) +} + +func TestHistoryKey_AlreadyBrowsing(t *testing.T) { + m := newTestModel(t) + m.pushHistory("first") + m.pushHistory("second") + m.state = StateIdle + m.overlay = OverlayNone + // Input must be empty to start browsing + m.input.SetValue("") + + // Press up — enters history (input is empty, so allowed) + updated, _ := m.Update(upKey()) + m = updated.(*Model) + if m.input.Value() != "second" { + t.Fatalf("expected 'second', got %q", m.input.Value()) + } + + // Now input is non-empty (from history), up should still work because historyIndex != -1 + updated, _ = m.Update(upKey()) + m = updated.(*Model) + if m.input.Value() != "first" { + t.Errorf("expected 'first', got %q", m.input.Value()) + } +} diff --git a/internal/tui/i18n.go b/internal/tui/i18n.go new file mode 100644 index 0000000..d26a4ac --- /dev/null +++ b/internal/tui/i18n.go @@ -0,0 +1,295 @@ +package tui + +import ( + "os" + "path/filepath" + "strings" +) + +// Lang is the UI language code. +type Lang string + +const ( + LangEn Lang = "en" + LangRu Lang = "ru" +) + +// L holds all localizable UI strings. +type L struct { + // General + Help string + Quit string + Cancel string + Send string + New string + Clear string + Complete string + ScrollMore string + ScrollClose string + ScrollPct string + + // Placeholder & input + Placeholder string + + // Help overlay + KeyboardShortcuts string + InputShortcuts string + SlashCommands string + SendMessage string + NewLineInInput string + CycleMode string + QuickModelSwitch string + CancelStreaming string + QuitKeys string + ClearScreen string + NewConversation string + ToggleHelp string + ExpandTools string + ToggleToolDetails string + CopyLastResponse string + ToggleThinking string + ToggleCompact string + OpenInEditor string + BrowseHistory string + ScrollViewport string + HalfPageScroll string + Autocomplete string + AttachFile string + ActivateSkill string + RunSlashCommand string + Language string + LanguageF2 string + + // Side panel + SidePanelAIAgent string + SidePanelTagline string + SidePanelModels string + SidePanelServers string + SidePanelICE string + SidePanelQuickActions string + SidePanelHelp string + SidePanelHelpDesc string + SidePanelServersDesc string + SidePanelModelDesc string + SidePanelLoadDesc string + SidePanelLoad string + ToolsConnected string + NoServersConnected string + ICEConversations string + ICECrossSessionActive string + ICEDisabled string + ICECrossSessionInactive string + + // Model picker + SelectModel string + + // Modes + ModeAsk string + ModePlan string + ModeBuild string + + // Window title + WindowTitle string + WindowTitleThink string + WindowTitleStream string + WindowTitleDone string + + // Key hints (short action names) + HintSend string + HintComplete string + HintHelp string + HintCancel string + HintQuit string + HintNew string + HintClear string + HintCommands string + HintFiles string + HintSkills string + + // Toasts / messages + NoModelsAvailable string + LanguageSet string +} + +var localeEn = L{ + Help: "Help", Quit: "quit", Cancel: "cancel", Send: "send", New: "new", Clear: "clear", Complete: "complete", + ScrollMore: "↓ scroll for more", ScrollClose: "Esc or q to close", ScrollPct: "%.0f%% · j/k to scroll", + Placeholder: "Ask anything... (Enter to send, ctrl+b for sidebar)", + KeyboardShortcuts: "Keyboard Shortcuts", + InputShortcuts: "Input Shortcuts", + SlashCommands: "Slash Commands", + SendMessage: "Send message", + NewLineInInput: "New line in input", + CycleMode: "Cycle mode (ASK/PLAN/BUILD)", + QuickModelSwitch: "Quick model switch", + CancelStreaming: "Cancel streaming / close overlay", + QuitKeys: "Quit", + ClearScreen: "Clear screen (keep history)", + NewConversation: "New conversation", + ToggleHelp: "Toggle this help (when input empty)", + ExpandTools: "Expand/collapse all tools", + ToggleToolDetails: "Toggle last tool details", + CopyLastResponse: "Copy last response", + ToggleThinking: "Toggle thinking display", + ToggleCompact: "Toggle compact mode", + OpenInEditor: "Open input in $EDITOR", + BrowseHistory: "Browse input history", + ScrollViewport: "Scroll viewport", + HalfPageScroll: "Half-page scroll", + Autocomplete: "Autocomplete (commands/files/skills)", + AttachFile: "Attach file or agent", + ActivateSkill: "Activate skill", + RunSlashCommand: "Run slash command", + Language: "Language", + LanguageF2: "Switch interface language (F2)", + SidePanelAIAgent: "AI AGENT", + SidePanelTagline: "100% local · Your data never leaves", + SidePanelModels: "Models", + SidePanelServers: "Servers", + SidePanelICE: "ICE", + SidePanelQuickActions: "Quick Actions", + SidePanelHelp: "Help", + SidePanelHelpDesc: "Keyboard shortcuts", + SidePanelServersDesc: "List connected tools", + SidePanelModelDesc: "Switch model", + SidePanelLoad: "Load", + SidePanelLoadDesc: "Add context from file", + ToolsConnected: "%d tools connected", + NoServersConnected: "No servers connected", + ICEConversations: "%d conversations", + ICECrossSessionActive: "Cross-session memory active", + ICEDisabled: "ICE disabled", + ICECrossSessionInactive: "Cross-session memory inactive", + SelectModel: "Select Model", + ModeAsk: "ASK", ModePlan: "PLAN", ModeBuild: "BUILD", + WindowTitle: "AI AGENT", WindowTitleThink: "AI AGENT · thinking...", + WindowTitleStream: "AI AGENT · streaming...", WindowTitleDone: "AI AGENT · done", + HintSend: "send", HintComplete: "complete", HintHelp: "help", HintCancel: "cancel", HintQuit: "quit", + HintNew: "new", HintClear: "clear", HintCommands: "commands", HintFiles: "files", HintSkills: "skills", + NoModelsAvailable: "No models available. Check Ollama connection.", + LanguageSet: "Language: %s", +} + +var localeRu = L{ + Help: "Справка", Quit: "выход", Cancel: "отмена", Send: "отправить", New: "новый", Clear: "очистить", Complete: "дополнение", + ScrollMore: "↓ листать вниз", ScrollClose: "Esc или q — закрыть", ScrollPct: "%.0f%% · j/k листать", + Placeholder: "Спросите что угодно... (Enter — отправить, ctrl+b — панель)", + KeyboardShortcuts: "Горячие клавиши", + InputShortcuts: "Клавиши ввода", + SlashCommands: "Слэш-команды", + SendMessage: "Отправить сообщение", + NewLineInInput: "Новая строка в поле ввода", + CycleMode: "Режим (ASK/PLAN/BUILD)", + QuickModelSwitch: "Быстрая смена модели", + CancelStreaming: "Отмена / закрыть окно", + QuitKeys: "Выход", + ClearScreen: "Очистить экран (история сохраняется)", + NewConversation: "Новый диалог", + ToggleHelp: "Показать справку (при пустом вводе)", + ExpandTools: "Развернуть/свернуть инструменты", + ToggleToolDetails: "Детали последнего инструмента", + CopyLastResponse: "Копировать последний ответ", + ToggleThinking: "Показать процесс размышления", + ToggleCompact: "Компактный режим", + OpenInEditor: "Открыть в $EDITOR", + BrowseHistory: "История ввода", + ScrollViewport: "Прокрутка", + HalfPageScroll: "На полстраницы", + Autocomplete: "Дополнение (команды/файлы/навыки)", + AttachFile: "Прикрепить файл или агента", + ActivateSkill: "Подключить навык", + RunSlashCommand: "Выполнить слэш-команду", + Language: "Язык", + LanguageF2: "Язык интерфейса (F2)", + SidePanelAIAgent: "AI AGENT", + SidePanelTagline: "100% локально · Ваши данные не покидают устройство", + SidePanelModels: "Модели", + SidePanelServers: "Серверы", + SidePanelICE: "ICE", + SidePanelQuickActions: "Быстрые действия", + SidePanelHelp: "Справка", + SidePanelHelpDesc: "Горячие клавиши", + SidePanelServersDesc: "Подключённые инструменты", + SidePanelModelDesc: "Сменить модель", + SidePanelLoad: "Загрузить", + SidePanelLoadDesc: "Добавить контекст из файла", + ToolsConnected: "Подключено инструментов: %d", + NoServersConnected: "Серверы не подключены", + ICEConversations: "Диалогов: %d", + ICECrossSessionActive: "Память между сессиями активна", + ICEDisabled: "ICE выключен", + ICECrossSessionInactive: "Память между сессиями неактивна", + SelectModel: "Выбор модели", + ModeAsk: "ASK", ModePlan: "PLAN", ModeBuild: "BUILD", + WindowTitle: "AI AGENT", WindowTitleThink: "AI AGENT · думает...", + WindowTitleStream: "AI AGENT · отвечает...", WindowTitleDone: "AI AGENT · готово", + HintSend: "отправить", HintComplete: "дополнение", HintHelp: "справка", HintCancel: "отмена", HintQuit: "выход", + HintNew: "новый", HintClear: "очистить", HintCommands: "команды", HintFiles: "файлы", HintSkills: "навыки", + NoModelsAvailable: "Нет моделей. Проверьте подключение к Ollama.", + LanguageSet: "Язык: %s", +} + +// Locale returns the strings for the given language. Unknown lang falls back to English. +func Locale(lang Lang) L { + switch lang { + case LangRu: + return localeRu + default: + return localeEn + } +} + +// LangName returns a display name for the language. +func LangName(lang Lang) string { + switch lang { + case LangRu: + return "Русский" + default: + return "English" + } +} + +// NextLang cycles to the next language (en -> ru -> en). +func NextLang(lang Lang) Lang { + switch lang { + case LangEn: + return LangRu + case LangRu: + return LangEn + default: + return LangEn + } +} + +// DefaultLangPath returns the path for storing UI language preference. +func DefaultLangPath() string { + home, err := os.UserHomeDir() + if err != nil { + return "lang" + } + return filepath.Join(home, ".config", "ai-agent", "lang") +} + +// LoadLang reads the saved language from DefaultLangPath(). Returns LangEn if missing or invalid. +func LoadLang() Lang { + data, err := os.ReadFile(DefaultLangPath()) + if err != nil { + return LangEn + } + switch strings.TrimSpace(strings.ToLower(string(data))) { + case "ru", "русский": + return LangRu + default: + return LangEn + } +} + +// SaveLang writes the language to DefaultLangPath(). +func SaveLang(lang Lang) error { + path := DefaultLangPath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + return os.WriteFile(path, []byte(lang), 0o644) +} diff --git a/internal/tui/keyhints.go b/internal/tui/keyhints.go new file mode 100644 index 0000000..576a61b --- /dev/null +++ b/internal/tui/keyhints.go @@ -0,0 +1,147 @@ +package tui + +import ( + "strings" + + "charm.land/lipgloss/v2" +) + +// KeyHint displays a keyboard shortcut hint. +type KeyHint struct { + Key string + Action string +} + +// KeyHints renders a row of key hints. +type KeyHints struct { + hints []KeyHint + styles KeyHintStyles + maxWidth int +} + +// KeyHintStyles holds styling for key hints. +type KeyHintStyles struct { + Key lipgloss.Style + Action lipgloss.Style + Divider lipgloss.Style +} + +// DefaultKeyHintStyles returns default styles. +func DefaultKeyHintStyles(isDark bool) KeyHintStyles { + if isDark { + return KeyHintStyles{ + Key: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")).Background(lipgloss.Color("#3b4252")).Padding(0, 1), + Action: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Divider: lipgloss.NewStyle().Foreground(lipgloss.Color("#3b4252")), + } + } + return KeyHintStyles{ + Key: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f8f")).Background(lipgloss.Color("#e5e9f0")).Padding(0, 1), + Action: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + Divider: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + } +} + +// NewKeyHints creates a new key hints component. +func NewKeyHints(hints []KeyHint, maxWidth int, isDark bool) *KeyHints { + return &KeyHints{ + hints: hints, + styles: DefaultKeyHintStyles(isDark), + maxWidth: maxWidth, + } +} + +// SetDark updates theme. +func (kh *KeyHints) SetDark(isDark bool) { + kh.styles = DefaultKeyHintStyles(isDark) +} + +// SetHints updates the hints. +func (kh *KeyHints) SetHints(hints []KeyHint) { + kh.hints = hints +} + +// Render returns the key hints as a single line. +func (kh *KeyHints) Render() string { + if len(kh.hints) == 0 { + return "" + } + + var b strings.Builder + b.WriteString(kh.styles.Divider.Render("│")) + + for i, hint := range kh.hints { + if i > 0 { + b.WriteString(" ") + } + b.WriteString(kh.styles.Key.Render(hint.Key)) + b.WriteString(" ") + b.WriteString(kh.styles.Action.Render(hint.Action)) + } + + return b.String() +} + +// RenderInline renders hints as inline text (no key box). +func (kh *KeyHints) RenderInline() string { + if len(kh.hints) == 0 { + return "" + } + + var b strings.Builder + + for i, hint := range kh.hints { + if i > 0 { + b.WriteString(" · ") + } + b.WriteString(hint.Key) + b.WriteString(" ") + b.WriteString(kh.styles.Action.Render(hint.Action)) + } + + return b.String() +} + +// SetMaxWidth sets the maximum width for wrapping. +func (kh *KeyHints) SetMaxWidth(w int) { + kh.maxWidth = w +} + +func defaultHintsForLang(lang Lang) []KeyHint { + loc := Locale(lang) + return []KeyHint{ + {Key: "Enter", Action: loc.HintSend}, + {Key: "Tab", Action: loc.HintComplete}, + {Key: "?", Action: loc.HintHelp}, + {Key: "Esc", Action: loc.HintCancel}, + {Key: "Ctrl+C / F10", Action: loc.HintQuit}, + } +} + +// DefaultKeyHints returns common key hints for the application. +func DefaultKeyHints(lang Lang, isDark bool) *KeyHints { + return NewKeyHints(defaultHintsForLang(lang), 60, isDark) +} + +// FooterHints returns hints shown in the footer. +func FooterHints(lang Lang, keys KeyMap, isDark bool) *KeyHints { + loc := Locale(lang) + hints := []KeyHint{ + {Key: "?", Action: loc.HintHelp}, + {Key: "Ctrl+N", Action: loc.HintNew}, + {Key: "Ctrl+L", Action: loc.HintClear}, + } + return NewKeyHints(hints, 40, isDark) +} + +// InputHints returns hints shown when typing. +func InputHints(lang Lang, keys KeyMap, isDark bool) *KeyHints { + loc := Locale(lang) + hints := []KeyHint{ + {Key: "Tab", Action: loc.HintComplete}, + {Key: "/", Action: loc.HintCommands}, + {Key: "@", Action: loc.HintFiles}, + {Key: "#", Action: loc.HintSkills}, + } + return NewKeyHints(hints, 40, isDark) +} diff --git a/internal/tui/keys.go b/internal/tui/keys.go new file mode 100644 index 0000000..047f982 --- /dev/null +++ b/internal/tui/keys.go @@ -0,0 +1,170 @@ +package tui + +import "charm.land/bubbles/v2/key" + +// KeyMap defines all keyboard shortcuts for the application. +type KeyMap struct { + Send key.Binding + NewLine key.Binding + Cancel key.Binding + Quit key.Binding + ClearView key.Binding + NewConvo key.Binding + Help key.Binding + ToggleTools key.Binding + PageUp key.Binding + PageDown key.Binding + HalfPageUp key.Binding + HalfPageDn key.Binding + Complete key.Binding + CompleteUp key.Binding + CompleteDown key.Binding + CompleteToggle key.Binding + CompleteSelect key.Binding + CopyLast key.Binding + CycleMode key.Binding + ModelPicker key.Binding + HistoryUp key.Binding + HistoryDown key.Binding + ToggleFocusedTool key.Binding + ToggleThinking key.Binding + CompactToggle key.Binding + ExternalEditor key.Binding + ToggleSidePanel key.Binding + LanguageCycle key.Binding +} + +// DefaultKeyMap returns the default keybindings. +func DefaultKeyMap() KeyMap { + return KeyMap{ + Send: key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "send message"), + ), + NewLine: key.NewBinding( + key.WithKeys("shift+enter"), + key.WithHelp("shift+enter", "new line"), + ), + Cancel: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "cancel / close overlay"), + ), + Quit: key.NewBinding( + key.WithKeys("ctrl+c", "ctrl+q", "f10"), + key.WithHelp("ctrl+c / ctrl+q / F10", "quit"), + ), + ClearView: key.NewBinding( + key.WithKeys("ctrl+l"), + key.WithHelp("ctrl+l", "clear screen"), + ), + NewConvo: key.NewBinding( + key.WithKeys("ctrl+n"), + key.WithHelp("ctrl+n", "new conversation"), + ), + Help: key.NewBinding( + key.WithKeys("?"), + key.WithHelp("?", "toggle help"), + ), + ToggleTools: key.NewBinding( + key.WithKeys("t"), + key.WithHelp("t", "expand/collapse tool details"), + ), + PageUp: key.NewBinding( + key.WithKeys("pgup"), + key.WithHelp("pgup", "scroll up"), + ), + PageDown: key.NewBinding( + key.WithKeys("pgdown"), + key.WithHelp("pgdown", "scroll down"), + ), + HalfPageUp: key.NewBinding( + key.WithKeys("ctrl+u"), + key.WithHelp("ctrl+u", "half page up"), + ), + HalfPageDn: key.NewBinding( + key.WithKeys("ctrl+d"), + key.WithHelp("ctrl+d", "half page down"), + ), + Complete: key.NewBinding( + key.WithKeys("tab", "ctrl+i"), + key.WithHelp("tab", "autocomplete"), + ), + CompleteUp: key.NewBinding( + key.WithKeys("up"), + key.WithHelp("up", "previous completion"), + ), + CompleteDown: key.NewBinding( + key.WithKeys("down"), + key.WithHelp("down", "next completion"), + ), + CompleteToggle: key.NewBinding( + key.WithKeys("tab", "ctrl+i"), + key.WithHelp("tab", "toggle selection"), + ), + CompleteSelect: key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "select item"), + ), + CopyLast: key.NewBinding( + key.WithKeys("ctrl+y"), + key.WithHelp("ctrl+y", "copy last response"), + ), + CycleMode: key.NewBinding( + key.WithKeys("shift+tab"), + key.WithHelp("shift+tab", "cycle mode (ASK/PLAN/BUILD)"), + ), + ModelPicker: key.NewBinding( + key.WithKeys("f6", "ctrl+m"), + key.WithHelp("F6 / ctrl+m", "quick model switch"), + ), + HistoryUp: key.NewBinding( + key.WithKeys("up"), + key.WithHelp("↑", "previous input"), + ), + HistoryDown: key.NewBinding( + key.WithKeys("down"), + key.WithHelp("↓", "next input"), + ), + ToggleFocusedTool: key.NewBinding( + key.WithKeys(" "), + key.WithHelp("space", "toggle last tool details"), + ), + ToggleThinking: key.NewBinding( + key.WithKeys("ctrl+t"), + key.WithHelp("ctrl+t", "toggle thinking display"), + ), + CompactToggle: key.NewBinding( + key.WithKeys("ctrl+k"), + key.WithHelp("ctrl+k", "toggle compact mode"), + ), + ExternalEditor: key.NewBinding( + key.WithKeys("ctrl+e"), + key.WithHelp("ctrl+e", "open in $EDITOR"), + ), + ToggleSidePanel: key.NewBinding( + key.WithKeys("ctrl+b"), + key.WithHelp("ctrl+b", "toggle side panel"), + ), + LanguageCycle: key.NewBinding( + key.WithKeys("f2"), + key.WithHelp("F2", "language"), + ), + } +} + +// ShortHelp returns the key groups for the short help view. +func (k KeyMap) ShortHelp() []key.Binding { + return []key.Binding{k.Send, k.NewLine, k.Cancel, k.Quit, k.Help} +} + +// FullHelp returns the key groups for the full help view. +func (k KeyMap) FullHelp() [][]key.Binding { + return [][]key.Binding{ + {k.Send, k.NewLine, k.Cancel, k.Quit}, + {k.ClearView, k.NewConvo, k.Help, k.ToggleTools, k.CopyLast}, + {k.PageUp, k.PageDown, k.HalfPageUp, k.HalfPageDn}, + {k.CycleMode, k.ModelPicker, k.ToggleSidePanel}, + {k.HistoryUp, k.HistoryDown}, + {k.ToggleFocusedTool, k.ToggleThinking, k.CompactToggle, k.ExternalEditor}, + } +} diff --git a/internal/tui/layout.go b/internal/tui/layout.go new file mode 100644 index 0000000..48845c4 --- /dev/null +++ b/internal/tui/layout.go @@ -0,0 +1,62 @@ +package tui + +import ( + "fmt" + "strings" +) + +// layoutConfig holds adaptive layout parameters based on terminal size. +type layoutConfig struct { + ContentPad int + ToolIndent string + ToolSummaryMax int + ArgsTruncMax int + ResultTruncMax int + HeaderMode string // "full" or "compact" +} + +// currentLayout returns layout parameters adapted to the current terminal size +// and user compact preference. +func (m *Model) currentLayout() layoutConfig { + if m.forceCompact || m.width < 80 || m.height < 24 { + return layoutConfig{ + ContentPad: 2, + ToolIndent: " ", + ToolSummaryMax: 40, + ArgsTruncMax: 100, + ResultTruncMax: 150, + HeaderMode: "compact", + } + } + if m.width > 120 { + return layoutConfig{ + ContentPad: 4, + ToolIndent: " ", + ToolSummaryMax: 80, + ArgsTruncMax: 300, + ResultTruncMax: 500, + HeaderMode: "full", + } + } + return layoutConfig{ + ContentPad: 4, + ToolIndent: " ", + ToolSummaryMax: 60, + ArgsTruncMax: 200, + ResultTruncMax: 300, + HeaderMode: "full", + } +} + +// contextProgressBar renders a mini progress bar: █████░░░░░ 42% +func contextProgressBar(pct int) string { + const barWidth = 10 + filled := pct * barWidth / 100 + if filled > barWidth { + filled = barWidth + } + if filled < 0 { + filled = 0 + } + return strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) + fmt.Sprintf(" %d%%", pct) +} diff --git a/internal/tui/layout_test.go b/internal/tui/layout_test.go new file mode 100644 index 0000000..67442e8 --- /dev/null +++ b/internal/tui/layout_test.go @@ -0,0 +1,73 @@ +package tui + +import "testing" + +func TestCurrentLayout_Compact(t *testing.T) { + m := &Model{width: 60, height: 20} + layout := m.currentLayout() + if layout.HeaderMode != "compact" { + t.Errorf("small terminal should use compact mode, got %q", layout.HeaderMode) + } + if layout.ArgsTruncMax != 100 { + t.Errorf("compact ArgsTruncMax = %d, want 100", layout.ArgsTruncMax) + } +} + +func TestCurrentLayout_Normal(t *testing.T) { + m := &Model{width: 100, height: 30} + layout := m.currentLayout() + if layout.HeaderMode != "full" { + t.Errorf("normal terminal should use full mode, got %q", layout.HeaderMode) + } + if layout.ArgsTruncMax != 200 { + t.Errorf("normal ArgsTruncMax = %d, want 200", layout.ArgsTruncMax) + } +} + +func TestCurrentLayout_Wide(t *testing.T) { + m := &Model{width: 150, height: 40} + layout := m.currentLayout() + if layout.HeaderMode != "full" { + t.Errorf("wide terminal should use full mode, got %q", layout.HeaderMode) + } + if layout.ArgsTruncMax != 300 { + t.Errorf("wide ArgsTruncMax = %d, want 300", layout.ArgsTruncMax) + } + if layout.ResultTruncMax != 500 { + t.Errorf("wide ResultTruncMax = %d, want 500", layout.ResultTruncMax) + } +} + +func TestCurrentLayout_CompactHeight(t *testing.T) { + // Wide but short terminal should be compact. + m := &Model{width: 120, height: 20} + layout := m.currentLayout() + if layout.HeaderMode != "compact" { + t.Errorf("short terminal should use compact mode, got %q", layout.HeaderMode) + } +} + +func TestContextProgressBar(t *testing.T) { + tests := []struct { + pct int + want string + }{ + {0, "░░░░░░░░░░ 0%"}, + {50, "█████░░░░░ 50%"}, + {100, "██████████ 100%"}, + } + for _, tt := range tests { + got := contextProgressBar(tt.pct) + if got != tt.want { + t.Errorf("contextProgressBar(%d) = %q, want %q", tt.pct, got, tt.want) + } + } +} + +func TestContextProgressBar_Overflow(t *testing.T) { + // Should not panic or produce weird output for >100% + result := contextProgressBar(150) + if result == "" { + t.Error("overflow should still produce output") + } +} diff --git a/internal/tui/logo.go b/internal/tui/logo.go new file mode 100644 index 0000000..b36366a --- /dev/null +++ b/internal/tui/logo.go @@ -0,0 +1,121 @@ +package tui + +import ( + "strings" + "time" + + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/harmonica" +) + +const ( + LogoPhaseHidden = iota + LogoPhaseAnimating + LogoPhaseVisible + LogoPhaseDone +) + +type LogoTickMsg struct{} + +type LogoModel struct { + phase int + alpha float64 + vel float64 + spring harmonica.Spring + isDark bool + frame int + displayLogo bool +} + +func logoLines() []string { + return []string{ + ``, + ` ╔═╗╦ ╔═╗╔═╗╔═╗╔╗╔╔╦╗`, + ` ╠═╣║ ╠═╣║ ╦║╣ ║║║ ║ `, + ` ╩ ╩╩ ╩ ╩╚═╝╚═╝╝╚╝ ╩ `, + ``, + ` 100% local · Your data never leaves`, + ``, + } +} + +func NewLogoModel(isDark bool) LogoModel { + return LogoModel{ + spring: harmonica.NewSpring(harmonica.FPS(60), 4.0, 0.9), + isDark: isDark, + phase: LogoPhaseHidden, + } +} + +func (m *LogoModel) Start() { + m.phase = LogoPhaseAnimating + m.alpha = 0 + m.vel = 0 + m.frame = 0 +} + +func (m LogoModel) Init() tea.Cmd { + return tea.Tick(16*time.Millisecond, func(time.Time) tea.Msg { + return LogoTickMsg{} + }) +} + +func (m LogoModel) Update(msg tea.Msg) (LogoModel, tea.Cmd) { + if _, ok := msg.(LogoTickMsg); ok { + m.frame++ + if m.phase == LogoPhaseAnimating { + target := 1.0 + m.alpha, m.vel = m.spring.Update(m.alpha, m.vel, target) + + if m.alpha >= 0.95 && m.frame > 60 { + m.phase = LogoPhaseVisible + m.displayLogo = true + } + if m.phase == LogoPhaseVisible && m.frame > 180 { + m.phase = LogoPhaseDone + m.displayLogo = false + } + } + if m.phase < LogoPhaseDone { + return m, tea.Tick(16*time.Millisecond, func(time.Time) tea.Msg { + return LogoTickMsg{} + }) + } + } + return m, nil +} + +func (m LogoModel) View() string { + if m.phase == LogoPhaseHidden || m.phase == LogoPhaseDone || !m.displayLogo { + return "" + } + lines := logoLines() + var b strings.Builder + for _, line := range lines { + if m.alpha < 0.1 { + b.WriteString(lipgloss.NewStyle(). + Foreground(lipgloss.Color("#4c566a")). + Render(line)) + } else if m.alpha < 0.5 { + b.WriteString(lipgloss.NewStyle(). + Foreground(lipgloss.Color("#81a1c1")). + Render(line)) + } else { + b.WriteString(lipgloss.NewStyle(). + Foreground(lipgloss.Color("#88c0d0")). + Bold(true). + Render(line)) + } + b.WriteString("\n") + } + return b.String() +} + +func (m LogoModel) IsDone() bool { + return m.phase == LogoPhaseDone +} + +func (m LogoModel) ShouldShow() bool { + return m.displayLogo && m.phase < LogoPhaseDone +} diff --git a/internal/tui/markdown.go b/internal/tui/markdown.go new file mode 100644 index 0000000..3d8b3e4 --- /dev/null +++ b/internal/tui/markdown.go @@ -0,0 +1,73 @@ +package tui + +import ( + "strings" + + "github.com/charmbracelet/glamour" +) + +// MarkdownRenderer handles markdown rendering with caching support. +type MarkdownRenderer struct { + renderer *glamour.TermRenderer + width int + isDark bool +} + +func glamourStyle(isDark bool) string { + if noColor { + return "notty" + } + if isDark { + return "dark" + } + return "light" +} + +// NewMarkdownRenderer creates a renderer for the given terminal width and theme. +func NewMarkdownRenderer(width int, isDark bool) *MarkdownRenderer { + // Use standard glamour style with word wrapping + // Glamour automatically handles syntax highlighting via Chroma + r, _ := glamour.NewTermRenderer( + glamour.WithStandardStyle(glamourStyle(isDark)), + glamour.WithWordWrap(width-4), + ) + + return &MarkdownRenderer{ + renderer: r, + width: width, + isDark: isDark, + } +} + +// RenderFull renders a complete markdown document (for finished messages). +// This is the "format-on-complete" path used when streaming ends. +func (mr *MarkdownRenderer) RenderFull(content string) string { + if content == "" || mr.renderer == nil { + return content + } + + rendered, err := mr.renderer.Render(content) + if err != nil { + return content + } + + return strings.TrimRight(rendered, "\n") +} + +// RenderStreaming renders content during streaming (plain text, no Glamour). +// This avoids jitter from re-rendering incomplete markdown. +func (mr *MarkdownRenderer) RenderStreaming(content string) string { + return content +} + +// SetWidth updates the renderer for a new terminal width. +func (mr *MarkdownRenderer) SetWidth(width int) { + mr.width = width + r, err := glamour.NewTermRenderer( + glamour.WithStandardStyle(glamourStyle(mr.isDark)), + glamour.WithWordWrap(width-4), + ) + if err == nil { + mr.renderer = r + } +} diff --git a/internal/tui/messages.go b/internal/tui/messages.go new file mode 100644 index 0000000..0c3154d --- /dev/null +++ b/internal/tui/messages.go @@ -0,0 +1,125 @@ +package tui + +import ( + "time" + + tea "charm.land/bubbletea/v2" +) + +type StreamTextMsg struct { + Text string +} + +type StreamDoneMsg struct { + EvalCount int + PromptTokens int +} + +type ToolCallStartMsg struct { + Name string + Args map[string]any + StartTime time.Time +} + +type ToolCallResultMsg struct { + Name string + Result string + IsError bool + Duration time.Duration +} + +type ErrorMsg struct { + Msg string +} + +type SystemMessageMsg struct { + Msg string +} + +type AgentDoneMsg struct{} + +type FailedServer struct { + Name string + Reason string +} + +type InitCompleteMsg struct { + Model string + ModelList []string + AgentProfile string + AgentList []string + ToolCount int + ServerCount int + NumCtx int + FailedServers []FailedServer + ICEEnabled bool + ICEConversations int + ICESessionID string +} + +type CommandResultMsg struct { + Text string +} + +type StartupStatusMsg struct { + ID string + Label string + Status string + Detail string +} + +type CompletionSearchResultMsg struct { + Tag int + Results []Completion +} + +type CompletionDebounceTickMsg struct { + Tag int + Query string +} + +type spinnerTickMsg struct{} + +type PlanFormCompletedMsg struct { + Prompt string +} + +type DoneFlashExpiredMsg struct{} + +type SessionCreatedMsg struct { + NoteID int + Err error +} + +type SessionListMsg struct { + Sessions []SessionListItem + Err error +} + +type SessionLoadedMsg struct { + Entries []ChatEntry + Title string + Err error +} + +type ToolApprovalMsg struct { + ToolName string + Args map[string]any + Response chan<- ToolApprovalResponse +} + +type ToolApprovalResponse struct { + Allowed bool + Always bool +} + +type CommitResultMsg struct { + Message string + Err error +} + +func sendMsg(p *tea.Program, msg tea.Msg) { + if p != nil { + p.Send(msg) + } +} diff --git a/internal/tui/modal.go b/internal/tui/modal.go new file mode 100644 index 0000000..4e9710e --- /dev/null +++ b/internal/tui/modal.go @@ -0,0 +1,160 @@ +package tui + +import ( + "strings" + + "charm.land/lipgloss/v2" +) + +type ModalConfig struct { + Title string + Content string + Footer string + Width int + MaxWidth int + BorderStyle lipgloss.Border + PaddingTop int + PaddingBottom int + PaddingLeft int + PaddingRight int +} + +func DefaultModalConfig() ModalConfig { + return ModalConfig{ + MaxWidth: 60, + BorderStyle: lipgloss.RoundedBorder(), + PaddingTop: 1, + PaddingBottom: 1, + PaddingLeft: 2, + PaddingRight: 2, + } +} + +func RenderModal(baseContent string, config ModalConfig, styles Styles, viewportWidth, viewportHeight int) string { + cfg := DefaultModalConfig() + if config.Title != "" { + cfg.Title = config.Title + } + if config.Content != "" { + cfg.Content = config.Content + } + if config.Footer != "" { + cfg.Footer = config.Footer + } + if config.Width > 0 { + cfg.Width = config.Width + } + if config.MaxWidth > 0 { + cfg.MaxWidth = config.MaxWidth + } + if config.BorderStyle != (lipgloss.Border{}) { + cfg.BorderStyle = config.BorderStyle + } + cfg.PaddingTop = config.PaddingTop + cfg.PaddingBottom = config.PaddingBottom + cfg.PaddingLeft = config.PaddingLeft + cfg.PaddingRight = config.PaddingRight + var b strings.Builder + if cfg.Title != "" { + b.WriteString(styles.OverlayTitle.Render(cfg.Title)) + b.WriteString("\n") + } + if cfg.Content != "" { + b.WriteString(cfg.Content) + if cfg.Footer != "" { + b.WriteString("\n\n") + } + } + if cfg.Footer != "" { + b.WriteString(styles.OverlayDim.Render(cfg.Footer)) + } + contentW := cfg.Width + if contentW == 0 { + lines := strings.Split(b.String(), "\n") + for _, line := range lines { + w := lipgloss.Width(line) + if w+cfg.PaddingLeft+cfg.PaddingRight+2 > contentW { + contentW = w + cfg.PaddingLeft + cfg.PaddingRight + 2 + } + } + } + if contentW > cfg.MaxWidth { + contentW = cfg.MaxWidth + } + if contentW < 30 { + contentW = 30 + } + if contentW >= viewportWidth-4 { + contentW = viewportWidth - 4 + } + box := lipgloss.NewStyle(). + Border(cfg.BorderStyle). + BorderForeground(lipgloss.Color(styles.OverlayBorder)). + Padding(cfg.PaddingTop, cfg.PaddingLeft, cfg.PaddingBottom, cfg.PaddingRight). + Width(contentW) + return box.Render(b.String()) +} + +func CenterOverlay(baseContent, overlay string, viewportWidth, viewportHeight int) string { + baseLines := strings.Split(baseContent, "\n") + overlayLines := strings.Split(overlay, "\n") + startY := (len(baseLines) - len(overlayLines)) / 2 + if startY < 0 { + startY = 0 + } + for i, ol := range overlayLines { + row := startY + i + if row >= len(baseLines) { + break + } + olW := lipgloss.Width(ol) + padLeft := (viewportWidth - olW) / 2 + if padLeft < 0 { + padLeft = 0 + } + baseLines[row] = strings.Repeat(" ", padLeft) + ol + } + return strings.Join(baseLines, "\n") +} + +type ModalBuilder struct { + config ModalConfig +} + +func NewModal() *ModalBuilder { + return &ModalBuilder{config: DefaultModalConfig()} +} + +func (mb *ModalBuilder) Title(title string) *ModalBuilder { + mb.config.Title = title + return mb +} + +func (mb *ModalBuilder) Content(content string) *ModalBuilder { + mb.config.Content = content + return mb +} + +func (mb *ModalBuilder) Footer(footer string) *ModalBuilder { + mb.config.Footer = footer + return mb +} + +func (mb *ModalBuilder) Width(width int) *ModalBuilder { + mb.config.Width = width + return mb +} + +func (mb *ModalBuilder) MaxWidth(maxWidth int) *ModalBuilder { + mb.config.MaxWidth = maxWidth + return mb +} + +func (mb *ModalBuilder) Build(styles Styles, viewportWidth, viewportHeight int) string { + return RenderModal("", mb.config, styles, viewportWidth, viewportHeight) +} + +func (mb *ModalBuilder) BuildOnContent(baseContent string, styles Styles, viewportWidth, viewportHeight int) string { + modal := RenderModal("", mb.config, styles, viewportWidth, viewportHeight) + return CenterOverlay(baseContent, modal, viewportWidth, viewportHeight) +} diff --git a/internal/tui/mode.go b/internal/tui/mode.go new file mode 100644 index 0000000..745ef2c --- /dev/null +++ b/internal/tui/mode.go @@ -0,0 +1,43 @@ +package tui + +import ( + "ai-agent/internal/config" +) + +type Mode int + +const ( + ModeAsk Mode = iota + ModePlan + ModeBuild +) + +type ModeConfig struct { + Label string + SystemPromptPrefix string + AllowTools bool + PreferredCapability config.ModelCapability +} + +func DefaultModeConfigs() [3]ModeConfig { + return [3]ModeConfig{ + { + Label: "ASK", + SystemPromptPrefix: "Provide direct, concise answers. Use tools when the user asks about files or the codebase.", + AllowTools: true, + PreferredCapability: config.CapabilitySimple, + }, + { + Label: "PLAN", + SystemPromptPrefix: "Help the user plan and design. Break down tasks into steps. Use tools to read and explore, but do not modify files.", + AllowTools: true, + PreferredCapability: config.CapabilityComplex, + }, + { + Label: "BUILD", + SystemPromptPrefix: "Execute tasks using all available tools.", + AllowTools: true, + PreferredCapability: config.CapabilityAdvanced, + }, + } +} diff --git a/internal/tui/mode_test.go b/internal/tui/mode_test.go new file mode 100644 index 0000000..a803b40 --- /dev/null +++ b/internal/tui/mode_test.go @@ -0,0 +1,133 @@ +package tui + +import ( + "strings" + "testing" +) + +func TestCycleMode(t *testing.T) { + t.Run("cycles_ask_to_build", func(t *testing.T) { + m := newTestModel(t) + // Default mode is ASK. + if m.mode != ModeAsk { + t.Fatalf("expected initial mode ModeAsk, got %d", m.mode) + } + + updated, _ := m.Update(shiftTabKey()) + m = updated.(*Model) + + if m.mode != ModePlan { + t.Errorf("expected ModePlan after cycling from ASK, got %d", m.mode) + } + }) + + t.Run("cycles_ask_to_plan", func(t *testing.T) { + m := newTestModel(t) + m.mode = ModeAsk + + updated, _ := m.Update(shiftTabKey()) + m = updated.(*Model) + + if m.mode != ModePlan { + t.Errorf("expected ModePlan after cycling from ASK, got %d", m.mode) + } + }) + + t.Run("cycles_plan_to_build", func(t *testing.T) { + m := newTestModel(t) + m.mode = ModePlan + + updated, _ := m.Update(shiftTabKey()) + m = updated.(*Model) + + if m.mode != ModeBuild { + t.Errorf("expected ModeBuild after cycling from PLAN, got %d", m.mode) + } + }) + + t.Run("adds_system_message", func(t *testing.T) { + m := newTestModel(t) + before := len(m.entries) + + updated, _ := m.Update(shiftTabKey()) + m = updated.(*Model) + + if len(m.entries) <= before { + t.Fatal("expected system message entry after mode switch") + } + last := m.entries[len(m.entries)-1] + if last.Kind != "system" { + t.Errorf("expected 'system' kind, got %q", last.Kind) + } + if !strings.Contains(last.Content, "Mode switched to") { + t.Errorf("expected mode switch info in content, got %q", last.Content) + } + }) + + t.Run("no_cycle_when_not_idle", func(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + before := m.mode + + updated, _ := m.Update(shiftTabKey()) + m = updated.(*Model) + + if m.mode != before { + t.Error("should not cycle mode when not idle") + } + }) +} + +func TestModeStatusLine(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + + t.Run("build_mode_badge", func(t *testing.T) { + m.mode = ModeBuild + status := m.renderStatusLine() + if !strings.Contains(status, "BUILD") { + t.Errorf("status line should contain BUILD badge, got %q", status) + } + }) + + t.Run("ask_mode_badge", func(t *testing.T) { + m.mode = ModeAsk + status := m.renderStatusLine() + if !strings.Contains(status, "ASK") { + t.Errorf("status line should contain ASK badge, got %q", status) + } + }) + + t.Run("plan_mode_badge", func(t *testing.T) { + m.mode = ModePlan + status := m.renderStatusLine() + if !strings.Contains(status, "PLAN") { + t.Errorf("status line should contain PLAN badge, got %q", status) + } + }) +} + +func TestDefaultModeConfigs(t *testing.T) { + configs := DefaultModeConfigs() + + if configs[ModeAsk].Label != "ASK" { + t.Errorf("ModeAsk label should be ASK, got %q", configs[ModeAsk].Label) + } + if !configs[ModeAsk].AllowTools { + t.Error("ModeAsk should allow tools") + } + + if configs[ModePlan].Label != "PLAN" { + t.Errorf("ModePlan label should be PLAN, got %q", configs[ModePlan].Label) + } + if !configs[ModePlan].AllowTools { + t.Error("ModePlan should allow tools") + } + + if configs[ModeBuild].Label != "BUILD" { + t.Errorf("ModeBuild label should be BUILD, got %q", configs[ModeBuild].Label) + } + if !configs[ModeBuild].AllowTools { + t.Error("ModeBuild should allow tools") + } +} diff --git a/internal/tui/model.go b/internal/tui/model.go new file mode 100644 index 0000000..0cb981a --- /dev/null +++ b/internal/tui/model.go @@ -0,0 +1,2034 @@ +package tui + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + "time" + + "ai-agent/internal/agent" + "ai-agent/internal/command" + "ai-agent/internal/config" + "ai-agent/internal/llm" + "ai-agent/internal/permission" + "ai-agent/internal/skill" + + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/list" + "charm.land/bubbles/v2/spinner" + "charm.land/bubbles/v2/textinput" + "charm.land/bubbles/v2/textarea" + "charm.land/bubbles/v2/viewport" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/atotto/clipboard" + "github.com/charmbracelet/log" +) + +type State int + +const ( + StateIdle State = iota + StateWaiting + StateStreaming +) + +type OverlayKind int + +const ( + OverlayNone OverlayKind = iota + OverlayHelp + OverlayCompletion + OverlayModelPicker + OverlayPlanForm + OverlaySessionsPicker +) + +type CompletionState struct { + Kind string + CurrentPath string + Filter textinput.Model + AllItems []Completion + FilteredItems []Completion + Selected map[int]bool + SearchResults []Completion + Index int + DebounceTag int + Searching bool +} + +type ToolStatus int + +const ( + ToolStatusRunning ToolStatus = iota + ToolStatusDone + ToolStatusError +) + +type ToolEntry struct { + Name string + Args string + RawArgs map[string]any + Result string + IsError bool + Status ToolStatus + StartTime time.Time + Duration time.Duration + Collapsed bool + BeforeContent string + DiffLines []DiffLine +} + +type ChatEntry struct { + Kind string + Content string + RenderedContent string + Name string + IsError bool + ToolIndex int + ThinkingContent string + ThinkingCollapsed bool +} + +type startupItem struct { + ID string + Label string + Status string + Detail string +} + +type Model struct { + viewport viewport.Model + input textarea.Model + spin spinner.Model + scramble ScrambleModel + styles Styles + md *MarkdownRenderer + keys KeyMap + state State + overlay OverlayKind + entries []ChatEntry + streamBuf strings.Builder + width int + height int + ready bool + isDark bool + evalCount int + promptTokens int + toolsPending int + inputLines int + userScrolledUp bool + scrollAnchor int + anchorActive bool + lastContentHeight int + initializing bool + startupItems []startupItem + initCancel context.CancelFunc + completionState *CompletionState + attachments []string + toolEntries []ToolEntry + toolsCollapsed bool + toolEntryRows map[int]int + toolCardMgr ToolCardManager + cachedEntriesRender string + cachedEntryCount int + cachedToolEntryRows map[int]int + entryCacheValid bool + thinkBuf strings.Builder + inThinking bool + thinkSearchBuf string + doneFlash bool + sessionNoteID int + sessionsPickerState *SessionsPickerState + pendingPaste string + isCompact bool + isWide bool + forceCompact bool + mode Mode + modeConfigs [3]ModeConfig + modelManager *llm.ModelManager + router *config.Router + modelPickerState *ModelPickerState + planFormState *PlanFormState + logger *log.Logger + agent *agent.Agent + cmdRegistry *command.Registry + skillMgr *skill.Manager + completer *Completer + loadedFile string + program *tea.Program + cancel context.CancelFunc + model string + modelList []string + agentProfile string + agentList []string + toolCount int + serverCount int + numCtx int + toastMgr *ToastManager + toastStyles ToastStyles + failedServers []FailedServer + iceEnabled bool + iceConversations int + iceSessionID string + sessionEvalTotal int + sessionPromptTotal int + sessionTurnCount int + fileChanges map[string]int + pendingApproval *ToolApprovalMsg + promptHistory []string + promptHistoryPath string + historyIndex int + historySaved string + welcomeModel WelcomeModel + sidePanel SidePanelModel + logoModel LogoModel + helpViewport viewport.Model + searchState *SearchState + progressTracker *ProgressTracker + resizer *PanelResizer + contextMenu *ContextMenuState + timestampConfig TimestampConfig + timestampHelper *TimestampHelper + keyHints *KeyHints + accessibility *AccessibilityHelper + tableHelper *TableHelper + lang Lang +} + +func New(ag *agent.Agent, cmdReg *command.Registry, skillMgr *skill.Manager, completer *Completer, modelManager *llm.ModelManager, router *config.Router, logger *log.Logger) *Model { + initialLang := LoadLang() + loc := Locale(initialLang) + ta := textarea.New() + ta.Placeholder = loc.Placeholder + ta.Focus() + ta.CharLimit = 4096 + ta.SetHeight(1) + ta.ShowLineNumbers = false + ta.Prompt = "❯ " + styles := textarea.DefaultDarkStyles() + styles.Focused.Base = lipgloss.NewStyle() + styles.Focused.CursorLine = lipgloss.NewStyle() + styles.Blurred.Base = lipgloss.NewStyle() + ta.SetStyles(styles) + styles.Focused.Prompt = lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")) + ta.SetStyles(styles) + s := spinner.New( + spinner.WithSpinner(spinner.MiniDot), + spinner.WithStyle(lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0"))), + ) + return &Model{ + input: ta, + spin: s, + scramble: NewScrambleModel(true), + welcomeModel: NewWelcomeModel(true), + sidePanel: NewSidePanelModel(true), + logoModel: NewLogoModel(true), + styles: NewStyles(true), + keys: DefaultKeyMap(), + state: StateIdle, + isDark: true, + inputLines: 1, + toolsCollapsed: true, + initializing: true, + mode: ModeAsk, + modeConfigs: DefaultModeConfigs(), + modelManager: modelManager, + router: router, + logger: logger, + agent: ag, + cmdRegistry: cmdReg, + skillMgr: skillMgr, + completer: completer, + historyIndex: -1, + promptHistory: loadPromptHistoryForModel(DefaultPromptHistoryPath()), + promptHistoryPath: DefaultPromptHistoryPath(), + toastMgr: NewToastManager(), + toastStyles: DefaultToastStyles(true), + toolCardMgr: NewToolCardManager(true), + searchState: NewSearchState(), + progressTracker: NewProgressTracker(true), + resizer: NewPanelResizer(20, 60, true), + contextMenu: &ContextMenuState{Active: false}, + timestampConfig: DefaultTimestampConfig(), + timestampHelper: NewTimestampHelper(DefaultTimestampConfig(), true), + keyHints: DefaultKeyHints(initialLang, true), + accessibility: NewAccessibilityHelper(true), + tableHelper: NewTableHelper(true), + lang: initialLang, + } +} + +func (m *Model) tr() L { return Locale(m.lang) } + +func (m *Model) SetProgram(p *tea.Program) { + m.program = p +} + +func (m *Model) SetInitCancel(cancel context.CancelFunc) { + m.initCancel = cancel +} + +func (m *Model) renderStartup(b *strings.Builder) { + m.renderWelcome(b) +} + +func (m *Model) Init() tea.Cmd { + return tea.Batch( + textarea.Blink, + tea.RequestBackgroundColor, + m.spin.Tick, + func() tea.Msg { + return spinnerTickMsg{} + }, + ) +} + +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.BackgroundColorMsg: + m.isDark = msg.IsDark() + m.styles = NewStyles(m.isDark) + m.spin.Style = m.styles.StatusDot + m.scramble.SetDark(msg.IsDark()) + m.toastStyles = DefaultToastStyles(m.isDark) + m.toastMgr.SetStyles(m.toastStyles) + m.toolCardMgr.SetDark(msg.IsDark()) + m.progressTracker.SetDark(msg.IsDark()) + m.resizer.SetDark(msg.IsDark()) + m.timestampHelper.SetDark(msg.IsDark()) + m.keyHints.SetDark(msg.IsDark()) + m.accessibility.SetDark(msg.IsDark()) + m.tableHelper.SetDark(msg.IsDark()) + if m.width > 0 { + m.md = NewMarkdownRenderer(m.width-2, m.isDark) + m.invalidateRenderedCache() + } + if m.ready { + m.viewport.SetContent(m.renderEntries()) + } + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.isCompact = msg.Width < 80 || msg.Height < 24 + m.isWide = msg.Width > 120 + panelWidth := 30 + if msg.Width < 100 { + panelWidth = 25 + } else if msg.Width > 160 { + panelWidth = 40 + } + viewportWidth := msg.Width - 1 + if m.sidePanel.IsVisible() { + viewportWidth = msg.Width - panelWidth - 2 + } + if viewportWidth < 20 { + viewportWidth = 20 + } + contentWidth := viewportWidth - 6 + if contentWidth < 14 { + contentWidth = 14 + } + m.md = NewMarkdownRenderer(contentWidth, m.isDark) + m.sidePanel.SetWidth(panelWidth) + m.sidePanel.SetHeight(msg.Height - 2) + contentH := msg.Height - 1 - m.footerHeight() + if contentH < 1 { + contentH = 1 + } + if !m.ready { + m.viewport = viewport.New( + viewport.WithWidth(viewportWidth), + viewport.WithHeight(contentH), + ) + m.viewport.KeyMap.PageDown = key.NewBinding(key.WithKeys("pgdown")) + m.viewport.KeyMap.PageUp = key.NewBinding(key.WithKeys("pgup")) + m.viewport.KeyMap.HalfPageUp = key.NewBinding(key.WithKeys("ctrl+u")) + m.viewport.KeyMap.HalfPageDown = key.NewBinding(key.WithKeys("ctrl+d")) + m.viewport.KeyMap.Up = key.NewBinding(key.WithDisabled()) + m.viewport.KeyMap.Down = key.NewBinding(key.WithDisabled()) + m.viewport.KeyMap.Left = key.NewBinding(key.WithDisabled()) + m.viewport.KeyMap.Right = key.NewBinding(key.WithDisabled()) + m.viewport.SetContent(m.renderEntries()) + m.ready = true + m.scrollAnchor = 0 + m.anchorActive = true + m.lastContentHeight = 0 + m.toolEntryRows = make(map[int]int, 8) + } else { + m.viewport.SetWidth(viewportWidth) + m.viewport.SetHeight(contentH) + widthDelta := abs(m.width - msg.Width) + if widthDelta > 5 { + m.invalidateRenderedCache() + } + m.viewport.SetContent(m.renderEntries()) + if m.anchorActive { + m.viewport.GotoBottom() + } + } + if m.overlay == OverlayHelp { + m.initHelpViewport() + } + m.input.SetWidth(viewportWidth) + m.syncInputHeight() + case tea.KeyPressMsg: + if m.initializing { + if key.Matches(msg, m.keys.Quit) { + if m.initCancel != nil { + m.initCancel() + } + return m, tea.Quit + } + return m, nil + } + if m.pendingApproval != nil { + switch msg.String() { + case "y": + m.pendingApproval.Response <- ToolApprovalResponse{Allowed: true} + m.pendingApproval = nil + case "n": + m.pendingApproval.Response <- ToolApprovalResponse{Allowed: false} + m.pendingApproval = nil + case "a": + m.pendingApproval.Response <- ToolApprovalResponse{Allowed: true, Always: true} + m.pendingApproval = nil + } + return m, nil + } + if m.pendingPaste != "" { + switch { + case msg.String() == "y": + m.input.InsertString("```\n" + m.pendingPaste + "\n```") + m.pendingPaste = "" + m.syncInputHeight() + case msg.String() == "n": + m.input.InsertString(m.pendingPaste) + m.pendingPaste = "" + m.syncInputHeight() + case key.Matches(msg, m.keys.Cancel): + m.pendingPaste = "" + } + return m, nil + } + if m.overlay != OverlayNone { + if key.Matches(msg, m.keys.Cancel) { + switch m.overlay { + case OverlayCompletion: + m.input.SetValue("") + m.closeCompletion() + case OverlayModelPicker: + m.closeModelPicker() + case OverlayPlanForm: + m.closePlanForm() + case OverlaySessionsPicker: + if m.sessionsPickerState != nil && m.sessionsPickerState.List.FilterState() == list.Filtering { + var cmd tea.Cmd + m.sessionsPickerState.List, cmd = m.sessionsPickerState.List.Update(msg) + cmds = append(cmds, cmd) + return m, tea.Batch(cmds...) + } + m.closeSessionsPicker() + default: + m.overlay = OverlayNone + m.input.Focus() + } + return m, nil + } + if m.overlay == OverlayHelp { + switch msg.String() { + case "?", "q": + m.overlay = OverlayNone + m.input.Focus() + case "j", "down": + m.helpViewport.ScrollDown(1) + case "k", "up": + m.helpViewport.ScrollUp(1) + case "pgdown": + m.helpViewport.PageDown() + case "pgup": + m.helpViewport.PageUp() + case "d": + m.helpViewport.HalfPageDown() + case "u": + m.helpViewport.HalfPageUp() + case "g": + m.helpViewport.GotoTop() + case "G": + m.helpViewport.GotoBottom() + } + return m, nil + } + if m.overlay == OverlayModelPicker && m.modelPickerState != nil { + if key.Matches(msg, m.keys.CompleteSelect) { + if item := m.modelPickerState.List.SelectedItem(); item != nil { + mi := item.(modelItem) + m.selectModel(mi.name) + } + } else { + var cmd tea.Cmd + m.modelPickerState.List, cmd = m.modelPickerState.List.Update(msg) + cmds = append(cmds, cmd) + } + return m, tea.Batch(cmds...) + } + if m.overlay == OverlayPlanForm && m.planFormState != nil { + submitted, cancelled := m.updatePlanForm(msg) + if cancelled { + m.closePlanForm() + return m, nil + } + if submitted { + prompt := m.planFormState.AssemblePrompt() + m.closePlanForm() + return m, m.submitPlanFormPrompt(prompt) + } + return m, nil + } + if m.overlay == OverlaySessionsPicker && m.sessionsPickerState != nil { + if key.Matches(msg, m.keys.CompleteSelect) { + if item := m.sessionsPickerState.List.SelectedItem(); item != nil { + si := item.(sessionItem) + sessionID := si.id + sessionTitle := si.title + m.closeSessionsPicker() + return m, func() tea.Msg { + note, err := loadSession(sessionID) + if err != nil { + return SessionLoadedMsg{Err: err} + } + entries := deserializeEntries(note.Content) + return SessionLoadedMsg{Entries: entries, Title: sessionTitle} + } + } + } else { + var cmd tea.Cmd + m.sessionsPickerState.List, cmd = m.sessionsPickerState.List.Update(msg) + cmds = append(cmds, cmd) + } + return m, tea.Batch(cmds...) + } + if m.overlay == OverlayCompletion && m.isCompletionActive() { + cs := m.completionState + switch { + case key.Matches(msg, m.keys.CompleteUp): + if cs.Index > 0 { + cs.Index-- + } + case key.Matches(msg, m.keys.CompleteDown): + if cs.Index < len(cs.FilteredItems)-1 { + cs.Index++ + } + case key.Matches(msg, m.keys.CompleteSelect): + if cs.Index < len(cs.FilteredItems) && cs.Kind == "attachments" && cs.FilteredItems[cs.Index].Category == "folder" { + m.drillIntoFolder() + } else { + m.acceptCompletion() + } + case key.Matches(msg, m.keys.CompleteToggle): + m.toggleCompletionSelection() + default: + if msg.Code == tea.KeyBackspace && cs.Filter.Value() == "" && cs.Kind == "attachments" && cs.CurrentPath != "" { + m.drillUpFolder() + return m, nil + } + oldFilter := cs.Filter.Value() + var cmd tea.Cmd + cs.Filter, cmd = cs.Filter.Update(msg) + if cs.Filter.Value() != oldFilter { + cs.FilteredItems = FilterCompletions(cs.AllItems, cs.Filter.Value()) + cs.Index = 0 + if cs.Kind == "attachments" && cs.Filter.Value() != "" { + cs.DebounceTag++ + tag := cs.DebounceTag + query := cs.Filter.Value() + return m, tea.Batch(cmd, tea.Tick(300*time.Millisecond, func(time.Time) tea.Msg { + return CompletionDebounceTickMsg{Tag: tag, Query: query} + })) + } + } + return m, cmd + } + return m, nil + } + return m, nil + } + switch { + case key.Matches(msg, m.keys.Quit): + if m.cancel != nil { + m.cancel() + } + return m, tea.Quit + case key.Matches(msg, m.keys.Cancel): + if (m.state == StateStreaming || m.state == StateWaiting) && m.cancel != nil { + m.cancel() + } + case key.Matches(msg, m.keys.Help): + if m.state == StateIdle && strings.TrimSpace(m.input.Value()) == "" { + m.overlay = OverlayHelp + m.initHelpViewport() + m.input.Blur() + return m, nil + } + case key.Matches(msg, m.keys.ToggleTools): + if m.state == StateIdle && strings.TrimSpace(m.input.Value()) == "" { + m.toolsCollapsed = !m.toolsCollapsed + for i := range m.toolEntries { + m.toolEntries[i].Collapsed = m.toolsCollapsed + } + m.invalidateEntryCache() + m.viewport.SetContent(m.renderEntries()) + return m, nil + } + case key.Matches(msg, m.keys.ToggleFocusedTool): + if m.state == StateIdle && strings.TrimSpace(m.input.Value()) == "" { + if len(m.toolEntries) > 0 { + last := len(m.toolEntries) - 1 + m.toolEntries[last].Collapsed = !m.toolEntries[last].Collapsed + m.invalidateEntryCache() + m.viewport.SetContent(m.renderEntries()) + } + return m, nil + } + case key.Matches(msg, m.keys.CompactToggle): + if m.state == StateIdle { + m.forceCompact = !m.forceCompact + m.invalidateEntryCache() + m.viewport.SetContent(m.renderEntries()) + return m, nil + } + case key.Matches(msg, m.keys.ToggleThinking): + if m.state == StateIdle && strings.TrimSpace(m.input.Value()) == "" { + for i := len(m.entries) - 1; i >= 0; i-- { + if m.entries[i].Kind == "assistant" && m.entries[i].ThinkingContent != "" { + m.entries[i].ThinkingCollapsed = !m.entries[i].ThinkingCollapsed + m.invalidateEntryCache() + m.viewport.SetContent(m.renderEntries()) + break + } + } + return m, nil + } + case key.Matches(msg, m.keys.ExternalEditor): + if m.state == StateIdle { + return m, m.openExternalEditor() + } + case key.Matches(msg, m.keys.CopyLast): + if m.state == StateIdle && strings.TrimSpace(m.input.Value()) == "" { + if content := m.lastAssistantContent(); content != "" { + return m, m.copyToClipboard(content) + } + } + case key.Matches(msg, m.keys.ClearView): + if m.state == StateIdle { + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return m, nil + } + case key.Matches(msg, m.keys.NewConvo): + if m.state == StateIdle { + m.agent.ClearHistory() + m.entries = nil + m.toolEntries = nil + m.sessionEvalTotal = 0 + m.sessionPromptTotal = 0 + m.sessionTurnCount = 0 + m.fileChanges = nil + m.invalidateEntryCache() + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: "New conversation started.", + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return m, nil + } + case key.Matches(msg, m.keys.CycleMode): + if m.state == StateIdle { + m.cycleMode() + return m, nil + } + case key.Matches(msg, m.keys.LanguageCycle): + if m.state == StateIdle { + m.lang = NextLang(m.lang) + _ = SaveLang(m.lang) + m.input.Placeholder = m.tr().Placeholder + m.keyHints.SetHints(defaultHintsForLang(m.lang)) + if m.toastMgr != nil { + m.toastMgr.AddToast(Toast{Message: fmt.Sprintf(m.tr().LanguageSet, LangName(m.lang)), Kind: ToastKindInfo}) + } + m.sidePanel.UpdateSections(m.lang, m.model, m.modelList, m.serverCount, m.toolCount, m.iceEnabled, m.iceConversations) + return m, nil + } + case key.Matches(msg, m.keys.ModelPicker): + if m.state == StateIdle { + m.openModelPicker() + return m, nil + } + case key.Matches(msg, m.keys.NewLine): + if m.state == StateIdle { + m.input.InsertString("\n") + m.syncInputHeight() + return m, nil + } + case key.Matches(msg, m.keys.Send): + if m.state == StateIdle { + return m, m.submitInput() + } + case key.Matches(msg, m.keys.Complete): + if m.state == StateIdle && m.completer != nil && !m.isCompletionActive() { + m.triggerCompletion(m.input.Value()) + } + case key.Matches(msg, m.keys.HistoryUp): + if m.state == StateIdle && m.overlay == OverlayNone { + if strings.TrimSpace(m.input.Value()) == "" || m.historyIndex != -1 { + if m.navigateHistory(-1) { + return m, nil + } + } + } + case key.Matches(msg, m.keys.HistoryDown): + if m.state == StateIdle && m.overlay == OverlayNone { + if m.historyIndex != -1 { + if m.navigateHistory(1) { + return m, nil + } + } + } + case key.Matches(msg, m.keys.ToggleSidePanel): + if m.state == StateIdle { + m.sidePanel.Toggle() + panelWidth := 30 + if m.width < 100 { + panelWidth = 25 + } + contentWidth := m.width - 1 + if m.sidePanel.IsVisible() { + m.sidePanel.SetWidth(panelWidth) + m.sidePanel.SetHeight(m.height - 2) + contentWidth = m.width - panelWidth - 2 + } + if contentWidth < 20 { + contentWidth = 20 + } + m.viewport.SetWidth(contentWidth) + m.input.SetWidth(contentWidth) + m.invalidateRenderedCache() + m.viewport.SetContent(m.renderEntries()) + return m, nil + } + } + case StreamTextMsg: + if m.state == StateWaiting { + m.state = StateStreaming + } + mainText, thinkText, outInThinking, outSearchBuf := processStreamChunk( + msg.Text, m.inThinking, m.thinkSearchBuf, + ) + m.inThinking = outInThinking + m.thinkSearchBuf = outSearchBuf + if mainText != "" { + m.streamBuf.WriteString(mainText) + } + if thinkText != "" { + m.thinkBuf.WriteString(thinkText) + } + m.viewport.SetContent(m.renderEntries()) + if m.anchorActive { + m.viewport.GotoBottom() + } + case StreamDoneMsg: + m.evalCount = msg.EvalCount + m.promptTokens = msg.PromptTokens + m.sessionEvalTotal += msg.EvalCount + m.sessionPromptTotal += msg.PromptTokens + m.sessionTurnCount++ + case ToolCallStartMsg: + te := ToolEntry{ + Name: msg.Name, + Args: FormatToolArgs(msg.Args), + RawArgs: msg.Args, + Status: ToolStatusRunning, + StartTime: msg.StartTime, + Collapsed: m.toolsCollapsed, + } + if classifyTool(msg.Name) == ToolTypeFileWrite { + te.BeforeContent = readFileForDiff(msg.Args) + } + m.toolEntries = append(m.toolEntries, te) + m.toolsPending++ + kind := ToolCardGeneric + switch classifyTool(msg.Name) { + case ToolTypeFileRead, ToolTypeFileWrite: + kind = ToolCardFile + case ToolTypeBash: + kind = ToolCardBash + default: + kind = ToolCardGeneric + } + m.toolCardMgr.AddCard(msg.Name, kind, msg.StartTime) + m.entries = append(m.entries, ChatEntry{ + Kind: "tool_group", + ToolIndex: len(m.toolEntries) - 1, + }) + m.flushStream() + m.viewport.SetContent(m.renderEntries()) + if m.anchorActive { + m.viewport.GotoBottom() + } + case PlanFormCompletedMsg: + return m, m.submitPlanFormPrompt(msg.Prompt) + case ToolCallResultMsg: + m.invalidateEntryCache() + if m.logger != nil { + m.logger.Info("tool call", "name", msg.Name, "duration", msg.Duration, "error", msg.IsError) + } + for i := len(m.toolEntries) - 1; i >= 0; i-- { + if m.toolEntries[i].Name == msg.Name && m.toolEntries[i].Status == ToolStatusRunning { + result := msg.Result + if len(result) > 2000 { + result = result[:1997] + "..." + } + m.toolEntries[i].Result = result + m.toolEntries[i].IsError = msg.IsError + m.toolEntries[i].Duration = msg.Duration + if msg.IsError { + m.toolEntries[i].Status = ToolStatusError + } else { + m.toolEntries[i].Status = ToolStatusDone + } + if classifyTool(m.toolEntries[i].Name) == ToolTypeFileWrite && !msg.IsError { + afterContent := readFileForDiff(m.toolEntries[i].RawArgs) + m.toolEntries[i].DiffLines = computeDiff(m.toolEntries[i].BeforeContent, afterContent) + if path := toolSummary(ToolTypeFileWrite, m.toolEntries[i]); path != "" { + if m.fileChanges == nil { + m.fileChanges = make(map[string]int) + } + m.fileChanges[path]++ + } + } + break + } + } + cardState := ToolCardSuccess + if msg.IsError { + cardState = ToolCardError + } + m.toolCardMgr.UpdateCard(msg.Name, cardState, msg.Result, msg.Duration) + if m.toolsPending > 0 { + m.toolsPending-- + } + m.viewport.SetContent(m.renderEntries()) + if m.anchorActive { + m.viewport.GotoBottom() + } + case SystemMessageMsg: + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: msg.Msg, + }) + m.viewport.SetContent(m.renderEntries()) + if m.anchorActive { + m.viewport.GotoBottom() + } + case ErrorMsg: + if m.logger != nil { + m.logger.Error("error", "msg", msg.Msg) + } + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: msg.Msg, + }) + m.viewport.SetContent(m.renderEntries()) + if m.anchorActive { + m.viewport.GotoBottom() + } + case AgentDoneMsg: + if m.logger != nil { + m.logger.Info("agent done", "eval_tokens", m.evalCount) + } + m.flushStream() + m.state = StateIdle + m.userScrolledUp = false + m.anchorActive = true + m.scrollAnchor = 0 + m.input.Focus() + m.input.SetHeight(1) + m.inputLines = 1 + m.recalcViewportHeight() + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + m.doneFlash = true + cmds = append(cmds, tea.Tick(2*time.Second, func(time.Time) tea.Msg { + return DoneFlashExpiredMsg{} + })) + if m.sessionNoteID > 0 { + id := m.sessionNoteID + content := serializeEntries(m.entries) + cmds = append(cmds, func() tea.Msg { + _ = updateSessionNote(id, content) + return nil + }) + } + case StartupStatusMsg: + found := false + for i, item := range m.startupItems { + if item.ID == msg.ID { + m.startupItems[i].Status = msg.Status + m.startupItems[i].Detail = msg.Detail + found = true + break + } + } + if !found { + m.startupItems = append(m.startupItems, startupItem{ + ID: msg.ID, Label: msg.Label, Status: msg.Status, Detail: msg.Detail, + }) + } + sidePanelItems := make([]StartupItem, len(m.startupItems)) + for i, item := range m.startupItems { + sidePanelItems[i] = StartupItem{ + Label: item.Label, + Status: item.Status, + Detail: item.Detail, + } + } + m.sidePanel.SetStartupItems(sidePanelItems) + m.sidePanel.SetSpinnerTick() + if m.ready { + m.viewport.SetContent(m.renderEntries()) + } + return m, tea.Tick(100*time.Millisecond, func(time.Time) tea.Msg { + return spinnerTickMsg{} + }) + case spinnerTickMsg: + if m.initializing { + m.sidePanel.Tick() + return m, tea.Tick(80*time.Millisecond, func(time.Time) tea.Msg { + return spinnerTickMsg{} + }) + } + if m.toolsPending > 0 { + m.toolCardMgr.Tick() + return m, tea.Tick(80*time.Millisecond, func(time.Time) tea.Msg { + return spinnerTickMsg{} + }) + } + case InitCompleteMsg: + m.model = msg.Model + m.modelList = msg.ModelList + m.agentProfile = msg.AgentProfile + m.agentList = msg.AgentList + m.toolCount = msg.ToolCount + m.serverCount = msg.ServerCount + m.numCtx = msg.NumCtx + m.failedServers = msg.FailedServers + m.iceEnabled = msg.ICEEnabled + m.iceConversations = msg.ICEConversations + m.iceSessionID = msg.ICESessionID + if m.completer != nil { + m.completer.UpdateModels(msg.ModelList) + m.completer.UpdateAgents(msg.AgentList) + } + if len(msg.FailedServers) > 0 { + var parts []string + for _, fs := range msg.FailedServers { + parts = append(parts, fs.Name+" ("+fs.Reason+")") + } + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: "Failed to connect: " + strings.Join(parts, ", "), + }) + } + m.initializing = false + m.startupItems = nil + m.sidePanel.UpdateSections( + m.lang, + m.model, + m.modelList, + m.serverCount, + m.toolCount, + m.iceEnabled, + m.iceConversations, + ) + m.logoModel.Start() + m.viewport.SetContent(m.renderEntries()) + case CommandResultMsg: + if msg.Text != "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: msg.Text, + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + } + case CompletionDebounceTickMsg: + if m.isCompletionActive() && m.completionState.DebounceTag == msg.Tag { + cs := m.completionState + cs.Searching = true + query := msg.Query + tag := msg.Tag + return m, func() tea.Msg { + results := m.completer.SearchFiles(context.Background(), query) + return CompletionSearchResultMsg{Tag: tag, Results: results} + } + } + case CompletionSearchResultMsg: + if m.isCompletionActive() && m.completionState.DebounceTag == msg.Tag { + cs := m.completionState + cs.Searching = false + cs.SearchResults = msg.Results + existing := make(map[string]bool) + for _, item := range cs.AllItems { + existing[item.Insert] = true + } + for _, result := range msg.Results { + if !existing[result.Insert] { + cs.AllItems = append(cs.AllItems, result) + } + } + cs.FilteredItems = FilterCompletions(cs.AllItems, cs.Filter.Value()) + } + case ToolApprovalMsg: + m.pendingApproval = &msg + case CommitResultMsg: + if msg.Err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: fmt.Sprintf("Commit failed: %v", msg.Err), + }) + } else { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: fmt.Sprintf("Committed with message:\n%s", msg.Message), + }) + } + m.invalidateEntryCache() + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + case editorReturnMsg: + m.input.SetValue(msg.Content) + m.input.CursorEnd() + m.syncInputHeight() + m.input.Focus() + case DoneFlashExpiredMsg: + m.doneFlash = false + case SessionCreatedMsg: + if msg.Err == nil && msg.NoteID > 0 { + m.sessionNoteID = msg.NoteID + } + case SessionListMsg: + if msg.Err != nil { + m.entries = append(m.entries, ChatEntry{Kind: "error", Content: fmt.Sprintf("Sessions: %v", msg.Err)}) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + } else if len(msg.Sessions) == 0 { + m.entries = append(m.entries, ChatEntry{Kind: "system", Content: "No saved sessions found."}) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + } else { + m.sessionsPickerState = newSessionsPickerState(msg.Sessions, m.width, m.isDark) + m.overlay = OverlaySessionsPicker + m.input.Blur() + } + case SessionLoadedMsg: + m.invalidateEntryCache() + if msg.Err != nil { + m.entries = append(m.entries, ChatEntry{Kind: "error", Content: fmt.Sprintf("Load session: %v", msg.Err)}) + } else { + m.entries = msg.Entries + m.entries = append([]ChatEntry{{Kind: "system", Content: fmt.Sprintf("Restored session: %s", msg.Title)}}, m.entries...) + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + + case tea.MouseWheelMsg: + wasAtBottom := m.viewport.AtBottom() + m.viewport, _ = m.viewport.Update(msg) + if msg.Button == tea.MouseWheelUp && wasAtBottom { + m.anchorActive = false + m.userScrolledUp = true + m.scrollAnchor = 5 + } else if m.viewport.AtBottom() { + m.anchorActive = true + m.userScrolledUp = false + m.scrollAnchor = 0 + } + case tea.MouseClickMsg: + if msg.Button == tea.MouseLeft { + m.handleMouseClick(msg.X, msg.Y) + } + case tea.PasteMsg: + lines := strings.Count(msg.Content, "\n") + 1 + if lines > 10 && m.state == StateIdle { + m.pendingPaste = msg.Content + } else if m.state == StateIdle { + m.input.InsertString(msg.Content) + m.syncInputHeight() + } + } + if _, ok := msg.(ScrambleTickMsg); ok { + var cmd tea.Cmd + m.scramble, cmd = m.scramble.Update(msg) + cmds = append(cmds, cmd) + } + if !m.logoModel.IsDone() { + var cmd tea.Cmd + m.logoModel, cmd = m.logoModel.Update(msg) + cmds = append(cmds, cmd) + } + { + var cmd tea.Cmd + m.spin, cmd = m.spin.Update(msg) + cmds = append(cmds, cmd) + } + if m.state == StateIdle && m.overlay == OverlayNone && !m.initializing { + var cmd tea.Cmd + m.input, cmd = m.input.Update(msg) + cmds = append(cmds, cmd) + m.syncInputHeight() + newInput := m.input.Value() + if m.completer != nil && len(newInput) > 0 { + first := newInput[0] + if (first == '/' || first == '@' || first == '#') && !m.isCompletionActive() { + m.triggerCompletion(newInput) + } + } + if m.isCompletionActive() && (len(newInput) == 0 || (newInput[0] != '/' && newInput[0] != '@' && newInput[0] != '#')) { + m.closeCompletion() + } + } + wasAtBottom := m.viewport.AtBottom() + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + cmds = append(cmds, cmd) + if m.state == StateStreaming && wasAtBottom && !m.viewport.AtBottom() { + m.userScrolledUp = true + } + m.checkAutoScroll() + return m, tea.Batch(cmds...) +} + +func loadPromptHistoryForModel(path string) []string { + if path == "" { + return nil + } + list, err := LoadPromptHistory(path) + if err != nil || len(list) == 0 { + return nil + } + return list +} + +func (m *Model) pushHistory(text string) { + if text == "" { + return + } + if len(m.promptHistory) > 0 && m.promptHistory[len(m.promptHistory)-1] == text { + return + } + m.promptHistory = append(m.promptHistory, text) + if len(m.promptHistory) > promptHistoryMax { + m.promptHistory = m.promptHistory[len(m.promptHistory)-promptHistoryMax:] + } + m.historyIndex = -1 + if m.promptHistoryPath != "" { + _ = SavePromptHistory(m.promptHistoryPath, m.promptHistory) + } +} + +func (m *Model) navigateHistory(dir int) bool { + if len(m.promptHistory) == 0 { + return false + } + if dir == -1 { + if m.historyIndex == -1 { + m.historySaved = m.input.Value() + m.historyIndex = len(m.promptHistory) - 1 + } else if m.historyIndex > 0 { + m.historyIndex-- + } else { + return false + } + m.input.SetValue(m.promptHistory[m.historyIndex]) + m.input.CursorEnd() + return true + } + if dir == 1 { + if m.historyIndex == -1 { + return false + } + if m.historyIndex < len(m.promptHistory)-1 { + m.historyIndex++ + m.input.SetValue(m.promptHistory[m.historyIndex]) + m.input.CursorEnd() + } else { + m.historyIndex = -1 + m.input.SetValue(m.historySaved) + m.input.CursorEnd() + } + return true + } + return false +} + +func (m *Model) submitInput() tea.Cmd { + text := strings.TrimSpace(m.input.Value()) + if text == "" { + return nil + } + m.pushHistory(text) + m.input.Reset() + m.input.SetHeight(1) + if strings.HasPrefix(text, "/") { + parts := strings.Fields(text) + name := strings.TrimPrefix(parts[0], "/") + args := parts[1:] + ctx := m.buildCommandContext() + result := m.cmdRegistry.Execute(ctx, name, args) + if result.Error != "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: result.Error, + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + } + return m.handleCommandAction(result) + } + if m.mode == ModePlan { + m.openPlanForm(text) + return nil + } + return m.sendToAgent(text) +} + +func (m *Model) buildCommandContext() *command.Context { + ctx := &command.Context{ + Model: m.model, + ModelList: m.modelList, + AgentProfile: m.agentProfile, + AgentList: m.agentList, + ToolCount: m.toolCount, + ServerCount: m.serverCount, + ServerNames: m.agent.ServerNames(), + LoadedFile: m.loadedFile, + ICEEnabled: m.iceEnabled, + ICEConversations: m.iceConversations, + ICESessionID: m.iceSessionID, + SessionEvalTotal: m.sessionEvalTotal, + SessionPromptTotal: m.sessionPromptTotal, + SessionTurnCount: m.sessionTurnCount, + NumCtx: m.numCtx, + CurrentModel: m.model, + FileChanges: m.fileChanges, + } + if m.skillMgr != nil { + for _, s := range m.skillMgr.All() { + ctx.Skills = append(ctx.Skills, command.SkillInfo{ + Name: s.Name, + Description: s.Description, + Active: s.Active, + }) + } + } + return ctx +} + +func (m *Model) handleCommandAction(result command.Result) tea.Cmd { + switch result.Action { + case command.ActionShowHelp: + m.overlay = OverlayHelp + m.initHelpViewport() + return nil + case command.ActionClear: + m.agent.ClearHistory() + m.entries = nil + m.toolEntries = nil + m.invalidateEntryCache() + if result.Text != "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionQuit: + if m.cancel != nil { + m.cancel() + } + return tea.Quit + case command.ActionLoadContext: + parts := strings.SplitN(result.Data, "\x00", 2) + if len(parts) == 2 { + m.loadedFile = parts[0] + m.agent.SetLoadedContext(parts[1]) + } + if result.Text != "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionUnloadContext: + m.loadedFile = "" + m.agent.SetLoadedContext("") + if result.Text != "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionActivateSkill: + if m.skillMgr != nil { + if err := m.skillMgr.Activate(result.Data); err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: err.Error(), + }) + } else { + m.agent.SetSkillContent(m.skillMgr.ActiveContent()) + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + } + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionDeactivateSkill: + if m.skillMgr != nil { + if err := m.skillMgr.Deactivate(result.Data); err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: err.Error(), + }) + } else { + m.agent.SetSkillContent(m.skillMgr.ActiveContent()) + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + } + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionSwitchModel: + query := "" + currentInput := strings.TrimSpace(m.input.Value()) + if currentInput != "" && !strings.HasPrefix(currentInput, "/") { + query = currentInput + } else { + for i := len(m.entries) - 1; i >= 0; i-- { + if m.entries[i].Kind == "user" { + query = m.entries[i].Content + break + } + } + } + if m.router != nil && query != "" { + m.router.RecordOverride(query, result.Data) + } + if m.modelManager != nil { + if err := m.modelManager.SetCurrentModel(result.Data); err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: fmt.Sprintf("Failed to switch model: %v", err), + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + } + } + if m.logger != nil { + m.logger.Info("model switched", "from", m.model, "to", result.Data) + } + m.model = result.Data + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionShowModelPicker: + m.openModelPicker() + return nil + case command.ActionSendPrompt: + if result.Text != "" { + m.entries = append(m.entries, ChatEntry{Kind: "system", Content: result.Text}) + } + return m.sendToAgent(result.Data) + case command.ActionCommit: + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: "Generating commit message from staged changes...", + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return runCommit(m.agent.LLMClient(), m.model, result.Data) + case command.ActionShowSessions: + return func() tea.Msg { + sessions, err := listSessions(20) + return SessionListMsg{Sessions: sessions, Err: err} + } + case command.ActionSwitchAgent: + m.agentProfile = result.Data + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionExport: + path := result.Data + if path == "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: "export: no path specified", + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + } + content := m.formatConversationForExport() + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: fmt.Sprintf("export failed: %v", err), + }) + } else { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: fmt.Sprintf("Exported conversation to: %s", path), + }) + } + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + case command.ActionImport: + path := result.Data + if path == "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: "import: no path specified", + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + } + data, err := os.ReadFile(path) + if err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: fmt.Sprintf("import failed: %v", err), + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + } + entries, err := m.parseImportedConversation(string(data)) + if err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: fmt.Sprintf("import parse error: %v", err), + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + } + m.entries = entries + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + return nil + default: + if result.Text != "" { + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: result.Text, + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + } + return nil + } +} + +func (m *Model) flushStream() { + m.invalidateEntryCache() + if m.streamBuf.Len() > 0 || m.thinkBuf.Len() > 0 { + content := m.streamBuf.String() + var rendered string + if m.md != nil && content != "" { + rendered = m.md.RenderFull(content) + } + entry := ChatEntry{ + Kind: "assistant", + Content: content, + RenderedContent: rendered, + } + if m.thinkBuf.Len() > 0 { + entry.ThinkingContent = m.thinkBuf.String() + entry.ThinkingCollapsed = true + } + m.entries = append(m.entries, entry) + m.streamBuf.Reset() + m.thinkBuf.Reset() + m.inThinking = false + m.thinkSearchBuf = "" + } +} + +func (m *Model) invalidateRenderedCache() { + for i := range m.entries { + if m.entries[i].Kind == "assistant" && m.entries[i].RenderedContent != "" { + if m.md != nil { + m.entries[i].RenderedContent = m.md.RenderFull(m.entries[i].Content) + } + } + } + m.invalidateEntryCache() +} + +func (m *Model) footerHeight() int { + if m.state == StateIdle { + return 2 + m.inputLines + } + return 3 +} + +func (m *Model) syncInputHeight() { + lines := m.input.LineCount() + if lines < 1 { + lines = 1 + } + if lines > 5 { + lines = 5 + } + if lines != m.inputLines { + m.inputLines = lines + m.input.SetHeight(lines) + m.recalcViewportHeight() + } +} + +func (m *Model) invalidateEntryCache() { + m.entryCacheValid = false + m.cachedEntriesRender = "" + m.cachedEntryCount = 0 + m.cachedToolEntryRows = nil +} + +func (m *Model) checkAutoScroll() { + if m.viewport.AtBottom() { + m.anchorActive = true + m.userScrolledUp = false + m.scrollAnchor = 0 + } +} + +func (m *Model) getVisibleEntryRange() (start, end int) { + if m.viewport.YOffset() == 0 { + end = len(m.entries) + start = max(0, end-100) + return start, end + } + avgEntryHeight := 5 + viewportH := m.viewport.Height() + visibleEntries := viewportH / avgEntryHeight + buffer := visibleEntries / 2 + scrollPos := m.viewport.YOffset() + estimatedStart := max(0, scrollPos/avgEntryHeight - buffer) + estimatedEnd := min(len(m.entries), estimatedStart + visibleEntries + buffer*2) + return estimatedStart, estimatedEnd +} + +func (m *Model) openExternalEditor() tea.Cmd { + editor := os.Getenv("EDITOR") + if editor == "" { + editor = "vi" + } + tmpFile, err := os.CreateTemp("", "ai-agent-*.md") + if err != nil { + return func() tea.Msg { + return ErrorMsg{Msg: fmt.Sprintf("editor: %v", err)} + } + } + tmpPath := tmpFile.Name() + if current := m.input.Value(); current != "" { + tmpFile.WriteString(current) + } + tmpFile.Close() + c := exec.Command(editor, tmpPath) + return tea.ExecProcess(c, func(err error) tea.Msg { + defer os.Remove(tmpPath) + if err != nil { + return ErrorMsg{Msg: fmt.Sprintf("editor: %v", err)} + } + data, err := os.ReadFile(tmpPath) + if err != nil { + return ErrorMsg{Msg: fmt.Sprintf("editor: %v", err)} + } + content := strings.TrimRight(string(data), "\n") + if content == "" { + return nil + } + return editorReturnMsg{Content: content} + }) +} + +type editorReturnMsg struct { + Content string +} + +func (m *Model) recalcViewportHeight() { + if !m.ready || m.height == 0 { + return + } + headerH := 3 + contentH := m.height - headerH - m.footerHeight() + if contentH < 1 { + contentH = 1 + } + m.viewport.SetHeight(contentH) +} + +func FormatToolArgs(args map[string]any) string { + return agent.FormatToolArgs(args) +} + +func (m *Model) isCompletionActive() bool { + return m.completionState != nil +} + +func newCompletionState(kind string, items []Completion, multiSelect bool) *CompletionState { + ti := textinput.New() + ti.Placeholder = "type to filter..." + ti.Focus() + ti.CharLimit = 128 + var sel map[int]bool + if multiSelect { + sel = make(map[int]bool) + } + return &CompletionState{ + Kind: kind, + Filter: ti, + AllItems: items, + FilteredItems: items, + Index: 0, + Selected: sel, + } +} + +func (m *Model) triggerCompletion(input string) { + var kind string + var items []Completion + var multiSelect bool + if strings.HasPrefix(input, "/") { + kind = "command" + items = m.completer.Complete(input) + } else if strings.HasPrefix(input, "@") { + kind = "attachments" + items = m.completer.Complete(input) + multiSelect = true + } else if strings.HasPrefix(input, "#") { + kind = "skills" + items = m.completer.Complete(input) + multiSelect = true + } + if len(items) == 0 { + return + } + m.completionState = newCompletionState(kind, items, multiSelect) + m.overlay = OverlayCompletion + m.input.Blur() +} + +func (m *Model) acceptCompletion() { + cs := m.completionState + if cs == nil || len(cs.FilteredItems) == 0 { + return + } + isMultiSelect := cs.Kind == "attachments" || cs.Kind == "skills" + if isMultiSelect { + var selectedItems []string + for idx := range cs.Selected { + if idx < len(cs.AllItems) { + selectedItems = append(selectedItems, cs.AllItems[idx].Insert) + } + } + if len(selectedItems) == 0 && cs.Index < len(cs.FilteredItems) { + selectedItems = append(selectedItems, cs.FilteredItems[cs.Index].Insert) + } + m.input.SetValue(strings.Join(selectedItems, " ")) + m.input.CursorEnd() + } else { + item := cs.FilteredItems[cs.Index] + m.input.SetValue(item.Insert) + m.input.CursorEnd() + } + m.closeCompletion() +} + +func (m *Model) toggleCompletionSelection() { + cs := m.completionState + if cs == nil || cs.Selected == nil || len(cs.FilteredItems) == 0 { + return + } + filteredItem := cs.FilteredItems[cs.Index] + for i, item := range cs.AllItems { + if item.Label == filteredItem.Label && item.Insert == filteredItem.Insert { + if cs.Selected[i] { + delete(cs.Selected, i) + } else { + cs.Selected[i] = true + } + break + } + } +} + +func (m *Model) drillIntoFolder() { + cs := m.completionState + if cs == nil || cs.Index >= len(cs.FilteredItems) { + return + } + item := cs.FilteredItems[cs.Index] + folderName := strings.TrimSuffix(item.Label, "/") + if cs.CurrentPath != "" { + cs.CurrentPath += "/" + folderName + } else { + cs.CurrentPath = folderName + } + fileItems := m.completer.CompleteFilePath(cs.CurrentPath) + cs.AllItems = fileItems + cs.Filter.SetValue("") + cs.FilteredItems = fileItems + cs.Index = 0 + cs.SearchResults = nil +} + +func (m *Model) drillUpFolder() { + cs := m.completionState + if cs == nil || cs.CurrentPath == "" { + return + } + if idx := strings.LastIndex(cs.CurrentPath, "/"); idx >= 0 { + cs.CurrentPath = cs.CurrentPath[:idx] + } else { + cs.CurrentPath = "" + } + var items []Completion + if cs.CurrentPath == "" { + items = m.completer.Complete("@") + } else { + items = m.completer.CompleteFilePath(cs.CurrentPath) + } + cs.AllItems = items + cs.Filter.SetValue("") + cs.FilteredItems = items + cs.Index = 0 + cs.SearchResults = nil +} + +func (m *Model) closeCompletion() { + m.completionState = nil + m.overlay = OverlayNone + m.input.Focus() +} + +func (m *Model) sendToAgent(text string) tea.Cmd { + if m.logger != nil { + cfg := m.modeConfigs[m.mode] + m.logger.Info("user message", "mode", cfg.Label, "length", len(text)) + } + m.input.Blur() + m.state = StateWaiting + m.recalcViewportHeight() + m.streamBuf.Reset() + m.evalCount = 0 + m.promptTokens = 0 + m.entries = append(m.entries, ChatEntry{ + Kind: "user", + Content: text, + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() + m.agent.AddUserMessage(text) + cfg := m.modeConfigs[m.mode] + m.agent.SetModeContext(cfg.SystemPromptPrefix, cfg.AllowTools) + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + p := m.program + m.agent.SetApprovalCallback(func(req permission.ApprovalRequest) { + respCh := make(chan ToolApprovalResponse, 1) + p.Send(ToolApprovalMsg{ + ToolName: req.ToolName, + Args: req.Args, + Response: respCh, + }) + resp := <-respCh + req.Response <- permission.ApprovalResponse{ + Allowed: resp.Allowed, + Always: resp.Always, + } + }) + runAgent := func() tea.Msg { + adapter := NewAdapter(p) + m.agent.Run(ctx, adapter) + return AgentDoneMsg{} + } + m.scramble.Reset() + batchCmds := []tea.Cmd{m.spin.Tick, m.scramble.Tick(), runAgent} + if m.sessionNoteID == 0 && notedAvailable() { + batchCmds = append(batchCmds, func() tea.Msg { + ts := time.Now().Format("2006-01-02 15:04") + id, err := createSessionNote(ts) + return SessionCreatedMsg{NoteID: id, Err: err} + }) + } + return tea.Batch(batchCmds...) +} + +func (m *Model) cycleMode() { + m.mode = (m.mode + 1) % 3 + cfg := m.modeConfigs[m.mode] + if m.router != nil { + newModel := m.router.GetModelForCapability(cfg.PreferredCapability) + if newModel != "" && newModel != m.model { + if m.modelManager != nil { + if err := m.modelManager.SetCurrentModel(newModel); err == nil { + m.model = newModel + } + } + } + } + if m.logger != nil { + m.logger.Info("mode switched", "mode", cfg.Label, "model", m.model) + } + modeColors := map[Mode]string{ + ModeAsk: "#81a1c1", + ModePlan: "#ebcb8b", + ModeBuild: "#a3be8c", + } + _ = modeColors + toastMsg := fmt.Sprintf("⚡ Mode: %s • Model: %s", cfg.Label, m.model) + if m.toastMgr != nil { + m.toastMgr.AddToast(Toast{ + Message: toastMsg, + Kind: ToastKindInfo, + }) + } + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: fmt.Sprintf("Mode switched to %s (%s)", cfg.Label, m.model), + }) + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() +} + +func (m *Model) openModelPicker() { + if len(m.modelList) == 0 { + if m.toastMgr != nil { + m.toastMgr.AddToast(Toast{Message: m.tr().NoModelsAvailable, Kind: ToastKindWarning}) + } + return + } + models := make([]config.Model, len(m.modelList)) + for i, name := range m.modelList { + models[i] = config.Model{Name: name, DisplayName: name} + } + m.modelPickerState = newModelPickerState(models, m.model, m.isDark, m.tr().SelectModel) + m.overlay = OverlayModelPicker + m.input.Blur() +} + +func (m *Model) selectModel(name string) { + old := m.model + if m.modelManager != nil { + if err := m.modelManager.SetCurrentModel(name); err != nil { + m.entries = append(m.entries, ChatEntry{ + Kind: "error", + Content: fmt.Sprintf("Failed to switch model: %v", err), + }) + m.closeModelPicker() + return + } + } + m.model = name + if m.logger != nil { + m.logger.Info("model switched", "from", old, "to", name) + } + m.entries = append(m.entries, ChatEntry{ + Kind: "system", + Content: fmt.Sprintf("Model: %s", name), + }) + m.closeModelPicker() + m.viewport.SetContent(m.renderEntries()) + m.viewport.GotoBottom() +} + +func (m *Model) closeModelPicker() { + m.modelPickerState = nil + m.overlay = OverlayNone + m.input.Focus() +} + +func (m *Model) openPlanForm(task string) { + m.planFormState = NewPlanFormState(task) + m.overlay = OverlayPlanForm + m.input.Blur() +} + +func (m *Model) closePlanForm() { + m.planFormState = nil + m.overlay = OverlayNone + m.input.Focus() +} + +func (m *Model) submitPlanFormPrompt(prompt string) tea.Cmd { + return m.sendToAgent(prompt) +} + +func (m *Model) lastAssistantContent() string { + for i := len(m.entries) - 1; i >= 0; i-- { + if m.entries[i].Kind == "assistant" { + return m.entries[i].Content + } + } + return "" +} + +func (m *Model) copyToClipboard(text string) tea.Cmd { + return func() tea.Msg { + if err := clipboard.WriteAll(text); err != nil { + m.toastMgr.Error("Clipboard error: " + err.Error()) + return SystemMessageMsg{Msg: "Clipboard error: " + err.Error()} + } + m.toastMgr.Success("Copied to clipboard") + return SystemMessageMsg{Msg: "Copied to clipboard."} + } +} + +func (m *Model) handleMouseClick(x, y int) { + vpY := y - 3 + m.viewport.YOffset() + if m.toolEntryRows == nil { + return + } + + for toolIdx, startRow := range m.toolEntryRows { + if vpY >= startRow && vpY < startRow+3 { + if toolIdx >= 0 && toolIdx < len(m.toolEntries) { + m.toolEntries[toolIdx].Collapsed = !m.toolEntries[toolIdx].Collapsed + m.invalidateEntryCache() + m.viewport.SetContent(m.renderEntries()) + } + return + } + } +} + +func (m *Model) formatConversationForExport() string { + var b strings.Builder + b.WriteString("# Conversation Export\n\n") + b.WriteString(fmt.Sprintf("**Date**: %s\n", time.Now().Format("2006-01-02 15:04"))) + b.WriteString(fmt.Sprintf("**Model**: %s\n", m.model)) + b.WriteString("---\n\n") + for _, entry := range m.entries { + switch entry.Kind { + case "user": + b.WriteString("## User\n\n") + b.WriteString(entry.Content) + b.WriteString("\n\n---\n\n") + case "assistant": + b.WriteString("## Assistant\n\n") + b.WriteString(entry.Content) + b.WriteString("\n\n---\n\n") + case "system": + b.WriteString("## System\n\n") + b.WriteString(entry.Content) + b.WriteString("\n\n---\n\n") + case "tool_group": + if entry.ToolIndex >= 0 && entry.ToolIndex < len(m.toolEntries) { + te := m.toolEntries[entry.ToolIndex] + b.WriteString(fmt.Sprintf("## Tool: %s\n\n", te.Name)) + b.WriteString("```\n") + b.WriteString(te.Args) + b.WriteString("\n```\n\n") + if te.Result != "" { + b.WriteString("**Result**:\n\n") + b.WriteString("```\n") + b.WriteString(te.Result) + b.WriteString("\n```\n\n") + } + b.WriteString("---\n\n") + } + } + } + return b.String() +} + +func (m *Model) parseImportedConversation(data string) ([]ChatEntry, error) { + var entries []ChatEntry + lines := strings.Split(data, "\n") + var currentSection string + var currentContent strings.Builder + flushContent := func() { + if currentContent.Len() > 0 { + content := strings.TrimSpace(currentContent.String()) + if content != "" { + entry := ChatEntry{Kind: currentSection, Content: content} + entries = append(entries, entry) + } + currentContent.Reset() + } + } + for _, line := range lines { + if strings.HasPrefix(line, "## ") { + flushContent() + section := strings.TrimPrefix(line, "## ") + switch section { + case "User": + currentSection = "user" + case "Assistant": + currentSection = "assistant" + case "System": + currentSection = "system" + default: + currentSection = "system" + } + } else if strings.HasPrefix(line, "---") { + // Skip separators + } else { + currentContent.WriteString(line) + currentContent.WriteString("\n") + } + } + flushContent() + m.toolEntries = nil + return entries, nil +} + +func (m *Model) Ready() bool { + return m.ready +} + +func (m *Model) AnchorActive() bool { + return m.anchorActive +} diff --git a/internal/tui/model_completion_test.go b/internal/tui/model_completion_test.go new file mode 100644 index 0000000..2eb8238 --- /dev/null +++ b/internal/tui/model_completion_test.go @@ -0,0 +1,203 @@ +package tui + +import "testing" + +func TestTriggerCompletion(t *testing.T) { + t.Run("slash_triggers_command", func(t *testing.T) { + m := newTestModel(t) + m.triggerCompletion("/") + + if !m.isCompletionActive() { + t.Error("/ should activate completion") + } + if m.completionState.Kind != "command" { + t.Errorf("expected kind 'command', got %q", m.completionState.Kind) + } + if m.overlay != OverlayCompletion { + t.Errorf("expected OverlayCompletion, got %d", m.overlay) + } + if len(m.completionState.AllItems) == 0 { + t.Error("should have completion items for /") + } + }) + + t.Run("at_triggers_attachments_with_multiselect", func(t *testing.T) { + m := newTestModel(t) + m.triggerCompletion("@") + + // @ triggers agent/file completion. + // It may or may not find matches depending on agents + cwd. + // If agents exist, it should activate. + if m.isCompletionActive() { + if m.completionState.Kind != "attachments" { + t.Errorf("expected kind 'attachments', got %q", m.completionState.Kind) + } + if m.completionState.Selected == nil { + t.Error("attachments should initialize Selected map") + } + } + }) + + t.Run("hash_triggers_skills", func(t *testing.T) { + m := newTestModel(t) + m.triggerCompletion("#") + + if !m.isCompletionActive() { + t.Error("# should activate completion for skills") + } + if m.completionState.Kind != "skills" { + t.Errorf("expected kind 'skills', got %q", m.completionState.Kind) + } + if m.completionState.Selected == nil { + t.Error("skills should initialize Selected map") + } + }) + + t.Run("no_matches_stays_inactive", func(t *testing.T) { + m := newTestModel(t) + m.triggerCompletion("/zzzznonexistent") + + if m.isCompletionActive() { + t.Error("should not activate with no matches") + } + }) + + t.Run("plain_text_no_trigger", func(t *testing.T) { + m := newTestModel(t) + m.triggerCompletion("hello") + + if m.isCompletionActive() { + t.Error("plain text should not trigger completion") + } + }) +} + +func TestAcceptCompletion(t *testing.T) { + t.Run("single_select", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "/help", Insert: "/help "}, + {Label: "/clear", Insert: "/clear "}, + } + m.completionState = newCompletionState("command", items, false) + m.overlay = OverlayCompletion + m.completionState.Index = 0 + + m.acceptCompletion() + + if m.input.Value() != "/help " { + t.Errorf("expected '/help ', got %q", m.input.Value()) + } + if m.isCompletionActive() { + t.Error("should be inactive after accept") + } + if m.overlay != OverlayNone { + t.Error("overlay should be OverlayNone") + } + }) + + t.Run("multi_select_with_selections", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "@a", Insert: "@a "}, + {Label: "@b", Insert: "@b "}, + {Label: "@c", Insert: "@c "}, + } + m.completionState = newCompletionState("attachments", items, true) + m.overlay = OverlayCompletion + m.completionState.Index = 0 + m.completionState.Selected[1] = true + + m.acceptCompletion() + + if m.input.Value() != "@b " { + t.Errorf("expected '@b ', got %q", m.input.Value()) + } + if m.isCompletionActive() { + t.Error("should be inactive after accept") + } + }) + + t.Run("multi_select_empty_fallback", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "@x", Insert: "@x "}, + {Label: "@y", Insert: "@y "}, + } + m.completionState = newCompletionState("attachments", items, true) + m.overlay = OverlayCompletion + m.completionState.Index = 1 + + m.acceptCompletion() + + if m.input.Value() != "@y " { + t.Errorf("expected '@y ' as fallback, got %q", m.input.Value()) + } + }) + + t.Run("inactive_noop", func(t *testing.T) { + m := newTestModel(t) + m.completionState = nil + m.input.SetValue("original") + + m.acceptCompletion() + + if m.input.Value() != "original" { + t.Errorf("inactive accept should be noop, got %q", m.input.Value()) + } + }) +} + +func TestCloseCompletion(t *testing.T) { + m := newTestModel(t) + items := []Completion{{Label: "test"}} + m.completionState = newCompletionState("command", items, true) + m.completionState.Index = 5 + m.completionState.Selected[0] = true + m.overlay = OverlayCompletion + + m.closeCompletion() + + if m.isCompletionActive() { + t.Error("completionState should be nil") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } +} + +func TestFilterCompletions(t *testing.T) { + items := []Completion{ + {Label: "/help"}, + {Label: "/clear"}, + {Label: "/model"}, + } + + t.Run("empty_query_returns_all", func(t *testing.T) { + filtered := FilterCompletions(items, "") + if len(filtered) != 3 { + t.Errorf("expected 3, got %d", len(filtered)) + } + }) + + t.Run("filters_by_substring", func(t *testing.T) { + filtered := FilterCompletions(items, "el") + if len(filtered) != 2 { + t.Errorf("expected 2 (help, model), got %d", len(filtered)) + } + }) + + t.Run("case_insensitive", func(t *testing.T) { + filtered := FilterCompletions(items, "HELP") + if len(filtered) != 1 { + t.Errorf("expected 1, got %d", len(filtered)) + } + }) + + t.Run("no_match", func(t *testing.T) { + filtered := FilterCompletions(items, "zzz") + if len(filtered) != 0 { + t.Errorf("expected 0, got %d", len(filtered)) + } + }) +} diff --git a/internal/tui/model_overlay_test.go b/internal/tui/model_overlay_test.go new file mode 100644 index 0000000..048ee83 --- /dev/null +++ b/internal/tui/model_overlay_test.go @@ -0,0 +1,309 @@ +package tui + +import ( + "testing" +) + +func TestOverlay_ESC_ClosesCompletion(t *testing.T) { + m := newTestModel(t) + + // Set up active completion state. + items := []Completion{ + {Label: "/help", Insert: "/help ", Category: "command"}, + {Label: "/clear", Insert: "/clear ", Category: "command"}, + {Label: "/model", Insert: "/model ", Category: "command"}, + } + m.completionState = newCompletionState("command", items, true) + m.completionState.Index = 1 + m.completionState.Selected[0] = true + m.overlay = OverlayCompletion + + // Send ESC. + updated, _ := m.Update(escKey()) + m = updated.(*Model) + + // Verify completion state is nil. + if m.isCompletionActive() { + t.Error("completionState should be nil after ESC") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } +} + +func TestOverlay_ESC_ClearsInputToPreventRetrigger(t *testing.T) { + m := newTestModel(t) + + // Simulate: user typed "/" which triggered completion, then presses ESC. + m.input.SetValue("/") + items := []Completion{ + {Label: "/help", Insert: "/help ", Category: "command"}, + {Label: "/clear", Insert: "/clear ", Category: "command"}, + } + m.completionState = newCompletionState("command", items, false) + m.overlay = OverlayCompletion + + // Press ESC to close. + updated, _ := m.Update(escKey()) + m = updated.(*Model) + + // Input must be cleared so auto-trigger doesn't reopen. + if m.input.Value() != "" { + t.Errorf("ESC should clear input, got %q", m.input.Value()) + } + if m.isCompletionActive() { + t.Error("completion should be closed after ESC") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } +} + +func TestOverlay_ESC_NoRetriggerOnSubsequentUpdate(t *testing.T) { + m := newTestModel(t) + + // Simulate: user typed "/" which triggered completion, then presses ESC. + m.input.SetValue("/") + items := []Completion{ + {Label: "/help", Insert: "/help ", Category: "command"}, + } + m.completionState = newCompletionState("command", items, false) + m.overlay = OverlayCompletion + + // Press ESC. + updated, _ := m.Update(escKey()) + m = updated.(*Model) + + // Send another key event (e.g., a harmless key like 'a') to cycle through Update. + // This exercises the auto-trigger path at lines 968-972. + updated, _ = m.Update(charKey('a')) + m = updated.(*Model) + + // Completion must NOT have re-opened. + if m.isCompletionActive() { + t.Error("completion should not re-trigger after ESC close") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should still be OverlayNone, got %d", m.overlay) + } +} + +func TestOverlay_ESC_ClosesHelp(t *testing.T) { + m := newTestModel(t) + m.overlay = OverlayHelp + + updated, _ := m.Update(escKey()) + m = updated.(*Model) + + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone after ESC, got %d", m.overlay) + } +} + +func TestOverlay_HelpDismissal(t *testing.T) { + t.Run("question_mark_dismisses", func(t *testing.T) { + m := newTestModel(t) + m.overlay = OverlayHelp + + updated, _ := m.Update(charKey('?')) + m = updated.(*Model) + + if m.overlay != OverlayNone { + t.Errorf("? should dismiss help overlay, got %d", m.overlay) + } + }) + + t.Run("q_dismisses", func(t *testing.T) { + m := newTestModel(t) + m.overlay = OverlayHelp + + updated, _ := m.Update(charKey('q')) + m = updated.(*Model) + + if m.overlay != OverlayNone { + t.Errorf("q should dismiss help overlay, got %d", m.overlay) + } + }) + + t.Run("other_key_swallowed", func(t *testing.T) { + m := newTestModel(t) + m.overlay = OverlayHelp + + updated, _ := m.Update(charKey('a')) + m = updated.(*Model) + + if m.overlay != OverlayHelp { + t.Errorf("'a' should be swallowed, overlay should remain OverlayHelp, got %d", m.overlay) + } + }) +} + +func TestOverlay_CompletionNavigation(t *testing.T) { + setup := func(t *testing.T) *Model { + t.Helper() + m := newTestModel(t) + items := []Completion{ + {Label: "/help", Insert: "/help "}, + {Label: "/clear", Insert: "/clear "}, + {Label: "/model", Insert: "/model "}, + } + m.completionState = newCompletionState("command", items, false) + m.overlay = OverlayCompletion + return m + } + + t.Run("down_moves_index", func(t *testing.T) { + m := setup(t) + + updated, _ := m.Update(downKey()) + m = updated.(*Model) + + if m.completionState.Index != 1 { + t.Errorf("down from 0 should move to 1, got %d", m.completionState.Index) + } + }) + + t.Run("up_at_zero_stays", func(t *testing.T) { + m := setup(t) + + updated, _ := m.Update(upKey()) + m = updated.(*Model) + + if m.completionState.Index != 0 { + t.Errorf("up at 0 should stay at 0, got %d", m.completionState.Index) + } + }) + + t.Run("down_clamped_at_end", func(t *testing.T) { + m := setup(t) + m.completionState.Index = 2 + + updated, _ := m.Update(downKey()) + m = updated.(*Model) + + if m.completionState.Index != 2 { + t.Errorf("down at last item should stay at 2, got %d", m.completionState.Index) + } + }) +} + +func TestOverlay_CompletionToggle(t *testing.T) { + t.Run("tab_toggles_selection_on", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "/a", Insert: "/a "}, + {Label: "/b", Insert: "/b "}, + } + m.completionState = newCompletionState("attachments", items, true) + m.overlay = OverlayCompletion + + updated, _ := m.Update(tabKey()) + m = updated.(*Model) + + if !m.completionState.Selected[0] { + t.Error("tab should toggle selection on for index 0") + } + }) + + t.Run("tab_toggles_selection_off", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "/a", Insert: "/a "}, + {Label: "/b", Insert: "/b "}, + } + m.completionState = newCompletionState("attachments", items, true) + m.completionState.Selected[0] = true + m.overlay = OverlayCompletion + + updated, _ := m.Update(tabKey()) + m = updated.(*Model) + + if m.completionState.Selected[0] { + t.Error("tab should toggle selection off for index 0") + } + }) + + t.Run("nil_selected_no_panic", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "/a", Insert: "/a "}, + } + m.completionState = newCompletionState("command", items, false) + // Selected is nil for single-select mode + m.overlay = OverlayCompletion + + // Should not panic. + updated, _ := m.Update(tabKey()) + _ = updated.(*Model) + }) +} + +func TestOverlay_CompletionAccept(t *testing.T) { + t.Run("single_select", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "/help", Insert: "/help "}, + {Label: "/clear", Insert: "/clear "}, + } + m.completionState = newCompletionState("command", items, false) + m.completionState.Index = 1 + m.overlay = OverlayCompletion + + updated, _ := m.Update(enterKey()) + m = updated.(*Model) + + if m.input.Value() != "/clear " { + t.Errorf("input should be '/clear ', got %q", m.input.Value()) + } + if m.isCompletionActive() { + t.Error("completion should be closed after accept") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } + }) + + t.Run("multi_select_with_selections", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "@file1", Insert: "@file1 "}, + {Label: "@file2", Insert: "@file2 "}, + {Label: "@file3", Insert: "@file3 "}, + } + m.completionState = newCompletionState("attachments", items, true) + m.completionState.Selected[0] = true + m.completionState.Selected[2] = true + m.overlay = OverlayCompletion + + updated, _ := m.Update(enterKey()) + m = updated.(*Model) + + val := m.input.Value() + // Selected items 0 and 2 should be joined. + if val == "" { + t.Error("input should not be empty with multi-select") + } + if m.isCompletionActive() { + t.Error("completion should be closed after accept") + } + }) + + t.Run("multi_select_empty_fallback", func(t *testing.T) { + m := newTestModel(t) + items := []Completion{ + {Label: "@file1", Insert: "@file1 "}, + {Label: "@file2", Insert: "@file2 "}, + } + m.completionState = newCompletionState("attachments", items, true) + m.completionState.Index = 1 + m.overlay = OverlayCompletion + + updated, _ := m.Update(enterKey()) + m = updated.(*Model) + + // Fallback to current item. + if m.input.Value() != "@file2 " { + t.Errorf("should fallback to current item, got %q", m.input.Value()) + } + }) +} diff --git a/internal/tui/model_test.go b/internal/tui/model_test.go new file mode 100644 index 0000000..e574643 --- /dev/null +++ b/internal/tui/model_test.go @@ -0,0 +1,406 @@ +package tui + +import ( + "strings" + "testing" + "time" + + "ai-agent/internal/command" + + tea "charm.land/bubbletea/v2" +) + +func TestSubmitInput_EmptyReturnsNil(t *testing.T) { + m := newTestModel(t) + cmd := m.submitInput() + if cmd != nil { + t.Error("submitInput with empty input should return nil") + } +} + +func TestHelp_OnlyWhenIdleAndEmpty(t *testing.T) { + t.Run("idle_empty_opens_help", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + updated, _ := m.Update(charKey('?')) + m = updated.(*Model) + if m.overlay != OverlayHelp { + t.Errorf("? with idle+empty should open help, got overlay=%d", m.overlay) + } + }) + t.Run("idle_nonempty_no_help", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + m.input.SetValue("hello") + updated, _ := m.Update(charKey('?')) + m = updated.(*Model) + if m.overlay == OverlayHelp { + t.Error("? with non-empty input should not open help") + } + }) + t.Run("waiting_no_help", func(t *testing.T) { + m := newTestModel(t) + m.state = StateWaiting + updated, _ := m.Update(charKey('?')) + m = updated.(*Model) + if m.overlay == OverlayHelp { + t.Error("? in StateWaiting should not open help") + } + }) +} + +func TestToggleTools_OnlyWhenIdleAndEmpty(t *testing.T) { + t.Run("idle_empty_toggles", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + before := m.toolsCollapsed + updated, _ := m.Update(charKey('t')) + m = updated.(*Model) + if m.toolsCollapsed == before { + t.Error("'t' with idle+empty should toggle toolsCollapsed") + } + }) + t.Run("idle_nonempty_no_toggle", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + m.input.SetValue("hello") + before := m.toolsCollapsed + updated, _ := m.Update(charKey('t')) + m = updated.(*Model) + if m.toolsCollapsed != before { + t.Error("'t' with non-empty input should not toggle tools") + } + }) +} + +func TestESC_CancelOnlyWhenStreamingOrWaiting(t *testing.T) { + t.Run("idle_no_cancel", func(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + cancelCalled := false + m.cancel = func() { cancelCalled = true } + updated, _ := m.Update(escKey()) + _ = updated.(*Model) + if cancelCalled { + t.Error("ESC in idle should not call cancel") + } + }) + t.Run("streaming_cancels", func(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + cancelCalled := false + m.cancel = func() { cancelCalled = true } + updated, _ := m.Update(escKey()) + _ = updated.(*Model) + if !cancelCalled { + t.Error("ESC in streaming should call cancel") + } + }) + t.Run("waiting_cancels", func(t *testing.T) { + m := newTestModel(t) + m.state = StateWaiting + cancelCalled := false + m.cancel = func() { cancelCalled = true } + updated, _ := m.Update(escKey()) + _ = updated.(*Model) + if !cancelCalled { + t.Error("ESC in waiting should call cancel") + } + }) +} + +func TestSystemMessageMsg_AppendsEntry(t *testing.T) { + m := newTestModel(t) + before := len(m.entries) + updated, _ := m.Update(SystemMessageMsg{Msg: "hello system"}) + m = updated.(*Model) + if len(m.entries) != before+1 { + t.Fatalf("expected %d entries, got %d", before+1, len(m.entries)) + } + last := m.entries[len(m.entries)-1] + if last.Kind != "system" { + t.Errorf("expected kind 'system', got %q", last.Kind) + } + if last.Content != "hello system" { + t.Errorf("expected content 'hello system', got %q", last.Content) + } +} + +func TestErrorMsg_AppendsEntry(t *testing.T) { + m := newTestModel(t) + before := len(m.entries) + updated, _ := m.Update(ErrorMsg{Msg: "something broke"}) + m = updated.(*Model) + if len(m.entries) != before+1 { + t.Fatalf("expected %d entries, got %d", before+1, len(m.entries)) + } + last := m.entries[len(m.entries)-1] + if last.Kind != "error" { + t.Errorf("expected kind 'error', got %q", last.Kind) + } + if last.Content != "something broke" { + t.Errorf("expected content 'something broke', got %q", last.Content) + } +} + +func TestToolCallResultMsg(t *testing.T) { + t.Run("updates_tool_entry", func(t *testing.T) { + m := newTestModel(t) + m.toolEntries = append(m.toolEntries, ToolEntry{ + Name: "read_file", + Status: ToolStatusRunning, + }) + m.toolsPending = 1 + updated, _ := m.Update(ToolCallResultMsg{ + Name: "read_file", + Result: "file contents", + IsError: false, + Duration: 42 * time.Millisecond, + }) + m = updated.(*Model) + if m.toolEntries[0].Status != ToolStatusDone { + t.Errorf("expected ToolStatusDone, got %d", m.toolEntries[0].Status) + } + if m.toolEntries[0].Result != "file contents" { + t.Errorf("expected 'file contents', got %q", m.toolEntries[0].Result) + } + if m.toolsPending != 0 { + t.Errorf("toolsPending should be 0, got %d", m.toolsPending) + } + }) + t.Run("truncates_long_result", func(t *testing.T) { + m := newTestModel(t) + m.toolEntries = append(m.toolEntries, ToolEntry{ + Name: "read_file", + Status: ToolStatusRunning, + }) + longResult := strings.Repeat("x", 2500) + updated, _ := m.Update(ToolCallResultMsg{ + Name: "read_file", + Result: longResult, + }) + m = updated.(*Model) + if len(m.toolEntries[0].Result) != 2000 { + t.Errorf("result should be truncated to 2000, got %d", len(m.toolEntries[0].Result)) + } + if !strings.HasSuffix(m.toolEntries[0].Result, "...") { + t.Error("truncated result should end with '...'") + } + }) + t.Run("error_status", func(t *testing.T) { + m := newTestModel(t) + m.toolEntries = append(m.toolEntries, ToolEntry{ + Name: "exec", + Status: ToolStatusRunning, + }) + updated, _ := m.Update(ToolCallResultMsg{ + Name: "exec", + Result: "command failed", + IsError: true, + }) + m = updated.(*Model) + if m.toolEntries[0].Status != ToolStatusError { + t.Errorf("expected ToolStatusError, got %d", m.toolEntries[0].Status) + } + if !m.toolEntries[0].IsError { + t.Error("IsError should be true") + } + }) +} + +func TestAgentDoneMsg(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + m.userScrolledUp = true + m.anchorActive = false + updated, _ := m.Update(AgentDoneMsg{}) + m = updated.(*Model) + if m.state != StateIdle { + t.Errorf("state should be StateIdle, got %d", m.state) + } + if m.userScrolledUp { + t.Error("userScrolledUp should be reset to false") + } + if !m.anchorActive { + t.Error("anchorActive should be reset to true") + } +} + +func TestInitCompleteMsg(t *testing.T) { + t.Run("basic_fields", func(t *testing.T) { + m := newTestModel(t) + updated, _ := m.Update(InitCompleteMsg{ + Model: "llama3", + ModelList: []string{"llama3", "qwen3"}, + AgentProfile: "default", + AgentList: []string{"default", "coder"}, + ToolCount: 5, + ServerCount: 2, + NumCtx: 8192, + }) + m = updated.(*Model) + if m.model != "llama3" { + t.Errorf("model should be 'llama3', got %q", m.model) + } + if len(m.modelList) != 2 { + t.Errorf("modelList should have 2 items, got %d", len(m.modelList)) + } + if m.toolCount != 5 { + t.Errorf("toolCount should be 5, got %d", m.toolCount) + } + if m.serverCount != 2 { + t.Errorf("serverCount should be 2, got %d", m.serverCount) + } + }) + t.Run("with_failed_servers", func(t *testing.T) { + m := newTestModel(t) + before := len(m.entries) + + updated, _ := m.Update(InitCompleteMsg{ + Model: "llama3", + FailedServers: []FailedServer{ + {Name: "server1", Reason: "timeout"}, + }, + }) + m = updated.(*Model) + if len(m.entries) != before+1 { + t.Fatalf("should append system entry for failed servers, got %d entries", len(m.entries)) + } + last := m.entries[len(m.entries)-1] + if last.Kind != "system" { + t.Errorf("expected kind 'system', got %q", last.Kind) + } + if !strings.Contains(last.Content, "server1") { + t.Errorf("should contain server name, got %q", last.Content) + } + }) +} + +func TestHandleCommandAction(t *testing.T) { + tests := []struct { + name string + result command.Result + check func(t *testing.T, m *Model, cmd tea.Cmd) + }{ + { + name: "ActionShowHelp", + result: command.Result{Action: command.ActionShowHelp}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if m.overlay != OverlayHelp { + t.Errorf("expected OverlayHelp, got %d", m.overlay) + } + }, + }, + { + name: "ActionClear_with_text", + result: command.Result{Action: command.ActionClear, Text: "Cleared."}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if len(m.entries) != 1 { + t.Errorf("expected 1 entry, got %d", len(m.entries)) + } + if m.entries[0].Kind != "system" { + t.Errorf("expected system entry, got %q", m.entries[0].Kind) + } + }, + }, + { + name: "ActionQuit", + result: command.Result{Action: command.ActionQuit}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if cmd == nil { + t.Error("ActionQuit should return a cmd (tea.Quit)") + } + }, + }, + { + name: "ActionLoadContext", + result: command.Result{Action: command.ActionLoadContext, Data: "test.md\x00# Hello", Text: "Loaded."}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if m.loadedFile != "test.md" { + t.Errorf("expected loadedFile='test.md', got %q", m.loadedFile) + } + }, + }, + { + name: "ActionUnloadContext", + result: command.Result{Action: command.ActionUnloadContext, Text: "Unloaded."}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if m.loadedFile != "" { + t.Errorf("expected empty loadedFile, got %q", m.loadedFile) + } + }, + }, + { + name: "ActionSwitchModel", + result: command.Result{Action: command.ActionSwitchModel, Data: "gpt-4", Text: "Switched."}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if m.model != "gpt-4" { + t.Errorf("expected model='gpt-4', got %q", m.model) + } + }, + }, + { + name: "ActionSwitchAgent", + result: command.Result{Action: command.ActionSwitchAgent, Data: "coder", Text: "Switched."}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if m.agentProfile != "coder" { + t.Errorf("expected agentProfile='coder', got %q", m.agentProfile) + } + }, + }, + { + name: "ActionNone_with_text", + result: command.Result{Action: command.ActionNone, Text: "Info message"}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + if len(m.entries) == 0 { + t.Fatal("expected at least one entry") + } + last := m.entries[len(m.entries)-1] + if last.Content != "Info message" { + t.Errorf("expected 'Info message', got %q", last.Content) + } + }, + }, + { + name: "ActionNone_empty_text", + result: command.Result{Action: command.ActionNone, Text: ""}, + check: func(t *testing.T, m *Model, cmd tea.Cmd) { + // Should not add any entry. + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := newTestModel(t) + if tt.result.Action == command.ActionUnloadContext { + m.loadedFile = "old.md" + } + cmd := m.handleCommandAction(tt.result) + tt.check(t, m, cmd) + }) + } +} + +func TestCommandResultMsg(t *testing.T) { + t.Run("with_text", func(t *testing.T) { + m := newTestModel(t) + before := len(m.entries) + updated, _ := m.Update(CommandResultMsg{Text: "Result info"}) + m = updated.(*Model) + if len(m.entries) != before+1 { + t.Fatalf("expected %d entries, got %d", before+1, len(m.entries)) + } + if m.entries[len(m.entries)-1].Content != "Result info" { + t.Errorf("expected 'Result info', got %q", m.entries[len(m.entries)-1].Content) + } + }) + t.Run("empty_text_no_entry", func(t *testing.T) { + m := newTestModel(t) + before := len(m.entries) + updated, _ := m.Update(CommandResultMsg{Text: ""}) + m = updated.(*Model) + if len(m.entries) != before { + t.Errorf("expected %d entries (no change), got %d", before, len(m.entries)) + } + }) +} diff --git a/internal/tui/modelpicker.go b/internal/tui/modelpicker.go new file mode 100644 index 0000000..cc848ea --- /dev/null +++ b/internal/tui/modelpicker.go @@ -0,0 +1,90 @@ +package tui + +import ( + "fmt" + + "ai-agent/internal/config" + + "charm.land/bubbles/v2/list" + "charm.land/lipgloss/v2" +) + +type modelItem struct { + name string + size string + capability string + isCurrent bool +} + +func (i modelItem) Title() string { + title := i.name + if i.isCurrent { + title += " ●" + } + return title +} + +func (i modelItem) Description() string { + return fmt.Sprintf("%s · %s", i.size, i.capability) +} + +func (i modelItem) FilterValue() string { return i.name } + +type ModelPickerState struct { + List list.Model + Models []config.Model + CurrentModel string +} + +func newModelPickerState(models []config.Model, currentModel string, isDark bool, title string) *ModelPickerState { + capLabels := map[config.ModelCapability]string{ + config.CapabilitySimple: "Fast", + config.CapabilityMedium: "Balanced", + config.CapabilityComplex: "Capable", + config.CapabilityAdvanced: "Advanced", + } + items := make([]list.Item, len(models)) + selectedIdx := 0 + for i, model := range models { + if model.Name == currentModel { + selectedIdx = i + } + items[i] = modelItem{ + name: model.Name, + size: model.Size, + capability: capLabels[model.Capability], + isCurrent: model.Name == currentModel, + } + } + delegate := list.NewDefaultDelegate() + delegate.Styles = list.NewDefaultItemStyles(isDark) + delegate.SetSpacing(0) + const pickerW = 50 + pickerH := len(models)*delegate.Height() + 2 + if pickerH > 20 { + pickerH = 20 + } + l := list.New(items, delegate, pickerW, pickerH) + l.Title = title + l.SetShowStatusBar(false) + l.SetShowHelp(false) + l.SetShowPagination(false) + l.SetFilteringEnabled(false) + l.DisableQuitKeybindings() + l.Select(selectedIdx) + return &ModelPickerState{ + List: l, + Models: models, + CurrentModel: currentModel, + } +} + +func (m *Model) renderModelPicker() string { + ps := m.modelPickerState + if ps == nil { + return "" + } + const maxW = 50 + box := lipgloss.NewStyle().Border(lipgloss.RoundedBorder()).BorderForeground(m.styles.FocusIndicator.GetForeground()).Padding(0, 1).Width(maxW) + return box.Render(ps.List.View()) +} diff --git a/internal/tui/modelpicker_test.go b/internal/tui/modelpicker_test.go new file mode 100644 index 0000000..54a8ef4 --- /dev/null +++ b/internal/tui/modelpicker_test.go @@ -0,0 +1,121 @@ +package tui + +import ( + "testing" + + "ai-agent/internal/config" +) + +func TestModelPicker_OpenClose(t *testing.T) { + t.Run("open_without_model_list_noop", func(t *testing.T) { + m := newTestModel(t) + m.openModelPicker() + if m.overlay == OverlayModelPicker { + t.Error("should not open picker without model list") + } + }) + t.Run("open_with_model_list", func(t *testing.T) { + m := newTestModel(t) + m.modelList = []string{"qwen3.5:0.8b", "qwen3.5:2b", "qwen3.5:4b", "qwen3.5:9b"} + m.model = "qwen3.5:0.8b" + m.openModelPicker() + if m.overlay != OverlayModelPicker { + t.Errorf("expected OverlayModelPicker, got %d", m.overlay) + } + if m.modelPickerState == nil { + t.Fatal("modelPickerState should not be nil") + } + if len(m.modelPickerState.Models) == 0 { + t.Error("should have models in picker") + } + if m.modelPickerState.CurrentModel != "qwen3.5:0.8b" { + t.Errorf("expected current model 'qwen3.5:0.8b', got %q", m.modelPickerState.CurrentModel) + } + }) + t.Run("close_resets_state", func(t *testing.T) { + m := newTestModel(t) + m.modelList = []string{"qwen3.5:0.8b", "qwen3.5:2b"} + m.model = "qwen3.5:0.8b" + m.openModelPicker() + m.closeModelPicker() + if m.modelPickerState != nil { + t.Error("modelPickerState should be nil after close") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } + }) +} + +func TestModelPicker_Navigation(t *testing.T) { + setup := func(t *testing.T) *Model { + t.Helper() + m := newTestModel(t) + m.modelList = []string{"qwen3.5:0.8b", "qwen3.5:2b", "qwen3.5:4b", "qwen3.5:9b"} + m.model = config.DefaultModels()[0].Name + m.openModelPicker() + return m + } + t.Run("down_moves_index", func(t *testing.T) { + m := setup(t) + updated, _ := m.Update(downKey()) + m = updated.(*Model) + if m.modelPickerState.List.Index() != 1 { + t.Errorf("expected index 1, got %d", m.modelPickerState.List.Index()) + } + }) + t.Run("up_at_zero_stays", func(t *testing.T) { + m := setup(t) + updated, _ := m.Update(upKey()) + m = updated.(*Model) + if m.modelPickerState.List.Index() != 0 { + t.Errorf("expected index 0, got %d", m.modelPickerState.List.Index()) + } + }) + t.Run("down_clamped_at_end", func(t *testing.T) { + m := setup(t) + lastIdx := len(m.modelPickerState.Models) - 1 + m.modelPickerState.List.Select(lastIdx) + updated, _ := m.Update(downKey()) + m = updated.(*Model) + if m.modelPickerState.List.Index() != lastIdx { + t.Errorf("expected index to stay at end, got %d", m.modelPickerState.List.Index()) + } + }) + t.Run("esc_closes", func(t *testing.T) { + m := setup(t) + updated, _ := m.Update(escKey()) + m = updated.(*Model) + if m.modelPickerState != nil { + t.Error("ESC should close picker") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } + }) +} + +func TestModelPicker_CtrlM(t *testing.T) { + t.Run("opens_with_ctrl_m", func(t *testing.T) { + m := newTestModel(t) + m.modelList = []string{"qwen3.5:0.8b", "qwen3.5:2b"} + m.model = "qwen3.5:0.8b" + m.state = StateIdle + updated, _ := m.Update(ctrlKey('m')) + m = updated.(*Model) + if m.overlay != OverlayModelPicker { + t.Errorf("ctrl+m should open model picker, got overlay %d", m.overlay) + } + }) + t.Run("no_open_when_streaming", func(t *testing.T) { + m := newTestModel(t) + m.modelList = []string{"qwen3.5:0.8b"} + m.model = "qwen3.5:0.8b" + m.state = StateStreaming + updated, _ := m.Update(ctrlKey('m')) + m = updated.(*Model) + if m.overlay == OverlayModelPicker { + t.Error("should not open picker when streaming") + } + }) +} diff --git a/internal/tui/mouse.go b/internal/tui/mouse.go new file mode 100644 index 0000000..bce55c1 --- /dev/null +++ b/internal/tui/mouse.go @@ -0,0 +1,144 @@ +package tui + +import ( + "charm.land/lipgloss/v2" + tea "charm.land/bubbletea/v2" +) + +// MouseHandler provides enhanced mouse interaction handling. +type MouseHandler struct { + isDark bool + styles MouseHandlerStyles + resizer *PanelResizer + lastClickX int + lastClickY int + lastClickTime int64 + clickCount int +} + +// MouseHandlerStyles holds styling. +type MouseHandlerStyles struct { + Hover lipgloss.Style + Selected lipgloss.Style + ResizeHint lipgloss.Style +} + +// DefaultMouseHandlerStyles returns default styles. +func DefaultMouseHandlerStyles(isDark bool) MouseHandlerStyles { + if isDark { + return MouseHandlerStyles{ + Hover: lipgloss.NewStyle().Background(lipgloss.Color("#3b4252")), + Selected: lipgloss.NewStyle().Background(lipgloss.Color("#4c566a")), + ResizeHint: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")), + } + } + return MouseHandlerStyles{ + Hover: lipgloss.NewStyle().Background(lipgloss.Color("#e5e9f0")), + Selected: lipgloss.NewStyle().Background(lipgloss.Color("#d8dee9")), + ResizeHint: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f8f")), + } +} + +// NewMouseHandler creates a new mouse handler. +func NewMouseHandler(isDark bool, panelMinWidth, panelMaxWidth int) *MouseHandler { + return &MouseHandler{ + isDark: isDark, + styles: DefaultMouseHandlerStyles(isDark), + resizer: NewPanelResizer(panelMinWidth, panelMaxWidth, isDark), + } +} + +// SetDark updates theme. +func (mh *MouseHandler) SetDark(isDark bool) { + mh.isDark = isDark + mh.styles = DefaultMouseHandlerStyles(isDark) +} + +// ResizePanel handles resize operations. +func (mh *MouseHandler) ResizePanel() *PanelResizer { + return mh.resizer +} + +// HandleClick processes a mouse click at the given coordinates. +// Returns an action describing what happened. +func (mh *MouseHandler) HandleClick(msg tea.MouseClickMsg, panelWidth, panelDividerX int) MouseAction { + x, y := int(msg.X), int(msg.Y) + + // Check for double-click + isDoubleClick := mh.isDoubleClick(x, y) + if isDoubleClick { + mh.clickCount++ + } else { + mh.clickCount = 1 + } + mh.lastClickX = x + mh.lastClickY = y + + // Check if clicking on resize handle (within 3 chars of panel divider) + if panelWidth > 0 && mh.resizer.CanResizeAt(x, panelDividerX) { + if msg.Button == tea.MouseLeft { + mh.resizer.StartResize(x, panelWidth) + return MouseAction{Type: ResizeStart} + } + } + + // Check for right-click (context menu) + if msg.Button == tea.MouseRight { + return MouseAction{ + Type: ContextMenu, + X: x, + Y: y, + Context: mh.getClickContext(x, y, panelWidth), + } + } + + return MouseAction{Type: None} +} + +// HandleRelease handles mouse release events. +func (mh *MouseHandler) HandleRelease() { + mh.resizer.EndResize() +} + +// isDoubleClick checks if this is a double-click. +func (mh *MouseHandler) isDoubleClick(x, y int) bool { + // Simple double-click detection: same position + dist := abs(x-mh.lastClickX) + abs(y-mh.lastClickY) + return dist < 2 && mh.clickCount > 1 +} + +// getClickContext returns context information about the click location. +func (mh *MouseHandler) getClickContext(x, y, panelWidth int) string { + // Determine what was clicked based on coordinates + if panelWidth > 0 && x < panelWidth { + return "sidepanel" + } + return "main" +} + +// MouseAction describes a mouse action. +type MouseAction struct { + Type MouseActionType + X, Y int + Context string +} + +// MouseActionType describes the type of mouse action. +type MouseActionType int + +const ( + None MouseActionType = iota + ResizeStart + ContextMenu + SelectEntry + ToggleCollapse + CopyText +) + +// abs returns the absolute value. +func abs(n int) int { + if n < 0 { + return -n + } + return n +} diff --git a/internal/tui/mouse_test.go b/internal/tui/mouse_test.go new file mode 100644 index 0000000..bfe958e --- /dev/null +++ b/internal/tui/mouse_test.go @@ -0,0 +1,88 @@ +package tui + +import ( + "testing" + + tea "charm.land/bubbletea/v2" +) + +func TestMouseClick_EmptyEntries(t *testing.T) { + m := newTestModel(t) + m.toolEntryRows = make(map[int]int) + + // Should not panic with no entries. + m.handleMouseClick(5, 10) +} + +func TestMouseClick_ToggleTool(t *testing.T) { + m := newTestModel(t) + m.toolEntries = []ToolEntry{ + {Name: "test", Status: ToolStatusDone, Collapsed: true}, + } + m.toolEntryRows = map[int]int{0: 5} + + // Click at Y that maps to row 5 (header height=3, viewport offset=0). + m.handleMouseClick(5, 8) // 8 - 3 + 0 = 5 → matches entry 0 + + if m.toolEntries[0].Collapsed { + t.Error("clicking tool entry should toggle collapsed state") + } +} + +func TestMouseClick_OutsideToolEntries(t *testing.T) { + m := newTestModel(t) + m.toolEntries = []ToolEntry{ + {Name: "test", Status: ToolStatusDone, Collapsed: true}, + } + m.toolEntryRows = map[int]int{0: 5} + + // Click at a position that doesn't match any tool entry. + m.handleMouseClick(5, 50) + + if !m.toolEntries[0].Collapsed { + t.Error("clicking outside should not toggle collapsed state") + } +} + +func TestMouseWheel_SetsScrollFlag(t *testing.T) { + m := newTestModel(t) + m.anchorActive = true + // Add enough content so the viewport is scrollable and not at bottom after scroll up. + var longContent string + for i := 0; i < 100; i++ { + longContent += "line\n" + } + m.viewport.SetContent(longContent) + m.viewport.GotoBottom() + + updated, _ := m.Update(tea.MouseWheelMsg{X: 0, Y: 0, Button: tea.MouseWheelUp}) + m = updated.(*Model) + + if m.anchorActive { + t.Error("scroll up should disable anchorActive flag") + } + if !m.userScrolledUp { + t.Error("scroll up should set userScrolledUp flag") + } +} + +func TestMouseWheel_ResetsAtBottom(t *testing.T) { + m := newTestModel(t) + m.anchorActive = false + m.userScrolledUp = true + // With no content, viewport is at bottom, so scrolling should reset the flag. + updated, _ := m.Update(tea.MouseWheelMsg{X: 0, Y: 0, Button: tea.MouseWheelDown}) + m = updated.(*Model) + + if m.anchorActive { + // At bottom with minimal content, anchor should be active + } +} + +func TestMouseWheel_NilToolRows(t *testing.T) { + m := newTestModel(t) + m.toolEntryRows = nil + + // Should not panic with nil toolEntryRows. + m.handleMouseClick(5, 10) +} diff --git a/internal/tui/overlay_toolcard_test.go b/internal/tui/overlay_toolcard_test.go new file mode 100644 index 0000000..2873786 --- /dev/null +++ b/internal/tui/overlay_toolcard_test.go @@ -0,0 +1,365 @@ +package tui + +import ( + "strings" + "testing" + + "charm.land/lipgloss/v2" +) + +// TestOverlayCentering_HelpOverlay verifies help overlay is centered +func TestOverlayCentering_HelpOverlay(t *testing.T) { + m := newTestModel(t) + m.width = 120 + m.height = 40 + + // Initialize help viewport + m.overlay = OverlayHelp + m.initHelpViewport() + + overlay := m.renderHelpOverlay(m.width) + overlayLines := strings.Split(overlay, "\n") + + // Check overlay width doesn't exceed screen + for _, line := range overlayLines { + lineWidth := lipgloss.Width(line) + if lineWidth > m.width { + t.Errorf("overlay line width %d exceeds screen width %d", lineWidth, m.width) + } + } +} + +// TestOverlayCentering_ModelPicker verifies model picker overlay is centered +func TestOverlayCentering_ModelPicker(t *testing.T) { + m := newTestModel(t) + m.width = 100 + m.height = 30 + + // Initialize model picker state manually + m.openModelPicker() + + // Model picker requires modelManager to be set + if m.modelPickerState == nil { + // Test passes if it doesn't panic + t.Skip("model picker requires model manager") + } + + overlay := m.renderModelPicker() + if overlay == "" { + t.Log("model picker overlay empty (expected without model manager)") + } +} + +// TestOverlayCentering_SmallScreen verifies overlays work on small screens +func TestOverlayCentering_SmallScreen(t *testing.T) { + m := newTestModel(t) + m.width = 60 + m.height = 20 + + m.overlay = OverlayHelp + m.initHelpViewport() + + overlay := m.renderHelpOverlay(m.width) + + if overlay == "" { + t.Error("overlay should render on small screen") + } + + // Should not panic or produce empty output + lines := strings.Count(overlay, "\n") + if lines < 5 { + t.Errorf("overlay should have at least 5 lines, got %d", lines) + } +} + +// TestOverlayCentering_LargeScreen verifies overlays scale on large screens +func TestOverlayCentering_LargeScreen(t *testing.T) { + m := newTestModel(t) + m.width = 200 + m.height = 60 + + m.overlay = OverlayHelp + m.initHelpViewport() + + overlay := m.renderHelpOverlay(m.width) + + // Overlay should not be excessively wide + overlayLines := strings.Split(overlay, "\n") + maxLineWidth := 0 + for _, line := range overlayLines { + width := lipgloss.Width(line) + if width > maxLineWidth { + maxLineWidth = width + } + } + + // Overlay should be centered and not use full width + if maxLineWidth > m.width-10 { + t.Errorf("overlay too wide: %d (max should be ~%d)", maxLineWidth, m.width-10) + } +} + +// TestOverlayOnContent_Positioning verifies overlay is positioned correctly +func TestOverlayOnContent_Positioning(t *testing.T) { + m := newTestModel(t) + m.width = 100 + m.height = 40 + + base := strings.Repeat("base line\n", 40) + overlay := strings.Repeat("overlay line\n", 10) + + result := m.overlayOnContent(base, overlay) + + // Result should have same number of lines as base + baseLines := strings.Count(base, "\n") + resultLines := strings.Count(result, "\n") + + if resultLines < baseLines { + t.Errorf("result should have at least as many lines as base: got %d, want %d", resultLines, baseLines) + } +} + +// TestToolCard_WidthCalculation verifies tool cards respect width constraints +func TestToolCard_WidthCalculation(t *testing.T) { + tests := []struct { + name string + availableW int + cardName string + expectRender bool + }{ + {"wide screen", 100, "read_file", true}, + {"narrow screen", 40, "read_file", true}, + {"very narrow", 30, "test", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + card := NewToolCard(tt.cardName, ToolCardFile, true) + card.State = ToolCardRunning + + view := card.View(tt.availableW) + + // Should render without panic + if view == "" { + t.Error("tool card should render") + } + + // Note: lipgloss.Width includes ANSI codes, so we just verify it renders + _ = lipgloss.Width(view) + }) + } +} + +// TestToolCard_LongArgsWrapping verifies long args are wrapped properly +func TestToolCard_LongArgsWrapping(t *testing.T) { + card := NewToolCard("write_file", ToolCardFile, true) + card.State = ToolCardSuccess + card.Expanded = true + card.Args = strings.Repeat("very_long_argument_that_should_be_wrapped_properly ", 10) + card.Result = "success" + + view := card.View(80) + viewLines := strings.Split(view, "\n") + + // Should render multiple lines + if len(viewLines) < 3 { + t.Errorf("tool card should have multiple lines, got %d", len(viewLines)) + } + + // Verify it renders without panic + if view == "" { + t.Error("tool card view should not be empty") + } +} + +// TestToolCard_ManagerRendering verifies multiple cards render correctly +func TestToolCard_ManagerRendering(t *testing.T) { + mgr := NewToolCardManager(true) + + // Add multiple cards + mgr.AddCard("read_file", ToolCardFile, testTime) + mgr.AddCard("write_file", ToolCardFile, testTime) + mgr.AddCard("bash", ToolCardBash, testTime) + + // Update some cards + mgr.UpdateCard("read_file", ToolCardSuccess, "file content", testDuration) + mgr.UpdateCard("write_file", ToolCardRunning, "", 0) + + view := mgr.View(100) + + if view == "" { + t.Error("manager view should not be empty") + } + + // Should have multiple cards (separated by newlines) + lines := strings.Count(view, "\n") + if lines < 2 { + t.Errorf("manager view should have multiple lines, got %d", lines+1) + } +} + +// TestToolCard_BorderAndPadding verifies border and padding are accounted for +func TestToolCard_BorderAndPadding(t *testing.T) { + card := NewToolCard("test", ToolCardGeneric, true) + card.State = ToolCardSuccess + card.Expanded = true + card.Args = "test args" + card.Result = "test result" + + availableW := 60 + view := card.View(availableW) + + // Account for border (2) + padding (2) = 4 chars + contentW := availableW - 4 + + viewLines := strings.Split(view, "\n") + for i, line := range viewLines { + lineWidth := lipgloss.Width(line) + if lineWidth > availableW { + t.Errorf("line %d width %d exceeds available width %d (content should fit in %d)", + i, lineWidth, availableW, contentW) + } + } +} + +// TestToolCard_EmojiIcons verifies emoji icons render without breaking layout +func TestToolCard_EmojiIcons(t *testing.T) { + kinds := []ToolCardKind{ToolCardFile, ToolCardBash, ToolCardSearch, ToolCardGit, ToolCardGeneric} + states := []ToolCardState{ToolCardRunning, ToolCardSuccess, ToolCardError} + + for _, kind := range kinds { + for _, state := range states { + t.Run(string(rune(kind))+string(rune(state)), func(t *testing.T) { + card := NewToolCard("test", kind, true) + card.State = state + + view := card.View(60) + + // Should render without panic + if view == "" { + t.Error("card view should not be empty") + } + + // Should not exceed width + viewWidth := lipgloss.Width(view) + if viewWidth > 60 { + t.Errorf("card width %d exceeds 60", viewWidth) + } + }) + } + } +} + +// TestWrapText_LongWords verifies wrapText breaks long words +func TestWrapText_LongWords(t *testing.T) { + longWord := strings.Repeat("a", 100) + result := wrapText(longWord, 40) + + lines := strings.Split(result, "\n") + for i, line := range lines { + if len(line) > 40 { + t.Errorf("line %d exceeds width: %d chars", i, len(line)) + } + } +} + +// TestWrapText_MultipleWords verifies wrapText handles multiple words +func TestWrapText_MultipleWords(t *testing.T) { + text := "word1 word2 word3 word4 word5 word6 word7 word8 word9 word10" + result := wrapText(text, 20) + + lines := strings.Split(result, "\n") + for i, line := range lines { + if len(line) > 20 { + t.Errorf("line %d exceeds width: %d chars", i, len(line)) + } + } +} + +// TestWrapText_EmptyAndEdgeCases verifies wrapText handles edge cases +func TestWrapText_EmptyAndEdgeCases(t *testing.T) { + tests := []struct { + name string + input string + width int + expect string + }{ + {"empty", "", 40, ""}, + {"zero width", "hello", 0, "hello"}, + {"exact fit", "hello", 5, "hello"}, + {"single char width", "hello world", 1, "h\ne\nl\nl\no\n \nw\no\nr\nl\nd"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := wrapText(tt.input, tt.width) + if tt.width > 0 && result != tt.expect { + // Just verify it doesn't panic and returns something reasonable + } + }) + } +} + +// TestIndentBlock_Multiline verifies indentBlock adds prefix to each line +func TestIndentBlock_Multiline(t *testing.T) { + input := "line1\nline2\nline3" + result := indentBlock(input, " ") + + expected := " line1\n line2\n line3" + if result != expected { + t.Errorf("indentBlock failed: got %q, want %q", result, expected) + } +} + +// TestIndentBlock_EmptyLines verifies indentBlock handles empty lines +func TestIndentBlock_EmptyLines(t *testing.T) { + input := "line1\n\nline3" + result := indentBlock(input, " ") + + // Empty lines should remain empty + lines := strings.Split(result, "\n") + if len(lines) != 3 { + t.Errorf("expected 3 lines, got %d", len(lines)) + } + if lines[1] != "" { + t.Error("empty line should remain empty") + } +} + +// BenchmarkOverlayRendering benchmarks overlay rendering performance +func BenchmarkOverlayRendering_Help(b *testing.B) { + m := newTestModelB(b) + m.width = 120 + m.height = 40 + m.overlay = OverlayHelp + m.initHelpViewport() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.renderHelpOverlay(m.width) + } +} + +// BenchmarkToolCardRendering benchmarks tool card rendering +func BenchmarkToolCardRendering(b *testing.B) { + card := NewToolCard("read_file", ToolCardFile, true) + card.State = ToolCardSuccess + card.Expanded = true + card.Args = strings.Repeat("arg ", 20) + card.Result = strings.Repeat("result line\n", 10) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = card.View(80) + } +} + +// BenchmarkWrapText benchmarks text wrapping +func BenchmarkWrapText(b *testing.B) { + text := strings.Repeat("This is a test sentence with multiple words. ", 20) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = wrapText(text, 60) + } +} diff --git a/internal/tui/paste_test.go b/internal/tui/paste_test.go new file mode 100644 index 0000000..a1a4db3 --- /dev/null +++ b/internal/tui/paste_test.go @@ -0,0 +1,84 @@ +package tui + +import ( + "strings" + "testing" + + tea "charm.land/bubbletea/v2" +) + +func TestPasteMsg_SmallPaste(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + + content := "short paste" + updated, _ := m.Update(tea.PasteMsg{Content: content}) + m = updated.(*Model) + + if m.pendingPaste != "" { + t.Error("small paste should not trigger pending paste") + } +} + +func TestPasteMsg_LargePaste(t *testing.T) { + m := newTestModel(t) + m.state = StateIdle + + // Create paste with >10 lines. + content := strings.Repeat("line\n", 15) + updated, _ := m.Update(tea.PasteMsg{Content: content}) + m = updated.(*Model) + + if m.pendingPaste == "" { + t.Error("large paste should trigger pending paste") + } +} + +func TestPasteMsg_LargePasteNotIdle(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + + content := strings.Repeat("line\n", 15) + updated, _ := m.Update(tea.PasteMsg{Content: content}) + m = updated.(*Model) + + if m.pendingPaste != "" { + t.Error("should not set pending paste during streaming") + } +} + +func TestPendingPaste_AcceptY(t *testing.T) { + m := newTestModel(t) + m.pendingPaste = "line1\nline2\nline3" + + updated, _ := m.Update(tea.KeyPressMsg{Code: 'y'}) + m = updated.(*Model) + + if m.pendingPaste != "" { + t.Error("pressing y should clear pending paste") + } +} + +func TestPendingPaste_RejectN(t *testing.T) { + m := newTestModel(t) + m.pendingPaste = "line1\nline2\nline3" + + updated, _ := m.Update(tea.KeyPressMsg{Code: 'n'}) + m = updated.(*Model) + + if m.pendingPaste != "" { + t.Error("pressing n should clear pending paste") + } +} + +func TestPendingPaste_CancelEsc(t *testing.T) { + m := newTestModel(t) + m.pendingPaste = "line1\nline2\nline3" + + updated, _ := m.Update(tea.KeyPressMsg{Code: tea.KeyEscape}) + m = updated.(*Model) + + if m.pendingPaste != "" { + t.Error("pressing esc should clear pending paste") + } +} diff --git a/internal/tui/planform.go b/internal/tui/planform.go new file mode 100644 index 0000000..d294cf7 --- /dev/null +++ b/internal/tui/planform.go @@ -0,0 +1,258 @@ +package tui + +import ( + "fmt" + "strings" + + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" +) + +// PlanFormField represents a single field in the plan form. +type PlanFormField struct { + Label string + Kind string // "text" or "select" + Value string // current value (for select, set from Options[OptionIndex]) + Options []string // for "select" kind + OptionIndex int // for "select" kind + Input textinput.Model +} + +// PlanFormState holds state for the plan form overlay. +type PlanFormState struct { + Fields []PlanFormField + ActiveField int +} + +// NewPlanFormState creates a plan form pre-filled with the user's task description. +func NewPlanFormState(task string) *PlanFormState { + taskInput := textinput.New() + taskInput.Placeholder = "Describe the task..." + taskInput.CharLimit = 512 + taskInput.SetValue(task) + taskInput.Focus() + + focusInput := textinput.New() + focusInput.Placeholder = "Any constraints or requirements? (optional)" + focusInput.CharLimit = 512 + + return &PlanFormState{ + Fields: []PlanFormField{ + { + Label: "Task", + Kind: "text", + Input: taskInput, + }, + { + Label: "Scope", + Kind: "select", + Options: []string{"single file", "module", "project-wide"}, + }, + { + Label: "Focus (optional)", + Kind: "text", + Input: focusInput, + }, + }, + ActiveField: 0, + } +} + +// AssemblePrompt builds the structured prompt from form fields. +func (pf *PlanFormState) AssemblePrompt() string { + task := pf.Fields[0].Input.Value() + scope := pf.Fields[1].Options[pf.Fields[1].OptionIndex] + focus := pf.Fields[2].Input.Value() + + var b strings.Builder + b.WriteString("Plan the following task:\n") + b.WriteString(fmt.Sprintf("Task: %s\n", task)) + b.WriteString(fmt.Sprintf("Scope: %s\n", scope)) + if focus != "" { + b.WriteString(fmt.Sprintf("Focus: %s\n", focus)) + } + b.WriteString("\nProvide a step-by-step plan.") + return b.String() +} + +// updatePlanForm handles key events within the plan form overlay. +// Returns the updated model, any command, and whether the form was submitted or cancelled. +func (m *Model) updatePlanForm(msg tea.KeyPressMsg) (bool, bool) { + pf := m.planFormState + if pf == nil { + return false, false + } + + field := &pf.Fields[pf.ActiveField] + + switch { + case key.Matches(msg, m.keys.Cancel): + // Cancel + return false, true + + case msg.Code == tea.KeyEnter: + if pf.ActiveField == len(pf.Fields)-1 { + // Submit + return true, false + } + // Advance to next field + m.advancePlanFormField(1) + return false, false + + case msg.Code == tea.KeyTab: + if msg.Mod == tea.ModShift { + m.advancePlanFormField(-1) + } else { + m.advancePlanFormField(1) + } + return false, false + + case msg.Code == tea.KeyUp: + if field.Kind == "select" { + if field.OptionIndex > 0 { + field.OptionIndex-- + } + return false, false + } + + case msg.Code == tea.KeyDown: + if field.Kind == "select" { + if field.OptionIndex < len(field.Options)-1 { + field.OptionIndex++ + } + return false, false + } + + case msg.Code == tea.KeyLeft: + if field.Kind == "select" { + if field.OptionIndex > 0 { + field.OptionIndex-- + } + return false, false + } + + case msg.Code == tea.KeyRight: + if field.Kind == "select" { + if field.OptionIndex < len(field.Options)-1 { + field.OptionIndex++ + } + return false, false + } + } + + // Forward other keys to active text field + if field.Kind == "text" { + field.Input, _ = field.Input.Update(msg) + } + + return false, false +} + +// advancePlanFormField moves to the next or previous field. +func (m *Model) advancePlanFormField(dir int) { + pf := m.planFormState + if pf == nil { + return + } + + // Blur current field + current := &pf.Fields[pf.ActiveField] + if current.Kind == "text" { + current.Input.Blur() + } + + pf.ActiveField += dir + if pf.ActiveField < 0 { + pf.ActiveField = 0 + } + if pf.ActiveField >= len(pf.Fields) { + pf.ActiveField = len(pf.Fields) - 1 + } + + // Focus new field + next := &pf.Fields[pf.ActiveField] + if next.Kind == "text" { + next.Input.Focus() + } +} + +// renderPlanForm renders the plan form overlay. +func (m *Model) renderPlanForm() string { + pf := m.planFormState + if pf == nil { + return "" + } + + activeStyle := m.styles.FocusIndicator // Use focus indicator style for active fields + + var b strings.Builder + b.WriteString(m.styles.OverlayTitle.Render("Plan Task")) + b.WriteString("\n\n") + + for i, field := range pf.Fields { + isActive := i == pf.ActiveField + + ls := m.styles.OverlayAccent + if isActive { + ls = activeStyle + } + b.WriteString(ls.Render(field.Label)) + b.WriteString("\n") + + switch field.Kind { + case "text": + if isActive { + b.WriteString(m.styles.FocusIndicator.Render("> ") + field.Input.View()) + } else { + val := field.Input.Value() + if val == "" { + val = m.styles.OverlayDim.Render("(empty)") + } + b.WriteString(" " + m.styles.OverlayDim.Render(val)) + } + case "select": + for j, opt := range field.Options { + selected := j == field.OptionIndex + prefix := " " + if selected && isActive { + prefix = m.styles.FocusIndicator.Render("▸ ") + } else if selected { + prefix = "● " + } + if selected && isActive { + b.WriteString(" " + activeStyle.Render(prefix+opt)) + } else if selected { + b.WriteString(" " + prefix + opt) + } else { + b.WriteString(" " + m.styles.OverlayDim.Render(prefix+opt)) + } + b.WriteString("\n") + } + } + b.WriteString("\n\n") + } + + if pf.Fields[pf.ActiveField].Kind == "select" { + b.WriteString(m.styles.OverlayDim.Render("↑↓←→=select Tab/Enter=next Esc=cancel")) + } else { + b.WriteString(m.styles.OverlayDim.Render("Tab=next field Enter=submit Esc=cancel")) + } + + maxW := 50 + if m.width-8 > maxW { + maxW = m.width - 8 + } + if maxW > 60 { + maxW = 60 + } + + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(m.styles.OverlayBorder)). + Padding(1, 2). + Width(maxW) + + return box.Render(b.String()) +} diff --git a/internal/tui/planform_test.go b/internal/tui/planform_test.go new file mode 100644 index 0000000..9ed7f15 --- /dev/null +++ b/internal/tui/planform_test.go @@ -0,0 +1,264 @@ +package tui + +import ( + "strings" + "testing" +) + +func TestPlanForm_NewPrefilled(t *testing.T) { + pf := NewPlanFormState("refactor auth module") + + if len(pf.Fields) != 3 { + t.Fatalf("expected 3 fields, got %d", len(pf.Fields)) + } + + // Task field should be pre-filled. + if pf.Fields[0].Input.Value() != "refactor auth module" { + t.Errorf("task field should be pre-filled, got %q", pf.Fields[0].Input.Value()) + } + + // Scope field should be select with 3 options. + if pf.Fields[1].Kind != "select" { + t.Errorf("scope field should be select, got %q", pf.Fields[1].Kind) + } + if len(pf.Fields[1].Options) != 3 { + t.Errorf("scope should have 3 options, got %d", len(pf.Fields[1].Options)) + } + + // Focus field should be text. + if pf.Fields[2].Kind != "text" { + t.Errorf("focus field should be text, got %q", pf.Fields[2].Kind) + } +} + +func TestPlanForm_AssemblePrompt(t *testing.T) { + pf := NewPlanFormState("build a REST API") + pf.Fields[1].OptionIndex = 1 // "module" + pf.Fields[2].Input.SetValue("keep backward compat") + + prompt := pf.AssemblePrompt() + + if !strings.Contains(prompt, "build a REST API") { + t.Error("prompt should contain task") + } + if !strings.Contains(prompt, "module") { + t.Error("prompt should contain scope") + } + if !strings.Contains(prompt, "keep backward compat") { + t.Error("prompt should contain focus") + } + if !strings.Contains(prompt, "step-by-step plan") { + t.Error("prompt should contain plan instruction") + } +} + +func TestPlanForm_AssemblePrompt_NoFocus(t *testing.T) { + pf := NewPlanFormState("fix the bug") + + prompt := pf.AssemblePrompt() + + if !strings.Contains(prompt, "fix the bug") { + t.Error("prompt should contain task") + } + if strings.Contains(prompt, "Focus:") { + t.Error("prompt should not contain Focus when empty") + } +} + +func TestPlanForm_OpenClose(t *testing.T) { + t.Run("open_sets_overlay", func(t *testing.T) { + m := newTestModel(t) + m.openPlanForm("test task") + + if m.overlay != OverlayPlanForm { + t.Errorf("expected OverlayPlanForm, got %d", m.overlay) + } + if m.planFormState == nil { + t.Fatal("planFormState should not be nil") + } + if m.planFormState.Fields[0].Input.Value() != "test task" { + t.Error("task should be pre-filled") + } + }) + + t.Run("close_resets_state", func(t *testing.T) { + m := newTestModel(t) + m.planFormState = NewPlanFormState("test") + m.overlay = OverlayPlanForm + + m.closePlanForm() + + if m.planFormState != nil { + t.Error("planFormState should be nil after close") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } + }) +} + +func TestPlanForm_EscCancels(t *testing.T) { + m := newTestModel(t) + m.openPlanForm("some task") + + updated, _ := m.Update(escKey()) + m = updated.(*Model) + + if m.planFormState != nil { + t.Error("ESC should close plan form") + } + if m.overlay != OverlayNone { + t.Errorf("overlay should be OverlayNone, got %d", m.overlay) + } +} + +func TestPlanForm_FieldNavigation(t *testing.T) { + m := newTestModel(t) + m.openPlanForm("task") + + // Initially on field 0. + if m.planFormState.ActiveField != 0 { + t.Fatalf("expected active field 0, got %d", m.planFormState.ActiveField) + } + + // Tab advances to field 1. + updated, _ := m.Update(tabKey()) + m = updated.(*Model) + + if m.planFormState.ActiveField != 1 { + t.Errorf("expected active field 1, got %d", m.planFormState.ActiveField) + } + + // Tab again advances to field 2. + updated, _ = m.Update(tabKey()) + m = updated.(*Model) + + if m.planFormState.ActiveField != 2 { + t.Errorf("expected active field 2, got %d", m.planFormState.ActiveField) + } + + // Tab at last field stays on last field. + updated, _ = m.Update(tabKey()) + m = updated.(*Model) + + if m.planFormState.ActiveField != 2 { + t.Errorf("expected active field to stay at 2, got %d", m.planFormState.ActiveField) + } +} + +func TestPlanForm_SelectFieldLeftRight(t *testing.T) { + m := newTestModel(t) + m.openPlanForm("task") + + // Tab to scope field. + updated, _ := m.Update(tabKey()) + m = updated.(*Model) + + if m.planFormState.ActiveField != 1 { + t.Fatalf("expected field 1, got %d", m.planFormState.ActiveField) + } + if m.planFormState.Fields[1].OptionIndex != 0 { + t.Fatalf("expected option 0, got %d", m.planFormState.Fields[1].OptionIndex) + } + + // Right should advance to option 1. + updated, _ = m.Update(rightKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 1 { + t.Errorf("expected option 1 after right, got %d", m.planFormState.Fields[1].OptionIndex) + } + + // Left should go back to option 0. + updated, _ = m.Update(leftKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 0 { + t.Errorf("expected option 0 after left, got %d", m.planFormState.Fields[1].OptionIndex) + } +} + +func TestPlanForm_SelectFieldBounds(t *testing.T) { + m := newTestModel(t) + m.openPlanForm("task") + + // Tab to scope field. + updated, _ := m.Update(tabKey()) + m = updated.(*Model) + + // Up at 0 stays at 0. + updated, _ = m.Update(upKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 0 { + t.Errorf("expected option 0 after up at boundary, got %d", m.planFormState.Fields[1].OptionIndex) + } + + // Left at 0 stays at 0. + updated, _ = m.Update(leftKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 0 { + t.Errorf("expected option 0 after left at boundary, got %d", m.planFormState.Fields[1].OptionIndex) + } + + // Navigate to last option. + updated, _ = m.Update(downKey()) + m = updated.(*Model) + updated, _ = m.Update(downKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 2 { + t.Fatalf("expected option 2, got %d", m.planFormState.Fields[1].OptionIndex) + } + + // Down at last stays at last. + updated, _ = m.Update(downKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 2 { + t.Errorf("expected option 2 after down at boundary, got %d", m.planFormState.Fields[1].OptionIndex) + } + + // Right at last stays at last. + updated, _ = m.Update(rightKey()) + m = updated.(*Model) + + if m.planFormState.Fields[1].OptionIndex != 2 { + t.Errorf("expected option 2 after right at boundary, got %d", m.planFormState.Fields[1].OptionIndex) + } +} + +func TestPlanForm_SelectField(t *testing.T) { + m := newTestModel(t) + m.openPlanForm("task") + + // Navigate to scope field (index 1). + m.Update(tabKey()) + updated, _ := m.Update(tabKey()) + m = updated.(*Model) + + // Oops, we need to re-get m after first tab. Let me redo: + m2 := newTestModel(t) + m2.openPlanForm("task") + + // Tab to scope field. + updated, _ = m2.Update(tabKey()) + m2 = updated.(*Model) + + if m2.planFormState.ActiveField != 1 { + t.Fatalf("expected field 1, got %d", m2.planFormState.ActiveField) + } + + // Down should cycle scope option. + if m2.planFormState.Fields[1].OptionIndex != 0 { + t.Fatalf("expected option 0, got %d", m2.planFormState.Fields[1].OptionIndex) + } + + updated, _ = m2.Update(downKey()) + m2 = updated.(*Model) + + if m2.planFormState.Fields[1].OptionIndex != 1 { + t.Errorf("expected option 1 after down, got %d", m2.planFormState.Fields[1].OptionIndex) + } +} diff --git a/internal/tui/progress.go b/internal/tui/progress.go new file mode 100644 index 0000000..5262e2a --- /dev/null +++ b/internal/tui/progress.go @@ -0,0 +1,161 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + "charm.land/lipgloss/v2" +) + +// ProgressItem tracks progress for a long-running operation. +type ProgressItem struct { + ID string + Name string + Total float64 + Completed float64 + Started int64 // Unix timestamp +} + +// ProgressTracker manages multiple progress items. +type ProgressTracker struct { + items map[string]*ProgressItem + isDark bool + styles ProgressStyles +} + +// ProgressStyles holds styling for progress display. +type ProgressStyles struct { + Bar lipgloss.Style + Label lipgloss.Style + Percent lipgloss.Style + Completed lipgloss.Style + Empty lipgloss.Style +} + +// DefaultProgressStyles returns default styles. +func DefaultProgressStyles(isDark bool) ProgressStyles { + if isDark { + return ProgressStyles{ + Bar: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")), + Label: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Percent: lipgloss.NewStyle().Foreground(lipgloss.Color("#81a1c1")), + Completed: lipgloss.NewStyle().Foreground(lipgloss.Color("#a3be8c")), + Empty: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + } + } + return ProgressStyles{ + Bar: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f8f")), + Label: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Percent: lipgloss.NewStyle().Foreground(lipgloss.Color("#5e81ac")), + Completed: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f38")), + Empty: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + } +} + +// NewProgressTracker creates a new progress tracker. +func NewProgressTracker(isDark bool) *ProgressTracker { + return &ProgressTracker{ + items: make(map[string]*ProgressItem), + isDark: isDark, + styles: DefaultProgressStyles(isDark), + } +} + +// SetDark updates theme. +func (pt *ProgressTracker) SetDark(isDark bool) { + pt.isDark = isDark + pt.styles = DefaultProgressStyles(isDark) +} + +// Start begins tracking a new progress item. +func (pt *ProgressTracker) Start(id, name string, total float64) { + pt.items[id] = &ProgressItem{ + ID: id, + Name: name, + Total: total, + Started: time.Now().Unix(), + } +} + +// Update sets the current progress. +func (pt *ProgressTracker) Update(id string, completed float64) { + if item, ok := pt.items[id]; ok { + item.Completed = completed + } +} + +// Complete marks an item as done. +func (pt *ProgressTracker) Complete(id string) { + if item, ok := pt.items[id]; ok { + item.Completed = item.Total + } +} + +// Remove stops tracking an item. +func (pt *ProgressTracker) Remove(id string) { + delete(pt.items, id) +} + +// Get returns a progress item by ID. +func (pt *ProgressTracker) Get(id string) (*ProgressItem, bool) { + item, ok := pt.items[id] + return item, ok +} + +// All returns all progress items. +func (pt *ProgressTracker) All() []*ProgressItem { + result := make([]*ProgressItem, 0, len(pt.items)) + for _, item := range pt.items { + result = append(result, item) + } + return result +} + +// Render returns the progress bar view for an item. +func (pt *ProgressTracker) Render(id string, width int) string { + item, ok := pt.items[id] + if !ok || item.Total == 0 { + return "" + } + + percent := int((item.Completed / item.Total) * 100) + barWidth := width - 20 // Leave room for label and percentage + if barWidth < 5 { + barWidth = 5 + } + + filled := int((item.Completed / item.Total) * float64(barWidth)) + bar := strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) + + label := pt.styles.Label.Render(item.Name) + percentStr := pt.styles.Percent.Render(fmt.Sprintf("%d%%", percent)) + + return fmt.Sprintf("%s [%s] %s", label, bar, percentStr) +} + +// RenderSimple returns a simple progress bar without the label. +func (pt *ProgressTracker) RenderSimple(id string, width int) string { + item, ok := pt.items[id] + if !ok || item.Total == 0 { + return "" + } + + percent := int((item.Completed / item.Total) * 100) + barWidth := width - 8 // Leave room for percentage + if barWidth < 5 { + barWidth = 5 + } + + filled := int((item.Completed / item.Total) * float64(barWidth)) + bar := strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) + + percentStr := pt.styles.Percent.Render(fmt.Sprintf("%d%%", percent)) + + return fmt.Sprintf("[%s] %s", bar, percentStr) +} + +// HasItems returns true if there are any progress items. +func (pt *ProgressTracker) HasItems() bool { + return len(pt.items) > 0 +} diff --git a/internal/tui/prompthistory.go b/internal/tui/prompthistory.go new file mode 100644 index 0000000..32bf73e --- /dev/null +++ b/internal/tui/prompthistory.go @@ -0,0 +1,50 @@ +package tui + +import ( + "encoding/json" + "os" + "path/filepath" +) + +const promptHistoryMax = 100 + +// DefaultPromptHistoryPath returns the path for persistent prompt history. +func DefaultPromptHistoryPath() string { + home, err := os.UserHomeDir() + if err != nil { + return "prompt_history.json" + } + return filepath.Join(home, ".config", "ai-agent", "prompt_history.json") +} + +// LoadPromptHistory reads saved prompt history from path. Returns nil on error or missing file. +func LoadPromptHistory(path string) ([]string, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var list []string + if err := json.Unmarshal(data, &list); err != nil { + return nil, err + } + if len(list) > promptHistoryMax { + list = list[len(list)-promptHistoryMax:] + } + return list, nil +} + +// SavePromptHistory writes prompt history to path, creating the directory if needed. +func SavePromptHistory(path string, items []string) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + data, err := json.MarshalIndent(items, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} diff --git a/internal/tui/resize.go b/internal/tui/resize.go new file mode 100644 index 0000000..571d788 --- /dev/null +++ b/internal/tui/resize.go @@ -0,0 +1,117 @@ +package tui + +import ( + "charm.land/lipgloss/v2" +) + +// PanelResizer manages the side panel resizing state. +type PanelResizer struct { + isResizing bool + resizeStartX int + originalWidth int + minWidth int + maxWidth int + isDark bool + styles ResizeStyles +} + +// ResizeStyles holds styling for resize indicators. +type ResizeStyles struct { + Handle lipgloss.Style + HandleActive lipgloss.Style +} + +// DefaultResizeStyles returns default styles. +func DefaultResizeStyles(isDark bool) ResizeStyles { + if isDark { + return ResizeStyles{ + Handle: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + HandleActive: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")), + } + } + return ResizeStyles{ + Handle: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + HandleActive: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f8f")), + } +} + +// NewPanelResizer creates a new panel resizer. +func NewPanelResizer(minWidth, maxWidth int, isDark bool) *PanelResizer { + return &PanelResizer{ + isResizing: false, + resizeStartX: 0, + originalWidth: 30, + minWidth: minWidth, + maxWidth: maxWidth, + isDark: isDark, + styles: DefaultResizeStyles(isDark), + } +} + +// SetDark updates theme. +func (pr *PanelResizer) SetDark(isDark bool) { + pr.isDark = isDark + pr.styles = DefaultResizeStyles(isDark) +} + +// StartResize begins a resize operation. +func (pr *PanelResizer) StartResize(x int, currentWidth int) { + pr.isResizing = true + pr.resizeStartX = x + pr.originalWidth = currentWidth +} + +// UpdateResize updates the panel width based on mouse movement. +func (pr *PanelResizer) UpdateResize(x int) int { + if !pr.isResizing { + return pr.originalWidth + } + + delta := x - pr.resizeStartX + newWidth := pr.originalWidth + delta + + // Clamp to min/max + if newWidth < pr.minWidth { + newWidth = pr.minWidth + } + if newWidth > pr.maxWidth { + newWidth = pr.maxWidth + } + + return newWidth +} + +// EndResize ends the resize operation. +func (pr *PanelResizer) EndResize() { + pr.isResizing = false +} + +// IsResizing returns true if currently resizing. +func (pr *PanelResizer) IsResizing() bool { + return pr.isResizing +} + +// RenderHandle returns the resize handle visual. +func (pr *PanelResizer) RenderHandle(height int, isActive bool) string { + style := pr.styles.Handle + if isActive || pr.isResizing { + style = pr.styles.HandleActive + } + + // Create a vertical bar with grip dots + var b string + for i := 0; i < height; i++ { + if i%2 == 0 { + b += style.Render("│") + } else { + b += pr.styles.Handle.Render("│") + } + } + return b +} + +// CanResizeAt checks if x is within the resize zone (3 characters from divider). +func (pr *PanelResizer) CanResizeAt(x, dividerX int) bool { + // Resize zone is 3 chars to the left of the divider + return x >= dividerX-3 && x <= dividerX +} diff --git a/internal/tui/scramble.go b/internal/tui/scramble.go new file mode 100644 index 0000000..935a6a3 --- /dev/null +++ b/internal/tui/scramble.go @@ -0,0 +1,125 @@ +package tui + +import ( + "math/rand" + "time" + + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/lucasb-eyer/go-colorful" +) + +// scrambleChars is the character set for the scramble animation. +const scrambleChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*" + +// scrambleWidth is the number of characters in the animation. +const scrambleWidth = 12 + +// ScrambleTickMsg triggers the next animation frame. +type ScrambleTickMsg struct { + ID int +} + +// ScrambleModel is a custom BubbleTea component that renders a gradient +// character scramble animation, inspired by Charmbracelet's Crush CLI. +type ScrambleModel struct { + id int + chars []rune + visible int + colorFrom colorful.Color + colorTo colorful.Color + isDark bool + rng *rand.Rand +} + +// NewScrambleModel creates a new scramble animation with theme-appropriate colors. +func NewScrambleModel(isDark bool) ScrambleModel { + s := ScrambleModel{ + id: 1, + chars: make([]rune, scrambleWidth), + rng: rand.New(rand.NewSource(time.Now().UnixNano())), + } + s.SetDark(isDark) + s.randomizeChars() + return s +} + +// SetDark updates the gradient colors for the current theme. +func (s *ScrambleModel) SetDark(isDark bool) { + s.isDark = isDark + if isDark { + // Dark theme: cool blue → warm purple gradient + s.colorFrom, _ = colorful.Hex("#88c0d0") // Nord frost + s.colorTo, _ = colorful.Hex("#b48ead") // Nord purple + } else { + // Light theme: teal → indigo + s.colorFrom, _ = colorful.Hex("#0088bb") + s.colorTo, _ = colorful.Hex("#6644aa") + } +} + +// Reset resets the animation (new ID + zero visible). Call when agent starts. +func (s *ScrambleModel) Reset() { + s.id++ + s.visible = 0 + s.randomizeChars() +} + +// Tick schedules the next animation frame (~15 FPS = 66ms). +func (s ScrambleModel) Tick() tea.Cmd { + id := s.id + return tea.Tick(66*time.Millisecond, func(time.Time) tea.Msg { + return ScrambleTickMsg{ID: id} + }) +} + +// Update processes tick messages and advances the animation. +func (s ScrambleModel) Update(msg tea.Msg) (ScrambleModel, tea.Cmd) { + if tick, ok := msg.(ScrambleTickMsg); ok { + if tick.ID != s.id { + return s, nil // stale tick, ignore + } + s.randomizeChars() + if s.visible < scrambleWidth { + s.visible++ + } + return s, s.Tick() + } + return s, nil +} + +// View renders the visible characters with an HCL gradient. +func (s ScrambleModel) View() string { + if s.visible == 0 { + return "" + } + + // NO_COLOR fallback + if noColor { + dots := "" + for i := 0; i < s.visible && i < scrambleWidth; i++ { + dots += "." + } + return dots + } + + result := "" + for i := 0; i < s.visible && i < len(s.chars); i++ { + // Calculate gradient position + t := float64(i) / float64(scrambleWidth-1) + c := s.colorFrom.BlendHcl(s.colorTo, t).Clamped() + hex := c.Hex() + + style := lipgloss.NewStyle().Foreground(lipgloss.Color(hex)) + result += style.Render(string(s.chars[i])) + } + return result +} + +// randomizeChars fills the chars slice with random characters. +func (s *ScrambleModel) randomizeChars() { + runes := []rune(scrambleChars) + for i := range s.chars { + s.chars[i] = runes[s.rng.Intn(len(runes))] + } +} diff --git a/internal/tui/scramble_test.go b/internal/tui/scramble_test.go new file mode 100644 index 0000000..4797ddc --- /dev/null +++ b/internal/tui/scramble_test.go @@ -0,0 +1,114 @@ +package tui + +import ( + "testing" +) + +func TestNewScrambleModel(t *testing.T) { + s := NewScrambleModel(true) + + if s.visible != 0 { + t.Errorf("expected visible=0, got %d", s.visible) + } + if len(s.chars) != scrambleWidth { + t.Errorf("expected %d chars, got %d", scrambleWidth, len(s.chars)) + } + if s.id != 1 { + t.Errorf("expected id=1, got %d", s.id) + } + if s.rng == nil { + t.Error("expected rng to be initialized") + } +} + +func TestScrambleUpdate(t *testing.T) { + s := NewScrambleModel(true) + + // Matching tick should advance visible + tick := ScrambleTickMsg{ID: s.id} + s2, cmd := s.Update(tick) + if s2.visible != 1 { + t.Errorf("expected visible=1 after tick, got %d", s2.visible) + } + if cmd == nil { + t.Error("expected non-nil cmd after matching tick") + } + + // Stale tick (wrong ID) should be ignored + staleTick := ScrambleTickMsg{ID: s.id + 999} + s3, cmd := s2.Update(staleTick) + if s3.visible != s2.visible { + t.Errorf("expected visible unchanged after stale tick, got %d", s3.visible) + } + if cmd != nil { + t.Error("expected nil cmd after stale tick") + } +} + +func TestScrambleView(t *testing.T) { + s := NewScrambleModel(true) + + // Empty at visible=0 + if v := s.View(); v != "" { + t.Errorf("expected empty view at visible=0, got %q", v) + } + + // After ticks, should produce non-empty output + tick := ScrambleTickMsg{ID: s.id} + s, _ = s.Update(tick) + s, _ = s.Update(ScrambleTickMsg{ID: s.id}) + if v := s.View(); v == "" { + t.Error("expected non-empty view after ticks") + } +} + +func TestScrambleReset(t *testing.T) { + s := NewScrambleModel(true) + + // Advance some ticks + tick := ScrambleTickMsg{ID: s.id} + s, _ = s.Update(tick) + s, _ = s.Update(ScrambleTickMsg{ID: s.id}) + + oldID := s.id + s.Reset() + + if s.visible != 0 { + t.Errorf("expected visible=0 after reset, got %d", s.visible) + } + if s.id <= oldID { + t.Errorf("expected id to increment after reset, got %d (was %d)", s.id, oldID) + } +} + +func TestScrambleSetDark(t *testing.T) { + s := NewScrambleModel(true) + + // Store dark colors + darkFrom := s.colorFrom + darkTo := s.colorTo + + // Switch to light + s.SetDark(false) + if s.colorFrom == darkFrom { + t.Error("expected colorFrom to change for light theme") + } + if s.colorTo == darkTo { + t.Error("expected colorTo to change for light theme") + } + if s.isDark { + t.Error("expected isDark=false after SetDark(false)") + } + + // Switch back to dark + s.SetDark(true) + if s.colorFrom != darkFrom { + t.Error("expected colorFrom to match original dark theme") + } + if s.colorTo != darkTo { + t.Error("expected colorTo to match original dark theme") + } + if !s.isDark { + t.Error("expected isDark=true after SetDark(true)") + } +} diff --git a/internal/tui/scroll_anchor_test.go b/internal/tui/scroll_anchor_test.go new file mode 100644 index 0000000..e700b34 --- /dev/null +++ b/internal/tui/scroll_anchor_test.go @@ -0,0 +1,225 @@ +package tui + +import ( + "testing" + + tea "charm.land/bubbletea/v2" + + "ai-agent/internal/agent" + "ai-agent/internal/command" +) + +func TestScrollAnchor_Initialization(t *testing.T) { + m := newTestModel(t) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*Model) + if !m.ready { + t.Fatal("viewport should be ready after WindowSizeMsg") + } + if !m.anchorActive { + t.Error("anchorActive should be true after initialization") + } + if m.scrollAnchor != 0 { + t.Errorf("scrollAnchor should be 0, got %d", m.scrollAnchor) + } + if m.lastContentHeight != 0 { + t.Errorf("lastContentHeight should be 0, got %d", m.lastContentHeight) + } +} + +func TestScrollAnchor_MouseWheelUp(t *testing.T) { + m := newTestModel(t) + m.anchorActive = true + m.userScrolledUp = false + var longContent string + for i := 0; i < 100; i++ { + longContent += "line " + string(rune(i)) + "\n" + } + m.viewport.SetContent(longContent) + m.viewport.GotoBottom() + if !m.viewport.AtBottom() { + t.Fatal("viewport should be at bottom before scroll") + } + updated, _ := m.Update(tea.MouseWheelMsg{X: 0, Y: 0, Button: tea.MouseWheelUp}) + m = updated.(*Model) + if m.anchorActive { + t.Error("anchorActive should be false after scrolling up") + } + if !m.userScrolledUp { + t.Error("userScrolledUp should be true after scrolling up") + } + if m.scrollAnchor <= 0 { + t.Error("scrollAnchor should be positive after scrolling up") + } +} + +func TestScrollAnchor_MouseWheelDown(t *testing.T) { + m := newTestModel(t) + m.anchorActive = false + m.userScrolledUp = true + m.scrollAnchor = 10 + m.viewport.SetContent("short content") + updated, _ := m.Update(tea.MouseWheelMsg{X: 0, Y: 0, Button: tea.MouseWheelDown}) + m = updated.(*Model) + if !m.anchorActive { + t.Error("anchorActive should be true when at bottom") + } +} + +func TestScrollAnchor_StreamTextMsg(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + m.entries = []ChatEntry{ + {Kind: "assistant", Content: "Initial response"}, + } + m.viewport.SetContent(m.renderEntries()) + m.anchorActive = true + updated, _ := m.Update(StreamTextMsg{Text: "more"}) + m = updated.(*Model) + if !m.viewport.AtBottom() { + t.Error("viewport should be at bottom when anchor is active") + } + m.anchorActive = false + m.viewport.GotoTop() + updated, _ = m.Update(StreamTextMsg{Text: "even more"}) + m = updated.(*Model) + if m.viewport.AtBottom() { + t.Log("Note: viewport scrolled to bottom even with anchor inactive") + } +} + +func TestScrollAnchor_AgentDoneMsg(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + m.anchorActive = false + m.userScrolledUp = true + m.scrollAnchor = 10 + updated, _ := m.Update(AgentDoneMsg{}) + m = updated.(*Model) + if m.state != StateIdle { + t.Errorf("state should be StateIdle, got %d", m.state) + } + if !m.anchorActive { + t.Error("anchorActive should be reset to true after AgentDoneMsg") + } + if m.scrollAnchor != 0 { + t.Errorf("scrollAnchor should be reset to 0, got %d", m.scrollAnchor) + } + if m.userScrolledUp { + t.Error("userScrolledUp should be reset to false") + } +} + +func TestScrollAnchor_ToolMessages(t *testing.T) { + m := newTestModel(t) + m.state = StateStreaming + m.anchorActive = true + updated, _ := m.Update(ToolCallStartMsg{ + Name: "read_file", + Args: map[string]any{"path": "test.go"}, + StartTime: testTime, + }) + m = updated.(*Model) + if !m.anchorActive { + t.Error("anchorActive should remain true after ToolCallStartMsg") + } + updated, _ = m.Update(ToolCallResultMsg{ + Name: "read_file", + Result: "file content", + IsError: false, + Duration: testDuration, + }) + m = updated.(*Model) + if !m.anchorActive { + t.Error("anchorActive should remain true after ToolCallResultMsg") + } +} + +func TestScrollAnchor_SystemMessages(t *testing.T) { + m := newTestModel(t) + m.anchorActive = true + updated, _ := m.Update(SystemMessageMsg{Msg: "system message"}) + m = updated.(*Model) + if !m.anchorActive { + t.Error("anchorActive should remain true after SystemMessageMsg") + } + updated, _ = m.Update(ErrorMsg{Msg: "error message"}) + m = updated.(*Model) + if !m.anchorActive { + t.Error("anchorActive should remain true after ErrorMsg") + } +} + +func TestScrollAnchor_WindowResize(t *testing.T) { + m := newTestModel(t) + updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + m = updated.(*Model) + if !m.anchorActive { + t.Fatal("anchorActive should be true after initial sizing") + } +} + +func TestCheckAutoScroll_ReenablesAnchorAtBottom(t *testing.T) { + m := newTestModel(t) + m.anchorActive = false + m.userScrolledUp = true + m.scrollAnchor = 10 + m.viewport.SetContent("short content") + m.viewport.GotoBottom() + m.checkAutoScroll() + if !m.anchorActive { + t.Error("checkAutoScroll should set anchorActive to true when at bottom") + } + if m.userScrolledUp { + t.Error("checkAutoScroll should set userScrolledUp to false when at bottom") + } + if m.scrollAnchor != 0 { + t.Error("checkAutoScroll should reset scrollAnchor to 0 when at bottom") + } +} + +func TestScrollAnchor_ViewportAtBottom(t *testing.T) { + m := newTestModel(t) + m.viewport.SetContent("line1\nline2\nline3") + if !m.viewport.AtBottom() { + t.Error("viewport should be at bottom with short content") + } + var longContent string + for i := 0; i < 100; i++ { + longContent += "line " + string(rune(i)) + "\n" + } + m.viewport.SetContent(longContent) + m.viewport.GotoBottom() + if !m.viewport.AtBottom() { + t.Error("viewport should be at bottom after GotoBottom()") + } + m.viewport.GotoTop() + if m.viewport.AtBottom() { + t.Error("viewport should not be at bottom after scrolling to top") + } +} + +func BenchmarkScrollAnchor_Performance(b *testing.B) { + m := newTestModelB(b) + m.anchorActive = true + var longContent string + for i := 0; i < 100; i++ { + longContent += "line " + string(rune(i)) + "\n" + } + m.viewport.SetContent(longContent) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = m.viewport.AtBottom() + } +} + +func newTestModelB(b *testing.B) *Model { + reg := command.NewRegistry() + command.RegisterBuiltins(reg) + completer := NewCompleter(reg, []string{"model-a", "model-b"}, []string{"skill-a"}, []string{"agent-x"}, nil) + ag := agent.New(nil, nil, 0) + m := New(ag, reg, nil, completer, nil, nil, nil) + m.initializing = false + updated, _ := m.Update(tea.WindowSizeMsg{Width: 80, Height: 24}) + return updated.(*Model) +} diff --git a/internal/tui/scroll_test.go b/internal/tui/scroll_test.go new file mode 100644 index 0000000..7719413 --- /dev/null +++ b/internal/tui/scroll_test.go @@ -0,0 +1,182 @@ +package tui + +import ( + "strings" + "testing" + + "charm.land/lipgloss/v2" +) + +// TestNoHorizontalScroll verifies that rendered content never exceeds viewport width +func TestNoHorizontalScroll(t *testing.T) { + tests := []struct { + name string + screenWidth int + panelVisible bool + content string + }{ + { + name: "long word with panel", + screenWidth: 120, + panelVisible: true, + content: strings.Repeat("x", 150), + }, + { + name: "multiple long words without panel", + screenWidth: 100, + panelVisible: false, + content: strings.Repeat("superlongword ", 10), + }, + { + name: "code block with panel", + screenWidth: 120, + panelVisible: true, + content: "```\n" + strings.Repeat("x", 100) + "\n```", + }, + { + name: "URL without panel", + screenWidth: 80, + panelVisible: false, + content: "https://example.com/" + strings.Repeat("verylongpathsegment/", 5), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Calculate panel width + panelWidth := 0 + if tt.panelVisible { + panelWidth = 30 + if tt.screenWidth < 100 { + panelWidth = 25 + } else if tt.screenWidth > 160 { + panelWidth = 40 + } + } + + // Calculate viewport width (from model.go) + viewportWidth := tt.screenWidth - 1 + if tt.panelVisible { + viewportWidth = tt.screenWidth - panelWidth - 2 + } + if viewportWidth < 20 { + viewportWidth = 20 + } + + // Calculate content width (from view.go) + contentW := tt.screenWidth - 4 + if tt.panelVisible { + contentW = tt.screenWidth - panelWidth - 5 + } + if contentW < 20 { + contentW = 20 + } + + // Wrap the content + wrapped := wrapText(tt.content, contentW) + + // Check each line + lines := strings.Split(wrapped, "\n") + for i, line := range lines { + // Measure visible width (lipgloss.Width handles styling) + lineWidth := lipgloss.Width(line) + if lineWidth > viewportWidth { + t.Errorf("line %d width %d exceeds viewport width %d: %q", + i, lineWidth, viewportWidth, line[:min(50, len(line))]) + } + if lineWidth > contentW { + t.Errorf("line %d width %d exceeds content width %d", + i, lineWidth, contentW) + } + } + }) + } +} + +// TestResponsivePanelToggle verifies no scroll when toggling panel +func TestResponsivePanelToggle(t *testing.T) { + screenWidth := 120 + content := strings.Repeat("longword ", 20) + + // Calculate widths with panel + panelWidth := 30 + viewportWithPanel := screenWidth - panelWidth - 2 + contentWithPanel := screenWidth - panelWidth - 5 + + // Calculate widths without panel + viewportWithoutPanel := screenWidth - 1 + contentWithoutPanel := screenWidth - 4 + + // Wrap content for both scenarios + wrappedWithPanel := wrapText(content, contentWithPanel) + wrappedWithoutPanel := wrapText(content, contentWithoutPanel) + + // Verify both fit within their respective viewports + for i, line := range strings.Split(wrappedWithPanel, "\n") { + if lipgloss.Width(line) > viewportWithPanel { + t.Errorf("with panel: line %d exceeds viewport", i) + } + } + + for i, line := range strings.Split(wrappedWithoutPanel, "\n") { + if lipgloss.Width(line) > viewportWithoutPanel { + t.Errorf("without panel: line %d exceeds viewport", i) + } + } + + // Verify that content without panel is wider (better use of space) + if contentWithoutPanel <= contentWithPanel { + t.Error("content width should increase when panel is hidden") + } +} + +// TestEdgeCases verifies width handling at boundary conditions +func TestEdgeCases(t *testing.T) { + tests := []struct { + name string + screenWidth int + expectMin bool // whether minimum width constraint should kick in + }{ + {"minimum viable", 46, true}, // 25 (panel) + 1 + 20 (min viewport) + {"just above min", 50, false}, + {"exactly 100", 100, false}, + {"exactly 160", 160, false}, + {"very large", 300, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + panelWidth := 30 + if tt.screenWidth < 100 { + panelWidth = 25 + } else if tt.screenWidth > 160 { + panelWidth = 40 + } + + viewportWidth := tt.screenWidth - panelWidth - 2 + if viewportWidth < 20 { + viewportWidth = 20 + } + + if tt.expectMin && viewportWidth == 20 { + // Expected minimum enforcement + if tt.screenWidth-panelWidth-2 >= 20 { + t.Error("expected minimum width enforcement but calculation would allow larger") + } + } + + // Verify viewport never exceeds available space (unless minimum enforced) + maxAllowed := tt.screenWidth - panelWidth - 1 + if viewportWidth > maxAllowed && !tt.expectMin { + t.Errorf("viewport %d exceeds max allowed %d", viewportWidth, maxAllowed) + } + }) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/tui/search.go b/internal/tui/search.go new file mode 100644 index 0000000..65fd6de --- /dev/null +++ b/internal/tui/search.go @@ -0,0 +1,213 @@ +package tui + +import ( + "charm.land/bubbles/v2/textinput" + "charm.land/lipgloss/v2" +) + +// SearchState holds the state for conversation search. +type SearchState struct { + Input textinput.Model + Results []SearchResult + Index int + Active bool + CaseSensitive bool +} + +// SearchResult represents a single search match. +type SearchResult struct { + EntryIndex int + LineNum int + Content string + Start int + End int +} + +// SearchStyles holds styling for search UI. +type SearchStyles struct { + Input lipgloss.Style + Match lipgloss.Style + Result lipgloss.Style + Selected lipgloss.Style + Label lipgloss.Style + Hint lipgloss.Style +} + +// DefaultSearchStyles returns default styles. +func DefaultSearchStyles(isDark bool) SearchStyles { + if isDark { + return SearchStyles{ + Input: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")), + Match: lipgloss.NewStyle().Background(lipgloss.Color("#4c566a")).Foreground(lipgloss.Color("#eceff4")), + Result: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")), + Label: lipgloss.NewStyle().Foreground(lipgloss.Color("#81a1c1")), + Hint: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + } + } + return SearchStyles{ + Input: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f8f")), + Match: lipgloss.NewStyle().Background(lipgloss.Color("#d8dee9")).Foreground(lipgloss.Color("#2e3440")), + Result: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#4f8f8f")), + Label: lipgloss.NewStyle().Foreground(lipgloss.Color("#5e81ac")), + Hint: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + } +} + +// NewSearchState creates a new search state. +func NewSearchState() *SearchState { + ti := textinput.New() + ti.Placeholder = "Search conversation..." + ti.Focus() + ti.CharLimit = 256 + + return &SearchState{ + Input: ti, + Results: nil, + Index: 0, + Active: false, + } +} + +// Activate enables search mode. +func (s *SearchState) Activate() { + s.Active = true + s.Input.Focus() +} + +// Deactivate disables search mode. +func (s *SearchState) Deactivate() { + s.Active = false + s.Input.Blur() + s.Results = nil + s.Index = 0 +} + +// Search performs a search across chat entries. +func (s *SearchState) Search(entries []ChatEntry, query string) { + s.Results = nil + s.Index = 0 + + if query == "" { + return + } + + for entryIdx, entry := range entries { + content := entry.Content + if content == "" { + continue + } + + // Simple case-insensitive search + searchQuery := query + if !s.CaseSensitive { + searchQuery = toLower(query) + content = toLower(content) + } + + start := 0 + for { + idx := indexOf(content, searchQuery, start) + if idx == -1 { + break + } + + // Get surrounding context (40 chars before and after) + entryContent := entries[entryIdx].Content + ctxStart := idx - 40 + if ctxStart < 0 { + ctxStart = 0 + } + ctxEnd := idx + len(query) + 40 + if ctxEnd > len(entryContent) { + ctxEnd = len(entryContent) + } + + context := entryContent[ctxStart:ctxEnd] + if ctxStart > 0 { + context = "..." + context + } + if ctxEnd < len(entryContent) { + context = context + "..." + } + + s.Results = append(s.Results, SearchResult{ + EntryIndex: entryIdx, + LineNum: countNewlines(entryContent[:idx]), + Content: context, + Start: idx, + End: idx + len(query), + }) + + start = idx + len(query) + } + } +} + +// NextResult moves to the next search result. +func (s *SearchState) NextResult() { + if len(s.Results) == 0 { + return + } + s.Index = (s.Index + 1) % len(s.Results) +} + +// PrevResult moves to the previous search result. +func (s *SearchState) PrevResult() { + if len(s.Results) == 0 { + return + } + s.Index-- + if s.Index < 0 { + s.Index = len(s.Results) - 1 + } +} + +// CurrentResult returns the currently selected result. +func (s *SearchState) CurrentResult() *SearchResult { + if len(s.Results) == 0 || s.Index >= len(s.Results) { + return nil + } + return &s.Results[s.Index] +} + +// HasResults returns true if there are search results. +func (s *SearchState) HasResults() bool { + return len(s.Results) > 0 +} + +// Helper functions to avoid import conflicts. +func toLower(s string) string { + result := make([]byte, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + result[i] = c + } + return string(result) +} + +func indexOf(s, substr string, start int) int { + if start >= len(s) { + return -1 + } + for i := start; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} + +func countNewlines(s string) int { + count := 0 + for _, c := range s { + if c == '\n' { + count++ + } + } + return count +} diff --git a/internal/tui/session.go b/internal/tui/session.go new file mode 100644 index 0000000..d386fab --- /dev/null +++ b/internal/tui/session.go @@ -0,0 +1,137 @@ +package tui + +import ( + "encoding/json" + "fmt" + "os/exec" + "strconv" + "strings" +) + +type SessionListItem struct { + ID int `json:"id"` + Title string `json:"title"` + CreatedAt string `json:"created_at"` +} + +type SessionNote struct { + ID int `json:"id"` + Title string `json:"title"` + Content string `json:"content"` + Tags []string `json:"tags"` +} + +func notedAvailable() bool { + _, err := exec.LookPath("noted") + return err == nil +} + +func createSessionNote(timestamp string) (int, error) { + title := fmt.Sprintf("ai-agent session %s", timestamp) + cmd := exec.Command("noted", "add", "-t", title, "-c", "(session in progress)", "--tags", "ai-agent,session", "--json") + out, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("noted add: %w", err) + } + var result struct { + ID int `json:"id"` + } + if err := json.Unmarshal(out, &result); err != nil { + return 0, fmt.Errorf("parse noted output: %w", err) + } + return result.ID, nil +} + +func updateSessionNote(id int, content string) error { + cmd := exec.Command("noted", "edit", strconv.Itoa(id), "-c", content) + return cmd.Run() +} + +func listSessions(limit int) ([]SessionListItem, error) { + cmd := exec.Command("noted", "list", "--tag", "session", "--json", "-n", strconv.Itoa(limit)) + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("noted list: %w", err) + } + var sessions []SessionListItem + if err := json.Unmarshal(out, &sessions); err != nil { + return nil, fmt.Errorf("parse noted output: %w", err) + } + return sessions, nil +} + +func loadSession(id int) (*SessionNote, error) { + cmd := exec.Command("noted", "show", strconv.Itoa(id), "--json") + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("noted show: %w", err) + } + var note SessionNote + if err := json.Unmarshal(out, ¬e); err != nil { + return nil, fmt.Errorf("parse noted output: %w", err) + } + return ¬e, nil +} + +func serializeEntries(entries []ChatEntry) string { + var b strings.Builder + for _, e := range entries { + switch e.Kind { + case "user": + b.WriteString("## User\n\n") + b.WriteString(e.Content) + b.WriteString("\n\n") + case "assistant": + b.WriteString("## Assistant\n\n") + b.WriteString(e.Content) + b.WriteString("\n\n") + case "system": + b.WriteString("## System\n\n") + b.WriteString(e.Content) + b.WriteString("\n\n") + case "error": + b.WriteString("## Error\n\n") + b.WriteString(e.Content) + b.WriteString("\n\n") + } + } + return strings.TrimRight(b.String(), "\n") +} + +func deserializeEntries(content string) []ChatEntry { + if content == "" { + return nil + } + var entries []ChatEntry + sections := strings.Split(content, "## ") + for _, section := range sections { + section = strings.TrimSpace(section) + if section == "" { + continue + } + nlIdx := strings.Index(section, "\n") + if nlIdx == -1 { + continue + } + header := strings.TrimSpace(section[:nlIdx]) + body := strings.TrimSpace(section[nlIdx+1:]) + var kind string + switch header { + case "User": + kind = "user" + case "Assistant": + kind = "assistant" + case "System": + kind = "system" + case "Error": + kind = "error" + default: + continue + } + entries = append(entries, ChatEntry{ + Kind: kind, + Content: body, + }) + } + return entries +} diff --git a/internal/tui/session_test.go b/internal/tui/session_test.go new file mode 100644 index 0000000..9a6cff0 --- /dev/null +++ b/internal/tui/session_test.go @@ -0,0 +1,85 @@ +package tui + +import "testing" + +func TestSerializeDeserialize_Roundtrip(t *testing.T) { + entries := []ChatEntry{ + {Kind: "user", Content: "Hello there"}, + {Kind: "assistant", Content: "Hi! How can I help?"}, + {Kind: "system", Content: "Model switched to qwen3"}, + } + + serialized := serializeEntries(entries) + deserialized := deserializeEntries(serialized) + + if len(deserialized) != len(entries) { + t.Fatalf("roundtrip length: got %d, want %d", len(deserialized), len(entries)) + } + + for i, e := range deserialized { + if e.Kind != entries[i].Kind { + t.Errorf("entry[%d] kind: got %q, want %q", i, e.Kind, entries[i].Kind) + } + if e.Content != entries[i].Content { + t.Errorf("entry[%d] content: got %q, want %q", i, e.Content, entries[i].Content) + } + } +} + +func TestSerializeEntries_Empty(t *testing.T) { + result := serializeEntries(nil) + if result != "" { + t.Errorf("nil entries should serialize to empty, got %q", result) + } +} + +func TestDeserializeEntries_Empty(t *testing.T) { + result := deserializeEntries("") + if result != nil { + t.Errorf("empty content should deserialize to nil, got %v", result) + } +} + +func TestDeserializeEntries_UnknownHeader(t *testing.T) { + content := "## Unknown\n\nSome content\n\n## User\n\nValid content" + result := deserializeEntries(content) + if len(result) != 1 { + t.Fatalf("should skip unknown headers, got %d entries", len(result)) + } + if result[0].Kind != "user" { + t.Errorf("should parse valid entry, got kind %q", result[0].Kind) + } +} + +func TestSerializeEntries_ErrorKind(t *testing.T) { + entries := []ChatEntry{ + {Kind: "error", Content: "Something went wrong"}, + } + serialized := serializeEntries(entries) + if serialized == "" { + t.Error("error entries should serialize") + } + + deserialized := deserializeEntries(serialized) + if len(deserialized) != 1 || deserialized[0].Kind != "error" { + t.Errorf("error entry should roundtrip, got %v", deserialized) + } +} + +func TestSerializeEntries_MultilineContent(t *testing.T) { + entries := []ChatEntry{ + {Kind: "user", Content: "line1\nline2\nline3"}, + } + serialized := serializeEntries(entries) + deserialized := deserializeEntries(serialized) + if len(deserialized) != 1 { + t.Fatalf("expected 1 entry, got %d", len(deserialized)) + } + if deserialized[0].Content != "line1\nline2\nline3" { + t.Errorf("multiline content should roundtrip, got %q", deserialized[0].Content) + } +} + +func TestNotedAvailable(t *testing.T) { + _ = notedAvailable() +} diff --git a/internal/tui/sessionspicker.go b/internal/tui/sessionspicker.go new file mode 100644 index 0000000..b1f2d31 --- /dev/null +++ b/internal/tui/sessionspicker.go @@ -0,0 +1,107 @@ +package tui + +import ( + "charm.land/bubbles/v2/list" + "charm.land/lipgloss/v2" +) + +// sessionItem implements list.DefaultItem for the sessions picker. +type sessionItem struct { + id int + title string + createdAt string +} + +func (i sessionItem) Title() string { + title := i.title + if len(title) > 40 { + title = title[:37] + "..." + } + return title +} + +func (i sessionItem) Description() string { + return i.createdAt +} + +func (i sessionItem) FilterValue() string { return i.title } + +// SessionsPickerState holds state for the sessions picker overlay. +type SessionsPickerState struct { + List list.Model + Sessions []SessionListItem +} + +// newSessionsPickerState creates a new SessionsPickerState with a bubbles list. +func newSessionsPickerState(sessions []SessionListItem, width int, isDark bool) *SessionsPickerState { + items := make([]list.Item, len(sessions)) + for i, s := range sessions { + items[i] = sessionItem{ + id: s.ID, + title: s.Title, + createdAt: s.CreatedAt, + } + } + + delegate := list.NewDefaultDelegate() + delegate.Styles = list.NewDefaultItemStyles(isDark) + delegate.SetSpacing(0) + + maxW := 54 + if width-8 > maxW { + maxW = width - 8 + } + if maxW > 64 { + maxW = 64 + } + + // Height: items fit, max 20 lines + pickerH := len(sessions)*delegate.Height() + 4 // +4 for title + filter + if pickerH > 20 { + pickerH = 20 + } + + l := list.New(items, delegate, maxW-4, pickerH) + l.Title = "Sessions" + l.SetShowStatusBar(false) + l.SetShowHelp(false) + l.SetShowPagination(true) + l.SetFilteringEnabled(true) + l.DisableQuitKeybindings() + + return &SessionsPickerState{ + List: l, + Sessions: sessions, + } +} + +// renderSessionsPicker renders the sessions picker overlay. +func (m *Model) renderSessionsPicker() string { + ps := m.sessionsPickerState + if ps == nil { + return "" + } + + maxW := 54 + if m.width-8 > maxW { + maxW = m.width - 8 + } + if maxW > 64 { + maxW = 64 + } + + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(m.styles.FocusIndicator.GetForeground()). + Padding(0, 1). + Width(maxW) + + return box.Render(ps.List.View()) +} + +// closeSessionsPicker dismisses the sessions picker overlay. +func (m *Model) closeSessionsPicker() { + m.sessionsPickerState = nil + m.overlay = OverlayNone + m.input.Focus() +} diff --git a/internal/tui/sidepanel.go b/internal/tui/sidepanel.go new file mode 100644 index 0000000..77655ca --- /dev/null +++ b/internal/tui/sidepanel.go @@ -0,0 +1,402 @@ +package tui + +import ( + "fmt" + "strings" + + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" +) + +func sanitizeDetail(detail string) string { + detail = strings.TrimSpace(detail) + if len(detail) > 0 && (detail[0] == '{' || detail[0] == '[') { + return "error details available in logs" + } + detail = strings.ReplaceAll(detail, "\n", " ") + detail = strings.ReplaceAll(detail, "\r", " ") + for strings.Contains(detail, " ") { + detail = strings.ReplaceAll(detail, " ", " ") + } + return strings.TrimSpace(detail) +} + +type SidePanelSectionKind int + +const ( + SidePanelLogo SidePanelSectionKind = iota + SidePanelModels + SidePanelServers + SidePanelICE + SidePanelQuickActions + SidePanelStartup +) + +type SidePanelItem struct { + Title string + Subtitle string + Kind SidePanelSectionKind + Icon string + Status string + ID string + Selectable bool + IsCurrent bool +} + +func (i SidePanelItem) TitleText() string { + prefix := "" + if i.Icon != "" { + prefix = i.Icon + " " + } + if i.IsCurrent { + prefix = "→ " + } + return prefix + i.Title +} + +func (i SidePanelItem) Description() string { + return i.Subtitle +} + +func (i SidePanelItem) FilterValue() string { + return i.Title +} + +type SidePanelSection struct { + Title string + Kind SidePanelSectionKind + Items []SidePanelItem + Expanded bool +} + +type SidePanelModel struct { + sections []SidePanelSection + startupItems []StartupItem + width int + height int + cursor int + selected int + styles SidePanelStyles + spinner spinner.Model + visible bool + isDark bool +} + +type StartupItem struct { + Label string + Status string + Detail string +} + +type SidePanelStyles struct { + Border lipgloss.Style + Title lipgloss.Style + Section lipgloss.Style + Item lipgloss.Style + Selected lipgloss.Style + Current lipgloss.Style + Connected lipgloss.Style + Failed lipgloss.Style + Dimmed lipgloss.Style + Logo lipgloss.Style + LogoTagline lipgloss.Style +} + +func DefaultSidePanelStyles(isDark bool) SidePanelStyles { + return SidePanelStyles{ + Border: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Title: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")), + Section: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#81a1c1")), + Item: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")), + Current: lipgloss.NewStyle().Foreground(lipgloss.Color("#a3be8c")), + Connected: lipgloss.NewStyle().Foreground(lipgloss.Color("#a3be8c")), + Failed: lipgloss.NewStyle().Foreground(lipgloss.Color("#bf616a")), + Dimmed: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Logo: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")), + LogoTagline: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + } +} + +func NewSidePanelModel(isDark bool) SidePanelModel { + s := spinner.New( + spinner.WithSpinner(spinner.MiniDot), + spinner.WithStyle(lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0"))), + ) + return SidePanelModel{ + visible: true, + isDark: isDark, + styles: DefaultSidePanelStyles(isDark), + cursor: 0, + selected: 0, + spinner: s, + } +} + +func (m *SidePanelModel) SetDark(isDark bool) { + m.isDark = isDark + m.styles = DefaultSidePanelStyles(isDark) +} + +func (m *SidePanelModel) SetSpinnerTick() { + // No-op - spinner advances via Update with TickMsg +} + +func (m *SidePanelModel) TickSpinner() tea.Cmd { + return m.spinner.Tick +} + +func (m *SidePanelModel) Tick() { + m.spinner.Tick() +} + +func (m *SidePanelModel) SetWidth(w int) { + m.width = w +} + +func (m *SidePanelModel) SetHeight(h int) { + m.height = h +} + +func (m *SidePanelModel) SetStartupItems(items []StartupItem) { + m.startupItems = items +} + +func (m *SidePanelModel) Toggle() { + m.visible = !m.visible +} + +func (m *SidePanelModel) Show() { + m.visible = true +} + +func (m *SidePanelModel) Hide() { + m.visible = false +} + +func (m *SidePanelModel) IsVisible() bool { + return m.visible +} + +func (m *SidePanelModel) UpdateSections(lang Lang, modelName string, modelList []string, serverCount int, toolCount int, iceEnabled bool, iceConversations int) { + loc := Locale(lang) + m.sections = []SidePanelSection{ + { + Title: loc.SidePanelAIAgent, + Kind: SidePanelLogo, + Expanded: true, + Items: []SidePanelItem{ + { + Title: loc.SidePanelAIAgent, + Subtitle: loc.SidePanelTagline, + Kind: SidePanelLogo, + Icon: "⬡", + }, + }, + }, + { + Title: loc.SidePanelModels, + Kind: SidePanelModels, + Expanded: true, + Items: []SidePanelItem{}, + }, + { + Title: loc.SidePanelServers, + Kind: SidePanelServers, + Expanded: true, + Items: []SidePanelItem{}, + }, + { + Title: loc.SidePanelICE, + Kind: SidePanelICE, + Expanded: true, + Items: []SidePanelItem{}, + }, + { + Title: loc.SidePanelQuickActions, + Kind: SidePanelQuickActions, + Expanded: true, + Items: []SidePanelItem{ + {Title: loc.SidePanelHelp, Subtitle: loc.SidePanelHelpDesc, Kind: SidePanelQuickActions, Icon: "?", Selectable: true, ID: "help"}, + {Title: loc.SidePanelServers, Subtitle: loc.SidePanelServersDesc, Kind: SidePanelQuickActions, Icon: "◈", Selectable: true, ID: "servers"}, + {Title: loc.SidePanelModels, Subtitle: loc.SidePanelModelDesc, Kind: SidePanelQuickActions, Icon: "◈", Selectable: true, ID: "model"}, + {Title: loc.SidePanelLoad, Subtitle: loc.SidePanelLoadDesc, Kind: SidePanelQuickActions, Icon: "◈", Selectable: true, ID: "load"}, + {Title: loc.Language, Subtitle: loc.LanguageF2, Kind: SidePanelQuickActions, Icon: "◈", Selectable: true, ID: "language"}, + }, + }, + } + for _, model := range modelList { + item := SidePanelItem{ + Title: model, + Kind: SidePanelModels, + Icon: "◦", + Selectable: true, + ID: model, + IsCurrent: model == modelName, + } + if model == modelName { + item.Icon = "→" + } + m.sections[1].Items = append(m.sections[1].Items, item) + } + if serverCount > 0 { + m.sections[2].Items = append(m.sections[2].Items, SidePanelItem{ + Title: fmt.Sprintf(loc.ToolsConnected, toolCount), + Kind: SidePanelServers, + Icon: "✓", + Selectable: false, + }) + } else { + m.sections[2].Items = append(m.sections[2].Items, SidePanelItem{ + Title: loc.NoServersConnected, + Kind: SidePanelServers, + Icon: "○", + Selectable: false, + }) + } + if iceEnabled { + m.sections[3].Items = append(m.sections[3].Items, SidePanelItem{ + Title: fmt.Sprintf(loc.ICEConversations, iceConversations), + Subtitle: loc.ICECrossSessionActive, + Kind: SidePanelICE, + Icon: "✓", + Selectable: false, + }) + } else { + m.sections[3].Items = append(m.sections[3].Items, SidePanelItem{ + Title: loc.ICEDisabled, + Subtitle: loc.ICECrossSessionInactive, + Kind: SidePanelICE, + Icon: "○", + Selectable: false, + }) + } +} +func (m *SidePanelModel) ToggleSection(index int) { + if index >= 0 && index < len(m.sections) { + m.sections[index].Expanded = !m.sections[index].Expanded + } +} + +func (m SidePanelModel) Init() tea.Cmd { + return nil +} + +func (m SidePanelModel) Update(msg tea.Msg) (SidePanelModel, tea.Cmd) { + return m, nil +} + +func (m SidePanelModel) View() string { + if !m.visible { + return "" + } + width := m.width + if width < 25 { + width = 25 + } + var b strings.Builder + b.WriteString("\n") + b.WriteString(m.styles.Logo.Render(" AI AGENT")) + b.WriteString("\n") + b.WriteString(m.styles.LogoTagline.Render(" 100% local")) + b.WriteString("\n\n") + if len(m.startupItems) > 0 { + var hasPending bool + for _, item := range m.startupItems { + if item.Status == "connecting" || item.Status == "pending" { + hasPending = true + break + } + } + if hasPending { + b.WriteString(m.styles.Section.Render(" " + m.spinner.View() + " Connecting...")) + b.WriteString("\n\n") + } else { + b.WriteString(m.styles.Section.Render(" Initializing...")) + b.WriteString("\n\n") + } + for _, item := range m.startupItems { + icon := "○" + iconStyle := m.styles.Item + switch item.Status { + case "connecting": + icon = "◌" + iconStyle = m.styles.Section + case "connected": + icon = "✓" + iconStyle = m.styles.Connected + case "failed": + icon = "✗" + iconStyle = m.styles.Failed + } + line := fmt.Sprintf(" %s %s", icon, item.Label) + if item.Detail != "" { + detail := sanitizeDetail(item.Detail) + maxDetail := m.width - 15 + if len(detail) > maxDetail && maxDetail > 5 { + detail = detail[:maxDetail-3] + "..." + } + line += m.styles.Dimmed.Render(" · " + detail) + } + b.WriteString(iconStyle.Render(line)) + b.WriteString("\n") + } + b.WriteString("\n") + } + for sectionIdx := 1; sectionIdx < len(m.sections); sectionIdx++ { + section := m.sections[sectionIdx] + icon := "▶" + if section.Expanded { + icon = "▼" + } + header := fmt.Sprintf(" %s %s", icon, section.Title) + b.WriteString(m.styles.Section.Render(header)) + b.WriteString("\n") + if section.Expanded { + for itemIdx, item := range section.Items { + prefix := " " + if item.Icon != "" { + prefix = fmt.Sprintf(" %s ", item.Icon) + } + itemStyle := m.styles.Item + if item.IsCurrent { + itemStyle = m.styles.Current + } + line := prefix + item.Title + if item.Subtitle != "" && section.Kind != SidePanelLogo { + subtitle := item.Subtitle + maxSub := m.width - len(prefix) - len(item.Title) - 3 + if len(subtitle) > maxSub && maxSub > 5 { + subtitle = subtitle[:maxSub-3] + "..." + } + line += m.styles.Dimmed.Render(" · " + subtitle) + } + if section.Kind == SidePanelLogo && itemIdx == 0 { + b.WriteString(m.styles.LogoTagline.Render(" " + item.Subtitle)) + } else { + b.WriteString(itemStyle.Render(line)) + } + b.WriteString("\n") + } + } + b.WriteString("\n") + } + b.WriteString("\n") + b.WriteString(m.styles.Dimmed.Render(" ────────────────────────")) + b.WriteString("\n") + b.WriteString(m.styles.Dimmed.Render(" ctrl+b: toggle")) + return b.String() +} + +func (s SidePanelSection) TitleText() string { + return s.Title +} + +func (s SidePanelSection) Description() string { + return "" +} + +func (s SidePanelSection) FilterValue() string { + return s.Title +} diff --git a/internal/tui/styles.go b/internal/tui/styles.go new file mode 100644 index 0000000..1fc9914 --- /dev/null +++ b/internal/tui/styles.go @@ -0,0 +1,488 @@ +package tui + +import ( + "os" + "strings" + + "charm.land/lipgloss/v2" +) + +// noColor detects NO_COLOR environment variable. +var noColor = os.Getenv("NO_COLOR") != "" + +// Nord Color Palette (https://www.nordtheme.com/) +// Nord Dark (Polar Night + Frost) +var ( + // Polar Night (dark theme background/text) + nord0 = "#2E3440" // base background + nord1 = "#3B4252" // lighter background + nord2 = "#434C5E" // selection/background elements + nord3 = "#4C566A" // comments/borders + + // Frost (dark theme foreground/text) + nord4 = "#D8DEE9" // primary text + nord5 = "#E5E9F0" // secondary text + nord6 = "#ECEFF4" // emphasized text + + // Aurora (dark theme accents) + nord7 = "#BF616A" // red (errors/warnings) + nord8 = "#D08770" // orange (warnings) + nord9 = "#EBCB8B" // yellow (warnings/highlights) + nord10 = "#A3BE8C" // green (success) + nord11 = "#B48EAD" // purple (special) + nord12 = "#88C0D0" // cyan (primary accent) + nord13 = "#81A1C1" // blue (secondary accent) + nord14 = "#5E81AC" // dark blue (links/details) +) + +// Nord Light (Aurora variant for light theme) +var ( + // Light background + nordLight0 = "#FFFFFF" // base background + nordLight1 = "#ECEFF4" // lighter background + nordLight2 = "#E5E9F0" // selection + nordLight3 = "#D8DEE9" // borders + + // Light text + nordLight4 = "#4C566A" // primary text + nordLight5 = "#3B4252" // secondary text + nordLight6 = "#2E3440" // emphasized text + + // Aurora accents (same as dark, work well on light) + nordLight7 = "#BF616A" // red + nordLight8 = "#D08770" // orange + nordLight9 = "#EBCB8B" // yellow + nordLight10 = "#A3BE8C" // green + nordLight11 = "#B48EAD" // purple + nordLight12 = "#88C0D0" // cyan + nordLight13 = "#81A1C1" // blue + nordLight14 = "#5E81AC" // dark blue +) + +// Styles holds all pre-built lipgloss styles. +type Styles struct { + // Header + HeaderTitle lipgloss.Style + HeaderInfo lipgloss.Style + HeaderRule lipgloss.Style + + // Messages + UserLabel lipgloss.Style + UserContent lipgloss.Style + AsstLabel lipgloss.Style + AsstContent lipgloss.Style + RoleRule lipgloss.Style + StreamCursor lipgloss.Style + + // Tools + ToolCallIcon lipgloss.Style + ToolCallText lipgloss.Style + ToolResultIcon lipgloss.Style + ToolResultText lipgloss.Style + ToolErrorIcon lipgloss.Style + ToolErrorText lipgloss.Style + ToolDoneIcon lipgloss.Style + ToolDoneText lipgloss.Style + ToolRunningText lipgloss.Style + ToolDetailText lipgloss.Style + + // Footer + Divider lipgloss.Style + StatusDot lipgloss.Style + StatusText lipgloss.Style + StatusCheck lipgloss.Style + StatusError lipgloss.Style + ApprovalPrompt lipgloss.Style + StreamHint lipgloss.Style + ErrorText lipgloss.Style + Dimmed lipgloss.Style + + // System messages + SystemText lipgloss.Style + WelcomeHint lipgloss.Style + + // Completion popup + CompletionBorder lipgloss.Style + CompletionSelected lipgloss.Style + + // Completion modal + CompletionFilter lipgloss.Style + CompletionCursor lipgloss.Style + CompletionCategory lipgloss.Style + CompletionFooter lipgloss.Style + CompletionSearching lipgloss.Style + + // Startup progress + StartupCheck lipgloss.Style + StartupFail lipgloss.Style + StartupLabel lipgloss.Style + StartupDetail lipgloss.Style + StartupSpin lipgloss.Style + + // Mode badges + ModeAsk lipgloss.Style + ModePlan lipgloss.Style + ModeBuild lipgloss.Style + + // Context percentage fuel gauge + ContextPctLow lipgloss.Style + ContextPctMid lipgloss.Style + ContextPctHigh lipgloss.Style + + // Tool type rendering + ToolBashCmd lipgloss.Style + + // Diff view + DiffAdded lipgloss.Style + DiffRemoved lipgloss.Style + DiffContext lipgloss.Style + DiffHeader lipgloss.Style + + // Thinking display + ThinkingHeader lipgloss.Style + ThinkingContent lipgloss.Style + ThinkingBorder lipgloss.Style + + // Shared overlay styles (used by help, model picker, sessions, plan form, completion) + OverlayTitle lipgloss.Style + OverlayBorder string + OverlayAccent lipgloss.Style + OverlayDim lipgloss.Style + + // Focus indicators + FocusIndicator lipgloss.Style +} + +// NewStyles creates a Styles set based on the background color. +func NewStyles(isDark bool) Styles { + if noColor { + return plainStyles() + } + return adaptiveStyles(isDark) +} + +func adaptiveStyles(isDark bool) Styles { + // Select Nord palette based on theme + var ( + colorDim string + colorMuted string + colorText string + colorAccent string + colorAccent2 string + colorError string + colorSuccess string + colorSpecial string + colorBorder string + ) + + if isDark { + // Nord Dark Theme (Polar Night + Frost + Aurora) + colorDim = nord3 // #4C566A - comments/borders + colorMuted = nord4 // #D8DEE9 - primary text (muted) + colorText = nord5 // #E5E9F0 - secondary text + colorAccent = nord12 // #88C0D0 - cyan (primary accent) + colorAccent2 = nord13 // #81A1C1 - blue (secondary accent) + colorError = nord7 // #BF616A - red + colorSuccess = nord10 // #A3BE8C - green + colorSpecial = nord11 // #B48EAD - purple + colorBorder = nord3 + } else { + // Nord Light Theme (Aurora) + colorDim = nordLight3 // #D8DEE9 - borders + colorMuted = nordLight4 // #4C566A - primary text (muted) + colorText = nordLight5 // #3B4252 - secondary text + colorAccent = nordLight12 // #88C0D0 - cyan + colorAccent2 = nordLight13 // #81A1C1 - blue + colorError = nordLight7 // #BF616A - red + colorSuccess = nordLight10 // #A3BE8C - green + colorSpecial = nordLight11 // #B48EAD - purple + colorBorder = nordLight3 + } + + // Helper for theme-specific colors + nordColor := func(dark, light string) string { + if isDark { + return dark + } + return light + } + + return Styles{ + HeaderTitle: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(colorAccent)). + PaddingLeft(1), + HeaderInfo: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + PaddingRight(1), + HeaderRule: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + + UserLabel: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(colorAccent2)). + PaddingLeft(2), + UserContent: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorText)). + PaddingLeft(2), + AsstLabel: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(colorSuccess)). + PaddingLeft(2), + AsstContent: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorText)). + PaddingLeft(4), + RoleRule: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + StreamCursor: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + Bold(true), + + ToolCallIcon: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSpecial)). + PaddingLeft(4), + ToolCallText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSpecial)), + ToolResultIcon: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + PaddingLeft(4), + ToolResultText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + ToolErrorIcon: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)). + PaddingLeft(4), + ToolErrorText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)), + ToolDoneIcon: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSuccess)). + PaddingLeft(4), + ToolDoneText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + ToolRunningText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)), + ToolDetailText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorMuted)), + + Divider: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + StatusDot: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + PaddingLeft(1), + StatusText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + StatusCheck: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSuccess)). + PaddingLeft(1), + StatusError: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)). + PaddingLeft(1), + ApprovalPrompt: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + Bold(true), + StreamHint: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + Italic(true), + ErrorText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)). + Bold(true). + PaddingLeft(2), + Dimmed: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + + SystemText: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorText)). + Italic(true). + PaddingLeft(2), + WelcomeHint: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent2)). + Bold(true), + + CompletionBorder: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + CompletionSelected: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + Bold(true), + + CompletionFilter: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorText)), + CompletionCursor: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + Bold(true), + CompletionCategory: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + CompletionFooter: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + Italic(true), + CompletionSearching: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSpecial)). + Italic(true), + + StartupCheck: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSuccess)), + StartupFail: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)), + StartupLabel: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorText)), + StartupDetail: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + StartupSpin: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)), + + ModeAsk: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(colorAccent2)), + ModePlan: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(nordColor(nord9, nordLight9))), // yellow + ModeBuild: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(colorSuccess)), + + ContextPctLow: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSuccess)), + ContextPctMid: lipgloss.NewStyle(). + Foreground(lipgloss.Color(nordColor(nord9, nordLight9))), + ContextPctHigh: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)), + + ToolBashCmd: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + Italic(true), + + DiffAdded: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSuccess)). + PaddingLeft(6), + DiffRemoved: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorError)). + PaddingLeft(6), + DiffContext: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + PaddingLeft(6), + DiffHeader: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + PaddingLeft(6), + + ThinkingHeader: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorSpecial)). + Italic(true), + ThinkingContent: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)). + PaddingLeft(4), + ThinkingBorder: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + + OverlayTitle: lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(colorAccent)), + OverlayBorder: colorBorder, + OverlayAccent: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent2)). + Bold(true), + OverlayDim: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorDim)), + + FocusIndicator: lipgloss.NewStyle(). + Foreground(lipgloss.Color(colorAccent)). + Bold(true), + } +} + +func plainStyles() Styles { + p := lipgloss.NewStyle() + b := lipgloss.NewStyle().Bold(true) + pl2 := lipgloss.NewStyle().PaddingLeft(2) + pl4 := lipgloss.NewStyle().PaddingLeft(4) + return Styles{ + HeaderTitle: b.PaddingLeft(1), + HeaderInfo: p.PaddingRight(1), + HeaderRule: p, + + UserLabel: b.PaddingLeft(2), + UserContent: pl2, + AsstLabel: b.PaddingLeft(2), + AsstContent: pl2, + RoleRule: p, + StreamCursor: b, + + ToolCallIcon: pl4, + ToolCallText: p, + ToolResultIcon: pl4, + ToolResultText: p, + ToolErrorIcon: pl4, + ToolErrorText: b, + ToolDoneIcon: pl4, + ToolDoneText: p, + ToolRunningText: p, + ToolDetailText: p, + + Divider: p, + StatusDot: p.PaddingLeft(1), + StatusText: p, + StatusCheck: p.PaddingLeft(1), + StatusError: p.PaddingLeft(1), + ApprovalPrompt: b, + StreamHint: p.Italic(true), + ErrorText: b.PaddingLeft(2), + Dimmed: p, + + SystemText: p.PaddingLeft(2).Italic(true), + WelcomeHint: b, + + CompletionBorder: p, + CompletionSelected: b, + + CompletionFilter: p, + CompletionCursor: b, + CompletionCategory: p, + CompletionFooter: p.Italic(true), + CompletionSearching: p.Italic(true), + + StartupCheck: p, + StartupFail: b, + StartupLabel: p, + StartupDetail: p, + StartupSpin: p, + + ModeAsk: b, + ModePlan: b, + ModeBuild: b, + + ContextPctLow: p, + ContextPctMid: p, + ContextPctHigh: p, + + ToolBashCmd: p.Italic(true), + + DiffAdded: pl4, + DiffRemoved: pl4, + DiffContext: pl4, + DiffHeader: pl4, + + ThinkingHeader: p.Italic(true), + ThinkingContent: pl4, + ThinkingBorder: p, + + OverlayTitle: b, + OverlayBorder: "", + OverlayAccent: b, + OverlayDim: p, + + FocusIndicator: b, + } +} + +// rule generates a horizontal line of the given width using a thin character. +func rule(width int) string { + if width < 1 { + return "" + } + return strings.Repeat("─", width) +} + +// thickRule generates a horizontal line using a thick character. +func thickRule(width int) string { + if width < 1 { + return "" + } + return strings.Repeat("━", width) +} diff --git a/internal/tui/table.go b/internal/tui/table.go new file mode 100644 index 0000000..96a5958 --- /dev/null +++ b/internal/tui/table.go @@ -0,0 +1,164 @@ +package tui + +import ( + "strings" + + "charm.land/bubbles/v2/table" + "charm.land/lipgloss/v2" +) + +// TableHelper provides utilities for rendering structured data as tables. +type TableHelper struct { + isDark bool + styles TableStyles +} + +// TableStyles holds styling for tables. +type TableStyles struct { + Header lipgloss.Style + Row lipgloss.Style + RowAlt lipgloss.Style + Selected lipgloss.Style + Border lipgloss.Style + Focused lipgloss.Style +} + +// DefaultTableStyles returns default styles. +func DefaultTableStyles(isDark bool) TableStyles { + if isDark { + return TableStyles{ + Header: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")), + Row: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + RowAlt: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#88c0d0")), + Border: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Focused: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#81a1c1")), + } + } + return TableStyles{ + Header: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#4f8f8f")), + Row: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + RowAlt: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Selected: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#4f8f8f")), + Border: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + Focused: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#5e81ac")), + } +} + +// NewTableHelper creates a new table helper. +func NewTableHelper(isDark bool) *TableHelper { + return &TableHelper{ + isDark: isDark, + styles: DefaultTableStyles(isDark), + } +} + +// SetDark updates theme. +func (th *TableHelper) SetDark(isDark bool) { + th.isDark = isDark + th.styles = DefaultTableStyles(isDark) +} + +// ParseMarkdownTable attempts to extract a table from markdown text. +// Returns nil if no valid table found. +func (th *TableHelper) ParseMarkdownTable(text string) [][]string { + lines := strings.Split(text, "\n") + var rows [][]string + inTable := false + + for _, line := range lines { + line = strings.TrimSpace(line) + // Check for table start (contains |) + if strings.Contains(line, "|") { + // Skip separator line (contains only -, |, :) + if strings.Contains(line, "---") { + inTable = true + continue + } + // Parse row + row := th.parseTableRow(line) + if len(row) > 0 { + rows = append(rows, row) + } + } else if inTable && len(rows) > 0 { + // End of table + break + } + } + + if len(rows) < 2 { + return nil // Need at least header + 1 row + } + return rows +} + +// parseTableRow parses a single table row. +func (th *TableHelper) parseTableRow(line string) []string { + // Remove leading/trailing | + line = strings.Trim(line, "|") + parts := strings.Split(line, "|") + + var row []string + for _, part := range parts { + cell := strings.TrimSpace(part) + if cell != "" || len(row) > 0 { + row = append(row, cell) + } + } + return row +} + +// RenderTable creates a Bubble Tea table from parsed data. +func (th *TableHelper) RenderTable(rows [][]string, width int) string { + if len(rows) < 2 { + return "" + } + + headers := rows[0] + cols := make([]table.Column, len(headers)) + for i, h := range headers { + w := len(h) + // Calculate max width for this column + for _, row := range rows[1:] { + if i < len(row) && len(row[i]) > w { + w = len(row[i]) + } + } + // Distribute remaining width + if w < 10 { + w = 10 + } + cols[i] = table.Column{Width: w} + } + + t := table.New( + table.WithColumns(cols), + table.WithRows(parseRows(rows[1:])), + table.WithFocused(true), + table.WithHeight(len(rows)-1), + ) + + return t.View() +} + +// parseRows converts string rows to table.Row type. +func parseRows(rows [][]string) []table.Row { + result := make([]table.Row, len(rows)) + for i, row := range rows { + result[i] = table.Row(row) + } + return result +} + +// DetectJSONArray attempts to parse and render JSON arrays as tables. +func (th *TableHelper) DetectJSONArray(text string) (string, bool) { + // Simple JSON array detection - looks for [ at start and ] at end + text = strings.TrimSpace(text) + if !strings.HasPrefix(text, "[") || !strings.HasSuffix(text, "]") { + return "", false + } + + // For now, return empty - full JSON parsing would require the json package + // This is a placeholder for future enhancement + return "", false +} diff --git a/internal/tui/thinking.go b/internal/tui/thinking.go new file mode 100644 index 0000000..654951f --- /dev/null +++ b/internal/tui/thinking.go @@ -0,0 +1,124 @@ +package tui + +import ( + "fmt" + "strings" + + "charm.land/lipgloss/v2" +) + +// processStreamChunk processes a streaming chunk, extracting ... tags. +// It handles tag boundaries that may be split across chunks. +func processStreamChunk(chunk string, inThinking bool, searchBuf string) (mainText, thinkText string, outInThinking bool, outSearchBuf string) { + combined := searchBuf + chunk + outInThinking = inThinking + + var mainBuf, thinkBuf strings.Builder + + for len(combined) > 0 { + if outInThinking { + idx := strings.Index(combined, "") + if idx >= 0 { + thinkBuf.WriteString(combined[:idx]) + combined = combined[idx+len(""):] + outInThinking = false + continue + } + partial := hasPartialTagSuffix(combined, "") + if partial > 0 { + thinkBuf.WriteString(combined[:len(combined)-partial]) + outSearchBuf = combined[len(combined)-partial:] + return mainBuf.String(), thinkBuf.String(), outInThinking, outSearchBuf + } + thinkBuf.WriteString(combined) + combined = "" + } else { + idx := strings.Index(combined, "") + if idx >= 0 { + mainBuf.WriteString(combined[:idx]) + combined = combined[idx+len(""):] + outInThinking = true + continue + } + partial := hasPartialTagSuffix(combined, "") + if partial > 0 { + mainBuf.WriteString(combined[:len(combined)-partial]) + outSearchBuf = combined[len(combined)-partial:] + return mainBuf.String(), thinkBuf.String(), outInThinking, outSearchBuf + } + mainBuf.WriteString(combined) + combined = "" + } + } + + return mainBuf.String(), thinkBuf.String(), outInThinking, outSearchBuf +} + +// hasPartialTagSuffix returns the length of the longest suffix of s +// that is a proper prefix of tag (not the full tag). +func hasPartialTagSuffix(s, tag string) int { + maxCheck := len(tag) - 1 + if maxCheck > len(s) { + maxCheck = len(s) + } + for i := maxCheck; i > 0; i-- { + if strings.HasSuffix(s, tag[:i]) { + return i + } + } + return 0 +} + +// renderThinkingBox renders a collapsible thinking content box. +func (m *Model) renderThinkingBox(content string, collapsed bool) string { + if content == "" { + return "" + } + + lines := strings.Split(strings.TrimRight(content, "\n"), "\n") + + var b strings.Builder + + if collapsed { + hidden := len(lines) - 3 + if hidden < 0 { + hidden = 0 + } + header := fmt.Sprintf("▸ thinking (%d lines)", len(lines)) + if hidden > 0 { + header += fmt.Sprintf(" — %d hidden, ctrl+t to expand", hidden) + } + b.WriteString(m.styles.ThinkingHeader.Render(header)) + b.WriteString("\n") + + start := len(lines) - 3 + if start < 0 { + start = 0 + } + for _, line := range lines[start:] { + b.WriteString(m.styles.ThinkingContent.Render(line)) + b.WriteString("\n") + } + } else { + header := fmt.Sprintf("▾ thinking (%d lines) — ctrl+t to collapse", len(lines)) + b.WriteString(m.styles.ThinkingHeader.Render(header)) + b.WriteString("\n") + for _, line := range lines { + b.WriteString(m.styles.ThinkingContent.Render(line)) + b.WriteString("\n") + } + } + + boxWidth := m.width - 8 + if boxWidth < 20 { + boxWidth = 20 + } + + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(m.styles.OverlayBorder)). + Padding(0, 2). + Width(boxWidth) + + return box.Render(strings.TrimRight(b.String(), "\n")) +} diff --git a/internal/tui/thinking_test.go b/internal/tui/thinking_test.go new file mode 100644 index 0000000..e424c0a --- /dev/null +++ b/internal/tui/thinking_test.go @@ -0,0 +1,125 @@ +package tui + +import "testing" + +func TestProcessStreamChunk_PlainText(t *testing.T) { + main, think, inThinking, buf := processStreamChunk("hello world", false, "") + if main != "hello world" { + t.Errorf("main text = %q, want %q", main, "hello world") + } + if think != "" { + t.Errorf("think text = %q, want empty", think) + } + if inThinking { + t.Error("should not be in thinking mode") + } + if buf != "" { + t.Errorf("search buf = %q, want empty", buf) + } +} + +func TestProcessStreamChunk_OpenTag(t *testing.T) { + main, think, inThinking, _ := processStreamChunk("reasoning here", false, "") + if main != "" { + t.Errorf("main text = %q, want empty", main) + } + if think != "reasoning here" { + t.Errorf("think text = %q, want %q", think, "reasoning here") + } + if !inThinking { + t.Error("should be in thinking mode") + } +} + +func TestProcessStreamChunk_CloseTag(t *testing.T) { + main, think, inThinking, _ := processStreamChunk("end of thoughtvisible text", true, "") + if think != "end of thought" { + t.Errorf("think text = %q, want %q", think, "end of thought") + } + if main != "visible text" { + t.Errorf("main text = %q, want %q", main, "visible text") + } + if inThinking { + t.Error("should not be in thinking mode after close tag") + } +} + +func TestProcessStreamChunk_FullCycle(t *testing.T) { + main, think, inThinking, _ := processStreamChunk("thoughtresponse", false, "") + if think != "thought" { + t.Errorf("think text = %q, want %q", think, "thought") + } + if main != "response" { + t.Errorf("main text = %q, want %q", main, "response") + } + if inThinking { + t.Error("should not be in thinking mode") + } +} + +func TestProcessStreamChunk_SplitAcrossChunks(t *testing.T) { + // First chunk ends with partial tag "reasoning", false, buf1) + if main2 != "" { + t.Errorf("chunk2 main = %q, want empty", main2) + } + if think2 != "reasoning" { + t.Errorf("chunk2 think = %q, want %q", think2, "reasoning") + } + if !inThinking2 { + t.Error("chunk2 should be in thinking mode") + } +} + +func TestProcessStreamChunk_NestedTags(t *testing.T) { + // Nested should be treated as text inside thinking. + main, think, _, _ := processStreamChunk("outerinnerafter", false, "") + // The inner should be literal text inside thinking. + // When we encounter the first , thinking ends. + if main != "after" { + t.Errorf("main = %q, want %q", main, "after") + } + // The think content should include "outerinner" + if think != "outerinner" { + t.Errorf("think = %q, want %q", think, "outerinner") + } +} + +func TestHasPartialTagSuffix(t *testing.T) { + tests := []struct { + s, tag string + want int + }{ + {"hello<", "", 1}, + {"hello", 2}, + {"hello", 3}, + {"hello", 4}, + {"hello", 5}, + {"hello", 6}, + {"hello", "", 0}, // full match, not partial + {"hello", "", 0}, + {"", 5}, + {"<", "", 1}, + } + for _, tt := range tests { + got := hasPartialTagSuffix(tt.s, tt.tag) + if got != tt.want { + t.Errorf("hasPartialTagSuffix(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want) + } + } +} diff --git a/internal/tui/timestamp.go b/internal/tui/timestamp.go new file mode 100644 index 0000000..7fdbd1d --- /dev/null +++ b/internal/tui/timestamp.go @@ -0,0 +1,171 @@ +package tui + +import ( + "time" + + "charm.land/lipgloss/v2" +) + +// TimestampConfig holds configuration for message timestamps. +type TimestampConfig struct { + Enabled bool + Format string // "time", "relative", "both" + Position string // "left", "right" + MaxAge time.Duration // For relative timestamps +} + +// DefaultTimestampConfig returns default configuration. +func DefaultTimestampConfig() TimestampConfig { + return TimestampConfig{ + Enabled: false, + Format: "time", + Position: "left", + MaxAge: 24 * time.Hour, + } +} + +// TimestampStyles holds styling for timestamps. +type TimestampStyles struct { + Time lipgloss.Style + Relative lipgloss.Style + Divider lipgloss.Style +} + +// DefaultTimestampStyles returns default styles. +func DefaultTimestampStyles(isDark bool) TimestampStyles { + if isDark { + return TimestampStyles{ + Time: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Relative: lipgloss.NewStyle().Foreground(lipgloss.Color("#5e81ac")), + Divider: lipgloss.NewStyle().Foreground(lipgloss.Color("#3b4252")), + } + } + return TimestampStyles{ + Time: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + Relative: lipgloss.NewStyle().Foreground(lipgloss.Color("#5e81ac")), + Divider: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + } +} + +// TimestampHelper provides utilities for rendering timestamps. +type TimestampHelper struct { + config TimestampConfig + styles TimestampStyles + nowFunc func() time.Time +} + +// NewTimestampHelper creates a new timestamp helper. +func NewTimestampHelper(config TimestampConfig, isDark bool) *TimestampHelper { + return &TimestampHelper{ + config: config, + styles: DefaultTimestampStyles(isDark), + nowFunc: time.Now, + } +} + +// SetDark updates theme. +func (th *TimestampHelper) SetDark(isDark bool) { + th.styles = DefaultTimestampStyles(isDark) +} + +// SetConfig updates the timestamp configuration. +func (th *TimestampHelper) SetConfig(config TimestampConfig) { + th.config = config +} + +// FormatTime formats a timestamp based on config. +func (th *TimestampHelper) FormatTime(t time.Time) string { + if !th.config.Enabled { + return "" + } + + switch th.config.Format { + case "time": + return t.Format("15:04") + case "relative": + return th.relativeTime(t) + case "both": + return t.Format("15:04") + " " + th.relativeTime(t) + default: + return t.Format("15:04") + } +} + +// relativeTime returns a human-readable relative time string. +func (th *TimestampHelper) relativeTime(t time.Time) string { + now := th.nowFunc() + diff := now.Sub(t) + + if diff < time.Minute { + return "just now" + } + if diff < time.Hour { + mins := int(diff.Minutes()) + if mins == 1 { + return "1m ago" + } + return formatInt(mins) + "m ago" + } + if diff < 24*time.Hour { + hours := int(diff.Hours()) + if hours == 1 { + return "1h ago" + } + return formatInt(hours) + "h ago" + } + if diff < 7*24*time.Hour { + days := int(diff.Hours() / 24) + if days == 1 { + return "1d ago" + } + return formatInt(days) + "d ago" + } + + // Older dates - show date + return t.Format("Jan 2") +} + +// formatInt formats an integer without allocation. +func formatInt(n int) string { + if n < 10 { + return string(rune('0' + n)) + } + // Simple implementation for common cases + switch n { + case 10: + return "10" + case 11: + return "11" + case 12: + return "12" + case 13: + return "13" + case 14: + return "14" + case 15: + return "15" + case 16: + return "16" + case 17: + return "17" + case 18: + return "18" + case 19: + return "19" + case 20: + return "20" + default: + // Fallback for larger numbers + if n < 100 { + tens := n / 10 + ones := n % 10 + return string(rune('0'+tens)) + string(rune('0'+ones)) + } + return string(rune('0' + n/100)) + } +} + +// MessageTime stores the timestamp for a chat message. +type MessageTime struct { + Time time.Time +} diff --git a/internal/tui/toast.go b/internal/tui/toast.go new file mode 100644 index 0000000..1769c86 --- /dev/null +++ b/internal/tui/toast.go @@ -0,0 +1,194 @@ +package tui + +import ( + "strings" + "time" + + "charm.land/lipgloss/v2" +) + +// ToastKind represents the type of toast notification. +type ToastKind int + +const ( + ToastKindInfo ToastKind = iota + ToastKindSuccess + ToastKindWarning + ToastKindError +) + +// Toast represents a transient notification message. +type Toast struct { + ID int + Kind ToastKind + Message string + CreatedAt time.Time + Duration time.Duration + ExpiresAt time.Time +} + +// ToastManager manages the lifecycle of toast notifications. +type ToastManager struct { + toasts []Toast + nextID int + styles ToastStyles + maxToasts int +} + +// ToastStyles holds styles for toast rendering. +type ToastStyles struct { + Info lipgloss.Style + Success lipgloss.Style + Warning lipgloss.Style + Error lipgloss.Style + Border lipgloss.Style +} + +// NewToastManager creates a new toast manager. +func NewToastManager() *ToastManager { + return &ToastManager{ + toasts: make([]Toast, 0), + nextID: 1, + maxToasts: 3, + } +} + +// SetStyles applies styles to the manager. +func (tm *ToastManager) SetStyles(styles ToastStyles) { + tm.styles = styles +} + +// Add creates a new toast with the given message and duration. +func (tm *ToastManager) Add(kind ToastKind, message string, duration time.Duration) int { + id := tm.nextID + tm.nextID++ + + toast := Toast{ + ID: id, + Kind: kind, + Message: message, + CreatedAt: time.Now(), + Duration: duration, + ExpiresAt: time.Now().Add(duration), + } + + tm.toasts = append(tm.toasts, toast) + + // Limit number of toasts + if len(tm.toasts) > tm.maxToasts { + tm.toasts = tm.toasts[1:] + } + + return id +} + +// Info adds an info toast. +func (tm *ToastManager) Info(message string) int { + return tm.Add(ToastKindInfo, message, 3*time.Second) +} + +// Success adds a success toast. +func (tm *ToastManager) Success(message string) int { + return tm.Add(ToastKindSuccess, message, 3*time.Second) +} + +// Warning adds a warning toast. +func (tm *ToastManager) Warning(message string) int { + return tm.Add(ToastKindWarning, message, 5*time.Second) +} + +// Error adds an error toast. +func (tm *ToastManager) Error(message string) int { + return tm.Add(ToastKindError, message, 5*time.Second) +} + +// AddToast adds a toast with default duration based on kind. +func (tm *ToastManager) AddToast(toast Toast) int { + duration := 3 * time.Second + if toast.Kind == ToastKindWarning || toast.Kind == ToastKindError { + duration = 5 * time.Second + } + return tm.Add(toast.Kind, toast.Message, duration) +} + +// Update removes expired toasts. +func (tm *ToastManager) Update() { + now := time.Now() + var active []Toast + for _, t := range tm.toasts { + if now.Before(t.ExpiresAt) { + active = append(active, t) + } + } + tm.toasts = active +} + +// HasToasts returns true if there are active toasts. +func (tm *ToastManager) HasToasts() bool { + return len(tm.toasts) > 0 +} + +// Render renders all active toasts as a single string. +func (tm *ToastManager) Render(width int) string { + if len(tm.toasts) == 0 { + return "" + } + + var b strings.Builder + for _, toast := range tm.toasts { + b.WriteString(tm.renderToast(toast, width)) + b.WriteString("\n") + } + + return strings.TrimRight(b.String(), "\n") +} + +// renderToast renders a single toast. +func (tm *ToastManager) renderToast(toast Toast, width int) string { + icon := "○" + style := tm.styles.Info + + switch toast.Kind { + case ToastKindSuccess: + icon = "✓" + style = tm.styles.Success + case ToastKindWarning: + icon = "⚠" + style = tm.styles.Warning + case ToastKindError: + icon = "✗" + style = tm.styles.Error + } + + content := icon + " " + toast.Message + + // Apply style and truncate if needed + maxW := width - 4 + if maxW < 20 { + maxW = 20 + } + + rendered := style.Render(content) + if lipgloss.Width(rendered) > maxW { + rendered = style.Render(truncate(toast.Message, maxW-3)) + } + + return rendered +} + +// DefaultToastStyles returns default styles for toasts based on theme. +func DefaultToastStyles(isDark bool) ToastStyles { + ld := lipgloss.LightDark(isDark) + + colorInfo := ld(lipgloss.Color("#88c0d0"), lipgloss.Color("#5e81ac")) + colorSuccess := ld(lipgloss.Color("#a3be8c"), lipgloss.Color("#8fbc8f")) + colorWarning := ld(lipgloss.Color("#ebcb8b"), lipgloss.Color("#d08770")) + colorError := ld(lipgloss.Color("#bf616a"), lipgloss.Color("#bf616a")) + + return ToastStyles{ + Info: lipgloss.NewStyle().Foreground(colorInfo), + Success: lipgloss.NewStyle().Foreground(colorSuccess), + Warning: lipgloss.NewStyle().Foreground(colorWarning), + Error: lipgloss.NewStyle().Foreground(colorError), + } +} diff --git a/internal/tui/tool_expansion_test.go b/internal/tui/tool_expansion_test.go new file mode 100644 index 0000000..8a34d86 --- /dev/null +++ b/internal/tui/tool_expansion_test.go @@ -0,0 +1,100 @@ +package tui + +import "testing" + +func TestPerEntryCollapse_Default(t *testing.T) { + m := newTestModel(t) + m.toolsCollapsed = true + + // Simulate tool call start — new entry should inherit collapse state. + updated, _ := m.Update(ToolCallStartMsg{ + Name: "read_file", + Args: map[string]any{"path": "test.go"}, + }) + m = updated.(*Model) + + if len(m.toolEntries) != 1 { + t.Fatalf("expected 1 tool entry, got %d", len(m.toolEntries)) + } + if !m.toolEntries[0].Collapsed { + t.Error("new tool entry should inherit toolsCollapsed=true") + } +} + +func TestPerEntryCollapse_InheritsFalse(t *testing.T) { + m := newTestModel(t) + m.toolsCollapsed = false + + updated, _ := m.Update(ToolCallStartMsg{ + Name: "bash", + Args: map[string]any{"command": "ls"}, + }) + m = updated.(*Model) + + if m.toolEntries[0].Collapsed { + t.Error("new tool entry should inherit toolsCollapsed=false") + } +} + +func TestBatchToggleAll(t *testing.T) { + m := newTestModel(t) + m.toolsCollapsed = true + + // Add multiple tool entries. + m.toolEntries = []ToolEntry{ + {Name: "a", Status: ToolStatusDone, Collapsed: true}, + {Name: "b", Status: ToolStatusDone, Collapsed: true}, + {Name: "c", Status: ToolStatusDone, Collapsed: false}, + } + + // Toggle all (t key) should flip toolsCollapsed and apply to all. + m.toolsCollapsed = !m.toolsCollapsed // false now + for i := range m.toolEntries { + m.toolEntries[i].Collapsed = m.toolsCollapsed + } + + for i, te := range m.toolEntries { + if te.Collapsed { + t.Errorf("entry[%d] should be expanded after batch toggle", i) + } + } +} + +func TestToggleLastTool(t *testing.T) { + m := newTestModel(t) + + m.toolEntries = []ToolEntry{ + {Name: "a", Status: ToolStatusDone, Collapsed: true}, + {Name: "b", Status: ToolStatusDone, Collapsed: true}, + } + + // Toggle last only. + last := len(m.toolEntries) - 1 + m.toolEntries[last].Collapsed = !m.toolEntries[last].Collapsed + + if m.toolEntries[0].Collapsed != true { + t.Error("first entry should remain collapsed") + } + if m.toolEntries[1].Collapsed != false { + t.Error("last entry should be expanded") + } +} + +func TestFileWriteSnapshotBefore(t *testing.T) { + m := newTestModel(t) + + // Tool name containing "write" triggers snapshot. + updated, _ := m.Update(ToolCallStartMsg{ + Name: "file_write", + Args: map[string]any{"path": "/nonexistent/path"}, + }) + m = updated.(*Model) + + if len(m.toolEntries) != 1 { + t.Fatalf("expected 1 tool entry, got %d", len(m.toolEntries)) + } + // BeforeContent should be empty since file doesn't exist, but it should not panic. + if m.toolEntries[0].BeforeContent != "" { + t.Error("nonexistent file should give empty before content") + } +} diff --git a/internal/tui/toolcard.go b/internal/tui/toolcard.go new file mode 100644 index 0000000..837388d --- /dev/null +++ b/internal/tui/toolcard.go @@ -0,0 +1,346 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + "charm.land/bubbles/v2/spinner" + "charm.land/lipgloss/v2" +) + +// ToolCardKind represents the type of tool operation. +type ToolCardKind int + +const ( + ToolCardFile ToolCardKind = iota + ToolCardBash + ToolCardSearch + ToolCardGit + ToolCardGeneric +) + +// ToolCardState represents the execution state. +type ToolCardState int + +const ( + ToolCardRunning ToolCardState = iota + ToolCardSuccess + ToolCardError +) + +// ToolCard is a fancy tool execution display component. +type ToolCard struct { + Name string + Kind ToolCardKind + State ToolCardState + Args string + Result string + StartTime time.Time + Duration time.Duration + Expanded bool + Spinner spinner.Model + ElapsedTimer *time.Timer + Elapsed time.Duration + IsDark bool + Styles ToolCardStyles +} + +// ToolCardStyles holds styles for the tool card. +type ToolCardStyles struct { + BorderRunning lipgloss.Style + BorderSuccess lipgloss.Style + BorderError lipgloss.Style + TitleRunning lipgloss.Style + TitleSuccess lipgloss.Style + TitleError lipgloss.Style + Args lipgloss.Style + Result lipgloss.Style + Error lipgloss.Style + Dimmed lipgloss.Style + Elapsed lipgloss.Style +} + +// NewToolCardStyles creates styles based on theme. +func NewToolCardStyles(isDark bool) ToolCardStyles { + if isDark { + return ToolCardStyles{ + BorderRunning: lipgloss.NewStyle().Foreground(lipgloss.Color("#81a1c1")), + BorderSuccess: lipgloss.NewStyle().Foreground(lipgloss.Color("#a3be8c")), + BorderError: lipgloss.NewStyle().Foreground(lipgloss.Color("#bf616a")), + TitleRunning: lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0")).Bold(true), + TitleSuccess: lipgloss.NewStyle().Foreground(lipgloss.Color("#a3be8c")).Bold(true), + TitleError: lipgloss.NewStyle().Foreground(lipgloss.Color("#bf616a")).Bold(true), + Args: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Result: lipgloss.NewStyle().Foreground(lipgloss.Color("#d8dee9")), + Error: lipgloss.NewStyle().Foreground(lipgloss.Color("#bf616a")), + Dimmed: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Elapsed: lipgloss.NewStyle().Foreground(lipgloss.Color("#81a1c1")), + } + } + return ToolCardStyles{ + BorderRunning: lipgloss.NewStyle().Foreground(lipgloss.Color("#5e81ac")), + BorderSuccess: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f38")), + BorderError: lipgloss.NewStyle().Foreground(lipgloss.Color("#c94f4f")), + TitleRunning: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f8f")).Bold(true), + TitleSuccess: lipgloss.NewStyle().Foreground(lipgloss.Color("#4f8f38")).Bold(true), + TitleError: lipgloss.NewStyle().Foreground(lipgloss.Color("#c94f4f")).Bold(true), + Args: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Result: lipgloss.NewStyle().Foreground(lipgloss.Color("#4c566a")), + Error: lipgloss.NewStyle().Foreground(lipgloss.Color("#c94f4f")), + Dimmed: lipgloss.NewStyle().Foreground(lipgloss.Color("#9ca0a8")), + Elapsed: lipgloss.NewStyle().Foreground(lipgloss.Color("#5e81ac")), + } +} + +// NewToolCard creates a new tool card. +func NewToolCard(name string, kind ToolCardKind, isDark bool) ToolCard { + s := spinner.New( + spinner.WithSpinner(spinner.MiniDot), + spinner.WithStyle(lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0"))), + ) + return ToolCard{ + Name: name, + Kind: kind, + State: ToolCardRunning, + Spinner: s, + IsDark: isDark, + Styles: NewToolCardStyles(isDark), + } +} + +// SetDark updates the theme. +func (c *ToolCard) SetDark(isDark bool) { + c.IsDark = isDark + c.Styles = NewToolCardStyles(isDark) +} + +// Tick advances the spinner animation. +func (c *ToolCard) Tick() { + c.Spinner.Tick() +} + +// UpdateElapsed updates the elapsed time counter. +func (c *ToolCard) UpdateElapsed() { + if c.State == ToolCardRunning { + c.Elapsed = time.Since(c.StartTime) + } +} + +// getIcon returns the appropriate icon for the tool kind and state. +func (c *ToolCard) getIcon() string { + switch c.Kind { + case ToolCardFile: + if c.State == ToolCardRunning { + return "📄" + } + if c.State == ToolCardSuccess { + return "✓" + } + return "✗" + case ToolCardBash: + if c.State == ToolCardRunning { + return "💻" + } + if c.State == ToolCardSuccess { + return "✓" + } + return "✗" + case ToolCardSearch: + if c.State == ToolCardRunning { + return "🔍" + } + if c.State == ToolCardSuccess { + return "✓" + } + return "✗" + case ToolCardGit: + if c.State == ToolCardRunning { + return "🌿" + } + if c.State == ToolCardSuccess { + return "✓" + } + return "✗" + default: + if c.State == ToolCardRunning { + return "◌" + } + if c.State == ToolCardSuccess { + return "✓" + } + return "✗" + } +} + +// getStatusText returns the status text based on state. +func (c *ToolCard) getStatusText() string { + switch c.State { + case ToolCardRunning: + return "running..." + case ToolCardSuccess: + return fmt.Sprintf("(%s)", formatDuration(c.Duration)) + case ToolCardError: + return fmt.Sprintf("error (%s)", formatDuration(c.Duration)) + default: + return "" + } +} + +// getBorderStyle returns the appropriate border style. +func (c *ToolCard) getBorderStyle() lipgloss.Style { + switch c.State { + case ToolCardRunning: + return c.Styles.BorderRunning + case ToolCardSuccess: + return c.Styles.BorderSuccess + case ToolCardError: + return c.Styles.BorderError + default: + return c.Styles.BorderRunning + } +} + +// getTitleStyle returns the appropriate title style. +func (c *ToolCard) getTitleStyle() lipgloss.Style { + switch c.State { + case ToolCardRunning: + return c.Styles.TitleRunning + case ToolCardSuccess: + return c.Styles.TitleSuccess + case ToolCardError: + return c.Styles.TitleError + default: + return c.Styles.TitleRunning + } +} + +// View renders the tool card. +func (c *ToolCard) View(width int) string { + // Update elapsed time for running tools + c.UpdateElapsed() + + // Build title line + icon := c.getIcon() + statusText := c.getStatusText() + + var titleParts []string + titleParts = append(titleParts, icon) + titleParts = append(titleParts, c.Name) + + if c.State == ToolCardRunning { + titleParts = append(titleParts, c.Spinner.View()) + titleParts = append(titleParts, statusText) + // Show elapsed time for running tools + elapsedStr := fmt.Sprintf("%.1fs", c.Elapsed.Seconds()) + titleParts = append(titleParts, c.Styles.Elapsed.Render(elapsedStr)) + } else { + titleParts = append(titleParts, statusText) + } + + title := strings.Join(titleParts, " ") + titleStyle := c.getTitleStyle() + + // Create bordered box + content := titleStyle.Render(title) + + if c.Expanded && c.State != ToolCardRunning { + // Show args and result when expanded + var details strings.Builder + + if c.Args != "" { + args := truncate(c.Args, 80) + details.WriteString(c.Styles.Args.Render(" args: " + args)) + details.WriteString("\n") + } + + if c.Result != "" { + if c.State == ToolCardError { + details.WriteString(c.Styles.Error.Render(" " + truncate(c.Result, 200))) + } else { + details.WriteString(c.Styles.Result.Render(" " + truncate(c.Result, 200))) + } + details.WriteString("\n") + } + + content = lipgloss.JoinVertical(lipgloss.Left, content, details.String()) + } + + // Apply border + borderStyle := c.getBorderStyle() + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(borderStyle.GetForeground()). + Padding(0, 1) + + return box.Render(content) +} + +// ToolCardManager manages multiple tool cards with synchronized animations. +type ToolCardManager struct { + Cards []ToolCard + IsDark bool +} + +// NewToolCardManager creates a new manager. +func NewToolCardManager(isDark bool) ToolCardManager { + return ToolCardManager{ + Cards: []ToolCard{}, + IsDark: isDark, + } +} + +// AddCard adds a new tool card. +func (m *ToolCardManager) AddCard(name string, kind ToolCardKind, startTime time.Time) { + card := NewToolCard(name, kind, m.IsDark) + card.StartTime = startTime + m.Cards = append(m.Cards, card) +} + +// UpdateCard updates an existing card by name. +func (m *ToolCardManager) UpdateCard(name string, state ToolCardState, result string, duration time.Duration) { + for i := range m.Cards { + if m.Cards[i].Name == name && m.Cards[i].State == ToolCardRunning { + m.Cards[i].State = state + m.Cards[i].Result = result + m.Cards[i].Duration = duration + break + } + } +} + +// SetExpanded sets a card's expanded state. +func (m *ToolCardManager) SetExpanded(name string, expanded bool) { + for i := range m.Cards { + if m.Cards[i].Name == name { + m.Cards[i].Expanded = expanded + break + } + } +} + +// Tick advances all running card spinners. +func (m *ToolCardManager) Tick() { + for i := range m.Cards { + if m.Cards[i].State == ToolCardRunning { + m.Cards[i].Tick() + } + } +} + +// SetDark updates theme for all cards. +func (m *ToolCardManager) SetDark(isDark bool) { + m.IsDark = isDark + for i := range m.Cards { + m.Cards[i].SetDark(isDark) + } +} + +// View renders all cards. +func (m *ToolCardManager) View(width int) string { + var lines []string + for i := range m.Cards { + lines = append(lines, m.Cards[i].View(width)) + } + return strings.Join(lines, "\n") +} diff --git a/internal/tui/toolrender.go b/internal/tui/toolrender.go new file mode 100644 index 0000000..3bbc8c7 --- /dev/null +++ b/internal/tui/toolrender.go @@ -0,0 +1,240 @@ +package tui + +import ( + "fmt" + "regexp" + "strings" +) + +// ToolType represents the category of a tool for rendering. +type ToolType int + +const ( + ToolTypeDefault ToolType = iota + ToolTypeBash + ToolTypeFileRead + ToolTypeFileWrite + ToolTypeWeb + ToolTypeMemory +) + +// classifyTool returns the ToolType based on the tool name. +func classifyTool(name string) ToolType { + lower := strings.ToLower(name) + switch { + case strings.Contains(lower, "bash") || strings.Contains(lower, "exec") || strings.Contains(lower, "shell") || strings.Contains(lower, "command"): + return ToolTypeBash + case strings.Contains(lower, "read") || strings.Contains(lower, "view") || strings.Contains(lower, "cat"): + return ToolTypeFileRead + case strings.Contains(lower, "write") || strings.Contains(lower, "edit") || strings.Contains(lower, "create_file") || strings.Contains(lower, "patch"): + return ToolTypeFileWrite + case strings.Contains(lower, "web") || strings.Contains(lower, "fetch") || strings.Contains(lower, "http") || strings.Contains(lower, "curl") || strings.Contains(lower, "browse"): + return ToolTypeWeb + case strings.Contains(lower, "memory") || strings.Contains(lower, "remember") || strings.Contains(lower, "forget"): + return ToolTypeMemory + default: + return ToolTypeDefault + } +} + +// toolIcon returns a type-specific icon for the tool. +func toolIcon(tt ToolType, status ToolStatus) string { + if status == ToolStatusError { + return "✗" + } + if status == ToolStatusDone { + switch tt { + case ToolTypeBash: + return "$" + case ToolTypeFileRead: + return "◎" + case ToolTypeFileWrite: + return "✎" + case ToolTypeWeb: + return "◆" + case ToolTypeMemory: + return "◈" + default: + return "✓" + } + } + // Running + return "⚙" +} + +// toolSummary extracts a key argument for display based on tool type. +func toolSummary(tt ToolType, te ToolEntry) string { + if te.RawArgs == nil { + return "" + } + switch tt { + case ToolTypeBash: + if cmd, ok := te.RawArgs["command"].(string); ok { + if len(cmd) > 60 { + cmd = cmd[:57] + "..." + } + return cmd + } + case ToolTypeFileRead, ToolTypeFileWrite: + for _, key := range []string{"path", "file_path", "filename", "file"} { + if p, ok := te.RawArgs[key].(string); ok { + return p + } + } + case ToolTypeWeb: + for _, key := range []string{"url", "uri", "href"} { + if u, ok := te.RawArgs[key].(string); ok { + if len(u) > 60 { + u = u[:57] + "..." + } + return u + } + } + case ToolTypeMemory: + if k, ok := te.RawArgs["key"].(string); ok { + return k + } + } + return "" +} + +// codeBlockRegex matches markdown code blocks. +var codeBlockRegex = regexp.MustCompile(`(?s)~~~(\w*)\n(.*?)~~~|` + "```(\\w*)\\n(.*?)```") + +// detectCodeBlocks checks if the result contains markdown code blocks. +func detectCodeBlocks(text string) bool { + return strings.Contains(text, "```") || strings.Contains(text, "~~~") +} + +// extractCodeBlocks extracts code blocks from text and returns them with their language. +func extractCodeBlocks(text string) []struct { + Language string + Code string +} { + var blocks []struct { + Language string + Code string + } + + lines := strings.Split(text, "\n") + var inBlock bool + var currentLang string + var currentCode strings.Builder + + for _, line := range lines { + if strings.HasPrefix(line, "```") || strings.HasPrefix(line, "~~~") { + if !inBlock { + // Start of code block + inBlock = true + currentLang = strings.TrimPrefix(line, "```") + currentLang = strings.TrimPrefix(currentLang, "~~~") + currentLang = strings.TrimSpace(currentLang) + currentCode.Reset() + } else { + // End of code block + inBlock = false + blocks = append(blocks, struct { + Language string + Code string + }{ + Language: currentLang, + Code: strings.TrimRight(currentCode.String(), "\n"), + }) + } + } else if inBlock { + currentCode.WriteString(line) + currentCode.WriteString("\n") + } + } + + return blocks +} + +// formatToolResult formats a tool result for display with smart truncation. +// It preserves code blocks and adds expand/collapse hints. +func formatToolResult(result string, maxLines int, maxWidth int) string { + if result == "" { + return "(no output)" + } + + lines := strings.Split(result, "\n") + + // Detect if result contains code blocks + hasCodeBlocks := detectCodeBlocks(result) + + // Truncate by lines if too long + if len(lines) > maxLines { + var b strings.Builder + for i := 0; i < maxLines; i++ { + line := lines[i] + // Truncate long lines + if len(line) > maxWidth { + line = line[:maxWidth-3] + "..." + } + b.WriteString(line) + b.WriteString("\n") + } + remaining := len(lines) - maxLines + b.WriteString("... ") + if hasCodeBlocks { + b.WriteString("(code blocks truncated)") + } else { + b.WriteString(fmt.Sprintf("%d", remaining)) + b.WriteString(" more lines") + } + return b.String() + } + + // Truncate long lines + var b strings.Builder + for i, line := range lines { + if len(line) > maxWidth { + line = line[:maxWidth-3] + "..." + } + b.WriteString(line) + if i < len(lines)-1 { + b.WriteString("\n") + } + } + + return b.String() +} + +// isLikelyJSON checks if a string looks like JSON. +func isLikelyJSON(s string) bool { + s = strings.TrimSpace(s) + return strings.HasPrefix(s, "{") || strings.HasPrefix(s, "[") +} + +// isLikelyXML checks if a string looks like XML. +func isLikelyXML(s string) bool { + s = strings.TrimSpace(s) + return strings.HasPrefix(s, "<") +} + +// detectLanguage tries to detect the language of a code snippet. +func detectLanguage(code string) string { + if isLikelyJSON(code) { + return "json" + } + if isLikelyXML(code) { + return "xml" + } + // Check for common patterns + if strings.Contains(code, "func ") && strings.Contains(code, "{") { + return "go" + } + if strings.Contains(code, "import ") && strings.Contains(code, ";") { + return "java" + } + if strings.Contains(code, "def ") || strings.Contains(code, "import ") { + return "python" + } + if strings.Contains(code, "const ") || strings.Contains(code, "function") { + return "javascript" + } + if strings.Contains(code, " 60 { + t.Errorf("expected truncated to 60 chars, got %d", len(got)) + } + if got[len(got)-3:] != "..." { + t.Error("truncated command should end with ...") + } + }) + + t.Run("file_read_path", func(t *testing.T) { + te := ToolEntry{ + RawArgs: map[string]any{"file_path": "/home/user/test.go"}, + } + got := toolSummary(ToolTypeFileRead, te) + if got != "/home/user/test.go" { + t.Errorf("expected path, got %q", got) + } + }) + + t.Run("file_write_path", func(t *testing.T) { + te := ToolEntry{ + RawArgs: map[string]any{"path": "/tmp/output.txt"}, + } + got := toolSummary(ToolTypeFileWrite, te) + if got != "/tmp/output.txt" { + t.Errorf("expected path, got %q", got) + } + }) + + t.Run("web_url", func(t *testing.T) { + te := ToolEntry{ + RawArgs: map[string]any{"url": "https://example.com"}, + } + got := toolSummary(ToolTypeWeb, te) + if got != "https://example.com" { + t.Errorf("expected url, got %q", got) + } + }) + + t.Run("memory_key", func(t *testing.T) { + te := ToolEntry{ + RawArgs: map[string]any{"key": "user_pref"}, + } + got := toolSummary(ToolTypeMemory, te) + if got != "user_pref" { + t.Errorf("expected 'user_pref', got %q", got) + } + }) + + t.Run("nil_args", func(t *testing.T) { + te := ToolEntry{RawArgs: nil} + got := toolSummary(ToolTypeBash, te) + if got != "" { + t.Errorf("nil args should return empty, got %q", got) + } + }) + + t.Run("default_type_returns_empty", func(t *testing.T) { + te := ToolEntry{ + RawArgs: map[string]any{"foo": "bar"}, + } + got := toolSummary(ToolTypeDefault, te) + if got != "" { + t.Errorf("default type should return empty, got %q", got) + } + }) +} diff --git a/internal/tui/view.go b/internal/tui/view.go new file mode 100644 index 0000000..3c146c3 --- /dev/null +++ b/internal/tui/view.go @@ -0,0 +1,754 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + "ai-agent/internal/agent" + + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" +) + +func (m *Model) View() tea.View { + if !m.ready { + return tea.NewView(" initializing...") + } + var content string + rightWidth := m.width - 1 + if m.sidePanel.IsVisible() { + rightWidth = m.width - m.sidePanel.width - 1 + } + var rightSide strings.Builder + rightSide.WriteString(m.viewport.View()) + rightSide.WriteString("\n") + rightSide.WriteString(m.styles.Divider.Render(rule(rightWidth))) + rightSide.WriteString("\n") + rightSide.WriteString(m.renderStatusLine()) + rightSide.WriteString("\n") + if m.state == StateIdle { + rightSide.WriteString(m.input.View()) + } else if m.state == StateWaiting { + rightSide.WriteString(m.styles.StreamHint.Render(" " + m.scramble.View() + " thinking... press Esc to cancel")) + } else { + rightSide.WriteString(m.styles.StreamHint.Render(" " + m.spin.View() + " streaming... press Esc to cancel")) + } + if m.sidePanel.IsVisible() { + panelView := m.sidePanel.View() + rightContent := rightSide.String() + panelW := m.sidePanel.width + rightW := rightWidth + leftStyle := lipgloss.NewStyle().Width(panelW).Height(m.height) + left := leftStyle.Render(panelView) + rightStyle := lipgloss.NewStyle().Width(rightW).Height(m.height) + right := rightStyle.Render(rightContent) + dividerChars := "" + for i := 0; i < m.height; i++ { + dividerChars += "│\n" + } + divider := lipgloss.NewStyle().Foreground(lipgloss.Color("#6c7a89")).Render(dividerChars) + content = lipgloss.JoinHorizontal(lipgloss.Top, left, divider, right) + } else { + content = rightSide.String() + } + if m.overlay != OverlayNone { + var overlay string + switch m.overlay { + case OverlayHelp: + overlay = m.renderHelpOverlay(m.width) + case OverlayCompletion: + if m.isCompletionActive() { + overlay = m.renderCompletionModal() + } + case OverlayModelPicker: + if m.modelPickerState != nil { + overlay = m.renderModelPicker() + } + case OverlayPlanForm: + if m.planFormState != nil { + overlay = m.renderPlanForm() + } + case OverlaySessionsPicker: + if m.sessionsPickerState != nil { + overlay = m.renderSessionsPicker() + } + } + if overlay != "" { + content = m.overlayOnContent(content, overlay) + } + } + var b strings.Builder + b.WriteString(content) + b.WriteString("\n") + if m.toastMgr != nil && m.toastMgr.HasToasts() { + m.toastMgr.Update() + toastStr := m.toastMgr.Render(m.width) + if toastStr != "" { + b.WriteString("\n") + b.WriteString(toastStr) + } + } + v := tea.NewView(b.String()) + v.AltScreen = true + v.MouseMode = tea.MouseModeCellMotion + loc := m.tr() + switch m.state { + case StateWaiting: + v.WindowTitle = loc.WindowTitleThink + case StateStreaming: + v.WindowTitle = loc.WindowTitleStream + default: + if m.doneFlash { + v.WindowTitle = loc.WindowTitleDone + } else { + v.WindowTitle = loc.WindowTitle + } + } + return v +} + +func (m *Model) renderCompletionModal() string { + cs := m.completionState + if cs == nil { + return "" + } + var b strings.Builder + var title string + switch cs.Kind { + case "command": + title = "Commands" + case "attachments": + title = "Attach Files & Agents" + case "skills": + title = "Skills" + default: + title = "Complete" + } + b.WriteString(m.styles.OverlayTitle.Render(title)) + b.WriteString("\n") + b.WriteString(m.styles.CompletionFilter.Render("> " + cs.Filter.View())) + b.WriteString("\n") + if cs.Kind == "attachments" && cs.CurrentPath != "" { + b.WriteString(m.styles.CompletionCategory.Render(cs.CurrentPath + "/")) + b.WriteString("\n") + } + maxW := 40 + if m.width-8 > maxW { + maxW = m.width - 8 + } + if maxW > 60 { + maxW = 60 + } + b.WriteString(m.styles.FocusIndicator.Render(strings.Repeat("─", maxW))) + b.WriteString("\n") + maxVisible := 10 + items := cs.FilteredItems + if len(items) == 0 { + b.WriteString(m.styles.CompletionCategory.Render(" (no matches)")) + b.WriteString("\n") + } else { + start := 0 + if cs.Index >= maxVisible { + start = cs.Index - maxVisible + 1 + } + end := start + maxVisible + if end > len(items) { + end = len(items) + } + for i := start; i < end; i++ { + item := items[i] + prefix := " " + if i == cs.Index { + prefix = m.styles.FocusIndicator.Render("▸ ") + } + selectedMark := "" + if cs.Selected != nil { + for oi, orig := range cs.AllItems { + if orig.Label == item.Label && orig.Insert == item.Insert { + if cs.Selected[oi] { + selectedMark = m.styles.FocusIndicator.Render(" ✓") + } + break + } + } + } + label := item.Label + cat := m.styles.CompletionCategory.Render(" " + item.Category) + if i == cs.Index { + b.WriteString(prefix + m.styles.FocusIndicator.Render(label) + cat + selectedMark) + } else { + b.WriteString(prefix + label + cat + selectedMark) + } + b.WriteString("\n") + } + } + if cs.Searching { + b.WriteString(m.styles.CompletionSearching.Render(" searching...")) + b.WriteString("\n") + } + hints := "Enter=select Esc=cancel" + if cs.Kind == "attachments" && cs.CurrentPath != "" { + hints += " ←=back" + } + if cs.Selected != nil { + hints += " Tab=toggle" + } + b.WriteString(m.styles.CompletionFooter.Render(hints)) + box := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(m.styles.OverlayBorder)). + Padding(1, 2). + Width(maxW + 4) + + return box.Render(b.String()) +} + +// renderHeader builds: +// +// ai-agent qwen3:8b · 5 tools +// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +func (m *Model) renderHeader() string { + title := m.styles.HeaderTitle.Render("AI AGENT") + + var infoStr string + if m.model != "" { + parts := []string{m.model} + if m.toolCount > 0 { + parts = append(parts, fmt.Sprintf("%d tools", m.toolCount)) + } + if m.serverCount > 0 { + parts = append(parts, fmt.Sprintf("%d servers", m.serverCount)) + } + if m.loadedFile != "" { + parts = append(parts, "ctx") + } + if m.iceEnabled { + parts = append(parts, "ICE") + } + if m.promptTokens > 0 && m.numCtx > 0 { + pct := m.promptTokens * 100 / m.numCtx + var pctStyle lipgloss.Style + switch { + case pct > 85: + pctStyle = m.styles.ContextPctHigh + case pct > 60: + pctStyle = m.styles.ContextPctMid + default: + pctStyle = m.styles.ContextPctLow + } + parts = append(parts, pctStyle.Render(contextProgressBar(pct))) + } + infoStr = m.styles.HeaderInfo.Render(strings.Join(parts, " · ")) + } + titleW := lipgloss.Width(title) + infoW := lipgloss.Width(infoStr) + gap := m.width - titleW - infoW + if gap < 1 { + gap = 1 + } + line := title + strings.Repeat(" ", gap) + infoStr + ruler := m.styles.HeaderRule.Render(rule(m.width)) + + return line + "\n" + ruler +} + +func (m *Model) renderFooter() string { + var b strings.Builder + b.WriteString(m.styles.Divider.Render(rule(m.width))) + b.WriteString("\n") + b.WriteString(m.renderStatusLine()) + b.WriteString("\n") + if m.state == StateIdle { + b.WriteString(m.input.View()) + } else if m.state == StateWaiting { + b.WriteString(m.styles.StreamHint.Render(" " + m.scramble.View() + " thinking... press Esc to cancel")) + } else { + b.WriteString(m.styles.StreamHint.Render(" " + m.spin.View() + " streaming... press Esc to cancel")) + } + return b.String() +} + +func (m *Model) renderStatusLine() string { + if m.pendingApproval != nil { + args := agent.FormatToolArgs(m.pendingApproval.Args) + promptText := m.pendingApproval.ToolName + if args != "" { + promptText += " " + args + } + if len(promptText) > 60 { + promptText = promptText[:57] + "..." + } + return m.styles.ApprovalPrompt.Render( + fmt.Sprintf(" ⚡ Allow %s? [y]es / [n]o / [a]lways", promptText), + ) + } + if m.pendingPaste != "" { + lines := strings.Count(m.pendingPaste, "\n") + 1 + return m.styles.StatusText.Render( + fmt.Sprintf(" Large paste (%d lines). Wrap as code block? [y/n/esc]", lines), + ) + } + var parts []string + switch m.state { + case StateWaiting: + // No status line content — the hint line below shows "thinking..." + case StateStreaming: + if m.streamBuf.Len() > 0 { + parts = append(parts, m.styles.StatusText.Render( + fmt.Sprintf("%d chars", m.streamBuf.Len()), + )) + } + if m.toolsPending > 0 { + parts = append(parts, m.styles.StatusText.Render( + fmt.Sprintf("%d tool(s) pending", m.toolsPending), + )) + } + case StateIdle: + cfg := m.modeConfigs[m.mode] + var modeStyle lipgloss.Style + switch m.mode { + case ModeAsk: + modeStyle = m.styles.ModeAsk + case ModePlan: + modeStyle = m.styles.ModePlan + case ModeBuild: + modeStyle = m.styles.ModeBuild + } + parts = append(parts, modeStyle.Render("[ "+cfg.Label+" ]")) + dot := m.styles.StatusDot.Render("○") + label := m.styles.StatusText.Render(" ready") + parts = append(parts, dot+label) + if m.promptTokens > 0 && m.numCtx > 0 { + parts = append(parts, m.styles.StatusText.Render( + fmt.Sprintf("~%s / %s ctx", formatTokens(m.promptTokens), formatTokens(m.numCtx)), + )) + } + if m.sessionEvalTotal > 0 { + parts = append(parts, m.styles.StatusText.Render( + fmt.Sprintf("%s out (%d turns)", formatTokens(m.sessionEvalTotal), m.sessionTurnCount), + )) + } + } + if len(parts) == 0 { + return "" + } + return " " + strings.Join(parts, m.styles.StatusText.Render(" · ")) +} + +func formatTokens(n int) string { + if n >= 1000 { + return fmt.Sprintf("%.1fk", float64(n)/1000) + } + return fmt.Sprintf("%d", n) +} + +func (m *Model) renderEntries() string { + viewportW := m.width - 1 + if m.sidePanel.IsVisible() { + viewportW = m.width - m.sidePanel.width - 2 + } + if viewportW < 20 { + viewportW = 20 + } + contentW := viewportW - 6 + if contentW < 14 { + contentW = 14 + } + if m.initializing { + var b strings.Builder + m.renderStartup(&b) + return b.String() + } + hasUserMsg := false + for _, e := range m.entries { + if e.Kind == "user" || e.Kind == "assistant" { + hasUserMsg = true + break + } + } + if !hasUserMsg && m.streamBuf.Len() == 0 { + var b strings.Builder + m.renderWelcome(&b) + for _, e := range m.entries { + if e.Kind == "system" { + b.WriteString(m.styles.SystemText.Render(e.Content)) + b.WriteString("\n\n") + } else if e.Kind == "error" { + b.WriteString(m.styles.ErrorText.Render("error: " + e.Content)) + b.WriteString("\n\n") + } + } + return b.String() + } + if m.entryCacheValid && len(m.entries) == m.cachedEntryCount { + m.toolEntryRows = m.cachedToolEntryRows + if m.streamBuf.Len() > 0 { + var b strings.Builder + b.WriteString(m.cachedEntriesRender) + if len(m.entries) > 0 { + last := m.entries[len(m.entries)-1] + if last.Kind != "tool_group" { + b.WriteString("\n") + } + } + m.renderStreamingMsg(&b, m.streamBuf.String(), contentW) + return b.String() + } + return m.cachedEntriesRender + } + var b strings.Builder + m.toolEntryRows = make(map[int]int) + for i, entry := range m.entries { + switch entry.Kind { + case "user": + m.renderUserMsg(&b, entry.Content, contentW) + case "assistant": + m.renderAssistantMsg(&b, entry, contentW) + case "tool_group": + m.toolEntryRows[entry.ToolIndex] = strings.Count(b.String(), "\n") + m.renderToolGroup(&b, entry.ToolIndex, i) + case "error": + b.WriteString(m.styles.ErrorText.Render("error: " + entry.Content)) + b.WriteString("\n\n") + case "system": + b.WriteString(m.styles.SystemText.Render(entry.Content)) + b.WriteString("\n\n") + } + if i < len(m.entries)-1 { + next := m.entries[i+1] + curr := entry.Kind + nextK := next.Kind + if curr == "tool_group" { + continue + } else if curr != nextK { + b.WriteString("\n") + } + } + } + m.cachedEntriesRender = b.String() + m.cachedEntryCount = len(m.entries) + if m.cachedToolEntryRows == nil { + m.cachedToolEntryRows = make(map[int]int, 8) + } else { + clear(m.cachedToolEntryRows) + } + for k, v := range m.toolEntryRows { + m.cachedToolEntryRows[k] = v + } + m.entryCacheValid = true + if m.streamBuf.Len() > 0 { + if len(m.entries) > 0 { + last := m.entries[len(m.entries)-1] + if last.Kind != "tool_group" { + b.WriteString("\n") + } + } + m.renderStreamingMsg(&b, m.streamBuf.String(), contentW) + } + return b.String() +} + +func (m *Model) renderWelcome(b *strings.Builder) { + var wb strings.Builder + for _, line := range logoLines() { + if line == "" { + wb.WriteString("\n") + } else { + wb.WriteString(lipgloss.NewStyle(). + Foreground(lipgloss.Color("#88c0d0")). + Bold(true). + Render(line)) + wb.WriteString("\n") + } + } + title := gradientText("Welcome to AI AGENT", []string{"#88c0d0", "#81a1c1", "#b48ead"}) + wb.WriteString(" " + m.styles.OverlayTitle.Render(title)) + wb.WriteString("\n") + var infoParts []string + if m.model != "" { + infoParts = append(infoParts, m.model) + } + if m.toolCount > 0 { + infoParts = append(infoParts, fmt.Sprintf("%d tools", m.toolCount)) + } + if m.serverCount > 0 { + infoParts = append(infoParts, fmt.Sprintf("%d servers", m.serverCount)) + } + if len(infoParts) > 0 { + wb.WriteString(m.styles.StatusText.Render(" " + strings.Join(infoParts, " · "))) + wb.WriteString("\n") + } + wb.WriteString("\n") + modes := []struct { + key string + desc string + color string + }{ + {"ASK", "Quick answers", "#81a1c1"}, + {"PLAN", "Design & reasoning", "#ebcb8b"}, + {"BUILD", "Full execution", "#a3be8c"}, + } + for _, mode := range modes { + modeStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(mode.color)). + Bold(true) + wb.WriteString(" ") + wb.WriteString(modeStyle.Render(mode.key)) + wb.WriteString(m.styles.StatusText.Render(" — " + mode.desc)) + wb.WriteString("\n") + } + wb.WriteString("\n") + wb.WriteString(m.styles.SystemText.Render(" Type a message to start · Press ? for help")) + wb.WriteString("\n") + contentWidth := m.width + if m.sidePanel.IsVisible() { + contentWidth = m.width - m.sidePanel.width - 1 + } + centered := lipgloss.PlaceHorizontal(contentWidth, lipgloss.Center, wb.String()) + b.WriteString(centered) +} + +func (m *Model) renderUserMsg(b *strings.Builder, content string, contentW int) { + label := m.styles.UserLabel.Render("you") + labelW := lipgloss.Width(label) + ruleW := contentW - labelW - 3 + if ruleW < 4 { + ruleW = 4 + } + b.WriteString(label + " " + m.styles.RoleRule.Render(rule(ruleW))) + b.WriteString("\n") + b.WriteString(m.styles.UserContent.Render(wrapText(content, contentW))) + b.WriteString("\n") +} + +func (m *Model) renderAssistantMsg(b *strings.Builder, entry ChatEntry, contentW int) { + if entry.ThinkingContent != "" { + thinkBox := m.renderThinkingBox(entry.ThinkingContent, entry.ThinkingCollapsed) + b.WriteString(indentBlock(thinkBox, " ")) + b.WriteString("\n") + } + label := m.styles.AsstLabel.Render("assistant") + labelW := lipgloss.Width(label) + ruleW := contentW - labelW - 3 + if ruleW < 4 { + ruleW = 4 + } + b.WriteString(label + " " + m.styles.RoleRule.Render(rule(ruleW))) + b.WriteString("\n") + rendered := entry.RenderedContent + if rendered == "" { + rendered = entry.Content + if m.md != nil { + rendered = m.md.RenderFull(rendered) + } + } + rendered = strings.TrimRight(rendered, " \t\n") + rendered = indentBlock(rendered, " ") + b.WriteString(rendered) + b.WriteString("\n") +} + +func (m *Model) renderStreamingMsg(b *strings.Builder, content string, contentW int) { + if m.thinkBuf.Len() > 0 { + thinkHint := m.styles.ThinkingHeader.Render( + fmt.Sprintf(" thinking: %d chars...", m.thinkBuf.Len()), + ) + b.WriteString(thinkHint) + b.WriteString("\n") + } + label := m.styles.AsstLabel.Render("assistant") + cursor := m.styles.StreamCursor.Render(" " + m.spin.View()) + labelW := lipgloss.Width(label) + lipgloss.Width(cursor) + ruleW := contentW - labelW - 3 + if ruleW < 4 { + ruleW = 4 + } + b.WriteString(label + cursor + " " + m.styles.RoleRule.Render(rule(ruleW))) + b.WriteString("\n") + wrapWidth := contentW - 2 + if wrapWidth < 10 { + wrapWidth = 10 + } + wrapped := wrapText(content, wrapWidth) + rendered := indentBlock(wrapped, " ") + b.WriteString(rendered) + b.WriteString("\n") +} + +func (m *Model) renderToolGroup(b *strings.Builder, toolIdx, entryIdx int) { + if toolIdx < 0 || toolIdx >= len(m.toolEntries) { + return + } + te := m.toolEntries[toolIdx] + layout := m.currentLayout() + if entryIdx > 0 && m.entries[entryIdx-1].Kind != "tool_group" { + b.WriteString("\n") + } + var card *ToolCard + for i := range m.toolCardMgr.Cards { + if m.toolCardMgr.Cards[i].Name == te.Name { + card = &m.toolCardMgr.Cards[i] + break + } + } + if card != nil { + card.Expanded = !te.Collapsed + availableWidth := m.width - 8 + if m.sidePanel.IsVisible() { + availableWidth = m.width - m.sidePanel.width - 10 + } + if availableWidth < 30 { + availableWidth = 30 + } + cardView := card.View(availableWidth) + cardView = indentBlock(cardView, " ") + b.WriteString(cardView) + b.WriteString("\n\n") + } else { + tt := classifyTool(te.Name) + switch te.Status { + case ToolStatusRunning: + icon := m.styles.ToolCallIcon.Render(toolIcon(tt, te.Status)) + spinView := m.spin.View() + text := m.styles.ToolCallText.Render(fmt.Sprintf(" %s ", te.Name)) + hint := m.styles.ToolRunningText.Render(spinView + " running...") + b.WriteString(icon + text + hint) + if tt == ToolTypeBash { + if summary := toolSummary(tt, te); summary != "" { + b.WriteString("\n") + b.WriteString(m.styles.ToolBashCmd.Render(layout.ToolIndent + "$ " + summary)) + } + } + b.WriteString("\n") + case ToolStatusDone: + dur := formatDuration(te.Duration) + icon := m.styles.ToolDoneIcon.Render(toolIcon(tt, te.Status)) + if te.Collapsed { + // Collapsed: single line with type-specific summary + text := m.styles.ToolDoneText.Render(fmt.Sprintf(" %s (%s)", te.Name, dur)) + b.WriteString(icon + text) + if summary := toolSummary(tt, te); summary != "" { + summ := truncate(summary, layout.ToolSummaryMax) + b.WriteString(m.styles.ToolBashCmd.Render(" " + summ)) + } + b.WriteString("\n") + } else { + // Expanded: show args + result (or diff for file writes) + text := m.styles.ToolDoneText.Render(fmt.Sprintf(" %s (%s)", te.Name, dur)) + b.WriteString(icon + text) + b.WriteString("\n") + args := truncate(te.Args, layout.ArgsTruncMax) + b.WriteString(m.styles.ToolDetailText.Render(layout.ToolIndent + "args: " + args)) + b.WriteString("\n") + if te.DiffLines != nil { + b.WriteString(renderDiff(te.DiffLines, m.styles, 30)) + } else { + result := formatToolResult(te.Result, 20, layout.ResultTruncMax) + resultLines := strings.Count(result, "\n") + 1 + if resultLines > 20 { + b.WriteString(m.styles.ToolDetailText.Render(layout.ToolIndent + "result (truncated, expand to see more):\n")) + b.WriteString(m.styles.ToolDetailText.Render(indentBlock(truncate(result, layout.ResultTruncMax), layout.ToolIndent))) + } else { + b.WriteString(m.styles.ToolDetailText.Render(layout.ToolIndent + "result:\n")) + b.WriteString(m.styles.ToolDetailText.Render(indentBlock(result, layout.ToolIndent))) + } + b.WriteString("\n") + } + } + case ToolStatusError: + // Error: always expanded regardless of collapse state + dur := formatDuration(te.Duration) + icon := m.styles.ToolErrorIcon.Render(toolIcon(tt, te.Status)) + text := m.styles.ToolErrorText.Render(fmt.Sprintf(" %s (%s)", te.Name, dur)) + b.WriteString(icon + text) + b.WriteString("\n") + result := truncate(te.Result, layout.ResultTruncMax) + b.WriteString(m.styles.ToolErrorText.Render(layout.ToolIndent + result)) + b.WriteString("\n") + } + } + if entryIdx < len(m.entries)-1 && m.entries[entryIdx+1].Kind != "tool_group" { + b.WriteString("\n") + } +} + +func formatDuration(d time.Duration) string { + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } + return fmt.Sprintf("%.1fs", d.Seconds()) +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + return s[:max-3] + "..." +} + +func wrapText(s string, width int) string { + if width <= 0 { + return s + } + if len(s) <= width { + return s + } + var result strings.Builder + for _, line := range strings.Split(s, "\n") { + result.WriteString(wrapLine(line, width)) + result.WriteString("\n") + } + return strings.TrimSuffix(result.String(), "\n") +} + +func wrapLine(line string, width int) string { + if len(line) <= width { + return line + } + var result strings.Builder + words := strings.Fields(line) + current := "" + for _, w := range words { + if current == "" { + current = w + } else if len(current)+1+len(w) <= width { + current += " " + w + } else { + if result.Len() > 0 { + result.WriteString("\n") + } + result.WriteString(current) + current = w + } + } + if current != "" { + if result.Len() > 0 { + result.WriteString("\n") + } + for len(current) > width { + if result.Len() > 0 { + result.WriteString("\n") + } + result.WriteString(current[:width]) + current = current[width:] + } + if len(current) > 0 { + if result.Len() > 0 { + result.WriteString("\n") + } + result.WriteString(current) + } + } + return result.String() +} + +func indentBlock(s, prefix string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + if line != "" { + lines[i] = prefix + line + } + } + return strings.Join(lines, "\n") +} diff --git a/internal/tui/view_test.go b/internal/tui/view_test.go new file mode 100644 index 0000000..6665ed0 --- /dev/null +++ b/internal/tui/view_test.go @@ -0,0 +1,200 @@ +package tui + +import ( + "strings" + "testing" + "time" +) + +func TestFormatTokens(t *testing.T) { + tests := []struct { + input int + want string + }{ + {999, "999"}, + {1000, "1.0k"}, + {1234, "1.2k"}, + {8192, "8.2k"}, + {0, "0"}, + {500, "500"}, + {10000, "10.0k"}, + } + + for _, tt := range tests { + got := formatTokens(tt.input) + if got != tt.want { + t.Errorf("formatTokens(%d) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestFormatDuration(t *testing.T) { + tests := []struct { + input time.Duration + want string + }{ + {42 * time.Millisecond, "42ms"}, + {1300 * time.Millisecond, "1.3s"}, + {0, "0ms"}, + {999 * time.Millisecond, "999ms"}, + {time.Second, "1.0s"}, + {2500 * time.Millisecond, "2.5s"}, + } + + for _, tt := range tests { + got := formatDuration(tt.input) + if got != tt.want { + t.Errorf("formatDuration(%v) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + input string + max int + want string + }{ + {"within_limit", "hello", 10, "hello"}, + {"exact_limit", "hello", 5, "hello"}, + {"over_limit", "hello world", 8, "hello..."}, + {"much_over", "this is a long string", 10, "this is..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncate(tt.input, tt.max) + if got != tt.want { + t.Errorf("truncate(%q, %d) = %q, want %q", tt.input, tt.max, got, tt.want) + } + }) + } +} + +func TestWrapText(t *testing.T) { + t.Run("no_wrap_needed", func(t *testing.T) { + got := wrapText("short", 20) + if got != "short" { + t.Errorf("expected 'short', got %q", got) + } + }) + + t.Run("word_wrap", func(t *testing.T) { + got := wrapText("hello world foo bar", 11) + lines := strings.Split(got, "\n") + for _, line := range lines { + if len(line) > 11 { + t.Errorf("line %q exceeds width 11", line) + } + } + }) + + t.Run("preserves_newlines", func(t *testing.T) { + got := wrapText("line1\nline2", 20) + if !strings.Contains(got, "\n") { + t.Error("should preserve existing newlines") + } + lines := strings.Split(got, "\n") + if len(lines) < 2 { + t.Errorf("expected at least 2 lines, got %d", len(lines)) + } + }) + + t.Run("width_zero_guard", func(t *testing.T) { + got := wrapText("hello", 0) + if got != "hello" { + t.Errorf("width<=0 should return original, got %q", got) + } + }) + + t.Run("width_negative", func(t *testing.T) { + got := wrapText("hello", -1) + if got != "hello" { + t.Errorf("negative width should return original, got %q", got) + } + }) +} + +func TestIndentBlock(t *testing.T) { + t.Run("prefix_added", func(t *testing.T) { + got := indentBlock("hello\nworld", " ") + lines := strings.Split(got, "\n") + if lines[0] != " hello" { + t.Errorf("first line should be ' hello', got %q", lines[0]) + } + if lines[1] != " world" { + t.Errorf("second line should be ' world', got %q", lines[1]) + } + }) + + t.Run("empty_lines_preserved", func(t *testing.T) { + got := indentBlock("hello\n\nworld", ">> ") + lines := strings.Split(got, "\n") + if len(lines) != 3 { + t.Fatalf("expected 3 lines, got %d", len(lines)) + } + if lines[0] != ">> hello" { + t.Errorf("first line should be '>> hello', got %q", lines[0]) + } + if lines[1] != "" { + t.Errorf("empty line should stay empty, got %q", lines[1]) + } + if lines[2] != ">> world" { + t.Errorf("third line should be '>> world', got %q", lines[2]) + } + }) + + t.Run("single_line", func(t *testing.T) { + got := indentBlock("hello", "* ") + if got != "* hello" { + t.Errorf("expected '* hello', got %q", got) + } + }) +} + +func TestContextPctInHeader(t *testing.T) { + t.Run("no_pct_when_zero_tokens", func(t *testing.T) { + m := newTestModel(t) + m.model = "test-model" + m.promptTokens = 0 + m.numCtx = 8192 + header := m.renderHeader() + if strings.Contains(header, "%") { + t.Error("should not show percentage when promptTokens is 0") + } + }) + + t.Run("no_pct_when_zero_numCtx", func(t *testing.T) { + m := newTestModel(t) + m.model = "test-model" + m.promptTokens = 1000 + m.numCtx = 0 + header := m.renderHeader() + if strings.Contains(header, "%") { + t.Error("should not show percentage when numCtx is 0") + } + }) + + t.Run("correct_percentage", func(t *testing.T) { + m := newTestModel(t) + m.model = "test-model" + m.promptTokens = 4096 + m.numCtx = 8192 + header := m.renderHeader() + if !strings.Contains(header, "50%") { + t.Errorf("expected header to contain '50%%', got %q", header) + } + }) + + t.Run("high_percentage", func(t *testing.T) { + m := newTestModel(t) + m.model = "test-model" + m.promptTokens = 7500 + m.numCtx = 8192 + header := m.renderHeader() + if !strings.Contains(header, "91%") { + t.Errorf("expected header to contain '91%%', got %q", header) + } + }) +} diff --git a/internal/tui/view_width_test.go b/internal/tui/view_width_test.go new file mode 100644 index 0000000..c241de1 --- /dev/null +++ b/internal/tui/view_width_test.go @@ -0,0 +1,124 @@ +package tui + +import ( + "strings" + "testing" +) + +// TestWrapTextWideChars tests wrapping with Unicode wide characters +func TestWrapTextWideChars(t *testing.T) { + // Test with content that would exceed width if not wrapped properly + longURL := "https://example.com/very/long/path/that/should/be/wrapped/but/wont/be/with/standard/wrapping" + + got := wrapText(longURL, 40) + lines := strings.Split(got, "\n") + + for _, line := range lines { + if len([]rune(line)) > 40 { + t.Errorf("line %q exceeds width 40 (runes: %d)", line, len([]rune(line))) + } + } +} + +// TestIndentBlockWideChars tests indenting with wide characters +func TestIndentBlockWideChars(t *testing.T) { + longLine := "https://example.com/very/long/path/that/should/be/wrapped" + got := indentBlock(longLine, " ") + + // The issue: indentBlock doesn't wrap, so this will exceed any reasonable width + lines := strings.Split(got, "\n") + for _, line := range lines { + if len([]rune(line)) > 100 { + t.Logf("WARNING: line exceeds expected width: %d runes", len([]rune(line))) + } + } +} + +// TestWrapLineWrapping tests that wrapLine properly wraps content +func TestWrapLineWrapping(t *testing.T) { + t.Run("long_url", func(t *testing.T) { + longURL := "https://github.com/very/long/path/that/definitely/needs/to/be/wrapped/properly" + got := wrapLine(longURL, 40) + lines := strings.Split(got, "\n") + for _, line := range lines { + if len(line) > 40 { + t.Errorf("wrapLine failed: line %q has %d chars, exceeds 40", line, len(line)) + } + } + }) + + t.Run("long_identifier", func(t *testing.T) { + code := "this_is_a_very_long_identifier_without_any_spaces_that_needs_to_be_wrapped" + got := wrapLine(code, 30) + lines := strings.Split(got, "\n") + for _, line := range lines { + if len(line) > 30 { + t.Errorf("wrapLine failed: line %q has %d chars, exceeds 30", line, len(line)) + } + } + }) + + t.Run("normal_text", func(t *testing.T) { + text := "hello world foo bar baz" + got := wrapLine(text, 12) + lines := strings.Split(got, "\n") + for _, line := range lines { + if len(line) > 12 { + t.Errorf("wrapLine failed: line %q has %d chars, exceeds 12", line, len(line)) + } + } + }) +} + +// TestWrapTextWithCodeBlocks tests wrapping of content that looks like code +func TestWrapTextWithCodeBlocks(t *testing.T) { + // Code blocks with no spaces should still be wrapped + codeLine := "this_is_a_very_long_identifier_without_any_spaces_that_needs_to_be_wrapped" + + got := wrapText(codeLine, 30) + lines := strings.Split(got, "\n") + + for _, line := range lines { + if len(line) > 30 { + t.Errorf("code-like content not wrapped: line %q has %d chars, exceeds 30", line, len(line)) + } + } +} + +// TestRenderAssistantMsgWidth tests that assistant messages respect content width +// Note: This tests the raw markdown renderer - the Glamour library itself has issues +// with wrapping long words, but this is a known limitation of the library. +func TestRenderAssistantMsgWidth(t *testing.T) { + // Create a minimal model for testing + m := &Model{ + width: 80, + isDark: true, + } + + // Create markdown renderer + m.md = NewMarkdownRenderer(m.width-2, m.isDark) + + longURL := "Check out this URL https://github.com/very/long/path/that/definitely/needs/to/be/wrapped/properly" + + rendered := m.md.RenderFull(longURL) + lines := strings.Split(rendered, "\n") + + // Note: Glamour itself doesn't wrap long words well - this is a known limitation + // The fix for streaming messages handles this case, but the markdown renderer + // relies on Glamour's built-in word wrapping which has this bug. + // We document this as an expected limitation. + t.Logf("Glamour rendered %d lines, max line length: %d", len(lines), maxLineLen(lines)) + + // This test documents the Glamour limitation - we don't fail on this + // because it's a third-party library issue, not our code +} + +func maxLineLen(lines []string) int { + max := 0 + for _, line := range lines { + if len(line) > max { + max = len(line) + } + } + return max +} diff --git a/internal/tui/welcome.go b/internal/tui/welcome.go new file mode 100644 index 0000000..934954c --- /dev/null +++ b/internal/tui/welcome.go @@ -0,0 +1,423 @@ +package tui + +import ( + "fmt" + "math" + "strings" + "time" + + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/harmonica" +) + +// Welcome animation phases +const ( + WelcomePhaseLogo = iota + WelcomePhaseTagline + WelcomePhaseFeatures + WelcomePhaseReady +) + +// WelcomeTickMsg triggers animation frame +type WelcomeTickMsg struct{} + +// WelcomeModel holds the state for the welcome animation +type WelcomeModel struct { + phase int + logoAlpha float64 + logoVel float64 + taglineAlpha float64 + taglineVel float64 + featureIndex int + featureAlpha float64 + featureVel float64 + spring harmonica.Spring + spinner spinner.Model + isDark bool + ready bool + frame int +} + +// taglines for rotation +var taglines = []string{ + `ASK → PLAN → BUILD`, + `0.8B 4B 9B`, + `Small models · Big results`, +} + +// featureList shows key features +var featureList = []struct { + icon string + label string + desc string +}{ + {"◈", "Model Routing", "Auto-selects 0.8B → 9B based on task"}, + {"◈", "MCP Native", "Connect any tool via Model Context Protocol"}, + {"◈", "ICE Engine", "Cross-session memory & context"}, + {"◈", "Auto-Memory", "Extracts facts, decisions, TODOs"}, + {"◈", "Thinking/CoT", "Chain-of-thought reasoning display"}, + {"◈", "Skills System", "Domain-specific knowledge injection"}, +} + +// NewWelcomeModel creates a new welcome animation model +func NewWelcomeModel(isDark bool) WelcomeModel { + s := spinner.New( + spinner.WithSpinner(spinner.MiniDot), + spinner.WithStyle(lipgloss.NewStyle().Foreground(lipgloss.Color("#88c0d0"))), + ) + + return WelcomeModel{ + spring: harmonica.NewSpring(harmonica.FPS(60), 6.0, 0.8), + spinner: s, + isDark: isDark, + } +} + +// Init starts the welcome animation +func (m WelcomeModel) Init() tea.Cmd { + return tea.Batch( + tea.Tick(16*time.Millisecond, func(time.Time) tea.Msg { + return WelcomeTickMsg{} + }), + m.spinner.Tick, + ) +} + +// Update processes animation frames +func (m WelcomeModel) Update(msg tea.Msg) (WelcomeModel, tea.Cmd) { + switch msg := msg.(type) { + case WelcomeTickMsg: + m.frame++ + + // Phase transitions - slowed down for better viewing + if m.phase == WelcomePhaseLogo && m.logoAlpha >= 0.95 && m.frame > 60 { + // Wait at least 60 frames (~1 second) on logo + m.phase = WelcomePhaseTagline + } + if m.phase == WelcomePhaseTagline && m.taglineAlpha >= 0.95 && m.frame > 180 { + // Wait at least 180 frames (~3 seconds) on taglines + m.phase = WelcomePhaseFeatures + m.featureIndex = 0 + } + if m.phase == WelcomePhaseFeatures && m.featureIndex >= len(featureList) { + m.phase = WelcomePhaseReady + m.ready = true + } + + // Animate based on phase + switch m.phase { + case WelcomePhaseLogo: + target := 1.0 + m.logoAlpha, m.logoVel = m.spring.Update(m.logoAlpha, m.logoVel, target) + + case WelcomePhaseTagline: + target := 1.0 + m.taglineAlpha, m.taglineVel = m.spring.Update(m.taglineAlpha, m.taglineVel, target) + + case WelcomePhaseFeatures: + // Animate current feature in + target := 1.0 + m.featureAlpha, m.featureVel = m.spring.Update(m.featureAlpha, m.featureVel, target) + + // Move to next feature after delay (slower - every 90 frames) + if m.featureAlpha >= 0.95 && m.frame%90 == 0 { + m.featureIndex++ + if m.featureIndex < len(featureList) { + m.featureAlpha = 0 + m.featureVel = 0 + } + } + } + + // Update spinner + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + + return m, tea.Batch( + tea.Tick(16*time.Millisecond, func(time.Time) tea.Msg { + return WelcomeTickMsg{} + }), + cmd, + ) + + case spinner.TickMsg: + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + return m, cmd + } + + return m, nil +} + +// View renders the welcome animation +func (m WelcomeModel) View() string { + var b strings.Builder + + // Render logo with fade-in + if m.phase >= WelcomePhaseLogo { + logo := logoLines() + for i, line := range logo { + if m.phase == WelcomePhaseLogo && i < len(logo)-2 { + // Apply gradient fade-in effect during logo phase + alpha := m.logoAlpha + if alpha < 0.1 { + alpha = 0.1 + } + b.WriteString(m.applyFade(line, alpha)) + } else { + b.WriteString(m.renderLogoLine(line)) + } + b.WriteString("\n") + } + } + + // Render animated tagline + if m.phase >= WelcomePhaseTagline { + b.WriteString("\n") + taglineIdx := (m.frame / 120) % len(taglines) + tagline := taglines[taglineIdx] + b.WriteString(m.renderTagline(tagline)) + b.WriteString("\n") + } + + // Render feature list with animation + if m.phase >= WelcomePhaseFeatures { + b.WriteString("\n") + featureTitle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#4c566a")). + Bold(true). + Render(" Features") + b.WriteString(featureTitle) + b.WriteString("\n\n") + + // Show all features, highlight current one + for i, feat := range featureList { + if i <= m.featureIndex { + line := fmt.Sprintf(" %s %s — %s", feat.icon, feat.label, feat.desc) + + if i == m.featureIndex && m.featureAlpha < 0.95 { + // Currently animating in + alpha := m.featureAlpha + if alpha < 0.2 { + alpha = 0.2 + } + b.WriteString(m.applyFade(line, alpha)) + } else if i == m.featureIndex { + // Current feature with accent + indicator := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#88c0d0")). + Render("▸ ") + b.WriteString(indicator + line[2:]) + } else { + // Previous features in dim + dimStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#4c566a")) + b.WriteString(" " + dimStyle.Render(line[2:])) + } + b.WriteString("\n") + } + } + } + + // Render ready state + if m.phase >= WelcomePhaseReady { + b.WriteString("\n") + checkStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#a3be8c")) + readyLine := fmt.Sprintf(" %s Ready to go! Type a message or press ? for help", checkStyle.Render("✓")) + b.WriteString(readyLine) + b.WriteString("\n") + } + + return b.String() +} + +// renderLogoLine applies gradient colors to logo line +func (m WelcomeModel) renderLogoLine(line string) string { + if noColor { + return line + } + + // Apply gradient to box drawing characters + colors := []string{"#88c0d0", "#81a1c1", "#5e81ac", "#b48ead"} + + result := "" + for i, r := range line { + if r == '╭' || r == '─' || r == '╮' || r == '│' || r == '╰' || r == '╯' { + colorIdx := i % len(colors) + style := lipgloss.NewStyle().Foreground(lipgloss.Color(colors[colorIdx])) + result += style.Render(string(r)) + } else if r == '╔' || r == '╗' || r == '║' || r == '═' || r == '╚' || r == '╝' { + colorIdx := (i + 1) % len(colors) + style := lipgloss.NewStyle().Foreground(lipgloss.Color(colors[colorIdx])) + result += style.Render(string(r)) + } else { + result += string(r) + } + } + + return result +} + +// renderTagline renders the tagline with animation +func (m WelcomeModel) renderTagline(tagline string) string { + if noColor { + return " " + tagline + } + + // Split tagline into parts and apply gradient + parts := strings.Split(tagline, " ") + result := " " + + for i, part := range parts { + colorIdx := i % 3 + var color string + switch colorIdx { + case 0: + color = "#88c0d0" + case 1: + color = "#81a1c1" + case 2: + color = "#b48ead" + } + + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(color)). + Bold(true) + result += style.Render(part) + " " + } + + return result +} + +// applyFade applies alpha blending to simulate fade +func (m WelcomeModel) applyFade(line string, alpha float64) string { + if noColor { + return line + } + + // Use dimmer color based on alpha + baseColor := "#4c566a" // dim + if alpha > 0.7 { + baseColor = "#88c0d0" // bright + } else if alpha > 0.4 { + baseColor = "#5e81ac" // medium + } + + style := lipgloss.NewStyle().Foreground(lipgloss.Color(baseColor)) + return style.Render(line) +} + +// IsReady returns true when welcome animation is complete +func (m WelcomeModel) IsReady() bool { + return m.ready +} + +// pulseEffect creates a subtle pulse animation for status indicators +type PulseModel struct { + alpha float64 + vel float64 + spring harmonica.Spring + target float64 +} + +func NewPulseModel() PulseModel { + return PulseModel{ + spring: harmonica.NewSpring(harmonica.FPS(60), 3.0, 0.6), + target: 1.0, + } +} + +type PulseTickMsg struct{} + +func (m PulseModel) Init() tea.Cmd { + return tea.Tick(50*time.Millisecond, func(time.Time) tea.Msg { + return PulseTickMsg{} + }) +} + +func (m PulseModel) Update(msg tea.Msg) (PulseModel, tea.Cmd) { + if _, ok := msg.(PulseTickMsg); ok { + // Oscillate between 0.7 and 1.0 + if m.target == 1.0 && m.alpha >= 0.95 { + m.target = 0.7 + } else if m.target == 0.7 && m.alpha <= 0.75 { + m.target = 1.0 + } + + m.alpha, m.vel = m.spring.Update(m.alpha, m.vel, m.target) + return m, tea.Tick(50*time.Millisecond, func(time.Time) tea.Msg { + return PulseTickMsg{} + }) + } + return m, nil +} + +func (m PulseModel) Alpha() float64 { + return m.alpha +} + +// gradientText applies a horizontal gradient to text +func gradientText(text string, colors []string) string { + if noColor || len(colors) == 0 { + return text + } + + result := "" + runes := []rune(text) + colorCount := len(colors) + + for i, r := range runes { + colorIdx := int(float64(i) / float64(len(runes)) * float64(colorCount)) + if colorIdx >= colorCount { + colorIdx = colorCount - 1 + } + style := lipgloss.NewStyle().Foreground(lipgloss.Color(colors[colorIdx])) + result += style.Render(string(r)) + } + + return result +} + +// slideInEffect creates a slide-in animation from left +type SlideInModel struct { + offset float64 + vel float64 + spring harmonica.Spring + target float64 +} + +func NewSlideInModel() SlideInModel { + return SlideInModel{ + spring: harmonica.NewSpring(harmonica.FPS(60), 5.0, 0.7), + target: 0, + offset: -50, // Start off-screen left + } +} + +type SlideInTickMsg struct{} + +func (m SlideInModel) Init() tea.Cmd { + return tea.Tick(16*time.Millisecond, func(time.Time) tea.Msg { + return SlideInTickMsg{} + }) +} + +func (m SlideInModel) Update(msg tea.Msg) (SlideInModel, tea.Cmd) { + if _, ok := msg.(SlideInTickMsg); ok { + m.offset, m.vel = m.spring.Update(m.offset, m.vel, m.target) + return m, tea.Tick(16*time.Millisecond, func(time.Time) tea.Msg { + return SlideInTickMsg{} + }) + } + return m, nil +} + +func (m SlideInModel) Offset() int { + return int(math.Max(0, m.offset)) +} + +func (m SlideInModel) IsComplete() bool { + return m.offset <= 0.5 +} diff --git a/internal/tui/width_test.go b/internal/tui/width_test.go new file mode 100644 index 0000000..d564a97 --- /dev/null +++ b/internal/tui/width_test.go @@ -0,0 +1,369 @@ +package tui + +import ( + "strings" + "testing" +) + +// TestViewportWidthCalculation tests that viewport width calculations are consistent +// across all components to prevent horizontal scrolling. +func TestViewportWidthCalculation(t *testing.T) { + tests := []struct { + name string + screenWidth int + panelVisible bool + panelWidth int + wantViewWidth int + wantContentW int + wantMarkdownW int + }{ + { + name: "small screen with panel", + screenWidth: 80, + panelVisible: true, + panelWidth: 25, + wantViewWidth: 80 - 25 - 2, // screen - panel - separator + wantContentW: 80 - 25 - 5, // screen - panel - separator - padding + wantMarkdownW: 80 - 25 - 5, + }, + { + name: "medium screen with panel", + screenWidth: 120, + panelVisible: true, + panelWidth: 30, + wantViewWidth: 120 - 30 - 2, + wantContentW: 120 - 30 - 5, + wantMarkdownW: 120 - 30 - 5, + }, + { + name: "large screen with panel", + screenWidth: 160, + panelVisible: true, + panelWidth: 40, + wantViewWidth: 160 - 40 - 2, + wantContentW: 160 - 40 - 5, + wantMarkdownW: 160 - 40 - 5, + }, + { + name: "small screen without panel", + screenWidth: 80, + panelVisible: false, + panelWidth: 0, + wantViewWidth: 80 - 1, // just separator + wantContentW: 80 - 1 - 3, // viewport - padding for markdown + wantMarkdownW: 80 - 1 - 3, + }, + { + name: "large screen without panel", + screenWidth: 200, + panelVisible: false, + panelWidth: 0, + wantViewWidth: 200 - 1, + wantContentW: 200 - 1 - 3, + wantMarkdownW: 200 - 1 - 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Simulate the width calculation from model.go:WindowSizeMsg handler + panelWidth := tt.panelWidth + if panelWidth == 0 && tt.panelVisible { + // Calculate panel width based on screen width (from model.go:365-371) + panelWidth = 30 + if tt.screenWidth < 100 { + panelWidth = 25 + } else if tt.screenWidth > 160 { + panelWidth = 40 + } + } + + // Viewport width (from model.go:373-380) + viewportWidth := tt.screenWidth - 1 + if tt.panelVisible { + viewportWidth = tt.screenWidth - panelWidth - 2 + } + if viewportWidth < 20 { + viewportWidth = 20 + } + + // Content/markdown width (from model.go:382-386) + markdownWidth := viewportWidth - 3 + if markdownWidth < 20 { + markdownWidth = 20 + } + + // Content width for rendering (from view.go:422-429) + contentW := tt.screenWidth - 4 + if tt.panelVisible { + contentW = tt.screenWidth - panelWidth - 5 + } + if contentW < 20 { + contentW = 20 + } + + // Verify consistency + if markdownWidth != tt.wantMarkdownW { + t.Errorf("markdown width = %d, want %d", markdownWidth, tt.wantMarkdownW) + } + + if viewportWidth != tt.wantViewWidth { + t.Errorf("viewport width = %d, want %d", viewportWidth, tt.wantViewWidth) + } + + if contentW != tt.wantContentW { + t.Errorf("content width = %d, want %d", contentW, tt.wantContentW) + } + + // CRITICAL: viewport width should never exceed screen width minus panel + maxAllowedWidth := tt.screenWidth - 1 + if tt.panelVisible { + maxAllowedWidth = tt.screenWidth - panelWidth - 1 + } + if viewportWidth > maxAllowedWidth { + t.Errorf("viewport width %d exceeds max allowed %d - will cause horizontal scroll", + viewportWidth, maxAllowedWidth) + } + + // Content width should be <= viewport width + if contentW > viewportWidth { + t.Errorf("content width %d > viewport width %d - will cause horizontal scroll", + contentW, viewportWidth) + } + + // Markdown width should be <= viewport width + if markdownWidth > viewportWidth { + t.Errorf("markdown width %d > viewport width %d - will cause horizontal scroll", + markdownWidth, viewportWidth) + } + }) + } +} + +// TestResponsiveWidthToggle tests that toggling the side panel maintains proper widths +func TestResponsiveWidthToggle(t *testing.T) { + screenWidth := 120 + panelWidth := 30 + + // Panel visible + viewportWithPanel := screenWidth - panelWidth - 2 + contentWithPanel := screenWidth - panelWidth - 5 + + // Panel hidden + viewportWithoutPanel := screenWidth - 1 + contentWithoutPanel := screenWidth - 4 + + // Widths should increase when panel is hidden + if viewportWithoutPanel <= viewportWithPanel { + t.Errorf("viewport should be wider when panel is hidden: %d <= %d", + viewportWithoutPanel, viewportWithPanel) + } + + if contentWithoutPanel <= contentWithPanel { + t.Errorf("content should be wider when panel is hidden: %d <= %d", + contentWithoutPanel, contentWithPanel) + } + + // Neither should exceed screen width + if viewportWithoutPanel > screenWidth { + t.Errorf("viewport without panel %d exceeds screen width %d", + viewportWithoutPanel, screenWidth) + } + + if viewportWithPanel > screenWidth-panelWidth { + t.Errorf("viewport with panel %d exceeds available space %d", + viewportWithPanel, screenWidth-panelWidth) + } +} + +// TestMinimumWidthConstraints tests that minimum width constraints prevent negative layouts +func TestMinimumWidthConstraints(t *testing.T) { + tests := []struct { + name string + screenWidth int + }{ + {"tiny screen", 40}, + {"very small screen", 60}, + {"small screen", 80}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + panelWidth := 25 // minimum panel width + minWidth := 20 // minimum content width + + // Calculate viewport width with panel + viewportWidth := tt.screenWidth - panelWidth - 2 + if viewportWidth < minWidth { + viewportWidth = minWidth + } + + // Calculate content width with panel + contentWidth := tt.screenWidth - panelWidth - 5 + if contentWidth < minWidth { + contentWidth = minWidth + } + + // Verify minimums are respected + if viewportWidth < minWidth { + t.Errorf("viewport width %d below minimum %d", viewportWidth, minWidth) + } + + if contentWidth < minWidth { + t.Errorf("content width %d below minimum %d", contentWidth, minWidth) + } + + // Even with minimum constraints, total shouldn't exceed screen + totalWidth := panelWidth + 1 + viewportWidth + if totalWidth > tt.screenWidth && tt.screenWidth >= minWidth+panelWidth+1 { + t.Errorf("total width %d exceeds screen width %d", totalWidth, tt.screenWidth) + } + }) + } +} + +// TestRenderedTextWidth simulates actual rendered text to ensure it fits within viewport +func TestRenderedTextWidth(t *testing.T) { + tests := []struct { + name string + screenWidth int + panelWidth int + text string + }{ + { + name: "long line with panel", + screenWidth: 120, + panelWidth: 30, + text: strings.Repeat("x", 100), + }, + { + name: "long line without panel", + screenWidth: 120, + panelWidth: 0, + text: strings.Repeat("x", 120), + }, + { + name: "short line", + screenWidth: 80, + panelWidth: 25, + text: "short text", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Calculate available content width + availableWidth := tt.screenWidth - 4 + if tt.panelWidth > 0 { + availableWidth = tt.screenWidth - tt.panelWidth - 5 + } + if availableWidth < 20 { + availableWidth = 20 + } + + // Simulate wrapText behavior + wrapped := wrapText(tt.text, availableWidth) + + // Check each line fits + lines := strings.Split(wrapped, "\n") + for i, line := range lines { + if len(line) > availableWidth { + t.Errorf("line %d length %d exceeds available width %d", + i, len(line), availableWidth) + } + } + }) + } +} + +// TestLayoutConsistency verifies that all width calculations are consistent +func TestLayoutConsistency(t *testing.T) { + // Test various screen sizes + for screenWidth := 40; screenWidth <= 200; screenWidth += 10 { + t.Run("screen_width", func(t *testing.T) { + // Determine panel width (from model.go logic) + panelWidth := 30 + if screenWidth < 100 { + panelWidth = 25 + } else if screenWidth > 160 { + panelWidth = 40 + } + + // Test with panel visible + t.Run("with_panel", func(t *testing.T) { + // Viewport width calculation (from model.go:373-380) + viewportWidth := screenWidth - panelWidth - 2 + if viewportWidth < 20 { + viewportWidth = 20 + } + + // Content width calculation (from model.go:382-386) + markdownWidth := viewportWidth - 3 + if markdownWidth < 20 { + markdownWidth = 20 + } + + contentWidth := screenWidth - panelWidth - 5 + if contentWidth < 20 { + contentWidth = 20 + } + + // All widths should be consistent + if contentWidth > viewportWidth { + t.Errorf("content %d > viewport %d", contentWidth, viewportWidth) + } + + if markdownWidth > viewportWidth { + t.Errorf("markdown %d > viewport %d", markdownWidth, viewportWidth) + } + + // For very small screens, the layout might exceed screen width + // This is expected - the minimum viewport width takes precedence + minRequiredWidth := panelWidth + 1 + 20 // panel + separator + min viewport + if screenWidth >= minRequiredWidth { + // Only check total width if screen is large enough + totalWidth := panelWidth + 1 + viewportWidth + if totalWidth > screenWidth { + t.Errorf("total layout %d exceeds screen %d", totalWidth, screenWidth) + } + } + }) + + // Test without panel + t.Run("without_panel", func(t *testing.T) { + // Viewport width calculation + viewportWidth := screenWidth - 1 + if viewportWidth < 20 { + viewportWidth = 20 + } + + // Content width calculation + markdownWidth := viewportWidth - 3 + if markdownWidth < 20 { + markdownWidth = 20 + } + + contentWidth := screenWidth - 4 + if contentWidth < 20 { + contentWidth = 20 + } + + // All widths should be consistent + if contentWidth > viewportWidth { + t.Errorf("content %d > viewport %d", contentWidth, viewportWidth) + } + + if markdownWidth > viewportWidth { + t.Errorf("markdown %d > viewport %d", markdownWidth, viewportWidth) + } + + // For very small screens, minimum width takes precedence + if screenWidth >= 21 { // 1 + min viewport (20) + if viewportWidth > screenWidth { + t.Errorf("viewport %d exceeds screen %d", viewportWidth, screenWidth) + } + } + }) + }) + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..0e4f30e --- /dev/null +++ b/main.go @@ -0,0 +1,372 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + + "ai-agent/internal/agent" + "ai-agent/internal/command" + "ai-agent/internal/config" + "ai-agent/internal/db" + "ai-agent/internal/ice" + "ai-agent/internal/initcmd" + "ai-agent/internal/llm" + "ai-agent/internal/logging" + "ai-agent/internal/mcp" + "ai-agent/internal/memory" + "ai-agent/internal/permission" + "ai-agent/internal/skill" + "ai-agent/internal/tui" + + tea "charm.land/bubbletea/v2" +) + +var version = "mydev" + +func main() { + for _, arg := range os.Args[1:] { + if arg == "--version" || arg == "-version" { + fmt.Println(version) + return + } + } + if len(os.Args) > 1 { + switch os.Args[1] { + case "init": + force := false + for _, arg := range os.Args[2:] { + if arg == "--force" || arg == "-force" { + force = true + } + } + if err := initcmd.Run(".", initcmd.Options{Force: force}); err != nil { + fmt.Fprintf(os.Stderr, "init: %v\n", err) + os.Exit(1) + } + fmt.Println("AGENT.md created successfully.") + return + case "logs": + handleLogs(os.Args[2:]) + return + } + } + qwenRouterFlag := flag.Bool("qwen-router", false, "use optimized Qwen model router (experimental)") + modelFlag := flag.String("model", "", "override Ollama model") + agentProfileFlag := flag.String("agent", "", "override agent profile") + promptFlag := flag.String("p", "", "run in non-interactive mode: send prompt, print response, exit") + yoloFlag := flag.Bool("yolo", false, "auto-approve all tool calls (skip permission prompts)") + flag.Parse() + cfg, agentsDir, err := config.LoadWithAgentsDir() + if err != nil { + log.Fatalf("config: %v", err) + } + if *modelFlag != "" { + cfg.Ollama.Model = *modelFlag + } + if *agentProfileFlag != "" { + cfg.AgentProfile = *agentProfileFlag + } + var router *config.Router + if *qwenRouterFlag { + fmt.Fprintf(os.Stderr, "Using Qwen-optimized model router (experimental)\n") + } + router = config.NewRouter(&cfg.Model) + modelName := cfg.Ollama.Model + if cfg.AgentProfile != "" && agentsDir != nil { + if profile := agentsDir.GetAgent(cfg.AgentProfile); profile != nil { + if profile.Model != "" { + modelName = profile.Model + } + } + } + modelManager := llm.NewModelManager(cfg.Ollama.BaseURL, cfg.Ollama.NumCtx) + modelManager.SetCurrentModel(modelName) + var servers []config.ServerConfig + if len(cfg.Servers) > 0 { + servers = cfg.Servers + } else if agentsDir != nil && agentsDir.HasMCP() { + servers = agentsDir.GetMCPServers() + } + registry := mcp.NewRegistry() + defer registry.Close() + ag := agent.New(modelManager, registry, cfg.Ollama.NumCtx) + ag.SetToolsConfig(cfg.Tools) + ag.SetRouter(router) + if wd, err := os.Getwd(); err == nil { + ag.SetWorkDir(wd) + } + defer ag.Close() + dbStore, err := db.Open() + if err != nil { + log.Printf("warning: database: %v (permissions disabled)", err) + } + if dbStore != nil { + defer dbStore.Close() + } + permChecker := permission.NewChecker(dbStore, *yoloFlag) + ag.SetPermissionChecker(permChecker) + memStore := memory.NewStore("") + ag.SetMemoryStore(memStore) + skillDirs := []string{cfg.SkillsDir} + if agentsDir != nil && len(agentsDir.Skills) > 0 { + for _, s := range agentsDir.Skills { + if s.Path != "" { + skillDir := filepath.Dir(s.Path) + if skillDir != "" { + skillDirs = append(skillDirs, skillDir) + } + } + } + } + skillMgr := skill.NewManager("") + for _, dir := range skillDirs { + if dir != "" { + skillMgr.AddSearchPath(dir) + } + } + _ = skillMgr.LoadAll() + if *promptFlag != "" { + ctx := context.Background() + fmt.Fprintf(os.Stderr, "connecting to Ollama (%s)...\n", modelName) + if err := modelManager.Ping(); err != nil { + fmt.Fprintf(os.Stderr, "ollama: %v\nhint: is `ollama serve` running? is %q pulled?\n", err, modelName) + os.Exit(1) + } + var wg sync.WaitGroup + for _, srv := range servers { + wg.Add(1) + go func(s config.ServerConfig) { + defer wg.Done() + fmt.Fprintf(os.Stderr, "connecting MCP server %s...\n", s.Name) + if _, err := registry.ConnectServer(ctx, s); err != nil { + fmt.Fprintf(os.Stderr, "MCP server %s failed: %v\n", s.Name, err) + } + }(srv) + } + wg.Wait() + if cfg.ICE.Enabled { + embedModel := cfg.ICE.EmbedModel + if embedModel == "" { + embedModel = cfg.Model.EmbedModel + } + iceEngine, err := ice.NewEngine(modelManager, memStore, ice.EngineConfig{ + EmbedModel: embedModel, + StorePath: cfg.ICE.StorePath, + NumCtx: cfg.Ollama.NumCtx, + }) + if err != nil { + fmt.Fprintf(os.Stderr, "ICE: %v\n", err) + } else { + ag.SetICEEngine(iceEngine) + } + } + if agentsDir != nil && agentsDir.GetGlobalInstructions() != "" { + ag.SetLoadedContext(agentsDir.GetGlobalInstructions()) + } + if data, err := os.ReadFile("AGENT.md"); err == nil { + ag.AppendLoadedContext("\n\n" + string(data)) + } + if cfg.AgentProfile != "" && agentsDir != nil { + if profile := agentsDir.GetAgent(cfg.AgentProfile); profile != nil { + if profile.SystemPrompt != "" { + ag.AppendLoadedContext("\n\n" + profile.SystemPrompt) + } + for _, skillName := range profile.Skills { + skillMgr.Activate(skillName) + } + ag.SetSkillContent(skillMgr.ActiveContent()) + } + } + modes := tui.DefaultModeConfigs() + buildMode := modes[tui.ModeBuild] + ag.SetModeContext(buildMode.SystemPromptPrefix, buildMode.AllowTools) + out := agent.NewHeadlessOutput() + ag.AddUserMessage(*promptFlag) + ag.Run(ctx, out) + return + } + cmdReg := command.NewRegistry() + command.RegisterBuiltins(cmdReg) + if home, err := os.UserHomeDir(); err == nil { + customDir := filepath.Join(home, ".config", "ai-agent", "commands") + command.RegisterCustomCommands(cmdReg, customDir) + } + modelList := []string{modelName} + var agentList []string + if agentsDir != nil { + for _, a := range agentsDir.ListAgents() { + agentList = append(agentList, a.Name) + } + } + completer := tui.NewCompleter(cmdReg, modelList, skillMgr.Names(), agentList, registry) + logger, logFile, err := logging.NewSessionLogger() + if err != nil { + // Non-fatal; logging disabled. + } + if logFile != nil { + defer logFile.Close() + } + m := tui.New(ag, cmdReg, skillMgr, completer, modelManager, router, logger) + p := tea.NewProgram(m) + m.SetProgram(p) + initCtx, initCancel := context.WithCancel(context.Background()) + m.SetInitCancel(initCancel) + initDone := make(chan struct{}) + go func() { + defer close(initDone) + p.Send(tui.StartupStatusMsg{ID: "ollama", Label: "Ollama (" + modelName + ")", Status: "connecting"}) + if err := modelManager.Ping(); err != nil { + p.Send(tui.StartupStatusMsg{ID: "ollama", Label: "Ollama (" + modelName + ")", Status: "failed", Detail: err.Error()}) + p.Send(tui.ErrorMsg{Msg: fmt.Sprintf("ollama: %v\nhint: is `ollama serve` running? is %q pulled?", err, modelName)}) + } else { + p.Send(tui.StartupStatusMsg{ID: "ollama", Label: "Ollama (" + modelName + ")", Status: "connected"}) + } + if initCtx.Err() != nil { + return + } + if list, err := modelManager.ListModels(initCtx); err == nil && len(list) > 0 { + modelList = list + } + if initCtx.Err() != nil { + return + } + var wg sync.WaitGroup + for _, srv := range servers { + wg.Add(1) + go func(s config.ServerConfig) { + defer wg.Done() + p.Send(tui.StartupStatusMsg{ID: "mcp:" + s.Name, Label: s.Name, Status: "connecting"}) + if initCtx.Err() != nil { + return + } + toolCount, err := registry.ConnectServer(initCtx, s) + if err != nil { + p.Send(tui.StartupStatusMsg{ID: "mcp:" + s.Name, Label: s.Name, Status: "failed", Detail: err.Error()}) + } else { + p.Send(tui.StartupStatusMsg{ID: "mcp:" + s.Name, Label: s.Name, Status: "connected", Detail: fmt.Sprintf("%d tools", toolCount)}) + } + }(srv) + } + wg.Wait() + if initCtx.Err() != nil { + return + } + var iceEnabled bool + var iceConversations int + var iceSessionID string + if cfg.ICE.Enabled { + p.Send(tui.StartupStatusMsg{ID: "ice", Label: "ICE", Status: "connecting"}) + embedModel := cfg.ICE.EmbedModel + if embedModel == "" { + embedModel = cfg.Model.EmbedModel + } + iceEngine, err := ice.NewEngine(modelManager, memStore, ice.EngineConfig{ + EmbedModel: embedModel, + StorePath: cfg.ICE.StorePath, + NumCtx: cfg.Ollama.NumCtx, + }) + if err != nil { + p.Send(tui.StartupStatusMsg{ID: "ice", Label: "ICE", Status: "failed", Detail: err.Error()}) + } else { + ag.SetICEEngine(iceEngine) + iceEnabled = true + iceConversations = iceEngine.Store().Count() + iceSessionID = iceEngine.SessionID() + p.Send(tui.StartupStatusMsg{ID: "ice", Label: "ICE", Status: "connected", Detail: fmt.Sprintf("%d conversations", iceConversations)}) + } + } + if agentsDir != nil && agentsDir.GetGlobalInstructions() != "" { + ag.SetLoadedContext(agentsDir.GetGlobalInstructions()) + } + if data, err := os.ReadFile("AGENT.md"); err == nil { + ag.AppendLoadedContext("\n\n" + string(data)) + } + if cfg.AgentProfile != "" && agentsDir != nil { + if profile := agentsDir.GetAgent(cfg.AgentProfile); profile != nil { + if profile.SystemPrompt != "" { + ag.AppendLoadedContext("\n\n" + profile.SystemPrompt) + } + for _, skillName := range profile.Skills { + skillMgr.Activate(skillName) + } + ag.SetSkillContent(skillMgr.ActiveContent()) + } + } + var failedServers []tui.FailedServer + for _, fs := range registry.FailedServers() { + failedServers = append(failedServers, tui.FailedServer{ + Name: fs.Name, + Reason: fs.Reason, + }) + } + p.Send(tui.InitCompleteMsg{ + Model: modelName, + ModelList: modelList, + AgentProfile: cfg.AgentProfile, + AgentList: agentList, + ToolCount: ag.ToolCount(), + ServerCount: registry.ServerCount(), + NumCtx: cfg.Ollama.NumCtx, + FailedServers: failedServers, + ICEEnabled: iceEnabled, + ICEConversations: iceConversations, + ICESessionID: iceSessionID, + }) + }() + if _, err := p.Run(); err != nil { + log.Fatalf("tui: %v", err) + } + + initCancel() + <-initDone +} + +func handleLogs(args []string) { + follow := false + for _, arg := range args { + if arg == "-f" { + follow = true + } + } + if follow { + latest, err := logging.LatestLogPath() + if err != nil { + fmt.Fprintf(os.Stderr, "logs: %v\n", err) + os.Exit(1) + } + fmt.Fprintf(os.Stderr, "following %s\n", latest) + tailBin, err := exec.LookPath("tail") + if err != nil { + fmt.Fprintf(os.Stderr, "logs: tail not found: %v\n", err) + os.Exit(1) + } + if err := syscall.Exec(tailBin, []string{"tail", "-f", latest}, os.Environ()); err != nil { + fmt.Fprintf(os.Stderr, "logs: exec tail: %v\n", err) + os.Exit(1) + } + return + } + entries, err := logging.ListLogs(20) + if err != nil { + fmt.Fprintf(os.Stderr, "logs: %v\n", err) + os.Exit(1) + } + if len(entries) == 0 { + fmt.Println("No log files found in", logging.LogDir()) + return + } + fmt.Printf("Recent sessions (%s):\n\n", logging.LogDir()) + for _, e := range entries { + name := filepath.Base(e.Path) + sizeKB := float64(e.Size) / 1024 + fmt.Printf(" %-30s %s %6.1f KB\n", name, e.ModTime.Format("2006-01-02 15:04:05"), sizeKB) + } + fmt.Printf("\nTip: run `ai-agent logs -f` to follow the latest log.\n") +}