127 lines
3.3 KiB
Go
127 lines
3.3 KiB
Go
package ice
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"ai-agent/internal/memory"
|
|
)
|
|
|
|
type Assembler struct {
|
|
embedder *Embedder
|
|
convStore *Store
|
|
memStore *memory.Store
|
|
budgetCfg BudgetConfig
|
|
sessionID string
|
|
}
|
|
|
|
func (a *Assembler) Assemble(ctx context.Context, query string) (string, error) {
|
|
budget := a.budgetCfg.Calculate(0)
|
|
type convResult struct {
|
|
chunks []ContextChunk
|
|
err error
|
|
}
|
|
type memResult struct {
|
|
chunks []ContextChunk
|
|
}
|
|
convCh := make(chan convResult, 1)
|
|
memCh := make(chan memResult, 1)
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
chunks, err := a.retrieveConversations(ctx, query, budget.Conversation)
|
|
convCh <- convResult{chunks: chunks, err: err}
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
chunks := a.retrieveMemories(query, budget.Memory)
|
|
memCh <- memResult{chunks: chunks}
|
|
}()
|
|
wg.Wait()
|
|
close(convCh)
|
|
close(memCh)
|
|
cr := <-convCh
|
|
mr := <-memCh
|
|
if cr.err != nil {
|
|
return "", fmt.Errorf("conversation retrieval: %w", cr.err)
|
|
}
|
|
return formatContext(cr.chunks, mr.chunks), nil
|
|
}
|
|
|
|
func (a *Assembler) retrieveConversations(ctx context.Context, query string, tokenBudget int) ([]ContextChunk, error) {
|
|
if tokenBudget <= 0 {
|
|
return nil, nil
|
|
}
|
|
queryEmb, err := a.embedder.Embed(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
results := a.convStore.Search(queryEmb, a.sessionID, 20)
|
|
var chunks []ContextChunk
|
|
usedTokens := 0
|
|
for _, r := range results {
|
|
tokens := estimateTokens(r.Entry.Content)
|
|
if usedTokens+tokens > tokenBudget {
|
|
continue
|
|
}
|
|
chunks = append(chunks, ContextChunk{
|
|
Source: SourceConversation,
|
|
Content: r.Entry.Content,
|
|
Score: r.Score,
|
|
Tokens: tokens,
|
|
})
|
|
usedTokens += tokens
|
|
}
|
|
return chunks, nil
|
|
}
|
|
|
|
func (a *Assembler) retrieveMemories(query string, tokenBudget int) []ContextChunk {
|
|
if a.memStore == nil || tokenBudget <= 0 {
|
|
return nil
|
|
}
|
|
memories := a.memStore.Recall(query, 10)
|
|
var chunks []ContextChunk
|
|
usedTokens := 0
|
|
for _, m := range memories {
|
|
tokens := estimateTokens(m.Content)
|
|
if usedTokens+tokens > tokenBudget {
|
|
continue
|
|
}
|
|
content := m.Content
|
|
if len(m.Tags) > 0 {
|
|
content += " [" + strings.Join(m.Tags, ", ") + "]"
|
|
}
|
|
chunks = append(chunks, ContextChunk{
|
|
Source: SourceMemory,
|
|
Content: content,
|
|
Tokens: tokens,
|
|
})
|
|
usedTokens += tokens
|
|
}
|
|
return chunks
|
|
}
|
|
|
|
func formatContext(convChunks, memChunks []ContextChunk) string {
|
|
var sb strings.Builder
|
|
if len(convChunks) > 0 {
|
|
sb.WriteString("\n## Relevant Past Conversations\n\n")
|
|
for _, c := range convChunks {
|
|
sb.WriteString("- ")
|
|
sb.WriteString(c.Content)
|
|
sb.WriteString("\n")
|
|
}
|
|
}
|
|
if len(memChunks) > 0 {
|
|
sb.WriteString("\n## Remembered Facts\n\n")
|
|
for _, c := range memChunks {
|
|
sb.WriteString("- ")
|
|
sb.WriteString(c.Content)
|
|
sb.WriteString("\n")
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|