diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index b45a58672e..7ef828f194 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -327,9 +327,10 @@ describe("embedding provider local fallback", () => { }); }); -describe("local embedding L2 normalization", () => { +describe("local embedding normalization", () => { afterEach(() => { vi.resetAllMocks(); + vi.resetModules(); vi.unstubAllGlobals(); vi.doUnmock("./node-llama.js"); }); @@ -353,7 +354,6 @@ describe("local embedding L2 normalization", () => { }), })); - vi.resetModules(); const { createEmbeddingProvider } = await import("./embeddings.js"); const result = await createEmbeddingProvider({ @@ -389,7 +389,6 @@ describe("local embedding L2 normalization", () => { }), })); - vi.resetModules(); const { createEmbeddingProvider } = await import("./embeddings.js"); const result = await createEmbeddingProvider({ @@ -402,7 +401,41 @@ describe("local embedding L2 normalization", () => { const embedding = await result.provider.embedQuery("test"); expect(embedding).toEqual([0, 0, 0, 0]); - expect(embedding.every((x) => !Number.isNaN(x))).toBe(true); + expect(embedding.every((value) => Number.isFinite(value))).toBe(true); + }); + + it("sanitizes non-finite values before normalization", async () => { + const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY]; + + vi.doMock("./node-llama.js", () => ({ + importNodeLlamaCpp: async () => ({ + getLlama: async () => ({ + loadModel: vi.fn().mockResolvedValue({ + createEmbeddingContext: vi.fn().mockResolvedValue({ + getEmbeddingFor: vi.fn().mockResolvedValue({ + vector: new Float32Array(nonFiniteVector), + }), + }), + }), + }), + resolveModelFile: async () => "/fake/model.gguf", + LlamaLogLevel: { error: 0 }, + }), + })); + + const { createEmbeddingProvider } = await import("./embeddings.js"); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "", + fallback: "none", + }); + + const embedding = await result.provider.embedQuery("test"); + + expect(embedding).toEqual([1, 0, 0, 0]); + expect(embedding.every((value) => Number.isFinite(value))).toBe(true); }); it("normalizes batch embeddings to magnitude ~1.0", async () => { @@ -430,7 +463,6 @@ describe("local embedding L2 normalization", () => { }), })); - vi.resetModules(); const { createEmbeddingProvider } = await import("./embeddings.js"); const result = await createEmbeddingProvider({ diff --git a/src/memory/embeddings.ts b/src/memory/embeddings.ts index 02ba73ab1b..a2783a1349 100644 --- a/src/memory/embeddings.ts +++ b/src/memory/embeddings.ts @@ -6,12 +6,13 @@ import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./emb import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; import { importNodeLlamaCpp } from "./node-llama.js"; -function l2Normalize(vec: number[]): number[] { - const magnitude = Math.sqrt(vec.reduce((sum, x) => sum + x * x, 0)); - if (!Number.isFinite(magnitude) || magnitude < 1e-10) { - return vec; +function sanitizeAndNormalizeEmbedding(vec: number[]): number[] { + const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0)); + const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0)); + if (magnitude < 1e-10) { + return sanitized; } - return vec.map((x) => x / magnitude); + return sanitized.map((value) => value / magnitude); } export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; @@ -106,14 +107,14 @@ async function createLocalEmbeddingProvider( embedQuery: async (text) => { const ctx = await ensureContext(); const embedding = await ctx.getEmbeddingFor(text); - return l2Normalize(Array.from(embedding.vector)); + return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); }, embedBatch: async (texts) => { const ctx = await ensureContext(); const embeddings = await Promise.all( texts.map(async (text) => { const embedding = await ctx.getEmbeddingFor(text); - return l2Normalize(Array.from(embedding.vector)); + return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); }), ); return embeddings;