diff --git a/src/agents/session-tool-result-guard.ts b/src/agents/session-tool-result-guard.ts index bbb2b0ff2d..8a2644dae4 100644 --- a/src/agents/session-tool-result-guard.ts +++ b/src/agents/session-tool-result-guard.ts @@ -4,6 +4,7 @@ import type { SessionManager } from "@mariozechner/pi-coding-agent"; import { emitSessionTranscriptUpdate } from "../sessions/transcript-events.js"; import { HARD_MAX_TOOL_RESULT_CHARS } from "./pi-embedded-runner/tool-result-truncation.js"; import { makeMissingToolResult, sanitizeToolCallInputs } from "./session-transcript-repair.js"; +import { extractToolCallsFromAssistant, extractToolResultId } from "./tool-call-id.js"; const GUARD_TRUNCATION_SUFFIX = "\n\n⚠️ [Content truncated during persistence — original exceeded size limit. " + @@ -71,45 +72,6 @@ function capToolResultSize(msg: AgentMessage): AgentMessage { return { ...msg, content: newContent } as AgentMessage; } -type ToolCall = { id: string; name?: string }; - -function extractAssistantToolCalls(msg: Extract): ToolCall[] { - const content = msg.content; - if (!Array.isArray(content)) { - return []; - } - - const toolCalls: ToolCall[] = []; - for (const block of content) { - if (!block || typeof block !== "object") { - continue; - } - const rec = block as { type?: unknown; id?: unknown; name?: unknown }; - if (typeof rec.id !== "string" || !rec.id) { - continue; - } - if (rec.type === "toolCall" || rec.type === "toolUse" || rec.type === "functionCall") { - toolCalls.push({ - id: rec.id, - name: typeof rec.name === "string" ? rec.name : undefined, - }); - } - } - return toolCalls; -} - -function extractToolResultId(msg: Extract): string | null { - const toolCallId = (msg as { toolCallId?: unknown }).toolCallId; - if (typeof toolCallId === "string" && toolCallId) { - return toolCallId; - } - const toolUseId = (msg as { toolUseId?: unknown }).toolUseId; - if (typeof toolUseId === "string" && toolUseId) { - return toolUseId; - } - return null; -} - export function installSessionToolResultGuard( sessionManager: SessionManager, opts?: { @@ -206,7 +168,7 @@ export function installSessionToolResultGuard( const toolCalls = nextRole === "assistant" - ? extractAssistantToolCalls(nextMessage as Extract) + ? extractToolCallsFromAssistant(nextMessage as Extract) : []; if (allowSyntheticToolResults) { diff --git a/src/agents/session-transcript-repair.ts b/src/agents/session-transcript-repair.ts index 1ac325b1d5..9eb2fe3581 100644 --- a/src/agents/session-transcript-repair.ts +++ b/src/agents/session-transcript-repair.ts @@ -1,11 +1,5 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; - -type ToolCallLike = { - id: string; - name?: string; -}; - -const TOOL_CALL_TYPES = new Set(["toolCall", "toolUse", "functionCall"]); +import { extractToolCallsFromAssistant, extractToolResultId } from "./tool-call-id.js"; type ToolCallBlock = { type?: unknown; @@ -15,40 +9,15 @@ type ToolCallBlock = { arguments?: unknown; }; -function extractToolCallsFromAssistant( - msg: Extract, -): ToolCallLike[] { - const content = msg.content; - if (!Array.isArray(content)) { - return []; - } - - const toolCalls: ToolCallLike[] = []; - for (const block of content) { - if (!block || typeof block !== "object") { - continue; - } - const rec = block as { type?: unknown; id?: unknown; name?: unknown }; - if (typeof rec.id !== "string" || !rec.id) { - continue; - } - - if (rec.type === "toolCall" || rec.type === "toolUse" || rec.type === "functionCall") { - toolCalls.push({ - id: rec.id, - name: typeof rec.name === "string" ? rec.name : undefined, - }); - } - } - return toolCalls; -} - function isToolCallBlock(block: unknown): block is ToolCallBlock { if (!block || typeof block !== "object") { return false; } const type = (block as { type?: unknown }).type; - return typeof type === "string" && TOOL_CALL_TYPES.has(type); + return ( + typeof type === "string" && + (type === "toolCall" || type === "toolUse" || type === "functionCall") + ); } function hasToolCallInput(block: ToolCallBlock): boolean { @@ -70,18 +39,6 @@ function hasToolCallName(block: ToolCallBlock): boolean { return hasNonEmptyStringField(block.name); } -function extractToolResultId(msg: Extract): string | null { - const toolCallId = (msg as { toolCallId?: unknown }).toolCallId; - if (typeof toolCallId === "string" && toolCallId) { - return toolCallId; - } - const toolUseId = (msg as { toolUseId?: unknown }).toolUseId; - if (typeof toolUseId === "string" && toolUseId) { - return toolUseId; - } - return null; -} - function makeMissingToolResult(params: { toolCallId: string; toolName?: string; diff --git a/src/agents/tool-call-id.ts b/src/agents/tool-call-id.ts index 040a935bea..ed7476b941 100644 --- a/src/agents/tool-call-id.ts +++ b/src/agents/tool-call-id.ts @@ -4,6 +4,12 @@ import { createHash } from "node:crypto"; export type ToolCallIdMode = "strict" | "strict9"; const STRICT9_LEN = 9; +const TOOL_CALL_TYPES = new Set(["toolCall", "toolUse", "functionCall"]); + +export type ToolCallLike = { + id: string; + name?: string; +}; /** * Sanitize a tool call ID to be compatible with various providers. @@ -35,6 +41,47 @@ export function sanitizeToolCallId(id: string, mode: ToolCallIdMode = "strict"): return alphanumericOnly.length > 0 ? alphanumericOnly : "sanitizedtoolid"; } +export function extractToolCallsFromAssistant( + msg: Extract, +): ToolCallLike[] { + const content = msg.content; + if (!Array.isArray(content)) { + return []; + } + + const toolCalls: ToolCallLike[] = []; + for (const block of content) { + if (!block || typeof block !== "object") { + continue; + } + const rec = block as { type?: unknown; id?: unknown; name?: unknown }; + if (typeof rec.id !== "string" || !rec.id) { + continue; + } + if (typeof rec.type === "string" && TOOL_CALL_TYPES.has(rec.type)) { + toolCalls.push({ + id: rec.id, + name: typeof rec.name === "string" ? rec.name : undefined, + }); + } + } + return toolCalls; +} + +export function extractToolResultId( + msg: Extract, +): string | null { + const toolCallId = (msg as { toolCallId?: unknown }).toolCallId; + if (typeof toolCallId === "string" && toolCallId) { + return toolCallId; + } + const toolUseId = (msg as { toolUseId?: unknown }).toolUseId; + if (typeof toolUseId === "string" && toolUseId) { + return toolUseId; + } + return null; +} + export function isValidCloudCodeAssistToolId(id: string, mode: ToolCallIdMode = "strict"): boolean { if (!id || typeof id !== "string") { return false;