260 lines
6.1 KiB
Go
260 lines
6.1 KiB
Go
package memory
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type Memory struct {
|
|
ID int `json:"id"`
|
|
Content string `json:"content"`
|
|
Tags []string `json:"tags,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
LastUsed time.Time `json:"last_used"`
|
|
}
|
|
|
|
type Store struct {
|
|
mu sync.Mutex
|
|
path string
|
|
memories []Memory
|
|
nextID int
|
|
}
|
|
|
|
func NewStore(path string) *Store {
|
|
if path == "" {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
home = "."
|
|
}
|
|
path = filepath.Join(home, ".config", "ai-agent", "memories.json")
|
|
}
|
|
s := &Store{path: path}
|
|
s.load()
|
|
return s
|
|
}
|
|
|
|
func (s *Store) Save(content string, tags []string) (int, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.nextID++
|
|
mem := Memory{
|
|
ID: s.nextID,
|
|
Content: content,
|
|
Tags: tags,
|
|
CreatedAt: time.Now(),
|
|
LastUsed: time.Now(),
|
|
}
|
|
s.memories = append(s.memories, mem)
|
|
if err := s.persist(); err != nil {
|
|
return 0, err
|
|
}
|
|
return mem.ID, nil
|
|
}
|
|
|
|
func (s *Store) Recall(query string, maxResults int) []Memory {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if maxResults <= 0 {
|
|
maxResults = 5
|
|
}
|
|
queryLower := strings.ToLower(query)
|
|
words := strings.Fields(queryLower)
|
|
type scored struct {
|
|
mem Memory
|
|
score int
|
|
}
|
|
var results []scored
|
|
for i := range s.memories {
|
|
mem := s.memories[i]
|
|
score := 0
|
|
contentLower := strings.ToLower(mem.Content)
|
|
for _, w := range words {
|
|
if strings.Contains(contentLower, w) {
|
|
score += 2
|
|
}
|
|
}
|
|
for _, tag := range mem.Tags {
|
|
tagLower := strings.ToLower(tag)
|
|
for _, w := range words {
|
|
if strings.Contains(tagLower, w) {
|
|
score += 3
|
|
}
|
|
}
|
|
}
|
|
if score > 0 {
|
|
results = append(results, scored{mem: mem, score: score})
|
|
}
|
|
}
|
|
sort.Slice(results, func(i, j int) bool {
|
|
if results[i].score != results[j].score {
|
|
return results[i].score > results[j].score
|
|
}
|
|
return results[i].mem.LastUsed.After(results[j].mem.LastUsed)
|
|
})
|
|
if len(results) > maxResults {
|
|
results = results[:maxResults]
|
|
}
|
|
now := time.Now()
|
|
out := make([]Memory, len(results))
|
|
for i, r := range results {
|
|
out[i] = r.mem
|
|
for j := range s.memories {
|
|
if s.memories[j].ID == r.mem.ID {
|
|
s.memories[j].LastUsed = now
|
|
break
|
|
}
|
|
}
|
|
}
|
|
_ = s.persist()
|
|
return out
|
|
}
|
|
|
|
func (s *Store) Recent(n int) []Memory {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if len(s.memories) == 0 {
|
|
return nil
|
|
}
|
|
sorted := make([]Memory, len(s.memories))
|
|
copy(sorted, s.memories)
|
|
sort.Slice(sorted, func(i, j int) bool {
|
|
return sorted[i].LastUsed.After(sorted[j].LastUsed)
|
|
})
|
|
if n > len(sorted) {
|
|
n = len(sorted)
|
|
}
|
|
return sorted[:n]
|
|
}
|
|
|
|
func (s *Store) Count() int {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
return len(s.memories)
|
|
}
|
|
|
|
func (s *Store) Delete(id int) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
for i, mem := range s.memories {
|
|
if mem.ID == id {
|
|
s.memories = append(s.memories[:i], s.memories[i+1:]...)
|
|
return true, s.persist()
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (s *Store) DeleteByTag(tag string) (int, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
tagLower := strings.ToLower(tag)
|
|
var remaining []Memory
|
|
deleted := 0
|
|
for _, mem := range s.memories {
|
|
found := false
|
|
for _, t := range mem.Tags {
|
|
if strings.ToLower(t) == tagLower {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if found {
|
|
deleted++
|
|
} else {
|
|
remaining = append(remaining, mem)
|
|
}
|
|
}
|
|
s.memories = remaining
|
|
if deleted > 0 {
|
|
return deleted, s.persist()
|
|
}
|
|
return deleted, nil
|
|
}
|
|
|
|
func (s *Store) Update(id int, content string, tags []string) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
for i, mem := range s.memories {
|
|
if mem.ID == id {
|
|
if content != "" {
|
|
s.memories[i].Content = content
|
|
}
|
|
if tags != nil {
|
|
s.memories[i].Tags = tags
|
|
}
|
|
s.memories[i].LastUsed = time.Now()
|
|
return true, s.persist()
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (s *Store) Prune(olderThan time.Duration) (int, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
cutoff := time.Now().Add(-olderThan)
|
|
var remaining []Memory
|
|
deleted := 0
|
|
for _, mem := range s.memories {
|
|
if mem.CreatedAt.Before(cutoff) {
|
|
deleted++
|
|
} else {
|
|
remaining = append(remaining, mem)
|
|
}
|
|
}
|
|
s.memories = remaining
|
|
if deleted > 0 {
|
|
return deleted, s.persist()
|
|
}
|
|
return deleted, nil
|
|
}
|
|
|
|
func (s *Store) Get(id int) (Memory, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
for _, mem := range s.memories {
|
|
if mem.ID == id {
|
|
return mem, true
|
|
}
|
|
}
|
|
return Memory{}, false
|
|
}
|
|
|
|
func (s *Store) load() {
|
|
data, err := os.ReadFile(s.path)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var memories []Memory
|
|
if err := json.Unmarshal(data, &memories); err != nil {
|
|
return
|
|
}
|
|
s.memories = memories
|
|
for _, m := range s.memories {
|
|
if m.ID > s.nextID {
|
|
s.nextID = m.ID
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Store) persist() error {
|
|
dir := filepath.Dir(s.path)
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return fmt.Errorf("create memory dir: %w", err)
|
|
}
|
|
data, err := json.MarshalIndent(s.memories, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("marshal memories: %w", err)
|
|
}
|
|
if err := os.WriteFile(s.path, data, 0o644); err != nil {
|
|
return fmt.Errorf("write memories: %w", err)
|
|
}
|
|
return nil
|
|
}
|