mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-03 03:03:24 -04:00
memory-neo4j: add context-length-aware embedding truncation
This commit is contained in:
@@ -5,7 +5,13 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, afterEach } from "vitest";
|
||||
import { memoryNeo4jConfigSchema, vectorDimsForModel, resolveExtractionConfig } from "./config.js";
|
||||
import {
|
||||
memoryNeo4jConfigSchema,
|
||||
vectorDimsForModel,
|
||||
contextLengthForModel,
|
||||
DEFAULT_EMBEDDING_CONTEXT_LENGTH,
|
||||
resolveExtractionConfig,
|
||||
} from "./config.js";
|
||||
|
||||
// ============================================================================
|
||||
// memoryNeo4jConfigSchema.parse()
|
||||
@@ -547,3 +553,57 @@ describe("resolveExtractionConfig", () => {
|
||||
expect(config.maxRetries).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// contextLengthForModel()
|
||||
// ============================================================================
|
||||
|
||||
describe("contextLengthForModel", () => {
|
||||
describe("exact match", () => {
|
||||
it("should return 512 for mxbai-embed-large", () => {
|
||||
expect(contextLengthForModel("mxbai-embed-large")).toBe(512);
|
||||
});
|
||||
|
||||
it("should return 8191 for text-embedding-3-small (OpenAI)", () => {
|
||||
expect(contextLengthForModel("text-embedding-3-small")).toBe(8191);
|
||||
});
|
||||
|
||||
it("should return 8191 for text-embedding-3-large (OpenAI)", () => {
|
||||
expect(contextLengthForModel("text-embedding-3-large")).toBe(8191);
|
||||
});
|
||||
|
||||
it("should return 8192 for nomic-embed-text", () => {
|
||||
expect(contextLengthForModel("nomic-embed-text")).toBe(8192);
|
||||
});
|
||||
|
||||
it("should return 256 for all-minilm", () => {
|
||||
expect(contextLengthForModel("all-minilm")).toBe(256);
|
||||
});
|
||||
});
|
||||
|
||||
describe("prefix match", () => {
|
||||
it("should match mxbai-embed-large-8k:latest via prefix to 8192", () => {
|
||||
expect(contextLengthForModel("mxbai-embed-large-8k:latest")).toBe(8192);
|
||||
});
|
||||
|
||||
it("should match nomic-embed-text:v1.5 via prefix to 8192", () => {
|
||||
expect(contextLengthForModel("nomic-embed-text:v1.5")).toBe(8192);
|
||||
});
|
||||
});
|
||||
|
||||
describe("unknown model fallback", () => {
|
||||
it("should return DEFAULT_EMBEDDING_CONTEXT_LENGTH for unknown model", () => {
|
||||
expect(contextLengthForModel("some-unknown-model")).toBe(DEFAULT_EMBEDDING_CONTEXT_LENGTH);
|
||||
});
|
||||
|
||||
it("should return 512 as the default context length", () => {
|
||||
// Verify the default value itself is 512
|
||||
expect(DEFAULT_EMBEDDING_CONTEXT_LENGTH).toBe(512);
|
||||
expect(contextLengthForModel("some-unknown-model")).toBe(512);
|
||||
});
|
||||
|
||||
it("should return default for empty string", () => {
|
||||
expect(contextLengthForModel("")).toBe(DEFAULT_EMBEDDING_CONTEXT_LENGTH);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -93,6 +93,36 @@ export function vectorDimsForModel(model: string): number {
|
||||
return DEFAULT_EMBEDDING_DIMS;
|
||||
}
|
||||
|
||||
/** Max input token lengths for known embedding models. */
|
||||
export const EMBEDDING_CONTEXT_LENGTHS: Record<string, number> = {
|
||||
// OpenAI models
|
||||
"text-embedding-3-small": 8191,
|
||||
"text-embedding-3-large": 8191,
|
||||
// Ollama models
|
||||
"mxbai-embed-large": 512,
|
||||
"mxbai-embed-large-2k": 2048,
|
||||
"mxbai-embed-large-8k": 8192,
|
||||
"nomic-embed-text": 8192,
|
||||
"all-minilm": 256,
|
||||
};
|
||||
|
||||
/** Conservative default for unknown models. */
|
||||
export const DEFAULT_EMBEDDING_CONTEXT_LENGTH = 512;
|
||||
|
||||
export function contextLengthForModel(model: string): number {
|
||||
if (EMBEDDING_CONTEXT_LENGTHS[model]) {
|
||||
return EMBEDDING_CONTEXT_LENGTHS[model];
|
||||
}
|
||||
// Prefer longest matching prefix (e.g. "mxbai-embed-large-8k" over "mxbai-embed-large")
|
||||
let best: { len: number; keyLen: number } | undefined;
|
||||
for (const [known, len] of Object.entries(EMBEDDING_CONTEXT_LENGTHS)) {
|
||||
if (model.startsWith(known) && (!best || known.length > best.keyLen)) {
|
||||
best = { len, keyLen: known.length };
|
||||
}
|
||||
}
|
||||
return best?.len ?? DEFAULT_EMBEDDING_CONTEXT_LENGTH;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve ${ENV_VAR} references in string values.
|
||||
*/
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* Tests the Embeddings class with mocked OpenAI client and mocked fetch for Ollama.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, afterEach } from "vitest";
|
||||
import { describe, it, expect, vi, afterEach, beforeEach } from "vitest";
|
||||
|
||||
// ============================================================================
|
||||
// Constructor
|
||||
@@ -190,3 +190,111 @@ describe("Embeddings - embedBatch", () => {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ============================================================================
|
||||
// Ollama context-length truncation
|
||||
// ============================================================================
|
||||
|
||||
describe("Embeddings - Ollama context-length truncation", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
beforeEach(() => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ embeddings: [[0.1, 0.2, 0.3]] }),
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
it("should truncate long input before calling Ollama embed", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
|
||||
// mxbai-embed-large context length is 512, so maxChars = 512 * 3 = 1536
|
||||
// Create input that exceeds the limit
|
||||
const longText = "word ".repeat(500); // ~2500 chars, well above 1536
|
||||
await emb.embed(longText);
|
||||
|
||||
// Verify the text sent to Ollama was truncated
|
||||
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
|
||||
const body = JSON.parse(call[1].body as string);
|
||||
expect(body.input.length).toBeLessThanOrEqual(512 * 3);
|
||||
});
|
||||
|
||||
it("should truncate at word boundary (not mid-word)", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
|
||||
// maxChars for mxbai-embed-large = 512 * 3 = 1536
|
||||
// Each "abcdefghij " is 11 chars; 200 repeats = 2200 chars total (exceeds 1536)
|
||||
const longText = "abcdefghij ".repeat(200);
|
||||
await emb.embed(longText);
|
||||
|
||||
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
|
||||
const body = JSON.parse(call[1].body as string);
|
||||
const sentText = body.input as string;
|
||||
|
||||
expect(sentText.length).toBeLessThanOrEqual(512 * 3);
|
||||
// The truncation should land on a word boundary: the sent text should
|
||||
// be a prefix of the original that ends at a complete word (i.e. the
|
||||
// character after the sent text in the original should be a space).
|
||||
// Since the pattern is "abcdefghij " repeated, a word-boundary cut
|
||||
// means sentText ends with "abcdefghij" (no trailing partial word).
|
||||
expect(sentText).toMatch(/abcdefghij$/);
|
||||
// Verify it's a proper prefix of the original
|
||||
expect(longText.startsWith(sentText)).toBe(true);
|
||||
});
|
||||
|
||||
it("should pass short input through unchanged", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
|
||||
const shortText = "This is a short text that fits within context length.";
|
||||
await emb.embed(shortText);
|
||||
|
||||
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
|
||||
const body = JSON.parse(call[1].body as string);
|
||||
expect(body.input).toBe(shortText);
|
||||
});
|
||||
|
||||
it("should use model-specific context length for truncation", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
// nomic-embed-text has context length 8192, maxChars = 8192 * 3 = 24576
|
||||
const emb = new Embeddings(undefined, "nomic-embed-text", "ollama");
|
||||
|
||||
// Create text that exceeds mxbai limit (1536) but fits nomic limit (24576)
|
||||
const mediumText = "hello ".repeat(400); // ~2400 chars
|
||||
await emb.embed(mediumText);
|
||||
|
||||
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
|
||||
const body = JSON.parse(call[1].body as string);
|
||||
// Should NOT be truncated since 2400 < 24576
|
||||
expect(body.input).toBe(mediumText);
|
||||
});
|
||||
|
||||
it("should truncate each item individually in embedBatch", async () => {
|
||||
const { Embeddings } = await import("./embeddings.js");
|
||||
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
|
||||
|
||||
// maxChars for mxbai-embed-large = 512 * 3 = 1536
|
||||
const longText = "word ".repeat(500); // ~2500 chars, exceeds limit
|
||||
const shortText = "short text"; // well under limit
|
||||
|
||||
await emb.embedBatch([longText, shortText]);
|
||||
|
||||
const calls = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls;
|
||||
expect(calls).toHaveLength(2);
|
||||
|
||||
// First call: long text should be truncated
|
||||
const body1 = JSON.parse(calls[0][1].body as string);
|
||||
expect(body1.input.length).toBeLessThanOrEqual(512 * 3);
|
||||
expect(body1.input.length).toBeLessThan(longText.length);
|
||||
|
||||
// Second call: short text should pass through unchanged
|
||||
const body2 = JSON.parse(calls[1][1].body as string);
|
||||
expect(body2.input).toBe(shortText);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import OpenAI from "openai";
|
||||
import type { EmbeddingProvider } from "./config.js";
|
||||
import { contextLengthForModel } from "./config.js";
|
||||
|
||||
type Logger = {
|
||||
info: (msg: string) => void;
|
||||
@@ -19,6 +20,7 @@ export class Embeddings {
|
||||
private readonly provider: EmbeddingProvider;
|
||||
private readonly baseUrl: string;
|
||||
private readonly logger: Logger | undefined;
|
||||
private readonly contextLength: number;
|
||||
|
||||
constructor(
|
||||
private readonly apiKey: string | undefined,
|
||||
@@ -30,6 +32,7 @@ export class Embeddings {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "");
|
||||
this.logger = logger;
|
||||
this.contextLength = contextLengthForModel(model);
|
||||
|
||||
if (provider === "openai") {
|
||||
if (!apiKey) {
|
||||
@@ -39,14 +42,37 @@ export class Embeddings {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncate text to fit within the model's context length.
|
||||
* Uses a conservative ~3 chars/token estimate to leave headroom.
|
||||
* Truncates at a word boundary when possible.
|
||||
*/
|
||||
private truncateToContext(text: string): string {
|
||||
const maxChars = this.contextLength * 3;
|
||||
if (text.length <= maxChars) return text;
|
||||
|
||||
// Try to truncate at a word boundary
|
||||
let truncated = text.slice(0, maxChars);
|
||||
const lastSpace = truncated.lastIndexOf(" ");
|
||||
if (lastSpace > maxChars * 0.8) {
|
||||
truncated = truncated.slice(0, lastSpace);
|
||||
}
|
||||
|
||||
this.logger?.debug?.(
|
||||
`memory-neo4j: truncated embedding input from ${text.length} to ${truncated.length} chars (model context: ${this.contextLength} tokens)`,
|
||||
);
|
||||
return truncated;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an embedding vector for a single text.
|
||||
*/
|
||||
async embed(text: string): Promise<number[]> {
|
||||
const input = this.truncateToContext(text);
|
||||
if (this.provider === "ollama") {
|
||||
return this.embedOllama(text);
|
||||
return this.embedOllama(input);
|
||||
}
|
||||
return this.embedOpenAI(text);
|
||||
return this.embedOpenAI(input);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -61,9 +87,11 @@ export class Embeddings {
|
||||
return [];
|
||||
}
|
||||
|
||||
const truncated = texts.map((t) => this.truncateToContext(t));
|
||||
|
||||
if (this.provider === "ollama") {
|
||||
// Ollama doesn't support batch, so we do sequential with resilient error handling
|
||||
const results = await Promise.allSettled(texts.map((t) => this.embedOllama(t)));
|
||||
const results = await Promise.allSettled(truncated.map((t) => this.embedOllama(t)));
|
||||
const embeddings: number[][] = [];
|
||||
let failures = 0;
|
||||
|
||||
@@ -90,7 +118,7 @@ export class Embeddings {
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
return this.embedBatchOpenAI(texts);
|
||||
return this.embedBatchOpenAI(truncated);
|
||||
}
|
||||
|
||||
private async embedOpenAI(text: string): Promise<number[]> {
|
||||
|
||||
@@ -805,9 +805,10 @@ const memoryNeo4jPlugin = {
|
||||
|
||||
const agentId = ctx.agentId || "default";
|
||||
|
||||
// Truncate prompt to avoid exceeding embedding model context length
|
||||
// ~6000 chars is safe for most embedding models (leaves headroom for 2k tokens)
|
||||
const MAX_QUERY_CHARS = 6000;
|
||||
// ~1500 chars is a safe ceiling for most embedding models (~500 tokens).
|
||||
// Models with larger context (8k+) can handle more, but recall queries
|
||||
// don't benefit from very long inputs — the embedding quality plateaus.
|
||||
const MAX_QUERY_CHARS = 1500;
|
||||
const query =
|
||||
event.prompt.length > MAX_QUERY_CHARS
|
||||
? event.prompt.slice(0, MAX_QUERY_CHARS)
|
||||
|
||||
Reference in New Issue
Block a user