diff --git a/src/agents/pi-embedded-runner/run.overflow-compaction.fixture.ts b/src/agents/pi-embedded-runner/run.overflow-compaction.fixture.ts index 2ba720f2a6..7ba709c911 100644 --- a/src/agents/pi-embedded-runner/run.overflow-compaction.fixture.ts +++ b/src/agents/pi-embedded-runner/run.overflow-compaction.fixture.ts @@ -15,6 +15,7 @@ export function makeAttemptResult( messagesSnapshot: [], didSendViaMessagingTool: false, messagingToolSentTexts: [], + messagingToolSentMediaUrls: [], messagingToolSentTargets: [], cloudCodeAssistFormatError: false, ...overrides, diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index e8d5f8aa0d..b87d13503e 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -980,6 +980,7 @@ export async function runEmbeddedPiAgent( }, didSendViaMessagingTool: attempt.didSendViaMessagingTool, messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, messagingToolSentTargets: attempt.messagingToolSentTargets, successfulCronAdds: attempt.successfulCronAdds, }; @@ -1022,6 +1023,7 @@ export async function runEmbeddedPiAgent( }, didSendViaMessagingTool: attempt.didSendViaMessagingTool, messagingToolSentTexts: attempt.messagingToolSentTexts, + messagingToolSentMediaUrls: attempt.messagingToolSentMediaUrls, messagingToolSentTargets: attempt.messagingToolSentTargets, successfulCronAdds: attempt.successfulCronAdds, }; diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index 6489aad876..6fa1103957 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -758,6 +758,7 @@ export async function runEmbeddedAttempt( unsubscribe, waitForCompactionRetry, getMessagingToolSentTexts, + getMessagingToolSentMediaUrls, getMessagingToolSentTargets, getSuccessfulCronAdds, didSendViaMessagingTool, @@ -1178,6 +1179,7 @@ export async function runEmbeddedAttempt( lastToolError: getLastToolError?.(), didSendViaMessagingTool: didSendViaMessagingTool(), messagingToolSentTexts: getMessagingToolSentTexts(), + messagingToolSentMediaUrls: getMessagingToolSentMediaUrls(), messagingToolSentTargets: getMessagingToolSentTargets(), successfulCronAdds: getSuccessfulCronAdds(), cloudCodeAssistFormatError: Boolean( diff --git a/src/agents/pi-embedded-runner/run/types.ts b/src/agents/pi-embedded-runner/run/types.ts index db29b79455..d1371618bc 100644 --- a/src/agents/pi-embedded-runner/run/types.ts +++ b/src/agents/pi-embedded-runner/run/types.ts @@ -45,6 +45,7 @@ export type EmbeddedRunAttemptResult = { }; didSendViaMessagingTool: boolean; messagingToolSentTexts: string[]; + messagingToolSentMediaUrls: string[]; messagingToolSentTargets: MessagingToolSend[]; successfulCronAdds?: number; cloudCodeAssistFormatError: boolean; diff --git a/src/agents/pi-embedded-runner/types.ts b/src/agents/pi-embedded-runner/types.ts index 30a0a5b56e..ac7c723d24 100644 --- a/src/agents/pi-embedded-runner/types.ts +++ b/src/agents/pi-embedded-runner/types.ts @@ -63,6 +63,8 @@ export type EmbeddedPiRunResult = { didSendViaMessagingTool?: boolean; // Texts successfully sent via messaging tools during the run. messagingToolSentTexts?: string[]; + // Media URLs successfully sent via messaging tools during the run. + messagingToolSentMediaUrls?: string[]; // Messaging tool targets that successfully sent a message during the run. messagingToolSentTargets?: MessagingToolSend[]; // Count of successful cron.add tool calls in this run. diff --git a/src/agents/pi-embedded-subscribe.handlers.tools.media.test-helpers.ts b/src/agents/pi-embedded-subscribe.handlers.tools.media.test-helpers.ts new file mode 100644 index 0000000000..057bf55627 --- /dev/null +++ b/src/agents/pi-embedded-subscribe.handlers.tools.media.test-helpers.ts @@ -0,0 +1,68 @@ +import type { AgentEvent } from "@mariozechner/pi-agent-core"; +import type { Mock } from "vitest"; +import type { EmbeddedPiSubscribeContext } from "./pi-embedded-subscribe.handlers.types.js"; +import type { SubscribeEmbeddedPiSessionParams } from "./pi-embedded-subscribe.types.js"; +import { + handleToolExecutionEnd, + handleToolExecutionStart, +} from "./pi-embedded-subscribe.handlers.tools.js"; + +/** + * Narrowed params type that omits the `session` class instance (never accessed + * by the handler paths under test). + */ +type TestParams = Omit; + +/** + * The subset of {@link EmbeddedPiSubscribeContext} that the media-emission + * tests actually populate. Using this avoids the need for `as unknown as` + * double-assertion in every mock factory. + */ +export type MockEmbeddedContext = Omit & { + params: TestParams; +}; + +/** Type-safe bridge: narrows parameter type so callers avoid assertions. */ +function asFullContext(ctx: MockEmbeddedContext): EmbeddedPiSubscribeContext { + return ctx as unknown as EmbeddedPiSubscribeContext; +} + +/** Typed wrapper around {@link handleToolExecutionStart}. */ +export function callToolExecutionStart( + ctx: MockEmbeddedContext, + evt: AgentEvent & { toolName: string; toolCallId: string; args: unknown }, +): Promise { + return handleToolExecutionStart(asFullContext(ctx), evt); +} + +/** Typed wrapper around {@link handleToolExecutionEnd}. */ +export function callToolExecutionEnd( + ctx: MockEmbeddedContext, + evt: AgentEvent & { + toolName: string; + toolCallId: string; + isError: boolean; + result?: unknown; + }, +): Promise { + return handleToolExecutionEnd(asFullContext(ctx), evt); +} + +/** + * Check whether a mock-call argument is an object containing `mediaUrls` + * but NOT `text` (i.e. a "direct media" emission). + */ +export function isDirectMediaCall(call: unknown[]): boolean { + const arg = call[0]; + if (!arg || typeof arg !== "object") { + return false; + } + return "mediaUrls" in arg && !("text" in arg); +} + +/** + * Filter a vi.fn() mock's call log to only direct-media emissions. + */ +export function filterDirectMediaCalls(mock: Mock): unknown[][] { + return mock.mock.calls.filter(isDirectMediaCall); +} diff --git a/src/agents/pi-embedded-subscribe.handlers.tools.ts b/src/agents/pi-embedded-subscribe.handlers.tools.ts index 86b4045994..21f7ec2d1a 100644 --- a/src/agents/pi-embedded-subscribe.handlers.tools.ts +++ b/src/agents/pi-embedded-subscribe.handlers.tools.ts @@ -145,6 +145,11 @@ export async function handleToolExecutionStart( ctx.state.pendingMessagingTexts.set(toolCallId, text); ctx.log.debug(`Tracking pending messaging text: tool=${toolName} len=${text.length}`); } + // Track media URL from messaging tool args (pending until tool_execution_end) + const mediaUrl = argsRecord.mediaUrl ?? argsRecord.path ?? argsRecord.filePath; + if (mediaUrl && typeof mediaUrl === "string") { + ctx.state.pendingMessagingMediaUrls.set(toolCallId, mediaUrl); + } } } } @@ -248,6 +253,14 @@ export async function handleToolExecutionEnd( ctx.trimMessagingToolSent(); } } + const pendingMediaUrl = ctx.state.pendingMessagingMediaUrls.get(toolCallId); + if (pendingMediaUrl) { + ctx.state.pendingMessagingMediaUrls.delete(toolCallId); + if (!isToolError) { + ctx.state.messagingToolSentMediaUrls.push(pendingMediaUrl); + ctx.trimMessagingToolSent(); + } + } // Track committed reminders only when cron.add completed successfully. if (!isToolError && toolName === "cron" && isCronAddAction(startData?.args)) { diff --git a/src/agents/pi-embedded-subscribe.handlers.types.ts b/src/agents/pi-embedded-subscribe.handlers.types.ts index 09ab1328f6..5de8cf8a92 100644 --- a/src/agents/pi-embedded-subscribe.handlers.types.ts +++ b/src/agents/pi-embedded-subscribe.handlers.types.ts @@ -70,9 +70,11 @@ export type EmbeddedPiSubscribeState = { messagingToolSentTexts: string[]; messagingToolSentTextsNormalized: string[]; messagingToolSentTargets: MessagingToolSend[]; + messagingToolSentMediaUrls: string[]; pendingMessagingTexts: Map; pendingMessagingTargets: Map; successfulCronAdds: number; + pendingMessagingMediaUrls: Map; lastAssistant?: AgentMessage; }; @@ -122,6 +124,44 @@ export type EmbeddedPiSubscribeContext = { getCompactionCount: () => number; }; +/** + * Minimal context type for tool execution handlers. Allows + * tests provide only the fields they exercise + * without needing the full `EmbeddedPiSubscribeContext`. + */ +export type ToolHandlerParams = Pick< + SubscribeEmbeddedPiSessionParams, + "runId" | "onBlockReplyFlush" | "onAgentEvent" | "onToolResult" +>; + +export type ToolHandlerState = Pick< + EmbeddedPiSubscribeState, + | "toolMetaById" + | "toolMetas" + | "toolSummaryById" + | "lastToolError" + | "pendingMessagingTargets" + | "pendingMessagingTexts" + | "pendingMessagingMediaUrls" + | "messagingToolSentTexts" + | "messagingToolSentTextsNormalized" + | "messagingToolSentMediaUrls" + | "messagingToolSentTargets" +>; + +export type ToolHandlerContext = { + params: ToolHandlerParams; + state: ToolHandlerState; + log: EmbeddedSubscribeLogger; + hookRunner?: HookRunner; + flushBlockReplyBuffer: () => void; + shouldEmitToolResult: () => boolean; + shouldEmitToolOutput: () => boolean; + emitToolSummary: (toolName?: string, meta?: string) => void; + emitToolOutput: (toolName?: string, meta?: string, output?: string) => void; + trimMessagingToolSent: () => void; +}; + export type EmbeddedPiSubscribeEvent = | AgentEvent | { type: string; [k: string]: unknown } diff --git a/src/agents/pi-embedded-subscribe.ts b/src/agents/pi-embedded-subscribe.ts index d782ed247a..1445eebe76 100644 --- a/src/agents/pi-embedded-subscribe.ts +++ b/src/agents/pi-embedded-subscribe.ts @@ -71,9 +71,11 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar messagingToolSentTexts: [], messagingToolSentTextsNormalized: [], messagingToolSentTargets: [], + messagingToolSentMediaUrls: [], pendingMessagingTexts: new Map(), pendingMessagingTargets: new Map(), successfulCronAdds: 0, + pendingMessagingMediaUrls: new Map(), }; const usageTotals = { input: 0, @@ -91,6 +93,7 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar const messagingToolSentTexts = state.messagingToolSentTexts; const messagingToolSentTextsNormalized = state.messagingToolSentTextsNormalized; const messagingToolSentTargets = state.messagingToolSentTargets; + const messagingToolSentMediaUrls = state.messagingToolSentMediaUrls; const pendingMessagingTexts = state.pendingMessagingTexts; const pendingMessagingTargets = state.pendingMessagingTargets; const replyDirectiveAccumulator = createStreamingDirectiveAccumulator(); @@ -192,6 +195,7 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar // These tools can send messages via sendMessage/threadReply actions (or sessions_send with message). const MAX_MESSAGING_SENT_TEXTS = 200; const MAX_MESSAGING_SENT_TARGETS = 200; + const MAX_MESSAGING_SENT_MEDIA_URLS = 200; const trimMessagingToolSent = () => { if (messagingToolSentTexts.length > MAX_MESSAGING_SENT_TEXTS) { const overflow = messagingToolSentTexts.length - MAX_MESSAGING_SENT_TEXTS; @@ -202,6 +206,10 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar const overflow = messagingToolSentTargets.length - MAX_MESSAGING_SENT_TARGETS; messagingToolSentTargets.splice(0, overflow); } + if (messagingToolSentMediaUrls.length > MAX_MESSAGING_SENT_MEDIA_URLS) { + const overflow = messagingToolSentMediaUrls.length - MAX_MESSAGING_SENT_MEDIA_URLS; + messagingToolSentMediaUrls.splice(0, overflow); + } }; const ensureCompactionPromise = () => { @@ -577,9 +585,11 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar messagingToolSentTexts.length = 0; messagingToolSentTextsNormalized.length = 0; messagingToolSentTargets.length = 0; + messagingToolSentMediaUrls.length = 0; pendingMessagingTexts.clear(); pendingMessagingTargets.clear(); state.successfulCronAdds = 0; + state.pendingMessagingMediaUrls.clear(); resetAssistantMessageState(0); }; @@ -663,6 +673,7 @@ export function subscribeEmbeddedPiSession(params: SubscribeEmbeddedPiSessionPar isCompacting: () => state.compactionInFlight || state.pendingCompactionRetry > 0, isCompactionInFlight: () => state.compactionInFlight, getMessagingToolSentTexts: () => messagingToolSentTexts.slice(), + getMessagingToolSentMediaUrls: () => messagingToolSentMediaUrls.slice(), getMessagingToolSentTargets: () => messagingToolSentTargets.slice(), getSuccessfulCronAdds: () => state.successfulCronAdds, // Returns true if any messaging tool successfully sent a message. diff --git a/src/auto-reply/reply/agent-runner-payloads.ts b/src/auto-reply/reply/agent-runner-payloads.ts index 3c2543e9cb..050da6f5ea 100644 --- a/src/auto-reply/reply/agent-runner-payloads.ts +++ b/src/auto-reply/reply/agent-runner-payloads.ts @@ -10,6 +10,7 @@ import { normalizeReplyPayloadDirectives } from "./reply-delivery.js"; import { applyReplyThreading, filterMessagingToolDuplicates, + filterMessagingToolMediaDuplicates, isRenderablePayload, shouldSuppressMessagingToolReplies, } from "./reply-payloads.js"; @@ -27,6 +28,7 @@ export function buildReplyPayloads(params: { currentMessageId?: string; messageProvider?: string; messagingToolSentTexts?: string[]; + messagingToolSentMediaUrls?: string[]; messagingToolSentTargets?: Parameters< typeof shouldSuppressMessagingToolReplies >[0]["messagingToolSentTargets"]; @@ -93,16 +95,22 @@ export function buildReplyPayloads(params: { payloads: replyTaggedPayloads, sentTexts: messagingToolSentTexts, }); + const mediaFilteredPayloads = filterMessagingToolMediaDuplicates({ + payloads: dedupedPayloads, + sentMediaUrls: params.messagingToolSentMediaUrls ?? [], + }); // Filter out payloads already sent via pipeline or directly during tool flush. const filteredPayloads = shouldDropFinalPayloads ? [] : params.blockStreamingEnabled - ? dedupedPayloads.filter((payload) => !params.blockReplyPipeline?.hasSentPayload(payload)) + ? mediaFilteredPayloads.filter( + (payload) => !params.blockReplyPipeline?.hasSentPayload(payload), + ) : params.directlySentBlockKeys?.size - ? dedupedPayloads.filter( + ? mediaFilteredPayloads.filter( (payload) => !params.directlySentBlockKeys!.has(createBlockReplyPayloadKey(payload)), ) - : dedupedPayloads; + : mediaFilteredPayloads; const replyPayloads = suppressMessagingToolReplies ? [] : filteredPayloads; return { diff --git a/src/auto-reply/reply/agent-runner.ts b/src/auto-reply/reply/agent-runner.ts index ea6d07951b..62f8d7b643 100644 --- a/src/auto-reply/reply/agent-runner.ts +++ b/src/auto-reply/reply/agent-runner.ts @@ -444,6 +444,7 @@ export async function runReplyAgent(params: { currentMessageId: sessionCtx.MessageSidFull ?? sessionCtx.MessageSid, messageProvider: followupRun.run.messageProvider, messagingToolSentTexts: runResult.messagingToolSentTexts, + messagingToolSentMediaUrls: runResult.messagingToolSentMediaUrls, messagingToolSentTargets: runResult.messagingToolSentTargets, originatingTo: sessionCtx.OriginatingTo ?? sessionCtx.To, accountId: sessionCtx.AccountId, diff --git a/src/auto-reply/reply/followup-runner.ts b/src/auto-reply/reply/followup-runner.ts index 9280a8fecf..b35d01f8e6 100644 --- a/src/auto-reply/reply/followup-runner.ts +++ b/src/auto-reply/reply/followup-runner.ts @@ -18,6 +18,7 @@ import { isSilentReplyText, SILENT_REPLY_TOKEN } from "../tokens.js"; import { applyReplyThreading, filterMessagingToolDuplicates, + filterMessagingToolMediaDuplicates, shouldSuppressMessagingToolReplies, } from "./reply-payloads.js"; import { resolveReplyToMode } from "./reply-threading.js"; @@ -252,13 +253,17 @@ export function createFollowupRunner(params: { payloads: replyTaggedPayloads, sentTexts: runResult.messagingToolSentTexts ?? [], }); + const mediaFilteredPayloads = filterMessagingToolMediaDuplicates({ + payloads: dedupedPayloads, + sentMediaUrls: runResult.messagingToolSentMediaUrls ?? [], + }); const suppressMessagingToolReplies = shouldSuppressMessagingToolReplies({ messageProvider: queued.run.messageProvider, messagingToolSentTargets: runResult.messagingToolSentTargets, originatingTo: queued.originatingTo, accountId: queued.run.agentAccountId, }); - const finalPayloads = suppressMessagingToolReplies ? [] : dedupedPayloads; + const finalPayloads = suppressMessagingToolReplies ? [] : mediaFilteredPayloads; if (finalPayloads.length === 0) { return; diff --git a/src/auto-reply/reply/reply-payloads.test.ts b/src/auto-reply/reply/reply-payloads.test.ts new file mode 100644 index 0000000000..160eed93aa --- /dev/null +++ b/src/auto-reply/reply/reply-payloads.test.ts @@ -0,0 +1,61 @@ +import { describe, expect, it } from "vitest"; +import { filterMessagingToolMediaDuplicates } from "./reply-payloads.js"; + +describe("filterMessagingToolMediaDuplicates", () => { + it("strips mediaUrl when it matches sentMediaUrls", () => { + const result = filterMessagingToolMediaDuplicates({ + payloads: [{ text: "hello", mediaUrl: "file:///tmp/photo.jpg" }], + sentMediaUrls: ["file:///tmp/photo.jpg"], + }); + expect(result).toEqual([{ text: "hello", mediaUrl: undefined, mediaUrls: undefined }]); + }); + + it("preserves mediaUrl when it is not in sentMediaUrls", () => { + const result = filterMessagingToolMediaDuplicates({ + payloads: [{ text: "hello", mediaUrl: "file:///tmp/photo.jpg" }], + sentMediaUrls: ["file:///tmp/other.jpg"], + }); + expect(result).toEqual([{ text: "hello", mediaUrl: "file:///tmp/photo.jpg" }]); + }); + + it("filters matching entries from mediaUrls array", () => { + const result = filterMessagingToolMediaDuplicates({ + payloads: [ + { + text: "gallery", + mediaUrls: ["file:///tmp/a.jpg", "file:///tmp/b.jpg", "file:///tmp/c.jpg"], + }, + ], + sentMediaUrls: ["file:///tmp/b.jpg"], + }); + expect(result).toEqual([ + { text: "gallery", mediaUrls: ["file:///tmp/a.jpg", "file:///tmp/c.jpg"] }, + ]); + }); + + it("clears mediaUrls when all entries match", () => { + const result = filterMessagingToolMediaDuplicates({ + payloads: [{ text: "gallery", mediaUrls: ["file:///tmp/a.jpg"] }], + sentMediaUrls: ["file:///tmp/a.jpg"], + }); + expect(result).toEqual([{ text: "gallery", mediaUrl: undefined, mediaUrls: undefined }]); + }); + + it("returns payloads unchanged when no media present", () => { + const payloads = [{ text: "plain text" }]; + const result = filterMessagingToolMediaDuplicates({ + payloads, + sentMediaUrls: ["file:///tmp/photo.jpg"], + }); + expect(result).toStrictEqual(payloads); + }); + + it("returns payloads unchanged when sentMediaUrls is empty", () => { + const payloads = [{ text: "hello", mediaUrl: "file:///tmp/photo.jpg" }]; + const result = filterMessagingToolMediaDuplicates({ + payloads, + sentMediaUrls: [], + }); + expect(result).toBe(payloads); + }); +}); diff --git a/src/auto-reply/reply/reply-payloads.ts b/src/auto-reply/reply/reply-payloads.ts index 9b879026c3..e6b26211a6 100644 --- a/src/auto-reply/reply/reply-payloads.ts +++ b/src/auto-reply/reply/reply-payloads.ts @@ -95,6 +95,31 @@ export function filterMessagingToolDuplicates(params: { return payloads.filter((payload) => !isMessagingToolDuplicate(payload.text ?? "", sentTexts)); } +export function filterMessagingToolMediaDuplicates(params: { + payloads: ReplyPayload[]; + sentMediaUrls: string[]; +}): ReplyPayload[] { + const { payloads, sentMediaUrls } = params; + if (sentMediaUrls.length === 0) { + return payloads; + } + const sentSet = new Set(sentMediaUrls); + return payloads.map((payload) => { + const mediaUrl = payload.mediaUrl; + const mediaUrls = payload.mediaUrls; + const stripSingle = mediaUrl && sentSet.has(mediaUrl); + const filteredUrls = mediaUrls?.filter((u) => !sentSet.has(u)); + if (!stripSingle && (!mediaUrls || filteredUrls?.length === mediaUrls.length)) { + return payload; // No change + } + return { + ...payload, + mediaUrl: stripSingle ? undefined : mediaUrl, + mediaUrls: filteredUrls?.length ? filteredUrls : undefined, + }; + }); +} + function normalizeAccountId(value?: string): string | undefined { const trimmed = value?.trim(); return trimmed ? trimmed.toLowerCase() : undefined;