319 lines
8.7 KiB
Go
319 lines
8.7 KiB
Go
package config
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type TaskComplexity string
|
|
|
|
const (
|
|
ComplexitySimple TaskComplexity = "simple"
|
|
ComplexityMedium TaskComplexity = "medium"
|
|
ComplexityComplex TaskComplexity = "complex"
|
|
ComplexityAdvanced TaskComplexity = "advanced"
|
|
)
|
|
|
|
var simpleIndicators = []string{
|
|
"what is", "how do i", "explain", "what does",
|
|
"find", "search", "list", "show", "get",
|
|
"print", "echo", "read", "cat", "ls",
|
|
"simple", "quick", "fast",
|
|
}
|
|
|
|
var mediumIndicators = []string{
|
|
"create", "write", "generate", "add", "modify",
|
|
"change", "update", "fix", "refactor",
|
|
"function", "class", "variable", "test",
|
|
"script", "command", "file", "directory",
|
|
}
|
|
|
|
var complexIndicators = []string{
|
|
"debug", "error", "bug", "issue", "problem",
|
|
"refactor", "architecture", "design", "review",
|
|
"multiple", "several", "across", "migrate",
|
|
"optimize", "performance", "security",
|
|
"explain why", "analyze", "compare",
|
|
}
|
|
|
|
var advancedIndicators = []string{
|
|
"build a", "create a", "implement", "develop",
|
|
"full stack", "system", "infrastructure",
|
|
"multi-step", "complex", "comprehensive",
|
|
"security audit", "architecture design",
|
|
}
|
|
|
|
// ModelPinger is an interface for checking if a model is available.
|
|
type ModelPinger interface {
|
|
PingModel(ctx context.Context, model string) error
|
|
}
|
|
|
|
// ModelOverride records when a user explicitly selects a model.
|
|
type ModelOverride struct {
|
|
Query string
|
|
UserModel string
|
|
RouterModel string
|
|
Timestamp time.Time
|
|
}
|
|
|
|
type Router struct {
|
|
config *ModelConfig
|
|
overrideLog []ModelOverride
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewRouter(cfg *ModelConfig) *Router {
|
|
return &Router{
|
|
config: cfg,
|
|
overrideLog: make([]ModelOverride, 0),
|
|
}
|
|
}
|
|
|
|
func (r *Router) ClassifyTaskComplexity(query string) TaskComplexity {
|
|
return ClassifyTask(query)
|
|
}
|
|
|
|
func (r *Router) SelectModel(query string) string {
|
|
complexity := r.ClassifyTaskComplexity(query)
|
|
|
|
// Check learned patterns if we have enough data
|
|
wordComplexity := r.getLearnedPatterns()
|
|
if len(wordComplexity) > 0 {
|
|
words := strings.Fields(strings.ToLower(query))
|
|
|
|
// Count votes from learned patterns
|
|
complexityVotes := make(map[TaskComplexity]int)
|
|
for _, w := range words {
|
|
if len(w) >= 3 { // Skip short words
|
|
if c, ok := wordComplexity[w]; ok {
|
|
complexityVotes[c]++
|
|
}
|
|
}
|
|
}
|
|
|
|
// If strong learned signal (>30% words match a pattern), use it
|
|
if len(words) > 0 {
|
|
matchRatio := float64(complexityVotes[ComplexitySimple]+complexityVotes[ComplexityAdvanced]) / float64(len(words))
|
|
if matchRatio > 0.3 {
|
|
if complexityVotes[ComplexityAdvanced] > complexityVotes[ComplexitySimple] {
|
|
complexity = ComplexityAdvanced
|
|
} else if complexityVotes[ComplexitySimple] > complexityVotes[ComplexityAdvanced] {
|
|
complexity = ComplexitySimple
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return r.config.SelectModelForTask(string(complexity))
|
|
}
|
|
|
|
// RecordOverride logs when a user explicitly selects a model.
|
|
// This helps the router learn from user preferences.
|
|
func (r *Router) RecordOverride(query, userModel string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
routerModel := r.SelectModel(query)
|
|
|
|
r.overrideLog = append(r.overrideLog, ModelOverride{
|
|
Query: query,
|
|
UserModel: userModel,
|
|
RouterModel: routerModel,
|
|
Timestamp: time.Now(),
|
|
})
|
|
|
|
// Keep last 100 overrides
|
|
if len(r.overrideLog) > 100 {
|
|
r.overrideLog = r.overrideLog[len(r.overrideLog)-100:]
|
|
}
|
|
}
|
|
|
|
// getLearnedPatterns analyzes override history to find word->complexity mappings.
|
|
func (r *Router) getLearnedPatterns() map[string]TaskComplexity {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
|
|
if len(r.overrideLog) < 3 {
|
|
return nil // Not enough data
|
|
}
|
|
|
|
wordCounts := make(map[string]map[TaskComplexity]int)
|
|
|
|
for _, o := range r.overrideLog {
|
|
if o.Query == "" || o.UserModel == "" {
|
|
continue
|
|
}
|
|
|
|
// Determine complexity from user-selected model
|
|
var complexity TaskComplexity
|
|
switch {
|
|
case strings.Contains(o.UserModel, "0.8") || strings.Contains(o.UserModel, "2b"):
|
|
complexity = ComplexitySimple
|
|
case strings.Contains(o.UserModel, "4b"):
|
|
complexity = ComplexityMedium
|
|
case strings.Contains(o.UserModel, "9b"):
|
|
complexity = ComplexityAdvanced
|
|
default:
|
|
continue
|
|
}
|
|
|
|
words := strings.Fields(strings.ToLower(o.Query))
|
|
for _, w := range words {
|
|
if len(w) < 3 {
|
|
continue // Skip short words
|
|
}
|
|
if _, ok := wordCounts[w]; !ok {
|
|
wordCounts[w] = make(map[TaskComplexity]int)
|
|
}
|
|
wordCounts[w][complexity]++
|
|
}
|
|
}
|
|
|
|
// For each word, find dominant complexity
|
|
wordComplexity := make(map[string]TaskComplexity)
|
|
for word, counts := range wordCounts {
|
|
var maxCount int
|
|
var dominant TaskComplexity
|
|
for c, cnt := range counts {
|
|
if cnt > maxCount {
|
|
maxCount = cnt
|
|
dominant = c
|
|
}
|
|
}
|
|
// Only use if we have enough samples (at least 2 overrides)
|
|
if maxCount >= 2 {
|
|
wordComplexity[word] = dominant
|
|
}
|
|
}
|
|
|
|
return wordComplexity
|
|
}
|
|
|
|
func (r *Router) GetFallbackChain(currentModel string) []string {
|
|
chain := r.config.FallbackChain
|
|
|
|
for i, model := range chain {
|
|
if model == currentModel {
|
|
return chain[i:]
|
|
}
|
|
}
|
|
|
|
return chain
|
|
}
|
|
|
|
func (r *Router) GetModelForCapability(capability ModelCapability) string {
|
|
for _, m := range r.config.Models {
|
|
if m.Capability == capability {
|
|
return m.Name
|
|
}
|
|
}
|
|
return r.config.DefaultModel
|
|
}
|
|
|
|
// SelectAvailableModel returns the first available model from the fallback chain.
|
|
// It checks each model in order and returns the first one that responds to a ping.
|
|
// If no models are available, returns the default model.
|
|
func (r *Router) SelectAvailableModel(ctx context.Context, pinger ModelPinger) string {
|
|
chain := r.config.FallbackChain
|
|
|
|
for _, model := range chain {
|
|
if err := pinger.PingModel(ctx, model); err == nil {
|
|
return model
|
|
}
|
|
}
|
|
|
|
// Fallback to default if none available
|
|
return r.config.DefaultModel
|
|
}
|
|
|
|
// SelectAvailableModelForTask returns the first available model for the given task complexity.
|
|
// It prioritizes models appropriate for the task, then falls back to larger models if unavailable.
|
|
func (r *Router) SelectAvailableModelForTask(ctx context.Context, pinger ModelPinger, query string) string {
|
|
// First, get the preferred model for this task
|
|
preferred := r.SelectModel(query)
|
|
|
|
// Check if preferred model is available
|
|
if err := pinger.PingModel(ctx, preferred); err == nil {
|
|
return preferred
|
|
}
|
|
|
|
// Try fallback chain
|
|
chain := r.GetFallbackChain(preferred)
|
|
for _, model := range chain {
|
|
if err := pinger.PingModel(ctx, model); err == nil {
|
|
return model
|
|
}
|
|
}
|
|
|
|
// Last resort: default model
|
|
return r.config.DefaultModel
|
|
}
|
|
|
|
func (r *Router) ForceModel(name string) (*Model, error) {
|
|
return r.config.GetModel(name)
|
|
}
|
|
|
|
func (r *Router) ListModels() []Model {
|
|
return r.config.Models
|
|
}
|
|
|
|
func (r *Router) GetDefaultModel() string {
|
|
return r.config.DefaultModel
|
|
}
|
|
|
|
func ClassifyTask(query string) TaskComplexity {
|
|
lowerQuery := strings.ToLower(query)
|
|
wordCount := len(strings.Fields(query))
|
|
|
|
score := 0
|
|
|
|
for _, indicator := range simpleIndicators {
|
|
if strings.Contains(lowerQuery, indicator) {
|
|
score -= 2
|
|
}
|
|
}
|
|
|
|
for _, indicator := range mediumIndicators {
|
|
if strings.Contains(lowerQuery, indicator) {
|
|
score += 1
|
|
}
|
|
}
|
|
|
|
for _, indicator := range complexIndicators {
|
|
if strings.Contains(lowerQuery, indicator) {
|
|
score += 2
|
|
}
|
|
}
|
|
|
|
for _, indicator := range advancedIndicators {
|
|
if strings.Contains(lowerQuery, indicator) {
|
|
score += 3
|
|
}
|
|
}
|
|
|
|
if wordCount > 50 {
|
|
score += 2
|
|
}
|
|
|
|
if strings.Contains(lowerQuery, "why") || strings.Contains(lowerQuery, "reason") {
|
|
score += 1
|
|
}
|
|
|
|
if strings.Contains(lowerQuery, "how") && wordCount > 10 {
|
|
score += 1
|
|
}
|
|
|
|
switch {
|
|
case score <= -2:
|
|
return ComplexitySimple
|
|
case score <= 1:
|
|
return ComplexityMedium
|
|
case score <= 4:
|
|
return ComplexityComplex
|
|
default:
|
|
return ComplexityAdvanced
|
|
}
|
|
}
|