diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 6815a92e91..e8d5f8aa0d 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -1,7 +1,9 @@ import fs from "node:fs/promises"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; +import type { PluginHookBeforeAgentStartResult } from "../../plugins/types.js"; import type { RunEmbeddedPiAgentParams } from "./run/params.js"; import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; +import { getGlobalHookRunner } from "../../plugins/hook-runner-global.js"; import { enqueueCommandInLane } from "../../process/command-queue.js"; import { isMarkdownCapableMessageChannel } from "../../utils/message-channel.js"; import { resolveOpenClawAgentDir } from "../agent-paths.js"; @@ -198,13 +200,43 @@ export async function runEmbeddedPiAgent( } const prevCwd = process.cwd(); - const provider = (params.provider ?? DEFAULT_PROVIDER).trim() || DEFAULT_PROVIDER; - const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL; + let provider = (params.provider ?? DEFAULT_PROVIDER).trim() || DEFAULT_PROVIDER; + let modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL; const agentDir = params.agentDir ?? resolveOpenClawAgentDir(); const fallbackConfigured = (params.config?.agents?.defaults?.model?.fallbacks?.length ?? 0) > 0; await ensureOpenClawModelsJson(params.config, agentDir); + // Run before_agent_start hooks early so plugins can override the model + // before it gets resolved. The hook result is passed downstream to + // attempt.ts to avoid double-firing. + let earlyHookResult: PluginHookBeforeAgentStartResult | undefined; + const hookRunner = getGlobalHookRunner(); + if (hookRunner?.hasHooks("before_agent_start")) { + try { + earlyHookResult = await hookRunner.runBeforeAgentStart( + { prompt: params.prompt }, + { + agentId: params.agentId, + sessionKey: params.sessionKey, + sessionId: params.sessionId, + workspaceDir: params.workspaceDir, + messageProvider: params.messageProvider ?? undefined, + }, + ); + if (earlyHookResult?.providerOverride) { + provider = earlyHookResult.providerOverride; + log.info(`[hooks] provider overridden to ${provider}`); + } + if (earlyHookResult?.modelOverride) { + modelId = earlyHookResult.modelOverride; + log.info(`[hooks] model overridden to ${modelId}`); + } + } catch (hookErr) { + log.warn(`before_agent_start hook (early) failed: ${String(hookErr)}`); + } + } + const { model, error, authStorage, modelRegistry } = resolveModel( provider, modelId, @@ -479,6 +511,7 @@ export async function runEmbeddedPiAgent( streamParams: params.streamParams, ownerNumbers: params.ownerNumbers, enforceFinalTag: params.enforceFinalTag, + earlyHookResult, }); const { diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index 084b84a39a..6489aad876 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -850,31 +850,37 @@ export async function runEmbeddedAttempt( try { const promptStartedAt = Date.now(); - // Run before_agent_start hooks to allow plugins to inject context + // Run before_agent_start hooks to allow plugins to inject context. + // If run.ts already fired the hook (for model override), reuse its result. let effectivePrompt = params.prompt; - if (hookRunner?.hasHooks("before_agent_start")) { - try { - const hookResult = await hookRunner.runBeforeAgentStart( - { - prompt: params.prompt, - messages: activeSession.messages, - }, - { - agentId: hookAgentId, - sessionKey: params.sessionKey, - sessionId: params.sessionId, - workspaceDir: params.workspaceDir, - messageProvider: params.messageProvider ?? undefined, - }, + const hookResult = + params.earlyHookResult ?? + (hookRunner?.hasHooks("before_agent_start") + ? await hookRunner + .runBeforeAgentStart( + { + prompt: params.prompt, + messages: activeSession.messages, + }, + { + agentId: hookAgentId, + sessionKey: params.sessionKey, + sessionId: params.sessionId, + workspaceDir: params.workspaceDir, + messageProvider: params.messageProvider ?? undefined, + }, + ) + .catch((hookErr: unknown) => { + log.warn(`before_agent_start hook failed: ${String(hookErr)}`); + return undefined; + }) + : undefined); + { + if (hookResult?.prependContext) { + effectivePrompt = `${hookResult.prependContext}\n\n${params.prompt}`; + log.debug( + `hooks: prepended context to prompt (${hookResult.prependContext.length} chars)`, ); - if (hookResult?.prependContext) { - effectivePrompt = `${hookResult.prependContext}\n\n${params.prompt}`; - log.debug( - `hooks: prepended context to prompt (${hookResult.prependContext.length} chars)`, - ); - } - } catch (hookErr) { - log.warn(`before_agent_start hook failed: ${String(hookErr)}`); } } diff --git a/src/agents/pi-embedded-runner/run/types.ts b/src/agents/pi-embedded-runner/run/types.ts index 8436f0e2e5..db29b79455 100644 --- a/src/agents/pi-embedded-runner/run/types.ts +++ b/src/agents/pi-embedded-runner/run/types.ts @@ -2,6 +2,7 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; import type { Api, AssistantMessage, Model } from "@mariozechner/pi-ai"; import type { ThinkLevel } from "../../../auto-reply/thinking.js"; import type { SessionSystemPromptReport } from "../../../config/sessions/types.js"; +import type { PluginHookBeforeAgentStartResult } from "../../../plugins/types.js"; import type { MessagingToolSend } from "../../pi-embedded-messaging.js"; import type { AuthStorage, ModelRegistry } from "../../pi-model-discovery.js"; import type { NormalizedUsage } from "../../usage.js"; @@ -19,6 +20,8 @@ export type EmbeddedRunAttemptParams = EmbeddedRunAttemptBase & { authStorage: AuthStorage; modelRegistry: ModelRegistry; thinkLevel: ThinkLevel; + /** Pre-computed hook result from run.ts to avoid double-firing before_agent_start. */ + earlyHookResult?: PluginHookBeforeAgentStartResult; }; export type EmbeddedRunAttemptResult = { diff --git a/src/plugins/hooks.ts b/src/plugins/hooks.ts index d05774089c..12f69eb848 100644 --- a/src/plugins/hooks.ts +++ b/src/plugins/hooks.ts @@ -200,6 +200,8 @@ export function createHookRunner(registry: PluginRegistry, options: HookRunnerOp acc?.prependContext && next.prependContext ? `${acc.prependContext}\n\n${next.prependContext}` : (next.prependContext ?? acc?.prependContext), + modelOverride: next.modelOverride ?? acc?.modelOverride, + providerOverride: next.providerOverride ?? acc?.providerOverride, }), ); } diff --git a/src/plugins/types.ts b/src/plugins/types.ts index ad9d283ccd..25ee3ced18 100644 --- a/src/plugins/types.ts +++ b/src/plugins/types.ts @@ -332,6 +332,10 @@ export type PluginHookBeforeAgentStartEvent = { export type PluginHookBeforeAgentStartResult = { systemPrompt?: string; prependContext?: string; + /** Override the model for this agent run. E.g. "llama3.3:8b" */ + modelOverride?: string; + /** Override the provider for this agent run. E.g. "ollama" */ + providerOverride?: string; }; // llm_input hook