fix: sanitize local embeddings (#5332) (thanks @akramcodez)

This commit is contained in:
Gustavo Madeira Santana
2026-02-01 22:51:38 -05:00
parent fc0fdebd8b
commit 171b133f64
2 changed files with 45 additions and 12 deletions

View File

@@ -327,9 +327,10 @@ describe("embedding provider local fallback", () => {
});
});
describe("local embedding L2 normalization", () => {
describe("local embedding normalization", () => {
afterEach(() => {
vi.resetAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
vi.doUnmock("./node-llama.js");
});
@@ -353,7 +354,6 @@ describe("local embedding L2 normalization", () => {
}),
}));
vi.resetModules();
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
@@ -389,7 +389,6 @@ describe("local embedding L2 normalization", () => {
}),
}));
vi.resetModules();
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
@@ -402,7 +401,41 @@ describe("local embedding L2 normalization", () => {
const embedding = await result.provider.embedQuery("test");
expect(embedding).toEqual([0, 0, 0, 0]);
expect(embedding.every((x) => !Number.isNaN(x))).toBe(true);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("sanitizes non-finite values before normalization", async () => {
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
vi.doMock("./node-llama.js", () => ({
importNodeLlamaCpp: async () => ({
getLlama: async () => ({
loadModel: vi.fn().mockResolvedValue({
createEmbeddingContext: vi.fn().mockResolvedValue({
getEmbeddingFor: vi.fn().mockResolvedValue({
vector: new Float32Array(nonFiniteVector),
}),
}),
}),
}),
resolveModelFile: async () => "/fake/model.gguf",
LlamaLogLevel: { error: 0 },
}),
}));
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({
config: {} as never,
provider: "local",
model: "",
fallback: "none",
});
const embedding = await result.provider.embedQuery("test");
expect(embedding).toEqual([1, 0, 0, 0]);
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
});
it("normalizes batch embeddings to magnitude ~1.0", async () => {
@@ -430,7 +463,6 @@ describe("local embedding L2 normalization", () => {
}),
}));
vi.resetModules();
const { createEmbeddingProvider } = await import("./embeddings.js");
const result = await createEmbeddingProvider({

View File

@@ -6,12 +6,13 @@ import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./emb
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
import { importNodeLlamaCpp } from "./node-llama.js";
function l2Normalize(vec: number[]): number[] {
const magnitude = Math.sqrt(vec.reduce((sum, x) => sum + x * x, 0));
if (!Number.isFinite(magnitude) || magnitude < 1e-10) {
return vec;
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
if (magnitude < 1e-10) {
return sanitized;
}
return vec.map((x) => x / magnitude);
return sanitized.map((value) => value / magnitude);
}
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
@@ -106,14 +107,14 @@ async function createLocalEmbeddingProvider(
embedQuery: async (text) => {
const ctx = await ensureContext();
const embedding = await ctx.getEmbeddingFor(text);
return l2Normalize(Array.from(embedding.vector));
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
},
embedBatch: async (texts) => {
const ctx = await ensureContext();
const embeddings = await Promise.all(
texts.map(async (text) => {
const embedding = await ctx.getEmbeddingFor(text);
return l2Normalize(Array.from(embedding.vector));
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
}),
);
return embeddings;