mirror of
https://github.com/openclaw/openclaw.git
synced 2026-02-19 18:39:20 -05:00
fix: sanitize local embeddings (#5332) (thanks @akramcodez)
This commit is contained in:
@@ -327,9 +327,10 @@ describe("embedding provider local fallback", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("local embedding L2 normalization", () => {
|
||||
describe("local embedding normalization", () => {
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.resetModules();
|
||||
vi.unstubAllGlobals();
|
||||
vi.doUnmock("./node-llama.js");
|
||||
});
|
||||
@@ -353,7 +354,6 @@ describe("local embedding L2 normalization", () => {
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.resetModules();
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
@@ -389,7 +389,6 @@ describe("local embedding L2 normalization", () => {
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.resetModules();
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
@@ -402,7 +401,41 @@ describe("local embedding L2 normalization", () => {
|
||||
const embedding = await result.provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([0, 0, 0, 0]);
|
||||
expect(embedding.every((x) => !Number.isNaN(x))).toBe(true);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("sanitizes non-finite values before normalization", async () => {
|
||||
const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];
|
||||
|
||||
vi.doMock("./node-llama.js", () => ({
|
||||
importNodeLlamaCpp: async () => ({
|
||||
getLlama: async () => ({
|
||||
loadModel: vi.fn().mockResolvedValue({
|
||||
createEmbeddingContext: vi.fn().mockResolvedValue({
|
||||
getEmbeddingFor: vi.fn().mockResolvedValue({
|
||||
vector: new Float32Array(nonFiniteVector),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
resolveModelFile: async () => "/fake/model.gguf",
|
||||
LlamaLogLevel: { error: 0 },
|
||||
}),
|
||||
}));
|
||||
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "local",
|
||||
model: "",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
const embedding = await result.provider.embedQuery("test");
|
||||
|
||||
expect(embedding).toEqual([1, 0, 0, 0]);
|
||||
expect(embedding.every((value) => Number.isFinite(value))).toBe(true);
|
||||
});
|
||||
|
||||
it("normalizes batch embeddings to magnitude ~1.0", async () => {
|
||||
@@ -430,7 +463,6 @@ describe("local embedding L2 normalization", () => {
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.resetModules();
|
||||
const { createEmbeddingProvider } = await import("./embeddings.js");
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
|
||||
@@ -6,12 +6,13 @@ import { createGeminiEmbeddingProvider, type GeminiEmbeddingClient } from "./emb
|
||||
import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js";
|
||||
import { importNodeLlamaCpp } from "./node-llama.js";
|
||||
|
||||
function l2Normalize(vec: number[]): number[] {
|
||||
const magnitude = Math.sqrt(vec.reduce((sum, x) => sum + x * x, 0));
|
||||
if (!Number.isFinite(magnitude) || magnitude < 1e-10) {
|
||||
return vec;
|
||||
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
|
||||
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
|
||||
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
|
||||
if (magnitude < 1e-10) {
|
||||
return sanitized;
|
||||
}
|
||||
return vec.map((x) => x / magnitude);
|
||||
return sanitized.map((value) => value / magnitude);
|
||||
}
|
||||
|
||||
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
|
||||
@@ -106,14 +107,14 @@ async function createLocalEmbeddingProvider(
|
||||
embedQuery: async (text) => {
|
||||
const ctx = await ensureContext();
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return l2Normalize(Array.from(embedding.vector));
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
},
|
||||
embedBatch: async (texts) => {
|
||||
const ctx = await ensureContext();
|
||||
const embeddings = await Promise.all(
|
||||
texts.map(async (text) => {
|
||||
const embedding = await ctx.getEmbeddingFor(text);
|
||||
return l2Normalize(Array.from(embedding.vector));
|
||||
return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector));
|
||||
}),
|
||||
);
|
||||
return embeddings;
|
||||
|
||||
Reference in New Issue
Block a user