ai-agent/internal/ice/store_test.go
admin 8dc496b626
Some checks failed
CI / test (push) Has been cancelled
Release / release (push) Failing after 4m36s
first commit
2026-03-08 15:40:34 +07:00

233 lines
6.6 KiB
Go

package ice
import (
"math"
"os"
"path/filepath"
"testing"
)
func TestCosineSimilarity(t *testing.T) {
tests := []struct {
name string
a, b []float32
want float32
tol float32
}{
{
name: "identical vectors",
a: []float32{1, 2, 3},
b: []float32{1, 2, 3},
want: 1.0,
tol: 1e-6,
},
{
name: "orthogonal vectors",
a: []float32{1, 0},
b: []float32{0, 1},
want: 0.0,
tol: 1e-6,
},
{
name: "opposite vectors",
a: []float32{1, 0},
b: []float32{-1, 0},
want: -1.0,
tol: 1e-6,
},
{
name: "different lengths returns 0",
a: []float32{1, 0},
b: []float32{1, 0, 0},
want: 0,
tol: 0,
},
{
name: "zero vector returns 0",
a: []float32{0, 0},
b: []float32{1, 1},
want: 0,
tol: 0,
},
{
name: "known value with tolerance",
a: []float32{1, 1},
b: []float32{1, 0},
// 1/(sqrt(2)*1) ≈ 0.7071
want: float32(1.0 / math.Sqrt(2)),
tol: 1e-4,
},
{
name: "empty vectors",
a: []float32{},
b: []float32{},
want: 0,
tol: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cosineSimilarity(tt.a, tt.b)
diff := got - tt.want
if diff < 0 {
diff = -diff
}
if diff > tt.tol {
t.Errorf("cosineSimilarity(%v, %v) = %f, want %f (±%f)",
tt.a, tt.b, got, tt.want, tt.tol)
}
})
}
}
func TestStore_Add_And_Count(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s := NewStore(path)
if s.Count() != 0 {
t.Fatalf("new store Count = %d, want 0", s.Count())
}
id1, err := s.Add("sess1", "user", "hello", []float32{1, 0, 0}, 0)
if err != nil {
t.Fatalf("Add returned error: %v", err)
}
if id1 != 1 {
t.Errorf("first Add returned id=%d, want 1", id1)
}
if s.Count() != 1 {
t.Errorf("Count after first Add = %d, want 1", s.Count())
}
id2, err := s.Add("sess1", "assistant", "world", []float32{0, 1, 0}, 1)
if err != nil {
t.Fatalf("Add returned error: %v", err)
}
if id2 != 2 {
t.Errorf("second Add returned id=%d, want 2", id2)
}
if s.Count() != 2 {
t.Errorf("Count after second Add = %d, want 2", s.Count())
}
}
func TestStore_Search(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s := NewStore(path)
// Add entries with known embeddings.
s.Add("sess1", "user", "entry A", []float32{1, 0, 0}, 0)
s.Add("sess1", "user", "entry B", []float32{0, 1, 0}, 1)
s.Add("sess2", "user", "entry C", []float32{0.9, 0.1, 0}, 0)
s.Add("sess2", "user", "entry D", []float32{0, 0, 1}, 1) // orthogonal to query
t.Run("similarity filtering and sorting", func(t *testing.T) {
// Query similar to entries A and C, exclude no session.
results := s.Search([]float32{1, 0, 0}, "", 10)
if len(results) == 0 {
t.Fatal("expected results, got 0")
}
// Entry A should be highest (identical to query).
if results[0].Entry.Content != "entry A" {
t.Errorf("top result = %q, want 'entry A'", results[0].Entry.Content)
}
})
t.Run("session exclusion", func(t *testing.T) {
results := s.Search([]float32{1, 0, 0}, "sess1", 10)
for _, r := range results {
if r.Entry.SessionID == "sess1" {
t.Errorf("excluded session sess1 should not appear in results")
}
}
})
t.Run("min threshold 0.3", func(t *testing.T) {
// Entry D: [0,0,1] is orthogonal to [1,0,0] → similarity 0.
results := s.Search([]float32{1, 0, 0}, "", 10)
for _, r := range results {
if r.Score < minSimilarityThreshold {
t.Errorf("result %q has score %f below threshold %f",
r.Entry.Content, r.Score, minSimilarityThreshold)
}
}
})
t.Run("topK limit", func(t *testing.T) {
results := s.Search([]float32{1, 0, 0}, "", 1)
if len(results) > 1 {
t.Errorf("topK=1 but got %d results", len(results))
}
})
t.Run("empty store returns nil", func(t *testing.T) {
emptyPath := filepath.Join(dir, "empty.json")
empty := NewStore(emptyPath)
results := empty.Search([]float32{1, 0}, "", 5)
if results != nil {
t.Errorf("empty store search should return nil, got %v", results)
}
})
t.Run("empty query embedding returns nil", func(t *testing.T) {
results := s.Search([]float32{}, "", 5)
if results != nil {
t.Errorf("empty query should return nil, got %v", results)
}
})
}
func TestStore_Flush_Persistence(t *testing.T) {
t.Run("round trip", func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s1 := NewStore(path)
s1.Add("sess1", "user", "hello", []float32{1, 0}, 0)
s1.Add("sess1", "assistant", "world", []float32{0, 1}, 1)
if err := s1.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
// Reload from same path.
s2 := NewStore(path)
if s2.Count() != 2 {
t.Errorf("reloaded store Count = %d, want 2", s2.Count())
}
})
t.Run("corrupt JSON recovery", func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
// Write corrupt JSON.
os.WriteFile(path, []byte("not valid json{{{"), 0o644)
s := NewStore(path)
if s.Count() != 0 {
t.Errorf("corrupt store Count = %d, want 0", s.Count())
}
})
t.Run("nextID restoration", func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "store.json")
s1 := NewStore(path)
s1.Add("sess1", "user", "first", []float32{1}, 0)
s1.Add("sess1", "user", "second", []float32{1}, 1)
s1.Flush()
s2 := NewStore(path)
id, _ := s2.Add("sess1", "user", "third", []float32{1}, 2)
if id != 3 {
t.Errorf("continued id = %d, want 3", id)
}
})
}