diff --git a/src/agents/pi-embedded-utils.ts b/src/agents/pi-embedded-utils.ts index 801e5c9faa..82ad3efc03 100644 --- a/src/agents/pi-embedded-utils.ts +++ b/src/agents/pi-embedded-utils.ts @@ -1,5 +1,6 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { AssistantMessage } from "@mariozechner/pi-ai"; +import { extractTextFromChatContent } from "../shared/chat-content.js"; import { stripReasoningTagsFromText } from "../shared/text/reasoning-tags.js"; import { sanitizeUserFacingText } from "./pi-embedded-helpers.js"; import { formatToolDetail, resolveToolDisplay } from "./tool-display.js"; @@ -207,25 +208,15 @@ export function stripThinkingTagsFromText(text: string): string { } export function extractAssistantText(msg: AssistantMessage): string { - const isTextBlock = (block: unknown): block is { type: "text"; text: string } => { - if (!block || typeof block !== "object") { - return false; - } - const rec = block as Record; - return rec.type === "text" && typeof rec.text === "string"; - }; - - const blocks = Array.isArray(msg.content) - ? msg.content - .filter(isTextBlock) - .map((c) => - stripThinkingTagsFromText( - stripDowngradedToolCallText(stripMinimaxToolCallXml(c.text)), - ).trim(), - ) - .filter(Boolean) - : []; - const extracted = blocks.join("\n").trim(); + const extracted = + extractTextFromChatContent(msg.content, { + sanitizeText: (text) => + stripThinkingTagsFromText( + stripDowngradedToolCallText(stripMinimaxToolCallXml(text)), + ).trim(), + joinWith: "\n", + normalizeText: (text) => text.trim(), + }) ?? ""; // Only apply keyword-based error rewrites when the assistant message is actually an error. // Otherwise normal prose that *mentions* errors (e.g. "context overflow") can get clobbered. const errorContext = msg.stopReason === "error" || Boolean(msg.errorMessage?.trim()); diff --git a/src/agents/tools/sessions-helpers.ts b/src/agents/tools/sessions-helpers.ts index 09c21e6999..f1a6b427e4 100644 --- a/src/agents/tools/sessions-helpers.ts +++ b/src/agents/tools/sessions-helpers.ts @@ -24,6 +24,7 @@ export { resolveSessionReference, shouldResolveSessionIdInput, } from "./sessions-resolution.js"; +import { extractTextFromChatContent } from "../../shared/chat-content.js"; import { sanitizeUserFacingText } from "../pi-embedded-helpers.js"; import { stripDowngradedToolCallText, @@ -152,23 +153,12 @@ export function extractAssistantText(message: unknown): string | undefined { if (!Array.isArray(content)) { return undefined; } - const chunks: string[] = []; - for (const block of content) { - if (!block || typeof block !== "object") { - continue; - } - if ((block as { type?: unknown }).type !== "text") { - continue; - } - const text = (block as { text?: unknown }).text; - if (typeof text === "string") { - const sanitized = sanitizeTextContent(text); - if (sanitized.trim()) { - chunks.push(sanitized); - } - } - } - const joined = chunks.join("").trim(); + const joined = + extractTextFromChatContent(content, { + sanitizeText: sanitizeTextContent, + joinWith: "", + normalizeText: (text) => text.trim(), + }) ?? ""; const stopReason = (message as { stopReason?: unknown }).stopReason; const errorMessage = (message as { errorMessage?: unknown }).errorMessage; const errorContext = diff --git a/src/shared/chat-content.ts b/src/shared/chat-content.ts index 5f8541d9c6..c052e457eb 100644 --- a/src/shared/chat-content.ts +++ b/src/shared/chat-content.ts @@ -1,8 +1,13 @@ export function extractTextFromChatContent( content: unknown, - opts?: { sanitizeText?: (text: string) => string }, + opts?: { + sanitizeText?: (text: string) => string; + joinWith?: string; + normalizeText?: (text: string) => string; + }, ): string | null { - const normalize = (text: string) => text.replace(/\s+/g, " ").trim(); + const normalize = opts?.normalizeText ?? ((text: string) => text.replace(/\s+/g, " ").trim()); + const joinWith = opts?.joinWith ?? " "; if (typeof content === "string") { const value = opts?.sanitizeText ? opts.sanitizeText(content) : content; @@ -32,6 +37,6 @@ export function extractTextFromChatContent( } } - const joined = normalize(chunks.join(" ")); + const joined = normalize(chunks.join(joinWith)); return joined ? joined : null; } diff --git a/src/shared/shared-misc.test.ts b/src/shared/shared-misc.test.ts index 298b3ff0a5..9ac04ca623 100644 --- a/src/shared/shared-misc.test.ts +++ b/src/shared/shared-misc.test.ts @@ -30,6 +30,22 @@ describe("extractTextFromChatContent", () => { }), ).toBe("Here ok"); }); + + it("supports custom join and normalization", () => { + expect( + extractTextFromChatContent( + [ + { type: "text", text: " hello " }, + { type: "text", text: "world " }, + ], + { + sanitizeText: (text) => text.trim(), + joinWith: "\n", + normalizeText: (text) => text.trim(), + }, + ), + ).toBe("hello\nworld"); + }); }); describe("shared/frontmatter", () => {