57 lines
1.3 KiB
Go
57 lines
1.3 KiB
Go
package ice
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"ai-agent/internal/llm"
|
|
)
|
|
|
|
const (
|
|
defaultEmbedModel = "nomic-embed-text"
|
|
maxBatchSize = 32
|
|
)
|
|
|
|
type Embedder struct {
|
|
client llm.Client
|
|
model string
|
|
}
|
|
|
|
func NewEmbedder(client llm.Client, model string) *Embedder {
|
|
if model == "" {
|
|
model = defaultEmbedModel
|
|
}
|
|
return &Embedder{client: client, model: model}
|
|
}
|
|
|
|
func (e *Embedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
|
vecs, err := e.EmbedBatch(ctx, []string{text})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(vecs) == 0 {
|
|
return nil, fmt.Errorf("empty embedding response")
|
|
}
|
|
return vecs[0], nil
|
|
}
|
|
|
|
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
|
|
if len(texts) == 0 {
|
|
return nil, nil
|
|
}
|
|
var all [][]float32
|
|
for i := 0; i < len(texts); i += maxBatchSize {
|
|
end := i + maxBatchSize
|
|
if end > len(texts) {
|
|
end = len(texts)
|
|
}
|
|
batch := texts[i:end]
|
|
vecs, err := e.client.Embed(ctx, e.model, batch)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("embed batch [%d:%d]: %w", i, end, err)
|
|
}
|
|
all = append(all, vecs...)
|
|
}
|
|
return all, nil
|
|
}
|