memory-neo4j: add context-length-aware embedding truncation

This commit is contained in:
Tarun Sukhani
2026-02-05 13:25:10 +00:00
parent e65b052d27
commit 8e5fe5fc14
5 changed files with 236 additions and 9 deletions

View File

@@ -5,7 +5,13 @@
*/
import { describe, it, expect, afterEach } from "vitest";
import { memoryNeo4jConfigSchema, vectorDimsForModel, resolveExtractionConfig } from "./config.js";
import {
memoryNeo4jConfigSchema,
vectorDimsForModel,
contextLengthForModel,
DEFAULT_EMBEDDING_CONTEXT_LENGTH,
resolveExtractionConfig,
} from "./config.js";
// ============================================================================
// memoryNeo4jConfigSchema.parse()
@@ -547,3 +553,57 @@ describe("resolveExtractionConfig", () => {
expect(config.maxRetries).toBe(2);
});
});
// ============================================================================
// contextLengthForModel()
// ============================================================================
describe("contextLengthForModel", () => {
describe("exact match", () => {
it("should return 512 for mxbai-embed-large", () => {
expect(contextLengthForModel("mxbai-embed-large")).toBe(512);
});
it("should return 8191 for text-embedding-3-small (OpenAI)", () => {
expect(contextLengthForModel("text-embedding-3-small")).toBe(8191);
});
it("should return 8191 for text-embedding-3-large (OpenAI)", () => {
expect(contextLengthForModel("text-embedding-3-large")).toBe(8191);
});
it("should return 8192 for nomic-embed-text", () => {
expect(contextLengthForModel("nomic-embed-text")).toBe(8192);
});
it("should return 256 for all-minilm", () => {
expect(contextLengthForModel("all-minilm")).toBe(256);
});
});
describe("prefix match", () => {
it("should match mxbai-embed-large-8k:latest via prefix to 8192", () => {
expect(contextLengthForModel("mxbai-embed-large-8k:latest")).toBe(8192);
});
it("should match nomic-embed-text:v1.5 via prefix to 8192", () => {
expect(contextLengthForModel("nomic-embed-text:v1.5")).toBe(8192);
});
});
describe("unknown model fallback", () => {
it("should return DEFAULT_EMBEDDING_CONTEXT_LENGTH for unknown model", () => {
expect(contextLengthForModel("some-unknown-model")).toBe(DEFAULT_EMBEDDING_CONTEXT_LENGTH);
});
it("should return 512 as the default context length", () => {
// Verify the default value itself is 512
expect(DEFAULT_EMBEDDING_CONTEXT_LENGTH).toBe(512);
expect(contextLengthForModel("some-unknown-model")).toBe(512);
});
it("should return default for empty string", () => {
expect(contextLengthForModel("")).toBe(DEFAULT_EMBEDDING_CONTEXT_LENGTH);
});
});
});

View File

@@ -93,6 +93,36 @@ export function vectorDimsForModel(model: string): number {
return DEFAULT_EMBEDDING_DIMS;
}
/** Max input token lengths for known embedding models. */
export const EMBEDDING_CONTEXT_LENGTHS: Record<string, number> = {
// OpenAI models
"text-embedding-3-small": 8191,
"text-embedding-3-large": 8191,
// Ollama models
"mxbai-embed-large": 512,
"mxbai-embed-large-2k": 2048,
"mxbai-embed-large-8k": 8192,
"nomic-embed-text": 8192,
"all-minilm": 256,
};
/** Conservative default for unknown models. */
export const DEFAULT_EMBEDDING_CONTEXT_LENGTH = 512;
export function contextLengthForModel(model: string): number {
if (EMBEDDING_CONTEXT_LENGTHS[model]) {
return EMBEDDING_CONTEXT_LENGTHS[model];
}
// Prefer longest matching prefix (e.g. "mxbai-embed-large-8k" over "mxbai-embed-large")
let best: { len: number; keyLen: number } | undefined;
for (const [known, len] of Object.entries(EMBEDDING_CONTEXT_LENGTHS)) {
if (model.startsWith(known) && (!best || known.length > best.keyLen)) {
best = { len, keyLen: known.length };
}
}
return best?.len ?? DEFAULT_EMBEDDING_CONTEXT_LENGTH;
}
/**
* Resolve ${ENV_VAR} references in string values.
*/

View File

@@ -4,7 +4,7 @@
* Tests the Embeddings class with mocked OpenAI client and mocked fetch for Ollama.
*/
import { describe, it, expect, vi, afterEach } from "vitest";
import { describe, it, expect, vi, afterEach, beforeEach } from "vitest";
// ============================================================================
// Constructor
@@ -190,3 +190,111 @@ describe("Embeddings - embedBatch", () => {
}
});
});
// ============================================================================
// Ollama context-length truncation
// ============================================================================
describe("Embeddings - Ollama context-length truncation", () => {
const originalFetch = globalThis.fetch;
beforeEach(() => {
globalThis.fetch = vi.fn().mockResolvedValue({
ok: true,
json: () => Promise.resolve({ embeddings: [[0.1, 0.2, 0.3]] }),
});
});
afterEach(() => {
globalThis.fetch = originalFetch;
});
it("should truncate long input before calling Ollama embed", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
// mxbai-embed-large context length is 512, so maxChars = 512 * 3 = 1536
// Create input that exceeds the limit
const longText = "word ".repeat(500); // ~2500 chars, well above 1536
await emb.embed(longText);
// Verify the text sent to Ollama was truncated
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
const body = JSON.parse(call[1].body as string);
expect(body.input.length).toBeLessThanOrEqual(512 * 3);
});
it("should truncate at word boundary (not mid-word)", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
// maxChars for mxbai-embed-large = 512 * 3 = 1536
// Each "abcdefghij " is 11 chars; 200 repeats = 2200 chars total (exceeds 1536)
const longText = "abcdefghij ".repeat(200);
await emb.embed(longText);
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
const body = JSON.parse(call[1].body as string);
const sentText = body.input as string;
expect(sentText.length).toBeLessThanOrEqual(512 * 3);
// The truncation should land on a word boundary: the sent text should
// be a prefix of the original that ends at a complete word (i.e. the
// character after the sent text in the original should be a space).
// Since the pattern is "abcdefghij " repeated, a word-boundary cut
// means sentText ends with "abcdefghij" (no trailing partial word).
expect(sentText).toMatch(/abcdefghij$/);
// Verify it's a proper prefix of the original
expect(longText.startsWith(sentText)).toBe(true);
});
it("should pass short input through unchanged", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
const shortText = "This is a short text that fits within context length.";
await emb.embed(shortText);
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
const body = JSON.parse(call[1].body as string);
expect(body.input).toBe(shortText);
});
it("should use model-specific context length for truncation", async () => {
const { Embeddings } = await import("./embeddings.js");
// nomic-embed-text has context length 8192, maxChars = 8192 * 3 = 24576
const emb = new Embeddings(undefined, "nomic-embed-text", "ollama");
// Create text that exceeds mxbai limit (1536) but fits nomic limit (24576)
const mediumText = "hello ".repeat(400); // ~2400 chars
await emb.embed(mediumText);
const call = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls[0];
const body = JSON.parse(call[1].body as string);
// Should NOT be truncated since 2400 < 24576
expect(body.input).toBe(mediumText);
});
it("should truncate each item individually in embedBatch", async () => {
const { Embeddings } = await import("./embeddings.js");
const emb = new Embeddings(undefined, "mxbai-embed-large", "ollama");
// maxChars for mxbai-embed-large = 512 * 3 = 1536
const longText = "word ".repeat(500); // ~2500 chars, exceeds limit
const shortText = "short text"; // well under limit
await emb.embedBatch([longText, shortText]);
const calls = (globalThis.fetch as ReturnType<typeof vi.fn>).mock.calls;
expect(calls).toHaveLength(2);
// First call: long text should be truncated
const body1 = JSON.parse(calls[0][1].body as string);
expect(body1.input.length).toBeLessThanOrEqual(512 * 3);
expect(body1.input.length).toBeLessThan(longText.length);
// Second call: short text should pass through unchanged
const body2 = JSON.parse(calls[1][1].body as string);
expect(body2.input).toBe(shortText);
});
});

View File

@@ -6,6 +6,7 @@
import OpenAI from "openai";
import type { EmbeddingProvider } from "./config.js";
import { contextLengthForModel } from "./config.js";
type Logger = {
info: (msg: string) => void;
@@ -19,6 +20,7 @@ export class Embeddings {
private readonly provider: EmbeddingProvider;
private readonly baseUrl: string;
private readonly logger: Logger | undefined;
private readonly contextLength: number;
constructor(
private readonly apiKey: string | undefined,
@@ -30,6 +32,7 @@ export class Embeddings {
this.provider = provider;
this.baseUrl = baseUrl ?? (provider === "ollama" ? "http://localhost:11434" : "");
this.logger = logger;
this.contextLength = contextLengthForModel(model);
if (provider === "openai") {
if (!apiKey) {
@@ -39,14 +42,37 @@ export class Embeddings {
}
}
/**
* Truncate text to fit within the model's context length.
* Uses a conservative ~3 chars/token estimate to leave headroom.
* Truncates at a word boundary when possible.
*/
private truncateToContext(text: string): string {
const maxChars = this.contextLength * 3;
if (text.length <= maxChars) return text;
// Try to truncate at a word boundary
let truncated = text.slice(0, maxChars);
const lastSpace = truncated.lastIndexOf(" ");
if (lastSpace > maxChars * 0.8) {
truncated = truncated.slice(0, lastSpace);
}
this.logger?.debug?.(
`memory-neo4j: truncated embedding input from ${text.length} to ${truncated.length} chars (model context: ${this.contextLength} tokens)`,
);
return truncated;
}
/**
* Generate an embedding vector for a single text.
*/
async embed(text: string): Promise<number[]> {
const input = this.truncateToContext(text);
if (this.provider === "ollama") {
return this.embedOllama(text);
return this.embedOllama(input);
}
return this.embedOpenAI(text);
return this.embedOpenAI(input);
}
/**
@@ -61,9 +87,11 @@ export class Embeddings {
return [];
}
const truncated = texts.map((t) => this.truncateToContext(t));
if (this.provider === "ollama") {
// Ollama doesn't support batch, so we do sequential with resilient error handling
const results = await Promise.allSettled(texts.map((t) => this.embedOllama(t)));
const results = await Promise.allSettled(truncated.map((t) => this.embedOllama(t)));
const embeddings: number[][] = [];
let failures = 0;
@@ -90,7 +118,7 @@ export class Embeddings {
return embeddings;
}
return this.embedBatchOpenAI(texts);
return this.embedBatchOpenAI(truncated);
}
private async embedOpenAI(text: string): Promise<number[]> {

View File

@@ -805,9 +805,10 @@ const memoryNeo4jPlugin = {
const agentId = ctx.agentId || "default";
// Truncate prompt to avoid exceeding embedding model context length
// ~6000 chars is safe for most embedding models (leaves headroom for 2k tokens)
const MAX_QUERY_CHARS = 6000;
// ~1500 chars is a safe ceiling for most embedding models (~500 tokens).
// Models with larger context (8k+) can handle more, but recall queries
// don't benefit from very long inputs — the embedding quality plateaus.
const MAX_QUERY_CHARS = 1500;
const query =
event.prompt.length > MAX_QUERY_CHARS
? event.prompt.slice(0, MAX_QUERY_CHARS)