mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Compare commits
34 Commits
test-scree
...
spare/test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aba4a2b548 | ||
|
|
9f36e197aa | ||
|
|
2e7b674625 | ||
|
|
f4fed71e3d | ||
|
|
e516c9ce3a | ||
|
|
86898ff0d8 | ||
|
|
37de838652 | ||
|
|
c5eff58bf8 | ||
|
|
2ba0082e78 | ||
|
|
7ef10b26c0 | ||
|
|
1dfc75520d | ||
|
|
642b9c29c6 | ||
|
|
e7457983a1 | ||
|
|
799201bbe9 | ||
|
|
7ee0b0aeab | ||
|
|
35e92e00ca | ||
|
|
3bc28ac691 | ||
|
|
1316e16f04 | ||
|
|
0591804272 | ||
|
|
0d8a27fb7a | ||
|
|
c9a86e8339 | ||
|
|
e48144b356 | ||
|
|
54d6d4a3e6 | ||
|
|
7dc3b880a6 | ||
|
|
1848810b32 | ||
|
|
2f8d2e10da | ||
|
|
4dc3d0c34c | ||
|
|
9cfaaba3b6 | ||
|
|
f5d3a6e606 | ||
|
|
627b52048b | ||
|
|
da5420fa07 | ||
|
|
fce7a59713 | ||
|
|
95d3679e14 | ||
|
|
89f8060c5d |
@@ -75,8 +75,6 @@ from backend.copilot.tools.models import (
|
||||
NoResultsResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
TaskResponse,
|
||||
TodoWriteResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
@@ -1421,8 +1419,6 @@ ToolResponseUnion = (
|
||||
| MemorySearchResponse
|
||||
| MemoryForgetCandidatesResponse
|
||||
| MemoryForgetConfirmResponse
|
||||
| TodoWriteResponse
|
||||
| TaskResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.llm import extract_openrouter_cost
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
@@ -240,24 +239,12 @@ class PerplexityBlock(Block):
|
||||
if "message" in choice and "annotations" in choice["message"]:
|
||||
annotations = choice["message"]["annotations"]
|
||||
|
||||
# Update execution stats. ``execution_stats`` is instance state,
|
||||
# so always reset token counters — a response without ``usage``
|
||||
# must not leak a previous run's tokens into ``PlatformCostLog``.
|
||||
self.execution_stats.input_token_count = 0
|
||||
self.execution_stats.output_token_count = 0
|
||||
# Update execution stats
|
||||
if response.usage:
|
||||
self.execution_stats.input_token_count = response.usage.prompt_tokens
|
||||
self.execution_stats.output_token_count = (
|
||||
response.usage.completion_tokens
|
||||
)
|
||||
# OpenRouter's ``x-total-cost`` response header carries the real
|
||||
# per-request USD cost. Piping it into ``provider_cost`` lets the
|
||||
# direct-run ``PlatformCostLog`` flow
|
||||
# (``executor.cost_tracking::log_system_credential_cost``) record
|
||||
# the actual operator-side spend instead of inferring from tokens.
|
||||
# Always overwrite — ``execution_stats`` is instance state, so a
|
||||
# response without the header must not reuse a previous run's cost.
|
||||
self.execution_stats.provider_cost = extract_openrouter_cost(response)
|
||||
|
||||
return {"response": response_content, "annotations": annotations or []}
|
||||
|
||||
|
||||
@@ -50,13 +50,13 @@ _VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
|
||||
# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without
|
||||
# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React
|
||||
# re-render of the non-virtualised chat list, which paint-storms the browser
|
||||
# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows
|
||||
# cuts the event rate ~150x while staying snappy enough that the Reasoning
|
||||
# main thread and freezes the UI. Batching into ~32-char / ~40 ms windows
|
||||
# cuts the event rate ~100x while staying snappy enough that the Reasoning
|
||||
# collapse still feels live (well under the ~100 ms perceptual threshold).
|
||||
# Per-delta persistence to ``session.messages`` stays granular — we only
|
||||
# coalesce the *wire* emission.
|
||||
_COALESCE_MIN_CHARS = 64
|
||||
_COALESCE_MAX_INTERVAL_MS = 50.0
|
||||
_COALESCE_MIN_CHARS = 32
|
||||
_COALESCE_MAX_INTERVAL_MS = 40.0
|
||||
|
||||
|
||||
class ReasoningDetail(BaseModel):
|
||||
@@ -243,11 +243,8 @@ class BaselineReasoningEmitter:
|
||||
in-place as further deltas arrive; :meth:`close` drops the reference
|
||||
but leaves the appended row intact.
|
||||
|
||||
``render_in_ui=False`` suppresses only the live wire events
|
||||
(``StreamReasoning*``); the ``role='reasoning'`` persistence row is
|
||||
still appended so ``convertChatSessionToUiMessages.ts`` can hydrate
|
||||
the reasoning bubble on reload. The state machine advances
|
||||
identically either way.
|
||||
``render_in_ui=False`` suppresses wire events + persistence row;
|
||||
state machine still advances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -262,8 +259,9 @@ class BaselineReasoningEmitter:
|
||||
self._open: bool = False
|
||||
self._session_messages = session_messages
|
||||
self._current_row: ChatMessage | None = None
|
||||
# Coalescing state — tests can disable (``=0``) for deterministic
|
||||
# event assertions.
|
||||
# Coalescing state — ``_pending_delta`` accumulates reasoning text
|
||||
# between wire flushes. Tuning knobs are kwargs so tests can
|
||||
# disable coalescing (``=0``) for deterministic event assertions.
|
||||
self._coalesce_min_chars = coalesce_min_chars
|
||||
self._coalesce_max_interval_ms = coalesce_max_interval_ms
|
||||
self._pending_delta: str = ""
|
||||
@@ -305,7 +303,7 @@ class BaselineReasoningEmitter:
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
self._open = True
|
||||
self._last_flush_monotonic = now
|
||||
if self._session_messages is not None:
|
||||
if self._render_in_ui and self._session_messages is not None:
|
||||
self._current_row = ChatMessage(role="reasoning", content=text)
|
||||
self._session_messages.append(self._current_row)
|
||||
return events
|
||||
|
||||
@@ -478,11 +478,10 @@ class TestBaselineReasoningEmitterRenderFlag:
|
||||
assert events == []
|
||||
assert emitter.is_open is False
|
||||
|
||||
def test_render_off_still_persists(self):
|
||||
"""Persistence is decoupled from the render flag — session
|
||||
transcript always keeps the ``role="reasoning"`` row so audit
|
||||
and ``--resume``-equivalent replay never lose thinking text.
|
||||
The frontend gates rendering separately."""
|
||||
def test_render_off_skips_persistence(self):
|
||||
"""When render is off the emitter must NOT append a ``role="reasoning"``
|
||||
row to ``session_messages`` — hydration would re-render it, undoing
|
||||
the operator's intent."""
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session, render_in_ui=False)
|
||||
|
||||
@@ -490,9 +489,7 @@ class TestBaselineReasoningEmitterRenderFlag:
|
||||
emitter.on_delta(_delta(reasoning="part two"))
|
||||
emitter.close()
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].role == "reasoning"
|
||||
assert session[0].content == "part one part two"
|
||||
assert session == []
|
||||
|
||||
def test_render_off_rotates_block_id_between_sessions(self):
|
||||
"""Even with wire events silenced the block id must rotate on close,
|
||||
|
||||
@@ -15,11 +15,10 @@ import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from contextvars import ContextVar
|
||||
from collections.abc import AsyncGenerator, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
@@ -84,7 +83,6 @@ from backend.copilot.session_cleanup import prune_orphan_tool_calls
|
||||
from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tools.models import TaskResponse
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
@@ -124,29 +122,6 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
# Maximum number of tool-call rounds before forcing a text response.
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
# Task (in-process sub-agent) safeguards.
|
||||
# Depth cap keeps runaway recursion from exhausting tokens or the OpenRouter
|
||||
# credit budget; the iteration cap bounds a single sub-agent's own tool-call
|
||||
# loop (separate from the parent's _MAX_TOOL_ROUNDS).
|
||||
_MAX_TASK_DEPTH = 2
|
||||
_MAX_TASK_ITERATIONS = 15
|
||||
|
||||
# Tracks Task nesting for the current request context — inspected by
|
||||
# ``_run_task_subagent`` to refuse deeper spawns. ContextVars survive across
|
||||
# ``await`` points which lets the depth check work inside the async
|
||||
# ``tool_call_loop``.
|
||||
_TASK_DEPTH_VAR: ContextVar[int] = ContextVar("copilot_baseline_task_depth", default=0)
|
||||
|
||||
_TASK_INNER_SYSTEM_PROMPT = (
|
||||
"You are an in-process sub-agent invoked via the `Task` tool by a "
|
||||
"parent copilot. Execute the task described in the user message using "
|
||||
"the tools available to you, then return a concise final summary. "
|
||||
"Intermediate tool calls and reasoning stay hidden from the parent — "
|
||||
"only your final message is surfaced back. Do not invoke the `Task` "
|
||||
"tool yourself; keep focus on the requested unit of work and stop as "
|
||||
"soon as you have a usable answer."
|
||||
)
|
||||
|
||||
# Max seconds to wait for transcript upload in the finally block before
|
||||
# letting it continue as a background task (tracked in _background_tasks).
|
||||
_TRANSCRIPT_UPLOAD_TIMEOUT_S = 5
|
||||
@@ -368,21 +343,7 @@ class _BaselineStreamState:
|
||||
"""
|
||||
|
||||
model: str = ""
|
||||
# Live delivery channel drained concurrently by ``stream_chat_completion_baseline``
|
||||
# so reasoning / text / tool events reach the SSE wire **during** the upstream
|
||||
# LLM stream, not after ``_baseline_llm_caller`` returns. Before this was a
|
||||
# ``list`` drained per ``tool_call_loop`` iteration, so any model with
|
||||
# extended thinking (Anthropic via OpenRouter, Moonshot, future reasoning
|
||||
# routes) froze the UI for the entire duration of each LLM round before
|
||||
# flushing the backlog in one burst. The queue is single-producer (the
|
||||
# streaming loop) / single-consumer (the outer async-gen yield loop);
|
||||
# ``None`` is the close sentinel.
|
||||
pending_events: asyncio.Queue[StreamBaseResponse | None] = field(
|
||||
default_factory=asyncio.Queue
|
||||
)
|
||||
# Mirror of every event put on ``pending_events`` — kept for unit tests that
|
||||
# inspect post-hoc what was emitted. Not consumed by production code.
|
||||
emitted_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text_started: bool = False
|
||||
@@ -430,26 +391,6 @@ class _BaselineStreamState:
|
||||
)
|
||||
|
||||
|
||||
def _emit(state: "_BaselineStreamState", event: StreamBaseResponse) -> None:
|
||||
"""Queue *event* for the live SSE wire AND mirror into ``emitted_events``.
|
||||
|
||||
Single helper so every streaming producer (LLM stream loop, tool executor,
|
||||
conversation updater) posts to the same single-consumer queue. The mirror
|
||||
list is read-only from production code — it exists so unit tests can assert
|
||||
on the full sequence emitted during one call.
|
||||
"""
|
||||
state.pending_events.put_nowait(event)
|
||||
state.emitted_events.append(event)
|
||||
|
||||
|
||||
def _emit_all(
|
||||
state: "_BaselineStreamState", events: Iterable[StreamBaseResponse]
|
||||
) -> None:
|
||||
"""Queue *events* in order — convenience for emitter batches."""
|
||||
for event in events:
|
||||
_emit(state, event)
|
||||
|
||||
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
"""Return True if *model* routes to Anthropic (native or via OpenRouter).
|
||||
|
||||
@@ -584,7 +525,7 @@ async def _baseline_llm_caller(
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
_emit(state, StreamStartStep())
|
||||
state.pending_events.append(StreamStartStep())
|
||||
# Fresh thinking-strip state per round so a malformed unclosed
|
||||
# block in one LLM call cannot silently drop content in the next.
|
||||
state.thinking_stripper = _ThinkingStripper()
|
||||
@@ -686,30 +627,31 @@ async def _baseline_llm_caller(
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
_emit_all(state, state.reasoning_emitter.on_delta(delta))
|
||||
state.pending_events.extend(state.reasoning_emitter.on_delta(delta))
|
||||
|
||||
if delta.content:
|
||||
# Text and reasoning must not interleave on the wire — the
|
||||
# AI SDK maps distinct start/end pairs to distinct UI
|
||||
# parts. Close any open reasoning block before emitting
|
||||
# the first text delta of this run.
|
||||
_emit_all(state, state.reasoning_emitter.close())
|
||||
state.pending_events.extend(state.reasoning_emitter.close())
|
||||
emit = state.thinking_stripper.process(delta.content)
|
||||
if emit:
|
||||
if not state.text_started:
|
||||
_emit(state, StreamTextStart(id=state.text_block_id))
|
||||
state.pending_events.append(
|
||||
StreamTextStart(id=state.text_block_id)
|
||||
)
|
||||
state.text_started = True
|
||||
round_text += emit
|
||||
_emit(
|
||||
state,
|
||||
StreamTextDelta(id=state.text_block_id, delta=emit),
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=emit)
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
# Same rule as the text branch: close any open reasoning
|
||||
# block before a tool_use starts so the AI SDK treats
|
||||
# reasoning and tool-use as distinct parts.
|
||||
_emit_all(state, state.reasoning_emitter.close())
|
||||
state.pending_events.extend(state.reasoning_emitter.close())
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
@@ -740,17 +682,19 @@ async def _baseline_llm_caller(
|
||||
# ``async for chunk in response`` would otherwise leave reasoning
|
||||
# and/or text unterminated and only ``StreamFinishStep`` emitted —
|
||||
# the Reasoning / Text collapses would never finalise.
|
||||
_emit_all(state, state.reasoning_emitter.close())
|
||||
state.pending_events.extend(state.reasoning_emitter.close())
|
||||
# Flush any buffered text held back by the thinking stripper.
|
||||
tail = state.thinking_stripper.flush()
|
||||
if tail:
|
||||
if not state.text_started:
|
||||
_emit(state, StreamTextStart(id=state.text_block_id))
|
||||
state.pending_events.append(StreamTextStart(id=state.text_block_id))
|
||||
state.text_started = True
|
||||
round_text += tail
|
||||
_emit(state, StreamTextDelta(id=state.text_block_id, delta=tail))
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=tail)
|
||||
)
|
||||
if state.text_started:
|
||||
_emit(state, StreamTextEnd(id=state.text_block_id))
|
||||
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
@@ -758,7 +702,7 @@ async def _baseline_llm_caller(
|
||||
state.assistant_text += round_text
|
||||
# Always emit StreamFinishStep to match the StreamStartStep,
|
||||
# even if an exception occurred during streaming.
|
||||
_emit(state, StreamFinishStep())
|
||||
state.pending_events.append(StreamFinishStep())
|
||||
|
||||
# Convert to shared format
|
||||
llm_tool_calls = [
|
||||
@@ -800,14 +744,13 @@ async def _baseline_tool_executor(
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
_emit(
|
||||
state,
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
@@ -816,17 +759,15 @@ async def _baseline_tool_executor(
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
_emit(
|
||||
state,
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name),
|
||||
state.pending_events.append(
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
)
|
||||
_emit(
|
||||
state,
|
||||
state.pending_events.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Announce the tool call to the session so in-turn guards like
|
||||
@@ -843,28 +784,14 @@ async def _baseline_tool_executor(
|
||||
session.announce_inflight_tool_call(tool_name)
|
||||
|
||||
try:
|
||||
if tool_name == "Task":
|
||||
# In-process sub-agent: baseline's answer to the SDK's CLI-native
|
||||
# Task tool. The outer ``execute_tool`` dispatch would hit the
|
||||
# TaskTool stub and return an error; we spin up a nested
|
||||
# ``tool_call_loop`` here so the parent context stays clean.
|
||||
result: StreamToolOutputAvailable = await _run_task_subagent(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_args=tool_args,
|
||||
tools=tools,
|
||||
parent_state=state,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
else:
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
_emit(state, result)
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=tool_args,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
state.pending_events.append(result)
|
||||
tool_output = (
|
||||
result.output if isinstance(result.output, str) else str(result.output)
|
||||
)
|
||||
@@ -881,14 +808,13 @@ async def _baseline_tool_executor(
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
_emit(
|
||||
state,
|
||||
state.pending_events.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
@@ -898,243 +824,6 @@ async def _baseline_tool_executor(
|
||||
)
|
||||
|
||||
|
||||
def _task_error_output(
|
||||
tool_call_id: str,
|
||||
message: str,
|
||||
*,
|
||||
description: str = "",
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Build a ``StreamToolOutputAvailable`` for a Task that failed pre-flight.
|
||||
|
||||
Error cases (parse failure, missing prompt, depth cap) short-circuit
|
||||
before the nested loop starts so parent-side usage accounting stays
|
||||
untouched.
|
||||
"""
|
||||
body = TaskResponse(
|
||||
message=message,
|
||||
description=description,
|
||||
response="",
|
||||
status="error",
|
||||
error=message,
|
||||
)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName="Task",
|
||||
output=body.model_dump_json(exclude_none=True),
|
||||
success=False,
|
||||
)
|
||||
|
||||
|
||||
def _inner_task_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
) -> None:
|
||||
"""Stripped-down conversation updater used by in-process Task sub-agents.
|
||||
|
||||
The sub-agent's message list needs to grow so the tool_call_loop can
|
||||
follow assistant/tool turns to a natural finish, but we deliberately
|
||||
skip the parent's transcript-builder and session-message writes so the
|
||||
sub-agent's step-by-step trace doesn't pollute the user-visible
|
||||
conversation history.
|
||||
"""
|
||||
_mutate_openai_messages(messages, response, tool_results)
|
||||
|
||||
|
||||
async def _run_task_subagent(
|
||||
*,
|
||||
tool_call_id: str,
|
||||
tool_args: dict[str, Any],
|
||||
tools: Sequence[Any],
|
||||
parent_state: _BaselineStreamState,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> StreamToolOutputAvailable:
|
||||
"""Execute a `Task` tool call as a nested, context-isolated loop.
|
||||
|
||||
The sub-agent runs with a fresh ``_BaselineStreamState`` so its streaming
|
||||
events (text deltas, StreamStartStep/StreamFinishStep envelopes, nested
|
||||
tool_use/tool_result pairs) stay out of the parent stream. When it
|
||||
finishes we roll up token/cost counters to the parent state and return
|
||||
only the final assistant text — that text becomes the tool_result the
|
||||
parent LLM sees, so the parent context is never polluted by the
|
||||
sub-agent's intermediate reasoning.
|
||||
|
||||
Safeguards:
|
||||
- Depth cap via ``_TASK_DEPTH_VAR`` prevents runaway recursion.
|
||||
- Iteration cap (``_MAX_TASK_ITERATIONS``) bounds one sub-agent's loop
|
||||
independently of the parent's ``_MAX_TOOL_ROUNDS``.
|
||||
- The inner tool list excludes ``Task`` so a sub-agent can't re-enter
|
||||
this path; combined with the depth cap this gives defence in depth.
|
||||
"""
|
||||
description = str(tool_args.get("description") or "").strip()
|
||||
prompt = str(tool_args.get("prompt") or "").strip()
|
||||
|
||||
if not prompt:
|
||||
return _task_error_output(
|
||||
tool_call_id,
|
||||
"Task requires a non-empty `prompt`.",
|
||||
description=description,
|
||||
)
|
||||
|
||||
depth = _TASK_DEPTH_VAR.get()
|
||||
if depth >= _MAX_TASK_DEPTH:
|
||||
return _task_error_output(
|
||||
tool_call_id,
|
||||
(
|
||||
f"Task nesting depth limit reached (max {_MAX_TASK_DEPTH}). "
|
||||
"Collapse the outer Task or perform the step inline."
|
||||
),
|
||||
description=description,
|
||||
)
|
||||
|
||||
inner_state = _BaselineStreamState(model=parent_state.model)
|
||||
inner_messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": _TASK_INNER_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
inner_tools: list[Any] = [
|
||||
t for t in tools if (t.get("function") or {}).get("name") != "Task"
|
||||
]
|
||||
# The parent pre-marks ``cache_control`` on the last tool schema once per
|
||||
# session for Anthropic routes (see ``_mark_tools_with_cache_control``
|
||||
# in ``stream_chat_completion_baseline``). Filtering ``Task`` off the
|
||||
# end can drop that marker, which would make every sub-agent LLM round
|
||||
# re-send the ~8 KB tool schema uncached. Re-apply on Anthropic routes
|
||||
# so each inner round hits the cache from round 2 onward; no-op for
|
||||
# OpenAI / Grok / other providers where ``cache_control`` is unknown.
|
||||
if _is_anthropic_model(parent_state.model) and inner_tools:
|
||||
inner_tools = cast(list[Any], _mark_tools_with_cache_control(inner_tools))
|
||||
|
||||
tool_names_seen: list[str] = []
|
||||
iterations = 0
|
||||
final_response_text = ""
|
||||
finished_naturally = True
|
||||
|
||||
inner_executor = partial(
|
||||
_baseline_tool_executor,
|
||||
state=inner_state,
|
||||
user_id=user_id,
|
||||
# NOTE: the sub-agent deliberately shares the parent's ``session``
|
||||
# object — ``_baseline_tool_executor`` calls
|
||||
# ``session.announce_inflight_tool_call(tool_name)`` which feeds
|
||||
# in-turn guards like ``require_guide_read``. Cross-contaminating
|
||||
# the announce-set is the intended behaviour: if the parent calls
|
||||
# ``get_agent_building_guide`` and the sub-agent then calls
|
||||
# ``create_agent``, the guard should recognise the prereq was met.
|
||||
# Message history and streaming events ARE isolated (fresh
|
||||
# ``_BaselineStreamState`` above) — only the announce-set leaks.
|
||||
session=session,
|
||||
)
|
||||
inner_llm = partial(_baseline_llm_caller, state=inner_state)
|
||||
|
||||
task_exc: Exception | None = None
|
||||
token = _TASK_DEPTH_VAR.set(depth + 1)
|
||||
try:
|
||||
try:
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=inner_messages,
|
||||
tools=inner_tools,
|
||||
llm_call=inner_llm,
|
||||
execute_tool=inner_executor,
|
||||
update_conversation=_inner_task_conversation_updater,
|
||||
max_iterations=_MAX_TASK_ITERATIONS,
|
||||
last_iteration_message=(
|
||||
"This is the final iteration — produce your summary now."
|
||||
),
|
||||
):
|
||||
# Discard inner streaming events so only the Task envelope
|
||||
# and its final summary reach the parent client. The inner
|
||||
# state's queue must be drained (not cleared — it's an
|
||||
# ``asyncio.Queue``) each iteration so the next round's
|
||||
# producers don't block on a full buffer. Token accounting
|
||||
# still happens via ``inner_state`` and rolls up after the
|
||||
# loop exits.
|
||||
while not inner_state.pending_events.empty():
|
||||
inner_state.pending_events.get_nowait()
|
||||
inner_state.emitted_events.clear()
|
||||
for tc in loop_result.last_tool_calls:
|
||||
tool_names_seen.append(tc.name)
|
||||
iterations = loop_result.iterations
|
||||
finished_naturally = loop_result.finished_naturally
|
||||
if loop_result.finished_naturally:
|
||||
final_response_text = loop_result.response_text or ""
|
||||
except Exception as exc:
|
||||
task_exc = exc
|
||||
# ``CancelledError`` / ``KeyboardInterrupt`` / ``SystemExit``
|
||||
# derive from ``BaseException`` and are intentionally NOT caught
|
||||
# here — they propagate through the outer ``finally`` below, which
|
||||
# still resets the depth counter and rolls up usage before the
|
||||
# exception reaches the caller. Letting them bubble naturally
|
||||
# avoids the ``except BaseException`` suppressor pattern.
|
||||
finally:
|
||||
_TASK_DEPTH_VAR.reset(token)
|
||||
# Usage rolls up on every path (success, caught Exception, or
|
||||
# propagating BaseException) so partial work is still billed.
|
||||
_absorb_inner_usage(parent_state, inner_state)
|
||||
|
||||
if task_exc is not None:
|
||||
logger.error(
|
||||
"[Baseline] Task sub-agent failed: %s", task_exc, exc_info=task_exc
|
||||
)
|
||||
body = TaskResponse(
|
||||
message=f"Task failed: {task_exc}",
|
||||
description=description,
|
||||
response="",
|
||||
iterations=iterations,
|
||||
tool_calls=tool_names_seen,
|
||||
status="error",
|
||||
error=str(task_exc),
|
||||
)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName="Task",
|
||||
output=body.model_dump_json(exclude_none=True),
|
||||
success=False,
|
||||
)
|
||||
|
||||
status: Literal["completed", "max_iterations"] = (
|
||||
"completed" if finished_naturally else "max_iterations"
|
||||
)
|
||||
body = TaskResponse(
|
||||
message=(
|
||||
"Task completed."
|
||||
if status == "completed"
|
||||
else f"Task hit iteration limit ({_MAX_TASK_ITERATIONS})."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
description=description,
|
||||
response=final_response_text,
|
||||
iterations=iterations,
|
||||
tool_calls=tool_names_seen,
|
||||
status=status,
|
||||
)
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName="Task",
|
||||
output=body.model_dump_json(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _absorb_inner_usage(
|
||||
parent_state: _BaselineStreamState,
|
||||
inner_state: _BaselineStreamState,
|
||||
) -> None:
|
||||
"""Fold a sub-agent's token/cost counters back into the parent state.
|
||||
|
||||
Usage accounting happens once per turn via
|
||||
``persist_and_record_usage`` (see ``stream_chat_completion_baseline``);
|
||||
the sub-agent runs under the same turn, so the user must still be
|
||||
billed for its LLM calls.
|
||||
"""
|
||||
parent_state.turn_prompt_tokens += inner_state.turn_prompt_tokens
|
||||
parent_state.turn_completion_tokens += inner_state.turn_completion_tokens
|
||||
parent_state.turn_cache_read_tokens += inner_state.turn_cache_read_tokens
|
||||
parent_state.turn_cache_creation_tokens += inner_state.turn_cache_creation_tokens
|
||||
if inner_state.cost_usd is not None:
|
||||
parent_state.cost_usd = (parent_state.cost_usd or 0.0) + inner_state.cost_usd
|
||||
|
||||
|
||||
def _mutate_openai_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
@@ -1977,172 +1666,139 @@ async def stream_chat_completion_baseline(
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Run the tool-call loop concurrently with the event consumer so
|
||||
# ``StreamReasoning*`` / ``StreamText*`` deltas emitted inside
|
||||
# ``_baseline_llm_caller`` reach the SSE wire DURING the upstream LLM
|
||||
# stream instead of only at iteration boundaries. Any reasoning route
|
||||
# that streams for several minutes per round (extended thinking on
|
||||
# Anthropic / Moonshot / future providers) would otherwise freeze the
|
||||
# UI for the whole window before flushing the backlog in one burst.
|
||||
loop_result_holder: list[Any] = [None]
|
||||
loop_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _run_tool_call_loop() -> None:
|
||||
# Read/write the current session via ``_session_holder`` so this
|
||||
# closure doesn't need to ``nonlocal session`` — pyright can't narrow
|
||||
# the outer ``session: ChatSession | None`` through a nested scope,
|
||||
# but the holder is typed non-optional after the preflight guard
|
||||
# above.
|
||||
try:
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
loop_result_holder[0] = loop_result
|
||||
# Inject any messages the user queued while the turn was
|
||||
# running. ``tool_call_loop`` mutates ``openai_messages``
|
||||
# in-place, so appending here means the model sees the new
|
||||
# messages on its next LLM call.
|
||||
#
|
||||
# IMPORTANT: skip when the loop has already finished (no
|
||||
# more LLM calls are coming). ``tool_call_loop`` yields
|
||||
# a final ``ToolCallLoopResult`` on both paths:
|
||||
# - natural finish: ``finished_naturally=True``
|
||||
# - hit max_iterations: ``finished_naturally=False``
|
||||
# and ``iterations >= max_iterations``
|
||||
# In either case the loop is about to return on the next
|
||||
# ``async for`` step, so draining here would silently
|
||||
# lose the message (the user sees 202 but the model never
|
||||
# reads the text). Those messages stay in the buffer and
|
||||
# get picked up at the start of the next turn.
|
||||
is_final_yield = (
|
||||
loop_result.finished_naturally
|
||||
or loop_result.iterations >= _MAX_TOOL_ROUNDS
|
||||
)
|
||||
if is_final_yield:
|
||||
continue
|
||||
try:
|
||||
pending = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[Baseline] mid-loop drain_pending_messages failed for "
|
||||
"session %s",
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
pending = []
|
||||
if pending:
|
||||
# Flush any buffered assistant/tool messages from completed
|
||||
# rounds into session.messages BEFORE appending the pending
|
||||
# user message. ``_baseline_conversation_updater`` only
|
||||
# records assistant+tool rounds into ``state.session_messages``
|
||||
# — they are normally batch-flushed in the finally block.
|
||||
# Without this in-order flush, the mid-loop pending user
|
||||
# message lands before the preceding round's assistant/tool
|
||||
# entries, producing chronologically-wrong session.messages
|
||||
# on persist (user interposed between an assistant tool_call
|
||||
# and its tool-result), which breaks OpenAI tool-call ordering
|
||||
# invariants on the next turn's replay.
|
||||
#
|
||||
# Also persist any assistant text from text-only rounds (rounds
|
||||
# with no tool calls, which ``_baseline_conversation_updater``
|
||||
# does NOT record in session_messages). If we only update
|
||||
# ``_flushed_assistant_text_len`` without persisting the text,
|
||||
# that text is silently lost: the finally block only appends
|
||||
# assistant_text[_flushed_assistant_text_len:], so text generated
|
||||
# before this drain never reaches session.messages.
|
||||
recorded_text = "".join(
|
||||
m.content or ""
|
||||
for m in state.session_messages
|
||||
if m.role == "assistant"
|
||||
)
|
||||
unflushed_text = state.assistant_text[
|
||||
state._flushed_assistant_text_len :
|
||||
]
|
||||
text_only_text = (
|
||||
unflushed_text[len(recorded_text) :]
|
||||
if unflushed_text.startswith(recorded_text)
|
||||
else unflushed_text
|
||||
)
|
||||
current_session = _session_holder[0]
|
||||
if text_only_text.strip():
|
||||
current_session.messages.append(
|
||||
ChatMessage(role="assistant", content=text_only_text)
|
||||
)
|
||||
for _buffered in state.session_messages:
|
||||
current_session.messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
# Record how much assistant_text has been covered by the
|
||||
# structured entries just flushed, so the finally block's
|
||||
# final-text dedup doesn't re-append rounds already persisted.
|
||||
state._flushed_assistant_text_len = len(state.assistant_text)
|
||||
|
||||
# Persist the assistant/tool flush BEFORE the pending append
|
||||
# so a later pending-persist failure can roll back the
|
||||
# pending rows without also discarding LLM output.
|
||||
current_session = await persist_session_safe(
|
||||
current_session, "[Baseline]"
|
||||
)
|
||||
# ``upsert_chat_session`` may return a *new* ``ChatSession``
|
||||
# instance (e.g. when a concurrent title update has written a
|
||||
# newer title to Redis, it returns ``session.model_copy``).
|
||||
# Keep ``_session_holder`` in sync so subsequent tool rounds
|
||||
# executed via ``_bound_tool_executor`` see the fresh session
|
||||
# — any tool-side mutations on the stale object would be
|
||||
# discarded when the new one is persisted in the ``finally``.
|
||||
_session_holder[0] = current_session
|
||||
|
||||
# ``format_pending_as_user_message`` embeds file attachments
|
||||
# and context URL/page content into the content string so
|
||||
# the in-session transcript is a faithful copy of what the
|
||||
# model actually saw. We also mirror each push into
|
||||
# ``openai_messages`` so the model's next LLM round sees it.
|
||||
#
|
||||
# Pre-compute the formatted dicts once so both the openai
|
||||
# messages append and the content_of lookup inside the
|
||||
# shared helper use the same string — and so ``on_rollback``
|
||||
# can trim ``openai_messages`` to the recorded anchor.
|
||||
formatted_by_pm = {
|
||||
id(pm): format_pending_as_user_message(pm) for pm in pending
|
||||
}
|
||||
_openai_anchor = len(openai_messages)
|
||||
for pm in pending:
|
||||
openai_messages.append(formatted_by_pm[id(pm)])
|
||||
|
||||
def _trim_openai_on_rollback(_session_anchor: int) -> None:
|
||||
del openai_messages[_openai_anchor:]
|
||||
|
||||
await persist_pending_as_user_rows(
|
||||
current_session,
|
||||
transcript_builder,
|
||||
pending,
|
||||
log_prefix="[Baseline]",
|
||||
content_of=lambda pm: formatted_by_pm[id(pm)]["content"],
|
||||
on_rollback=_trim_openai_on_rollback,
|
||||
)
|
||||
finally:
|
||||
# Always post the sentinel so the outer consumer exits — even if
|
||||
# ``tool_call_loop`` raised. ``_baseline_llm_caller``'s own
|
||||
# finally block has already pushed ``StreamReasoningEnd`` /
|
||||
# ``StreamTextEnd`` / ``StreamFinishStep`` at this point, so the
|
||||
# sentinel only terminates the consumer; it does not suppress
|
||||
# any still-unflushed events.
|
||||
state.pending_events.put_nowait(None)
|
||||
|
||||
loop_task = asyncio.create_task(_run_tool_call_loop())
|
||||
try:
|
||||
while True:
|
||||
evt = await state.pending_events.get()
|
||||
if evt is None:
|
||||
break
|
||||
yield evt
|
||||
# Sentinel received — surface any exception the inner task hit.
|
||||
await loop_task
|
||||
loop_result = loop_result_holder[0]
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
|
||||
# Inject any messages the user queued while the turn was
|
||||
# running. ``tool_call_loop`` mutates ``openai_messages``
|
||||
# in-place, so appending here means the model sees the new
|
||||
# messages on its next LLM call.
|
||||
#
|
||||
# IMPORTANT: skip when the loop has already finished (no
|
||||
# more LLM calls are coming). ``tool_call_loop`` yields
|
||||
# a final ``ToolCallLoopResult`` on both paths:
|
||||
# - natural finish: ``finished_naturally=True``
|
||||
# - hit max_iterations: ``finished_naturally=False``
|
||||
# and ``iterations >= max_iterations``
|
||||
# In either case the loop is about to return on the next
|
||||
# ``async for`` step, so draining here would silently
|
||||
# lose the message (the user sees 202 but the model never
|
||||
# reads the text). Those messages stay in the buffer and
|
||||
# get picked up at the start of the next turn.
|
||||
is_final_yield = (
|
||||
loop_result.finished_naturally
|
||||
or loop_result.iterations >= _MAX_TOOL_ROUNDS
|
||||
)
|
||||
if is_final_yield:
|
||||
continue
|
||||
try:
|
||||
pending = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[Baseline] mid-loop drain_pending_messages failed for session %s",
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
pending = []
|
||||
if pending:
|
||||
# Flush any buffered assistant/tool messages from completed
|
||||
# rounds into session.messages BEFORE appending the pending
|
||||
# user message. ``_baseline_conversation_updater`` only
|
||||
# records assistant+tool rounds into ``state.session_messages``
|
||||
# — they are normally batch-flushed in the finally block.
|
||||
# Without this in-order flush, the mid-loop pending user
|
||||
# message lands before the preceding round's assistant/tool
|
||||
# entries, producing chronologically-wrong session.messages
|
||||
# on persist (user interposed between an assistant tool_call
|
||||
# and its tool-result), which breaks OpenAI tool-call ordering
|
||||
# invariants on the next turn's replay.
|
||||
#
|
||||
# Also persist any assistant text from text-only rounds (rounds
|
||||
# with no tool calls, which ``_baseline_conversation_updater``
|
||||
# does NOT record in session_messages). If we only update
|
||||
# ``_flushed_assistant_text_len`` without persisting the text,
|
||||
# that text is silently lost: the finally block only appends
|
||||
# assistant_text[_flushed_assistant_text_len:], so text generated
|
||||
# before this drain never reaches session.messages.
|
||||
recorded_text = "".join(
|
||||
m.content or ""
|
||||
for m in state.session_messages
|
||||
if m.role == "assistant"
|
||||
)
|
||||
unflushed_text = state.assistant_text[
|
||||
state._flushed_assistant_text_len :
|
||||
]
|
||||
text_only_text = (
|
||||
unflushed_text[len(recorded_text) :]
|
||||
if unflushed_text.startswith(recorded_text)
|
||||
else unflushed_text
|
||||
)
|
||||
if text_only_text.strip():
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=text_only_text)
|
||||
)
|
||||
for _buffered in state.session_messages:
|
||||
session.messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
# Record how much assistant_text has been covered by the
|
||||
# structured entries just flushed, so the finally block's
|
||||
# final-text dedup doesn't re-append rounds already persisted.
|
||||
state._flushed_assistant_text_len = len(state.assistant_text)
|
||||
|
||||
# Persist the assistant/tool flush BEFORE the pending append
|
||||
# so a later pending-persist failure can roll back the
|
||||
# pending rows without also discarding LLM output.
|
||||
session = await persist_session_safe(session, "[Baseline]")
|
||||
# ``upsert_chat_session`` may return a *new* ``ChatSession``
|
||||
# instance (e.g. when a concurrent title update has written a
|
||||
# newer title to Redis, it returns ``session.model_copy``).
|
||||
# Keep ``_session_holder`` in sync so subsequent tool rounds
|
||||
# executed via ``_bound_tool_executor`` see the fresh session
|
||||
# — any tool-side mutations on the stale object would be
|
||||
# discarded when the new one is persisted in the ``finally``.
|
||||
_session_holder[0] = session
|
||||
|
||||
# ``format_pending_as_user_message`` embeds file attachments
|
||||
# and context URL/page content into the content string so
|
||||
# the in-session transcript is a faithful copy of what the
|
||||
# model actually saw. We also mirror each push into
|
||||
# ``openai_messages`` so the model's next LLM round sees it.
|
||||
#
|
||||
# Pre-compute the formatted dicts once so both the openai
|
||||
# messages append and the content_of lookup inside the
|
||||
# shared helper use the same string — and so ``on_rollback``
|
||||
# can trim ``openai_messages`` to the recorded anchor.
|
||||
formatted_by_pm = {
|
||||
id(pm): format_pending_as_user_message(pm) for pm in pending
|
||||
}
|
||||
_openai_anchor = len(openai_messages)
|
||||
for pm in pending:
|
||||
openai_messages.append(formatted_by_pm[id(pm)])
|
||||
|
||||
def _trim_openai_on_rollback(_session_anchor: int) -> None:
|
||||
del openai_messages[_openai_anchor:]
|
||||
|
||||
await persist_pending_as_user_rows(
|
||||
session,
|
||||
transcript_builder,
|
||||
pending,
|
||||
log_prefix="[Baseline]",
|
||||
content_of=lambda pm: formatted_by_pm[id(pm)]["content"],
|
||||
on_rollback=_trim_openai_on_rollback,
|
||||
)
|
||||
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
@@ -2153,34 +1809,25 @@ async def stream_chat_completion_baseline(
|
||||
errorText=limit_msg,
|
||||
code="baseline_tool_round_limit",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# Drain any queued tail events (reasoning/text close + finish step)
|
||||
# that ``_baseline_llm_caller``'s finally block pushed before the
|
||||
# sentinel arrived — without this the frontend would be missing the
|
||||
# matching end / finish parts for the partial round.
|
||||
while not state.pending_events.empty():
|
||||
evt = state.pending_events.get_nowait()
|
||||
if evt is not None:
|
||||
yield evt
|
||||
# ``_baseline_llm_caller``'s finally block closes any open
|
||||
# reasoning / text blocks and appends ``StreamFinishStep`` on
|
||||
# both normal and exception paths, so pending_events already has
|
||||
# the correct protocol ordering:
|
||||
# StreamStartStep -> StreamReasoningStart -> ...deltas... ->
|
||||
# StreamReasoningEnd -> StreamTextStart -> ...deltas... ->
|
||||
# StreamTextEnd -> StreamFinishStep
|
||||
# Just drain what's buffered, then yield the error.
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Cancel the inner task if we're unwinding early (client disconnect,
|
||||
# unexpected error in the consumer) so it doesn't keep streaming
|
||||
# tokens into a dead queue.
|
||||
if loop_task is not None and not loop_task.done():
|
||||
loop_task.cancel()
|
||||
try:
|
||||
await loop_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Re-sync the outer ``session`` binding in case the inner task
|
||||
# reassigned it via a mid-loop ``persist_session_safe`` call.
|
||||
session = _session_holder[0]
|
||||
|
||||
# In-flight tool-call announcements are only meaningful for the
|
||||
# current turn; clear at the top of the outer finally so the next
|
||||
# turn starts with a clean scratch buffer even if one of the
|
||||
|
||||
@@ -10,8 +10,6 @@ import pytest
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_MAX_TASK_DEPTH,
|
||||
_TASK_DEPTH_VAR,
|
||||
_baseline_conversation_updater,
|
||||
_baseline_llm_caller,
|
||||
_BaselineStreamState,
|
||||
@@ -23,7 +21,6 @@ from backend.copilot.baseline.service import (
|
||||
_is_anthropic_model,
|
||||
_mark_system_message_with_cache_control,
|
||||
_mark_tools_with_cache_control,
|
||||
_run_task_subagent,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.response_model import (
|
||||
@@ -42,10 +39,7 @@ from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallRe
|
||||
class TestBaselineStreamState:
|
||||
def test_defaults(self):
|
||||
state = _BaselineStreamState()
|
||||
# ``pending_events`` is an asyncio.Queue now (live SSE channel).
|
||||
# The durable inspection view is ``emitted_events``.
|
||||
assert state.pending_events.empty()
|
||||
assert state.emitted_events == []
|
||||
assert state.pending_events == []
|
||||
assert state.assistant_text == ""
|
||||
assert state.text_started is False
|
||||
assert state.turn_prompt_tokens == 0
|
||||
@@ -1693,7 +1687,7 @@ class TestBaselineReasoningStreaming:
|
||||
state=state,
|
||||
)
|
||||
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningDelta" in types
|
||||
assert "StreamReasoningEnd" in types
|
||||
@@ -1708,14 +1702,14 @@ class TestBaselineReasoningStreaming:
|
||||
# a fresh id after the reasoning-end rotation.
|
||||
reasoning_ids = {
|
||||
e.id
|
||||
for e in state.emitted_events
|
||||
for e in state.pending_events
|
||||
if isinstance(
|
||||
e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd)
|
||||
)
|
||||
}
|
||||
text_ids = {
|
||||
e.id
|
||||
for e in state.emitted_events
|
||||
for e in state.pending_events
|
||||
if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd))
|
||||
}
|
||||
assert len(reasoning_ids) == 1
|
||||
@@ -1723,7 +1717,7 @@ class TestBaselineReasoningStreaming:
|
||||
assert reasoning_ids.isdisjoint(text_ids)
|
||||
|
||||
combined = "".join(
|
||||
e.delta for e in state.emitted_events if isinstance(e, StreamReasoningDelta)
|
||||
e.delta for e in state.pending_events if isinstance(e, StreamReasoningDelta)
|
||||
)
|
||||
assert combined == "thinking... more"
|
||||
|
||||
@@ -1765,7 +1759,7 @@ class TestBaselineReasoningStreaming:
|
||||
|
||||
# A reasoning-end must have been emitted — this is the tool_calls
|
||||
# branch's responsibility, not the stream-end cleanup.
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningEnd" in types
|
||||
|
||||
@@ -1808,7 +1802,7 @@ class TestBaselineReasoningStreaming:
|
||||
state=state,
|
||||
)
|
||||
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
# The reasoning block was opened, the exception fired, and the
|
||||
# finally block must have closed it before emitting the finish
|
||||
# step.
|
||||
@@ -1941,7 +1935,7 @@ class TestBaselineReasoningStreaming:
|
||||
state=state,
|
||||
)
|
||||
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningEnd" in types
|
||||
# No text was produced — no text events should be emitted.
|
||||
@@ -2012,415 +2006,3 @@ class TestBaselineReasoningStreaming:
|
||||
reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"]
|
||||
assert len(reasoning_rows) == 1
|
||||
assert reasoning_rows[0].content == "first thought"
|
||||
|
||||
|
||||
# ── In-process Task sub-agent ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _task_session():
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
return ChatSession.new(user_id="task-user", dry_run=False)
|
||||
|
||||
|
||||
def _final_text_chunk(text: str):
|
||||
"""Build a single streaming chunk that finishes a turn with plain text."""
|
||||
return _make_stream_mock(_make_delta_chunk(content=text))
|
||||
|
||||
|
||||
class TestBaselineTaskSubagent:
|
||||
"""Tests for ``_run_task_subagent`` — the in-process sub-agent runner
|
||||
that powers the baseline ``Task`` tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_returns_completed_summary(self):
|
||||
"""Sub-agent finishes naturally with text and no further tool calls."""
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=_final_text_chunk("All done — result is 42.")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
result = await _run_task_subagent(
|
||||
tool_call_id="tc-1",
|
||||
tool_args={"description": "Compute answer", "prompt": "Compute 6*7"},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert result.toolName == "Task"
|
||||
assert result.success is True
|
||||
import json
|
||||
|
||||
payload = json.loads(result.output)
|
||||
assert payload["status"] == "completed"
|
||||
assert payload["description"] == "Compute answer"
|
||||
assert payload["response"] == "All done — result is 42."
|
||||
assert payload["iterations"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt_returns_error(self):
|
||||
parent_state = _BaselineStreamState()
|
||||
session = _task_session()
|
||||
|
||||
result = await _run_task_subagent(
|
||||
tool_call_id="tc-err",
|
||||
tool_args={"description": "no work", "prompt": " "},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
payload = json.loads(result.output)
|
||||
assert result.success is False
|
||||
assert payload["status"] == "error"
|
||||
assert "prompt" in payload["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_depth_cap_refuses_deeper_spawn(self):
|
||||
"""At the depth cap no LLM call should be issued."""
|
||||
parent_state = _BaselineStreamState()
|
||||
session = _task_session()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=_final_text_chunk("should not be called")
|
||||
)
|
||||
|
||||
token = _TASK_DEPTH_VAR.set(_MAX_TASK_DEPTH)
|
||||
try:
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
result = await _run_task_subagent(
|
||||
tool_call_id="tc-deep",
|
||||
tool_args={"description": "nested", "prompt": "run"},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
finally:
|
||||
_TASK_DEPTH_VAR.reset(token)
|
||||
|
||||
import json
|
||||
|
||||
payload = json.loads(result.output)
|
||||
assert result.success is False
|
||||
assert payload["status"] == "error"
|
||||
assert "depth" in payload["error"].lower()
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inner_stream_events_are_suppressed(self):
|
||||
"""Only the parent sees the Task tool envelope — inner text deltas,
|
||||
StreamStartStep/StreamFinishStep envelopes, and nested tool_input
|
||||
events must stay off the parent's pending_events."""
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=_final_text_chunk("wrap")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _run_task_subagent(
|
||||
tool_call_id="tc-stream",
|
||||
tool_args={"description": "short", "prompt": "do it"},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
# ``_run_task_subagent`` returns a ``StreamToolOutputAvailable`` but
|
||||
# does NOT emit to ``parent_state`` itself — the outer
|
||||
# ``_baseline_tool_executor`` handles that. What we care about here
|
||||
# is the negative invariant: the inner loop's deltas, step
|
||||
# envelopes, and sub-tool events never leaked in. Both the live
|
||||
# queue AND the test-mirror list must be untouched.
|
||||
assert parent_state.pending_events.empty()
|
||||
assert parent_state.emitted_events == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_rolls_up_to_parent(self):
|
||||
"""Inner sub-agent LLM usage counts toward the parent session so
|
||||
the user is still billed for the delegated work."""
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
|
||||
usage_chunk = _make_usage_chunk(
|
||||
prompt_tokens=42, completion_tokens=7, cost=0.012
|
||||
)
|
||||
text_chunk = _make_delta_chunk(content="Did the thing.")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_stream_mock(text_chunk, usage_chunk)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _run_task_subagent(
|
||||
tool_call_id="tc-usage",
|
||||
tool_args={"description": "short", "prompt": "do it"},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert parent_state.turn_prompt_tokens == 42
|
||||
assert parent_state.turn_completion_tokens == 7
|
||||
assert parent_state.cost_usd == pytest.approx(0.012)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_depth_var_reset_on_exception(self):
|
||||
"""ContextVar must be reset via ``finally`` even when the inner
|
||||
loop raises. A leaked depth would either refuse a later sibling
|
||||
Task at the cap or run one level shallower than intended."""
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
depth_before = _TASK_DEPTH_VAR.get()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=RuntimeError("boom")
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
result = await _run_task_subagent(
|
||||
tool_call_id="tc-exc",
|
||||
tool_args={"description": "kaboom", "prompt": "run"},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert _TASK_DEPTH_VAR.get() == depth_before
|
||||
import json
|
||||
|
||||
payload = json.loads(result.output)
|
||||
assert payload["status"] == "error"
|
||||
assert "boom" in payload["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_depth_var_reset_on_cancellation(self):
|
||||
"""CancelledError propagates out of ``_run_task_subagent`` but the
|
||||
depth counter must be restored first so the cancelled asyncio task
|
||||
doesn't poison the next Task call on the same context."""
|
||||
import asyncio
|
||||
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
depth_before = _TASK_DEPTH_VAR.get()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
side_effect=asyncio.CancelledError()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await _run_task_subagent(
|
||||
tool_call_id="tc-cancel",
|
||||
tool_args={"description": "cancel", "prompt": "run"},
|
||||
tools=[],
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert _TASK_DEPTH_VAR.get() == depth_before
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inner_tools_cache_marked_on_anthropic(self):
|
||||
"""After stripping ``Task`` the inner tool list still needs a
|
||||
``cache_control`` marker on its final entry so long sub-agent
|
||||
loops don't re-send the tool schema uncached on Anthropic
|
||||
routes (~8 KB × iterations of wasted tokens otherwise)."""
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
|
||||
captured_tools: list = []
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_tools.append(kwargs.get("tools"))
|
||||
return _final_text_chunk("ok")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(side_effect=fake_create)
|
||||
|
||||
sample_tools: list[ChatCompletionToolParam] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Task",
|
||||
"description": "recurse",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_block",
|
||||
"description": "search",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_block",
|
||||
"description": "execute",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _run_task_subagent(
|
||||
tool_call_id="tc-cache",
|
||||
tool_args={"description": "short", "prompt": "do it"},
|
||||
tools=sample_tools,
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert captured_tools
|
||||
inner = captured_tools[0] or []
|
||||
assert inner, "inner loop must receive at least one tool"
|
||||
assert "cache_control" in inner[-1]
|
||||
assert inner[-1]["cache_control"]["type"] == "ephemeral"
|
||||
names = [(t.get("function") or {}).get("name") for t in inner]
|
||||
assert "Task" not in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inner_tools_unmarked_on_non_anthropic(self):
|
||||
"""Non-Anthropic providers reject ``cache_control``; the marker
|
||||
must NOT be applied when the model isn't Anthropic."""
|
||||
parent_state = _BaselineStreamState(model="openai/gpt-4o")
|
||||
session = _task_session()
|
||||
|
||||
captured_tools: list = []
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_tools.append(kwargs.get("tools"))
|
||||
return _final_text_chunk("ok")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(side_effect=fake_create)
|
||||
|
||||
sample_tools: list[ChatCompletionToolParam] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_block",
|
||||
"description": "search",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _run_task_subagent(
|
||||
tool_call_id="tc-openai",
|
||||
tool_args={"description": "short", "prompt": "do it"},
|
||||
tools=sample_tools,
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert captured_tools
|
||||
inner = captured_tools[0] or []
|
||||
assert inner
|
||||
assert "cache_control" not in inner[-1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_not_in_inner_tools(self):
|
||||
"""The inner tool list must strip ``Task`` so a sub-agent can't
|
||||
re-enter this path. Depth cap is belt-and-braces; the strip is the
|
||||
primary guard."""
|
||||
parent_state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6")
|
||||
session = _task_session()
|
||||
|
||||
captured_tools: list = []
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_tools.append(kwargs.get("tools"))
|
||||
return _final_text_chunk("ok")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(side_effect=fake_create)
|
||||
|
||||
sample_tools: list[ChatCompletionToolParam] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Task",
|
||||
"description": "recurse",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_block",
|
||||
"description": "search",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _run_task_subagent(
|
||||
tool_call_id="tc-filter",
|
||||
tool_args={"description": "short", "prompt": "do it"},
|
||||
tools=sample_tools,
|
||||
parent_state=parent_state,
|
||||
user_id="u",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert captured_tools, "inner loop must make at least one LLM call"
|
||||
inner_names = [
|
||||
(t.get("function") or {}).get("name") for t in (captured_tools[0] or [])
|
||||
]
|
||||
assert "Task" not in inner_names
|
||||
assert "find_block" in inner_names
|
||||
|
||||
@@ -92,10 +92,8 @@ class ChatConfig(BaseSettings):
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash-lite",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output). "
|
||||
"Gemini 2.5 Flash-Lite is ~3x cheaper than Flash ($0.10/$0.40 vs $0.30/$1.20 per MTok) "
|
||||
"with JSON-mode reliability adequate for shape-matching block outputs.",
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
@@ -196,7 +194,7 @@ class ChatConfig(BaseSettings):
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
description="Model for the Claude Agent SDK path. If None, derives from "
|
||||
"`thinking_standard_model` by stripping the OpenRouter provider prefix.",
|
||||
"the `model` field by stripping the OpenRouter provider prefix.",
|
||||
)
|
||||
claude_agent_max_buffer_size: int = Field(
|
||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
||||
@@ -253,11 +251,9 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
render_reasoning_in_ui: bool = Field(
|
||||
default=True,
|
||||
description="Render reasoning as live UI parts "
|
||||
"(``StreamReasoning*`` wire events). False suppresses the live "
|
||||
"wire events only; ``role='reasoning'`` rows are always persisted "
|
||||
"so the reasoning bubble hydrates on reload. Tokens are billed "
|
||||
"upstream regardless.",
|
||||
description="Render reasoning as live UI parts + persist "
|
||||
"``role='reasoning'`` rows. False suppresses both; tokens are still "
|
||||
"billed upstream.",
|
||||
)
|
||||
stream_replay_count: int = Field(
|
||||
default=200,
|
||||
|
||||
@@ -124,14 +124,9 @@ ToolName = Literal[
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
# SDK built-in tool names — tools provided by the Claude Code CLI that our
|
||||
# code does not implement directly. ``Task`` and ``TodoWrite`` are
|
||||
# DELIBERATELY excluded: baseline mode ships MCP-wrapped platform versions
|
||||
# of both (see ``tools/task.py`` / ``tools/todo_write.py``), while SDK mode
|
||||
# still uses the CLI-native originals via ``_SDK_BUILTIN_ALWAYS`` in
|
||||
# ``sdk/tool_adapter.py`` — the MCP copies are filtered out there.
|
||||
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
|
||||
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{"Agent", "Edit", "Glob", "Grep", "Read", "WebSearch", "Write"}
|
||||
n for n in ALL_TOOL_NAMES if n[0].isupper()
|
||||
)
|
||||
|
||||
# Platform tool names — everything that isn't an SDK built-in.
|
||||
@@ -384,7 +379,6 @@ def apply_tool_permissions(
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_READ_TOOL_NAME,
|
||||
BASELINE_ONLY_MCP_TOOLS,
|
||||
MCP_TOOL_PREFIX,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
@@ -425,14 +419,7 @@ def apply_tool_permissions(
|
||||
# keeping only those present in the original base_allowed list.
|
||||
def to_sdk_names(short: str) -> list[str]:
|
||||
names: list[str] = []
|
||||
if short in BASELINE_ONLY_MCP_TOOLS:
|
||||
# Baseline ships MCP versions of these (Task/TodoWrite) for
|
||||
# model-flexibility parity, but SDK mode uses the CLI-native
|
||||
# originals. Permissions target the CLI built-in here so
|
||||
# ``base_allowed`` (which excludes the MCP wrappers) still
|
||||
# matches.
|
||||
names.append(short)
|
||||
elif short in TOOL_REGISTRY:
|
||||
if short in TOOL_REGISTRY:
|
||||
names.append(f"{MCP_TOOL_PREFIX}{short}")
|
||||
elif short in _SDK_TO_MCP:
|
||||
# Map SDK built-in file tool to its MCP equivalent.
|
||||
|
||||
@@ -582,10 +582,6 @@ class TestApplyToolPermissions:
|
||||
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
# Task and TodoWrite are NOT in SDK_BUILTIN_TOOL_NAMES: baseline ships
|
||||
# MCP-wrapped platform versions for model-flexibility parity, and SDK
|
||||
# mode sources them from the CLI-native originals outside the
|
||||
# PLATFORM vs SDK_BUILTIN classification used by permissions.
|
||||
expected = {
|
||||
"Agent",
|
||||
"Read",
|
||||
@@ -593,11 +589,11 @@ class TestSdkBuiltinToolNames:
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
}
|
||||
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
|
||||
assert "Task" not in SDK_BUILTIN_TOOL_NAMES
|
||||
assert "TodoWrite" not in SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
def test_platform_names_match_tool_registry(self):
|
||||
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""
|
||||
|
||||
@@ -145,38 +145,9 @@ When the user asks to interact with a service or API, follow this order:
|
||||
|
||||
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
|
||||
|
||||
### Planning checklist — `TodoWrite`
|
||||
Use the `TodoWrite` tool to surface a step-by-step plan whenever the work
|
||||
needs 3+ distinct actions, or when the user explicitly asks to track
|
||||
progress. Guidelines:
|
||||
|
||||
- Send the **full** updated list every call (not a delta) so the rendered
|
||||
checklist stays in sync.
|
||||
- Each item needs both `content` (imperative, e.g. "Run the test suite")
|
||||
and `activeForm` (present-continuous, e.g. "Running the test suite").
|
||||
- Keep exactly one item `in_progress` at a time; mark it `completed`
|
||||
before flipping the next one to `in_progress`.
|
||||
- Skip this tool for trivial / single-step requests — it's only useful
|
||||
when a checklist makes progress easier to follow.
|
||||
|
||||
### Sub-agents — `Task`
|
||||
The `Task` tool runs an **in-process sub-agent** inside the current
|
||||
conversation. The sub-agent inherits the parent's tool set **except
|
||||
`Task` itself** (nested delegation is refused — plan at most one level
|
||||
deep), and it gets its own message history so its intermediate tool
|
||||
calls stay out of the parent context — you only see the sub-agent's
|
||||
final summary as the tool result. Use it for self-contained work that
|
||||
would otherwise generate a lot of intermediate chatter (focused
|
||||
research, bounded refactors, multi-step exploration where only the
|
||||
conclusion matters).
|
||||
|
||||
- Provide a short `description` (3-5 words, shown in the accordion) and a
|
||||
complete `prompt` — the sub-agent does NOT inherit the parent
|
||||
conversation, so include every bit of context it needs.
|
||||
- NEVER set `run_in_background` — the SDK honours this flag only on the
|
||||
CLI-native Task, and baseline doesn't support it; leave it off.
|
||||
- For long-running work that must survive tab-close or worker restarts,
|
||||
use `run_sub_session` instead (queue-backed durable Sub-AutoPilot).
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **`run_sub_session`** tool to delegate a task to a fresh
|
||||
|
||||
@@ -9,32 +9,8 @@ when Redis is unavailable to avoid blocking users.
|
||||
Storing microdollars rather than tokens means the counter already reflects
|
||||
real model pricing (including cache discounts and provider surcharges), so
|
||||
this module carries no pricing table — the cost comes from OpenRouter's
|
||||
``usage.cost`` field (baseline), the Claude Agent SDK's reported total
|
||||
cost (SDK path), web_search tool calls, and the prompt-simulation harness.
|
||||
|
||||
Boundary with the credit wallet
|
||||
===============================
|
||||
|
||||
Microdollars (this module) and credits (``backend.data.block_cost_config``)
|
||||
are intentionally separate budgets:
|
||||
|
||||
* **Credits** are the user-facing prepaid wallet. Every block invocation
|
||||
that has a ``BlockCost`` entry decrements credits — this is what the
|
||||
user buys, tops up, and sees on the billing page. Marketplace blocks
|
||||
may also charge credits to block creators. The credit charge is a flat
|
||||
per-run amount sourced from ``BLOCK_COSTS``. Copilot ``run_block``
|
||||
calls go through this path too: block execution bills the user's
|
||||
credit wallet, not this counter.
|
||||
* **Microdollars** meter AutoGPT's **operator-side infrastructure cost**
|
||||
for the copilot **LLM turn itself** — the real USD we spend on the
|
||||
baseline model, Claude Agent SDK runs, the web_search tool, and the
|
||||
prompt simulator. They gate the chat loop so a single user can't burn
|
||||
the daily / weekly infra budget driving the chat regardless of their
|
||||
credit balance. BYOK runs (user supplied their own API key) do **not**
|
||||
decrement this counter — the user is paying the provider, not us.
|
||||
|
||||
A future option is to unify these into one wallet; until then the
|
||||
boundary above is the contract.
|
||||
``usage.cost`` field (baseline) or the Claude Agent SDK's reported total
|
||||
cost (SDK path).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -9,8 +9,8 @@ persistence, and the ``CompactionTracker`` state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -25,6 +25,8 @@ from ..response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompactionResult:
|
||||
@@ -71,14 +73,6 @@ def _new_tool_call_id() -> str:
|
||||
return f"compaction-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def _summarize_sources(sources: list[str]) -> str:
|
||||
counts = Counter(sources)
|
||||
parts: list[str] = []
|
||||
for source, count in counts.items():
|
||||
parts.append(f"{source}:{count}" if count > 1 else source)
|
||||
return ",".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public event builder
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -191,54 +185,26 @@ class CompactionTracker:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._compact_start = asyncio.Event()
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
self._active_transcript_path: str = ""
|
||||
self._pending_transcript_paths: deque[str] = deque()
|
||||
self._attempted_sources: list[str] = []
|
||||
self._completed_sources: list[str] = []
|
||||
|
||||
@property
|
||||
def attempt_count(self) -> int:
|
||||
return len(self._attempted_sources)
|
||||
|
||||
@property
|
||||
def attempt_sources(self) -> tuple[str, ...]:
|
||||
return tuple(self._attempted_sources)
|
||||
|
||||
@property
|
||||
def completed_count(self) -> int:
|
||||
return len(self._completed_sources)
|
||||
|
||||
@property
|
||||
def completed_sources(self) -> tuple[str, ...]:
|
||||
return tuple(self._completed_sources)
|
||||
|
||||
def get_observability_metadata(self) -> dict[str, Any]:
|
||||
if not self._attempted_sources and not self._completed_sources:
|
||||
return {}
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"compaction_attempt_count": self.attempt_count,
|
||||
"compaction_attempt_sources": _summarize_sources(self._attempted_sources),
|
||||
}
|
||||
if self._completed_sources:
|
||||
metadata["compaction_count"] = self.completed_count
|
||||
metadata["compaction_sources"] = _summarize_sources(self._completed_sources)
|
||||
return metadata
|
||||
|
||||
def get_log_summary(self) -> dict[str, Any]:
|
||||
return {
|
||||
"attempt_count": self.attempt_count,
|
||||
"attempt_sources": _summarize_sources(self._attempted_sources),
|
||||
"completed_count": self.completed_count,
|
||||
"completed_sources": _summarize_sources(self._completed_sources),
|
||||
}
|
||||
self._transcript_path: str = ""
|
||||
|
||||
def on_compact(self, transcript_path: str = "") -> None:
|
||||
"""Callback for the PreCompact hook. Queues an SDK compaction attempt."""
|
||||
self._attempted_sources.append("sdk_internal")
|
||||
self._pending_transcript_paths.append(transcript_path)
|
||||
"""Callback for the PreCompact hook. Stores transcript_path."""
|
||||
if (
|
||||
self._transcript_path
|
||||
and transcript_path
|
||||
and self._transcript_path != transcript_path
|
||||
):
|
||||
logger.warning(
|
||||
"[Compaction] Overwriting transcript_path %s -> %s",
|
||||
self._transcript_path,
|
||||
transcript_path,
|
||||
)
|
||||
self._transcript_path = transcript_path
|
||||
self._compact_start.set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
@@ -246,8 +212,7 @@ class CompactionTracker:
|
||||
|
||||
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Emit + persist a self-contained compaction tool call."""
|
||||
self._attempted_sources.append("pre_query")
|
||||
self._completed_sources.append("pre_query")
|
||||
self._done = True
|
||||
return emit_compaction(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -256,17 +221,18 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
self._active_transcript_path = ""
|
||||
self._pending_transcript_paths.clear()
|
||||
self._transcript_path = ""
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
if self._pending_transcript_paths and not self._start_emitted:
|
||||
if self._compact_start.is_set() and not self._start_emitted and not self._done:
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = True
|
||||
self._tool_call_id = _new_tool_call_id()
|
||||
self._active_transcript_path = self._pending_transcript_paths.popleft()
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
@@ -280,30 +246,27 @@ class CompactionTracker:
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if not self._start_emitted and not self._pending_transcript_paths:
|
||||
if self._done:
|
||||
return CompactionResult()
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
return CompactionResult()
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
|
||||
persist_id = self._tool_call_id
|
||||
transcript_path = self._active_transcript_path
|
||||
else:
|
||||
# PreCompact fired but start never emitted — self-contained
|
||||
persist_id = _new_tool_call_id()
|
||||
done_events = compaction_events(
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
transcript_path = (
|
||||
self._pending_transcript_paths.popleft()
|
||||
if self._pending_transcript_paths
|
||||
else ""
|
||||
)
|
||||
|
||||
transcript_path = self._transcript_path
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
self._active_transcript_path = ""
|
||||
self._completed_sources.append("sdk_internal")
|
||||
self._done = True
|
||||
self._transcript_path = ""
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return CompactionResult(
|
||||
events=done_events, just_ended=True, transcript_path=transcript_path
|
||||
|
||||
@@ -162,11 +162,10 @@ class TestFilterCompactionMessages:
|
||||
|
||||
|
||||
class TestCompactionTracker:
|
||||
def test_on_compact_registers_pending_attempt(self):
|
||||
def test_on_compact_sets_event(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
assert tracker.attempt_count == 1
|
||||
assert list(tracker._pending_transcript_paths) == [""]
|
||||
assert tracker._compact_start.is_set()
|
||||
|
||||
def test_emit_start_if_ready_no_event(self):
|
||||
tracker = CompactionTracker()
|
||||
@@ -245,39 +244,36 @@ class TestCompactionTracker:
|
||||
evts = tracker.emit_pre_query(session)
|
||||
assert len(evts) == 5
|
||||
assert len(session.messages) == 2
|
||||
assert tracker.attempt_count == 1
|
||||
assert tracker.completed_count == 1
|
||||
assert tracker.get_observability_metadata() == {
|
||||
"compaction_attempt_count": 1,
|
||||
"compaction_attempt_sources": "pre_query",
|
||||
"compaction_count": 1,
|
||||
"compaction_sources": "pre_query",
|
||||
}
|
||||
assert tracker._done is True
|
||||
|
||||
def test_reset_for_query(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact("/some/path")
|
||||
tracker._done = True
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker._active_transcript_path = "/active/path"
|
||||
tracker._transcript_path = "/some/path"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
assert tracker._active_transcript_path == ""
|
||||
assert list(tracker._pending_transcript_paths) == []
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_does_not_block_sdk_compaction_within_query(self):
|
||||
"""SDK auto-compaction can still fire after a pre-query compaction."""
|
||||
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
|
||||
"""After pre-query compaction, SDK compaction is blocked until
|
||||
reset_for_query is called."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
# _done is True so emit_start_if_ready is blocked
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == []
|
||||
# Reset clears _done, allowing subsequent compaction
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert tracker.completed_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
@@ -322,18 +318,43 @@ class TestCompactionTracker:
|
||||
assert len(result1.events) == 2
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Second compaction cycle in the same query
|
||||
# Second compaction cycle (should NOT be blocked — _done resets
|
||||
# because emit_end_if_ready sets it True, but the next on_compact
|
||||
# + emit_start_if_ready checks !_done which IS True now.
|
||||
# So we need reset_for_query between queries, but within a single
|
||||
# query multiple compactions work because _done blocks emit_start
|
||||
# until the next message arrives, at which point emit_end detects it)
|
||||
#
|
||||
# Actually: _done=True blocks emit_start_if_ready, so we need
|
||||
# the stream loop to reset. In practice service.py doesn't call
|
||||
# reset between compactions within the same query — let's verify
|
||||
# the actual behavior.
|
||||
tracker.on_compact("/path/2")
|
||||
# _done is True from first compaction, so start is blocked
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert len(start_evts) == 3
|
||||
assert start_evts == []
|
||||
# But emit_end returns no-op because _done is True
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
assert tracker.completed_count == 2
|
||||
assert result2.just_ended is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_with_intervening_message(self):
|
||||
"""Multiple compactions remain supported across query boundaries."""
|
||||
"""Multiple compactions work when the stream loop processes messages between them.
|
||||
|
||||
In the real service.py flow:
|
||||
1. PreCompact fires → on_compact()
|
||||
2. emit_start shows spinner
|
||||
3. Next message arrives → emit_end completes compaction (_done=True)
|
||||
4. Stream continues processing messages...
|
||||
5. If a second PreCompact fires, _done=True blocks emit_start
|
||||
6. But the next message triggers emit_end, which sees _done=True → no-op
|
||||
7. The stream loop needs to detect this and handle accordingly
|
||||
|
||||
The actual flow for multiple compactions within a query requires
|
||||
_done to be cleared between them. The service.py code uses
|
||||
CompactionResult.just_ended to trigger replace_entries, and _done
|
||||
stays True until reset_for_query.
|
||||
"""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
@@ -355,10 +376,10 @@ class TestCompactionTracker:
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
def test_on_compact_queues_transcript_path(self):
|
||||
def test_on_compact_stores_transcript_path(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact("/some/path.jsonl")
|
||||
assert list(tracker._pending_transcript_paths) == ["/some/path.jsonl"]
|
||||
assert tracker._transcript_path == "/some/path.jsonl"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_returns_transcript_path(self):
|
||||
@@ -370,71 +391,17 @@ class TestCompactionTracker:
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == "/my/session.jsonl"
|
||||
assert tracker._active_transcript_path == ""
|
||||
# transcript_path is cleared after emit_end
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_clears_active_transcript_path(self):
|
||||
"""After emit_end, the active transcript path is reset."""
|
||||
async def test_emit_end_clears_transcript_path(self):
|
||||
"""After emit_end, _transcript_path is reset so it doesn't leak to
|
||||
subsequent non-compaction emit_end calls."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/first/path.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
assert tracker._active_transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_pending_hooks_are_counted_even_before_completion(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
tracker.on_compact("/path/2")
|
||||
tracker.on_compact("/path/3")
|
||||
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert result1.transcript_path == "/path/1"
|
||||
assert tracker.attempt_count == 3
|
||||
assert tracker.completed_count == 1
|
||||
|
||||
tracker.emit_start_if_ready()
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
tracker.emit_start_if_ready()
|
||||
result3 = await tracker.emit_end_if_ready(session)
|
||||
assert result3.just_ended is True
|
||||
assert result3.transcript_path == "/path/3"
|
||||
assert tracker.completed_count == 3
|
||||
|
||||
def test_get_observability_metadata_includes_attempts_and_completions(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.on_compact("/path/2")
|
||||
|
||||
assert tracker.get_observability_metadata() == {
|
||||
"compaction_attempt_count": 3,
|
||||
"compaction_attempt_sources": "pre_query,sdk_internal:2",
|
||||
"compaction_count": 1,
|
||||
"compaction_sources": "pre_query",
|
||||
}
|
||||
|
||||
def test_get_log_summary_includes_attempts_and_completions(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.on_compact("/path/2")
|
||||
|
||||
assert tracker.get_log_summary() == {
|
||||
"attempt_count": 3,
|
||||
"attempt_sources": "pre_query,sdk_internal:2",
|
||||
"completed_count": 1,
|
||||
"completed_sources": "pre_query",
|
||||
}
|
||||
# After compaction, _transcript_path is cleared
|
||||
assert tracker._transcript_path == ""
|
||||
|
||||
@@ -720,9 +720,7 @@ class TestDoTransientBackoff:
|
||||
flips the reconstruction consistently with the rest of the path."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.copilot.sdk.service import _do_transient_backoff
|
||||
|
||||
cfg = _make_config(render_reasoning_in_ui=False)
|
||||
from backend.copilot.sdk.service import _do_transient_backoff, config
|
||||
|
||||
original_adapter = MagicMock()
|
||||
state = MagicMock()
|
||||
@@ -731,7 +729,6 @@ class TestDoTransientBackoff:
|
||||
|
||||
with (
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
patch(f"{_SVC}.config", cfg),
|
||||
patch("backend.copilot.sdk.service.SDKResponseAdapter") as mock_cls,
|
||||
):
|
||||
new_adapter = MagicMock()
|
||||
@@ -742,7 +739,7 @@ class TestDoTransientBackoff:
|
||||
mock_cls.assert_called_once_with(
|
||||
message_id="msg-1",
|
||||
session_id="sess-1",
|
||||
render_reasoning_in_ui=False,
|
||||
render_reasoning_in_ui=config.render_reasoning_in_ui,
|
||||
)
|
||||
assert state.adapter is new_adapter
|
||||
|
||||
|
||||
@@ -68,7 +68,9 @@ class SDKResponseAdapter:
|
||||
self.reasoning_block_id = str(uuid.uuid4())
|
||||
self.has_started_reasoning = False
|
||||
self.has_ended_reasoning = True
|
||||
self.render_reasoning_in_ui = render_reasoning_in_ui
|
||||
# When False, reasoning wire events + persisted reasoning rows are
|
||||
# suppressed; transcript continuity is unaffected.
|
||||
self._render_reasoning_in_ui = render_reasoning_in_ui
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.step_open = False
|
||||
@@ -163,12 +165,13 @@ class SDKResponseAdapter:
|
||||
if block.thinking:
|
||||
self._end_text_if_open(responses)
|
||||
self._ensure_reasoning_started(responses)
|
||||
responses.append(
|
||||
StreamReasoningDelta(
|
||||
id=self.reasoning_block_id,
|
||||
delta=block.thinking,
|
||||
if self._render_reasoning_in_ui:
|
||||
responses.append(
|
||||
StreamReasoningDelta(
|
||||
id=self.reasoning_block_id,
|
||||
delta=block.thinking,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
@@ -365,13 +368,15 @@ class SDKResponseAdapter:
|
||||
"""Start (or restart) a reasoning block if needed.
|
||||
|
||||
Each ``ThinkingBlock`` the SDK emits gets its own streaming block
|
||||
so the frontend can render a new ``Reasoning`` part per LLM turn
|
||||
(rather than concatenating across the whole session). Events
|
||||
are emitted unconditionally — the caller filters them out of the
|
||||
SSE wire when ``render_reasoning_in_ui=False`` but still feeds
|
||||
them through ``_dispatch_response`` so the session transcript
|
||||
keeps a ``role='reasoning'`` row.
|
||||
on the wire so the frontend can render a new ``Reasoning`` part
|
||||
per LLM turn (rather than concatenating across the whole session).
|
||||
|
||||
No-op when ``render_reasoning_in_ui=False`` — callers still drive
|
||||
the method on every ``ThinkingBlock`` so persistence stays in
|
||||
lockstep, but nothing reaches the wire.
|
||||
"""
|
||||
if not self._render_reasoning_in_ui:
|
||||
return
|
||||
if not self.has_started_reasoning or self.has_ended_reasoning:
|
||||
if self.has_ended_reasoning:
|
||||
self.reasoning_block_id = str(uuid.uuid4())
|
||||
@@ -380,7 +385,13 @@ class SDKResponseAdapter:
|
||||
self.has_started_reasoning = True
|
||||
|
||||
def _end_reasoning_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current reasoning block if one is open."""
|
||||
"""End the current reasoning block if one is open.
|
||||
|
||||
No-op when ``render_reasoning_in_ui=False`` — no start was emitted,
|
||||
so no end is needed.
|
||||
"""
|
||||
if not self._render_reasoning_in_ui:
|
||||
return
|
||||
if self.has_started_reasoning and not self.has_ended_reasoning:
|
||||
responses.append(StreamReasoningEnd(id=self.reasoning_block_id))
|
||||
self.has_ended_reasoning = True
|
||||
|
||||
@@ -331,13 +331,11 @@ def test_empty_thinking_block_is_ignored():
|
||||
assert [type(r).__name__ for r in results] == ["StreamStartStep"]
|
||||
|
||||
|
||||
def test_render_reasoning_in_ui_false_still_emits_adapter_events():
|
||||
"""With the persist/render decoupling the adapter is flag-agnostic:
|
||||
it always emits ``StreamReasoning*`` so the session transcript keeps a
|
||||
durable reasoning record. Wire-level suppression when
|
||||
``render_reasoning_in_ui=False`` happens at the SDK service yield
|
||||
boundary, not here — see
|
||||
``backend/copilot/sdk/service.py::_filter_reasoning_events``.
|
||||
def test_render_reasoning_in_ui_false_suppresses_thinking_events():
|
||||
"""``render_reasoning_in_ui=False`` silences ``StreamReasoning*`` on
|
||||
the wire — the frontend sees a text-only stream. Persistence via
|
||||
``_format_sdk_content_blocks`` is handled elsewhere; this test only
|
||||
pins the wire contract.
|
||||
"""
|
||||
adapter = SDKResponseAdapter(
|
||||
message_id="m",
|
||||
@@ -350,17 +348,14 @@ def test_render_reasoning_in_ui_false_still_emits_adapter_events():
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningDelta" in types
|
||||
assert "StreamReasoningStart" not in types
|
||||
assert "StreamReasoningDelta" not in types
|
||||
assert "StreamReasoningEnd" not in types
|
||||
|
||||
|
||||
def test_render_reasoning_off_text_after_thinking_still_closes_reasoning():
|
||||
"""Adapter still emits a ``StreamReasoningEnd`` when text follows a
|
||||
thinking block — decoupled from the render flag. The service layer
|
||||
drops the reasoning events at yield time; the adapter's structural
|
||||
open/close pairing must not depend on the flag or downstream filters
|
||||
would see orphan reasoning starts on the persisted transcript.
|
||||
"""
|
||||
def test_render_reasoning_off_text_after_thinking_emits_no_reasoning_end():
|
||||
"""With rendering off the ReasoningEnd is never synthesized when text
|
||||
follows — no ReasoningStart ever hit the wire, so no close is due."""
|
||||
adapter = SDKResponseAdapter(
|
||||
message_id="m",
|
||||
session_id="s",
|
||||
@@ -376,7 +371,7 @@ def test_render_reasoning_off_text_after_thinking_still_closes_reasoning():
|
||||
AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamReasoningEnd" in types
|
||||
assert "StreamReasoningEnd" not in types
|
||||
assert "StreamTextStart" in types
|
||||
assert "StreamTextDelta" in types
|
||||
|
||||
|
||||
@@ -2378,18 +2378,6 @@ async def _run_stream_attempt(
|
||||
skip_strip=response is tail_delta,
|
||||
)
|
||||
if dispatched is not None:
|
||||
# Persistence (via _dispatch_response) always runs so the
|
||||
# session transcript keeps role='reasoning' rows; the
|
||||
# wire is gated so UI can suppress rendering.
|
||||
if not state.adapter.render_reasoning_in_ui and isinstance(
|
||||
dispatched,
|
||||
(
|
||||
StreamReasoningStart,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
),
|
||||
):
|
||||
continue
|
||||
yield dispatched
|
||||
|
||||
# Mid-turn follow-up persistence: the MCP tool wrapper drains
|
||||
@@ -3772,17 +3760,15 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
if ended_with_stream_error:
|
||||
logger.warning(
|
||||
"%s Stream ended with SDK error after %d messages (compaction=%s)",
|
||||
"%s Stream ended with SDK error after %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
compaction.get_log_summary(),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Stream completed successfully with %d messages (compaction=%s)",
|
||||
"%s Stream completed successfully with %d messages",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
compaction.get_log_summary(),
|
||||
)
|
||||
except GeneratorExit:
|
||||
# GeneratorExit is raised when the async generator is closed by the
|
||||
|
||||
@@ -853,28 +853,9 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# Platform-tool names whose MCP wrappers must NOT be exposed to SDK mode.
|
||||
# Baseline implements MCP versions of these for model-flexibility parity;
|
||||
# SDK mode keeps using the CLI-native built-ins listed in
|
||||
# ``_SDK_BUILTIN_ALWAYS`` so there is no double exposure. This is a
|
||||
# deliberate cross-module constant — ``permissions.apply_tool_permissions``
|
||||
# consults it to map short tool names back to the CLI built-in form in SDK
|
||||
# mode. Public (no leading underscore) so a future refactor renaming it is
|
||||
# visible to both call sites, not silently broken.
|
||||
BASELINE_ONLY_MCP_TOOLS: frozenset[str] = frozenset({"Task", "TodoWrite"})
|
||||
|
||||
|
||||
def _registry_mcp_tools() -> list[str]:
|
||||
return [
|
||||
f"{MCP_TOOL_PREFIX}{name}"
|
||||
for name in TOOL_REGISTRY.keys()
|
||||
if name not in BASELINE_ONLY_MCP_TOOLS
|
||||
]
|
||||
|
||||
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*_registry_mcp_tools(),
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{WRITE_TOOL_NAME}",
|
||||
f"{MCP_TOOL_PREFIX}{READ_TOOL_NAME}",
|
||||
f"{MCP_TOOL_PREFIX}{EDIT_TOOL_NAME}",
|
||||
@@ -896,7 +877,7 @@ def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
# from E2B_FILE_TOOLS instead), so don't include them here.
|
||||
# _READ_TOOL_NAME is still needed for SDK tool-result reads.
|
||||
return [
|
||||
*_registry_mcp_tools(),
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
|
||||
@@ -43,8 +43,6 @@ from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .run_sub_session import RunSubSessionTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .task import TaskTool
|
||||
from .todo_write import TodoWriteTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .web_search import WebSearchTool
|
||||
@@ -88,11 +86,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_sub_session": RunSubSessionTool(),
|
||||
"get_sub_session_result": GetSubSessionResultTool(),
|
||||
# Planning + delegation (baseline parity with Claude Code built-ins).
|
||||
# SDK mode uses the CLI's native Task/TodoWrite and these MCP versions
|
||||
# are filtered out of SDK's allowed_tools in ``sdk/tool_adapter.py``.
|
||||
"Task": TaskTool(),
|
||||
"TodoWrite": TodoWriteTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
|
||||
@@ -181,9 +181,7 @@ async def execute_block(
|
||||
# (e.g., "42" → 42, string booleans → bool, enum defaults applied).
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in simulate_block(
|
||||
block, input_data, user_id=user_id
|
||||
):
|
||||
async for output_name, output_data in simulate_block(block, input_data):
|
||||
outputs[output_name].append(output_data)
|
||||
# simulator signals internal failure via ("error", "[SIMULATOR ERROR …]")
|
||||
sim_error = outputs.get("error", [])
|
||||
|
||||
@@ -26,10 +26,7 @@ _USER = "test-user-helpers"
|
||||
_SESSION = "test-session-helpers"
|
||||
|
||||
|
||||
def _make_block(
|
||||
block_id: str = "block-1",
|
||||
name: str = "TestBlock",
|
||||
):
|
||||
def _make_block(block_id: str = "block-1", name: str = "TestBlock"):
|
||||
"""Create a minimal mock block for execute_block()."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
@@ -208,154 +205,6 @@ class TestExecuteBlockCreditCharging:
|
||||
assert result.success is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unregistered block regression: blocks without BLOCK_COSTS entry still run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
class TestUnregisteredBlockRunsFree:
|
||||
"""Ensure blocks not listed in BLOCK_COSTS execute cleanly at zero cost.
|
||||
|
||||
A future refactor that accidentally turns an unregistered block into a
|
||||
non-zero charge — or crashes when the BLOCK_COSTS lookup returns no
|
||||
entry — would silently bill free blocks. ``block_usage_cost`` already
|
||||
returns ``(0, {})`` for unregistered blocks; this test locks that
|
||||
contract in at the copilot execution boundary.
|
||||
"""
|
||||
|
||||
async def test_unregistered_block_runs_without_charge(self):
|
||||
block = _make_block(block_id="unregistered-block", name="UnregisteredBlock")
|
||||
credit_patch, mock_credit = _patch_credit_db()
|
||||
|
||||
with (
|
||||
_patch_workspace(),
|
||||
credit_patch,
|
||||
):
|
||||
result = await execute_block(
|
||||
block=block,
|
||||
block_id="unregistered-block",
|
||||
input_data={},
|
||||
user_id=_USER,
|
||||
session_id=_SESSION,
|
||||
node_exec_id="exec-unreg",
|
||||
matched_credentials={},
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, BlockOutputResponse)
|
||||
assert result.success is True
|
||||
# Zero-cost lookup must not touch either credit-wallet endpoint.
|
||||
mock_credit.get_credits.assert_not_awaited()
|
||||
mock_credit.spend_credits.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BLOCK_COSTS regression: newly-registered paid-API blocks must decrement credits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewlyRegisteredBlockCosts:
|
||||
"""Regression coverage for the cost-tracking leak closure.
|
||||
|
||||
Every block listed here was missing from BLOCK_COSTS before this PR and
|
||||
would silently no-op ``spend_credits`` when invoked via copilot
|
||||
``run_block``. Adding a block id to this test locks in the credit wall
|
||||
so a future refactor can't quietly drop the entry.
|
||||
"""
|
||||
|
||||
def test_perplexity_block_registered(self):
|
||||
from backend.blocks.perplexity import PerplexityBlock, PerplexityModel
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert PerplexityBlock in BLOCK_COSTS
|
||||
entries = BLOCK_COSTS[PerplexityBlock]
|
||||
# Pin model->cost mapping so swapped prices fail the regression test.
|
||||
costs_by_model = {
|
||||
entry.cost_filter["model"]: entry.cost_amount for entry in entries
|
||||
}
|
||||
assert costs_by_model == {
|
||||
PerplexityModel.SONAR: 1,
|
||||
PerplexityModel.SONAR_PRO: 5,
|
||||
PerplexityModel.SONAR_DEEP_RESEARCH: 10,
|
||||
}
|
||||
|
||||
def test_fact_checker_block_registered(self):
|
||||
from backend.blocks.jina.fact_checker import FactCheckerBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert FactCheckerBlock in BLOCK_COSTS
|
||||
assert BLOCK_COSTS[FactCheckerBlock][0].cost_amount == 1
|
||||
|
||||
def test_mem0_blocks_registered(self):
|
||||
from backend.blocks.mem0 import (
|
||||
AddMemoryBlock,
|
||||
GetAllMemoriesBlock,
|
||||
GetLatestMemoryBlock,
|
||||
SearchMemoryBlock,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
for block_cls in (
|
||||
AddMemoryBlock,
|
||||
SearchMemoryBlock,
|
||||
GetAllMemoriesBlock,
|
||||
GetLatestMemoryBlock,
|
||||
):
|
||||
assert block_cls in BLOCK_COSTS, f"{block_cls.__name__} missing"
|
||||
assert BLOCK_COSTS[block_cls][0].cost_amount == 1
|
||||
|
||||
def test_screenshotone_block_registered(self):
|
||||
from backend.blocks.screenshotone import ScreenshotWebPageBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert ScreenshotWebPageBlock in BLOCK_COSTS
|
||||
assert BLOCK_COSTS[ScreenshotWebPageBlock][0].cost_amount == 2
|
||||
|
||||
def test_nvidia_deepfake_block_registered(self):
|
||||
from backend.blocks.nvidia.deepfake import NvidiaDeepfakeDetectBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert NvidiaDeepfakeDetectBlock in BLOCK_COSTS
|
||||
assert BLOCK_COSTS[NvidiaDeepfakeDetectBlock][0].cost_amount == 2
|
||||
|
||||
def test_smartlead_blocks_registered(self):
|
||||
from backend.blocks.smartlead.campaign import (
|
||||
AddLeadToCampaignBlock,
|
||||
CreateCampaignBlock,
|
||||
SaveCampaignSequencesBlock,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert BLOCK_COSTS[CreateCampaignBlock][0].cost_amount == 2
|
||||
assert BLOCK_COSTS[AddLeadToCampaignBlock][0].cost_amount == 1
|
||||
assert BLOCK_COSTS[SaveCampaignSequencesBlock][0].cost_amount == 1
|
||||
|
||||
def test_zerobounce_validate_block_registered(self):
|
||||
from backend.blocks.zerobounce.validate_emails import ValidateEmailsBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert ValidateEmailsBlock in BLOCK_COSTS
|
||||
assert BLOCK_COSTS[ValidateEmailsBlock][0].cost_amount == 2
|
||||
|
||||
def test_claude_code_block_registered(self):
|
||||
"""ClaudeCodeBlock spawns an E2B sandbox + runs Claude inside it.
|
||||
|
||||
Cost is dominated by the in-sandbox LLM spend ($0.50-$2/run typical),
|
||||
not the sandbox compute itself. Flat 100 credits ($1.00) is the
|
||||
conservative estimate until we wire the in-sandbox x-total-cost back
|
||||
into NodeExecutionStats.provider_cost.
|
||||
"""
|
||||
from backend.blocks.claude_code import ClaudeCodeBlock
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
||||
assert ClaudeCodeBlock in BLOCK_COSTS
|
||||
assert BLOCK_COSTS[ClaudeCodeBlock][0].cost_amount == 100
|
||||
# Filter keys on `e2b_credentials` (not `credentials`) — verifies the
|
||||
# cost gate matches the block's actual input field name.
|
||||
assert "e2b_credentials" in BLOCK_COSTS[ClaudeCodeBlock][0].cost_filter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type coercion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -88,10 +88,6 @@ class ResponseType(str, Enum):
|
||||
MEMORY_FORGET_CANDIDATES = "memory_forget_candidates"
|
||||
MEMORY_FORGET_CONFIRM = "memory_forget_confirm"
|
||||
|
||||
# Planning / delegation
|
||||
TODO_WRITE = "todo_write"
|
||||
TASK = "task"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
@@ -606,13 +602,11 @@ class WebSearchResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.WEB_SEARCH
|
||||
query: str
|
||||
# Web-grounded synthesised answer the search provider wrote from
|
||||
# fresh page content. The LLM caller should read this directly
|
||||
# instead of re-fetching each citation URL — many sites are
|
||||
# bot-protected and ``web_fetch`` won't get through. Empty string
|
||||
# when the provider returned only citations.
|
||||
answer: str = ""
|
||||
results: list[WebSearchResult] = Field(default_factory=list)
|
||||
# Backend-reported usage for this call (copied from Anthropic's
|
||||
# ``usage.server_tool_use``). Surfaces as metadata for frontend
|
||||
# debug panels but is also what drives rate-limit / cost tracking
|
||||
# via ``persist_and_record_usage(provider="anthropic")``.
|
||||
search_requests: int = 0
|
||||
|
||||
|
||||
@@ -845,60 +839,3 @@ class MemoryForgetConfirmResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.MEMORY_FORGET_CONFIRM
|
||||
deleted_uuids: list[str] = Field(default_factory=list)
|
||||
failed_uuids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# --- Planning / delegation ---
|
||||
|
||||
|
||||
class TodoItem(BaseModel):
|
||||
"""One entry in a ``TodoWrite`` checklist.
|
||||
|
||||
Mirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so
|
||||
the frontend's ``GenericTool`` accordion renders baseline-emitted todos
|
||||
identically to SDK-emitted ones.
|
||||
"""
|
||||
|
||||
content: str = Field(description="Imperative description of the task.")
|
||||
activeForm: str = Field(
|
||||
description="Present-continuous form shown while the task is running.",
|
||||
)
|
||||
status: Literal["pending", "in_progress", "completed"] = Field(
|
||||
default="pending",
|
||||
)
|
||||
|
||||
|
||||
class TodoWriteResponse(ToolResponseBase):
|
||||
"""Ack returned by ``TodoWrite``.
|
||||
|
||||
The tool is effectively stateless — the authoritative task list lives in
|
||||
the assistant's latest tool-call arguments, which are replayed from the
|
||||
transcript on each turn. The tool output only needs to confirm that the
|
||||
update was accepted so the model can proceed.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.TODO_WRITE
|
||||
todos: list[TodoItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TaskResponse(ToolResponseBase):
|
||||
"""Result of a delegated ``Task`` in-process sub-agent run.
|
||||
|
||||
The sub-agent runs a fresh tool-call loop with an isolated message
|
||||
history, then returns only its final assistant text. Intermediate tool
|
||||
calls and thinking stay inside the sub-agent's loop so the parent
|
||||
context is not polluted.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.TASK
|
||||
description: str = ""
|
||||
response: str = Field(
|
||||
default="",
|
||||
description="Final assistant text the sub-agent produced.",
|
||||
)
|
||||
iterations: int = 0
|
||||
tool_calls: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Names of tools the sub-agent invoked (for observability).",
|
||||
)
|
||||
status: Literal["completed", "max_iterations", "error"] = "completed"
|
||||
error: str | None = None
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
"""In-process sub-agent tool for baseline copilot mode.
|
||||
|
||||
The ``Task`` tool delegates a focused, context-isolated unit of work to a
|
||||
fresh tool-call loop that runs **inside the current session** — same user,
|
||||
same tools, same workspace — but with its own message history. The parent
|
||||
LLM never sees the sub-agent's intermediate tool calls or reasoning; it
|
||||
only sees the sub-agent's final summary as the tool result.
|
||||
|
||||
Why baseline needs its own: the Claude Agent SDK ships a built-in
|
||||
``Task`` / ``Agent`` tool that does this natively. Baseline routes through
|
||||
OpenAI-compatible providers (Kimi, GPT, Grok, Gemini) where no such
|
||||
built-in exists. This platform-tool rebuild gives baseline feature parity
|
||||
without giving up the model-flexibility advantage.
|
||||
|
||||
**Execution note.** Baseline's service loop short-circuits ``Task`` *before*
|
||||
dispatching through ``execute_tool`` because the nested loop needs direct
|
||||
access to the parent's ``_BaselineStreamState`` primitives (LLM caller,
|
||||
tool executor, reasoning emitter). Calls that reach ``_execute`` here are
|
||||
an unsupported path — they get a clear error so a misconfigured caller
|
||||
fails loudly rather than silently producing a no-op response.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskTool(BaseTool):
|
||||
"""Delegate a focused task to an in-process sub-agent."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# Capitalised to match the frontend's switch on ``"Task"`` / ``"Agent"``
|
||||
# (see ``copilot/tools/GenericTool/helpers.ts``). Keeping the name
|
||||
# identical to the SDK's built-in means the chat UI renders baseline
|
||||
# and SDK sub-agent runs the same way.
|
||||
return "Task"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Run a focused task in an in-process sub-agent with isolated "
|
||||
"history; only its final summary returns. For durable/background "
|
||||
"work use `run_sub_session` instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Short (3-5 word) accordion label.",
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Full instructions — sub-agent does NOT inherit "
|
||||
"parent conversation."
|
||||
),
|
||||
},
|
||||
"subagent_type": {
|
||||
"type": "string",
|
||||
"description": "Optional profile name (SDK parity; ignored).",
|
||||
},
|
||||
},
|
||||
"required": ["description", "prompt"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id, kwargs
|
||||
# Baseline's service loop is supposed to intercept ``Task`` calls
|
||||
# before they reach this path. Reaching here means either the SDK
|
||||
# path dispatched through MCP (which would be a misconfiguration —
|
||||
# SDK already has a CLI-native Task tool) or baseline's short-circuit
|
||||
# was bypassed. Either way, return a loud error so the misconfig is
|
||||
# visible in the trace instead of silently returning nothing.
|
||||
logger.warning(
|
||||
"Task tool reached the generic execute path — expected baseline "
|
||||
"service to intercept. session=%s",
|
||||
session.session_id,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Task is a baseline-only in-process sub-agent tool and must "
|
||||
"be dispatched by the baseline service loop. In SDK mode use "
|
||||
"the CLI-native Task tool; for durable/background work use "
|
||||
"run_sub_session instead."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Tests for TaskTool (the schema stub).
|
||||
|
||||
The actual sub-agent execution is unit-tested alongside the baseline
|
||||
service loop in ``baseline/service_unit_test.py`` because it requires the
|
||||
baseline's LLM caller and tool executor closures. This file just verifies
|
||||
the tool schema and that the fall-back path surfaces a loud error if the
|
||||
service loop short-circuit ever gets bypassed.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
from backend.copilot.tools.task import TaskTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> TaskTool:
|
||||
return TaskTool()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
def test_openai_schema_shape(tool: TaskTool):
|
||||
schema = tool.as_openai_tool()
|
||||
assert schema["type"] == "function"
|
||||
assert schema["function"]["name"] == "Task"
|
||||
params = schema["function"]["parameters"]
|
||||
assert sorted(params["required"]) == ["description", "prompt"]
|
||||
# ``subagent_type`` must remain optional (SDK parity) so models that
|
||||
# don't know about it don't break schema validation.
|
||||
assert "subagent_type" in params["properties"]
|
||||
assert "subagent_type" not in params["required"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generic_dispatch_returns_error(tool: TaskTool, session: ChatSession):
|
||||
"""If anything dispatches Task through BaseTool.execute instead of the
|
||||
baseline short-circuit, surface a loud error so the misconfig is
|
||||
visible in logs and transcripts."""
|
||||
result = await tool._execute(
|
||||
user_id="user",
|
||||
session=session,
|
||||
description="demo",
|
||||
prompt="do a thing",
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "baseline service loop" in result.message
|
||||
@@ -237,7 +237,7 @@ async def test_execute_block_dry_run_skips_real_execution():
|
||||
mock_block = make_mock_block()
|
||||
mock_block.execute = AsyncMock() # should NOT be called
|
||||
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", "simulated"
|
||||
|
||||
# Patching at helpers.simulate_block works because helpers.py imports
|
||||
@@ -267,7 +267,7 @@ async def test_execute_block_dry_run_response_format():
|
||||
"""Dry-run response should look like a normal success (no dry-run signal to LLM)."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", "simulated"
|
||||
|
||||
with patch(
|
||||
@@ -331,7 +331,7 @@ async def test_execute_block_real_execution_unchanged():
|
||||
# Just verify simulate_block is NOT called.
|
||||
simulate_called = False
|
||||
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
async def fake_simulate(block, input_data):
|
||||
nonlocal simulate_called
|
||||
simulate_called = True
|
||||
yield "result", "should not happen"
|
||||
@@ -455,7 +455,7 @@ async def test_execute_block_dry_run_no_empty_error_from_simulator():
|
||||
"""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
async def fake_simulate(block, input_data):
|
||||
# Simulator now omits empty error pins at source
|
||||
yield "result", "simulated output"
|
||||
|
||||
@@ -485,7 +485,7 @@ async def test_execute_block_dry_run_keeps_nonempty_error_pin():
|
||||
"""Dry-run should keep the 'error' pin when it contains a real error message."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", ""
|
||||
yield "error", "API rate limit exceeded"
|
||||
|
||||
@@ -515,7 +515,7 @@ async def test_execute_block_dry_run_message_includes_completed_status():
|
||||
"""Dry-run message should clearly indicate COMPLETED status."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
async def fake_simulate(block, input_data):
|
||||
yield "result", "simulated"
|
||||
|
||||
with patch(
|
||||
@@ -541,7 +541,7 @@ async def test_execute_block_dry_run_simulator_error_returns_error_response():
|
||||
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate_error(block, input_data, **_kwargs):
|
||||
async def fake_simulate_error(block, input_data):
|
||||
yield (
|
||||
"error",
|
||||
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).",
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
"""Task-list tool for baseline copilot mode.
|
||||
|
||||
Mirrors the schema and UX of Claude Code's built-in ``TodoWrite`` tool so
|
||||
the frontend's generic tool renderer draws baseline-emitted checklists the
|
||||
same way it draws SDK-emitted ones. The tool is stateless: the model's
|
||||
latest ``todos`` argument IS the canonical list, replayed from transcript
|
||||
on subsequent turns.
|
||||
|
||||
Baseline needs this as a platform tool because OpenAI-compatible providers
|
||||
(Kimi, GPT, Grok, Gemini) do not ship a built-in equivalent. The SDK path
|
||||
continues to use the CLI's native ``TodoWrite`` — the MCP-wrapped version
|
||||
of this tool is filtered out of SDK's allowed_tools list (see
|
||||
``sdk/tool_adapter.py``) to avoid name shadowing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, TodoItem, TodoWriteResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TodoWriteTool(BaseTool):
|
||||
"""Maintain a step-by-step task checklist visible to the user."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# Capitalised to match the frontend's switch on ``"TodoWrite"``
|
||||
# (see ``copilot/tools/GenericTool/helpers.ts``).
|
||||
return "TodoWrite"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Plan and track multi-step work as a visible checklist. Send "
|
||||
"the full list every call; exactly one item in_progress at a time."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"description": "Full updated task list (not a delta).",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Imperative (e.g. 'Run tests').",
|
||||
},
|
||||
"activeForm": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Present-continuous (e.g. 'Running tests')."
|
||||
),
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"],
|
||||
"default": "pending",
|
||||
},
|
||||
},
|
||||
"required": ["content", "activeForm"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["todos"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id
|
||||
raw_todos = kwargs.get("todos")
|
||||
if raw_todos is None:
|
||||
return ErrorResponse(
|
||||
message="`todos` is required.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
if not isinstance(raw_todos, list):
|
||||
return ErrorResponse(
|
||||
message="`todos` must be an array.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed = [TodoItem.model_validate(item) for item in raw_todos]
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Invalid todo entry: {exc}",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
in_progress = sum(1 for t in parsed if t.status == "in_progress")
|
||||
if in_progress > 1:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Only one todo may be 'in_progress' at a time "
|
||||
f"(found {in_progress})."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return TodoWriteResponse(
|
||||
message="Task list updated.",
|
||||
session_id=session.session_id,
|
||||
todos=parsed,
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Tests for TodoWriteTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import ErrorResponse, TodoItem, TodoWriteResponse
|
||||
from backend.copilot.tools.todo_write import TodoWriteTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> TodoWriteTool:
|
||||
return TodoWriteTool()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_todo_list(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[
|
||||
{
|
||||
"content": "Write tests",
|
||||
"activeForm": "Writing tests",
|
||||
"status": "pending",
|
||||
},
|
||||
{
|
||||
"content": "Ship PR",
|
||||
"activeForm": "Shipping PR",
|
||||
"status": "in_progress",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TodoWriteResponse)
|
||||
assert result.session_id == session.session_id
|
||||
assert len(result.todos) == 2
|
||||
assert result.todos[0] == TodoItem(
|
||||
content="Write tests",
|
||||
activeForm="Writing tests",
|
||||
status="pending",
|
||||
)
|
||||
assert result.todos[1].status == "in_progress"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_status_is_pending(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[{"content": "Write tests", "activeForm": "Writing tests"}],
|
||||
)
|
||||
|
||||
assert isinstance(result, TodoWriteResponse)
|
||||
assert result.todos[0].status == "pending"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_todos_returns_error(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(user_id=None, session=session)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "todos" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_list_todos_returns_error(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(user_id=None, session=session, todos="not a list")
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_item_returns_error(tool: TodoWriteTool, session: ChatSession):
|
||||
# Missing required `activeForm` field.
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[{"content": "Missing active form"}],
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_in_progress_rejected(tool: TodoWriteTool, session: ChatSession):
|
||||
"""Exactly one item should be in_progress at a time — SDK parity rule."""
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[
|
||||
{
|
||||
"content": "A",
|
||||
"activeForm": "Doing A",
|
||||
"status": "in_progress",
|
||||
},
|
||||
{
|
||||
"content": "B",
|
||||
"activeForm": "Doing B",
|
||||
"status": "in_progress",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "in_progress" in result.message
|
||||
|
||||
|
||||
def test_openai_schema_shape(tool: TodoWriteTool):
|
||||
schema = tool.as_openai_tool()
|
||||
assert schema["type"] == "function"
|
||||
assert schema["function"]["name"] == "TodoWrite"
|
||||
params = schema["function"]["parameters"]
|
||||
assert params["required"] == ["todos"]
|
||||
items = params["properties"]["todos"]["items"]
|
||||
assert items["required"] == ["content", "activeForm"]
|
||||
assert items["properties"]["status"]["enum"] == [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
]
|
||||
@@ -25,19 +25,7 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
# (server-side Anthropic beta). Description already trimmed to the
|
||||
# minimum viable copy; the bump absorbs the schema skeleton cost
|
||||
# (~300 chars / ~75 tokens) for a new LLM-facing primitive.
|
||||
# Bumped 32800 -> 33200 on PR #12873 for the web_search Perplexity
|
||||
# Sonar refactor — adds a load-bearing `deep` boolean with explicit
|
||||
# "~100x more expensive" cost warning the model must see to avoid
|
||||
# accidentally triggering sonar-reasoning on ordinary lookups, plus
|
||||
# synthesised-answer wording in the top-level description so the LLM
|
||||
# reads the answer before reaching for `web_fetch`. Both are
|
||||
# LLM-decision-critical copy, not bloat.
|
||||
# Bumped 33200 -> 34600 when baseline gained MCP `TodoWrite` and `Task`
|
||||
# platform tools for parity with the Claude Code SDK's built-ins
|
||||
# (PR: feat/copilot-baseline-todowrite-task). The two new schemas add
|
||||
# ~1200 chars / ~300 tokens combined; descriptions are already trimmed
|
||||
# to the minimum viable copy.
|
||||
_CHAR_BUDGET = 34_600
|
||||
_CHAR_BUDGET = 32_800
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -127,10 +115,9 @@ def test_total_schema_char_budget() -> None:
|
||||
|
||||
This locks in the 34% token reduction from #12398 and prevents future
|
||||
description bloat from eroding the gains. Uses character count with a
|
||||
~4 chars/token heuristic; see ``_CHAR_BUDGET`` above for the current
|
||||
value and its change history. Character count is tokenizer-agnostic
|
||||
— no dependency on GPT or Claude tokenizers — while still providing a
|
||||
stable regression gate.
|
||||
~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens).
|
||||
Character count is tokenizer-agnostic — no dependency on GPT or Claude
|
||||
tokenizers — while still providing a stable regression gate.
|
||||
"""
|
||||
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
|
||||
serialized = json.dumps(schemas)
|
||||
|
||||
@@ -1,66 +1,29 @@
|
||||
"""Web search tool — Perplexity Sonar via OpenRouter.
|
||||
"""Web search tool — wraps Anthropic's server-side ``web_search`` beta.
|
||||
|
||||
One provider, two tiers, one billing path:
|
||||
|
||||
* ``deep=False`` (default) — ``perplexity/sonar``. Searches the web
|
||||
natively and returns citation annotations in a single inference pass.
|
||||
* ``deep=True`` — ``perplexity/sonar-deep-research``. Multi-step
|
||||
agentic research; slower and costlier.
|
||||
|
||||
Why Sonar and not the ``openrouter:web_search`` server tool + dispatch
|
||||
model? The server tool feeds all search-result page content back into
|
||||
the dispatch model for a second inference pass — one observed call was
|
||||
74K input tokens at Gemini Flash rates, billing $0.072. Sonar
|
||||
searches natively in one pass, returns annotations typed as
|
||||
``ChatCompletionMessage.annotations`` in ``openai.types``, and at
|
||||
$1 / MTok base pricing lands ~$0.01 / call at our default shape.
|
||||
|
||||
``resp.usage.cost`` carries the real billed value via OpenRouter's
|
||||
``include: true`` extension; the value flows through
|
||||
``persist_and_record_usage(provider='open_router')`` into the daily /
|
||||
weekly microdollar rate-limit counter on the same rails as every other
|
||||
OpenRouter turn — no separate provider ledger line, no estimation
|
||||
drift. ``_extract_cost_usd`` mirrors the baseline service's
|
||||
``_extract_usage_cost`` logic; keep the two in sync if one changes.
|
||||
Single entry point for web search on both SDK and baseline paths. The
|
||||
``web_search_20250305`` tool is server-side on Anthropic, so we call
|
||||
the Messages API directly regardless of which LLM invoked the copilot
|
||||
tool — OpenRouter can't proxy server-side tool execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, WebSearchResponse, WebSearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_chat_config = ChatConfig()
|
||||
|
||||
_QUICK_MODEL = "perplexity/sonar"
|
||||
# Sonar base can emit up to ~4K output; cap at the provider ceiling so the
|
||||
# model stops when the answer is complete rather than when our budget trips.
|
||||
_QUICK_MAX_TOKENS = 4096
|
||||
|
||||
_DEEP_MODEL = "perplexity/sonar-deep-research"
|
||||
# Deep runs can produce long structured writeups — ~4x the quick ceiling
|
||||
# is enough headroom for multi-source comparisons without uncapping.
|
||||
_DEEP_MAX_TOKENS = _QUICK_MAX_TOKENS * 4
|
||||
|
||||
_WEB_SEARCH_DISPATCH_MODEL = "claude-haiku-4-5"
|
||||
_MAX_DISPATCH_TOKENS = 512
|
||||
_DEFAULT_MAX_RESULTS = 5
|
||||
_HARD_MAX_RESULTS = 20
|
||||
_SNIPPET_MAX_CHARS = 500
|
||||
|
||||
# OpenRouter-specific extra_body flag that embeds the real generation
|
||||
# cost into the response usage object. Same dict shape the baseline
|
||||
# service uses — keep the two aligned.
|
||||
_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}}
|
||||
|
||||
|
||||
class WebSearchTool(BaseTool):
|
||||
@@ -73,13 +36,9 @@ class WebSearchTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the web for live info (news, recent docs). Returns a "
|
||||
"synthesised answer grounded in fresh page content plus "
|
||||
"{title, url, snippet} citations — read the answer first "
|
||||
"before reaching for web_fetch. Set deep=true when the user "
|
||||
"asks for research / comparison / in-depth analysis; leave "
|
||||
"deep=false for quick fact lookups. Prefer one targeted "
|
||||
"query over many reformulations."
|
||||
"Search the web for live info (news, recent docs). Returns "
|
||||
"{title, url, snippet}; use web_fetch to deep-dive a URL. "
|
||||
"Prefer one targeted query over many reformulations."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -99,18 +58,6 @@ class WebSearchTool(BaseTool):
|
||||
),
|
||||
"default": _DEFAULT_MAX_RESULTS,
|
||||
},
|
||||
"deep": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Only set true when the user EXPLICITLY asks for "
|
||||
"research, comparison, or in-depth investigation "
|
||||
"across many sources — it is ~100x more expensive "
|
||||
"and much slower than a normal search. Default "
|
||||
"false; do not flip it for ordinary fact lookups "
|
||||
"or fresh-news questions."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
@@ -121,7 +68,7 @@ class WebSearchTool(BaseTool):
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return bool(_chat_config.api_key and _chat_config.base_url)
|
||||
return bool(Settings().secrets.anthropic_api_key)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
@@ -129,7 +76,6 @@ class WebSearchTool(BaseTool):
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
max_results: int = _DEFAULT_MAX_RESULTS,
|
||||
deep: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
query = (query or "").strip()
|
||||
@@ -147,35 +93,44 @@ class WebSearchTool(BaseTool):
|
||||
max_results = _DEFAULT_MAX_RESULTS
|
||||
max_results = max(1, min(max_results, _HARD_MAX_RESULTS))
|
||||
|
||||
if not _chat_config.api_key or not _chat_config.base_url:
|
||||
api_key = Settings().secrets.anthropic_api_key
|
||||
if not api_key:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web search is unavailable — the deployment has no "
|
||||
"OpenRouter credentials configured."
|
||||
"Anthropic API key configured."
|
||||
),
|
||||
error="web_search_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=_chat_config.api_key, base_url=_chat_config.base_url
|
||||
)
|
||||
model_used = _DEEP_MODEL if deep else _QUICK_MODEL
|
||||
max_tokens = _DEEP_MAX_TOKENS if deep else _QUICK_MAX_TOKENS
|
||||
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
try:
|
||||
resp = await client.chat.completions.create(
|
||||
model=model_used,
|
||||
max_tokens=max_tokens,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
extra_body=_OPENROUTER_INCLUDE_USAGE_COST,
|
||||
resp = await client.messages.create(
|
||||
model=_WEB_SEARCH_DISPATCH_MODEL,
|
||||
max_tokens=_MAX_DISPATCH_TOKENS,
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
"max_uses": 1,
|
||||
}
|
||||
],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Use the web_search tool exactly once with the "
|
||||
f"query {query!r} and then stop. Do not "
|
||||
f"summarise — the caller parses the raw "
|
||||
f"tool_result."
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[web_search] OpenRouter call failed (deep=%s) for query=%r: %s",
|
||||
deep,
|
||||
query,
|
||||
exc,
|
||||
"[web_search] Anthropic call failed for query=%r: %s", query, exc
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Web search failed: {exc}",
|
||||
@@ -183,20 +138,20 @@ class WebSearchTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
answer = _extract_answer(resp)
|
||||
results = _extract_results(resp, limit=max_results)
|
||||
cost_usd = _extract_cost_usd(resp.usage)
|
||||
results, search_requests = _extract_results(resp, limit=max_results)
|
||||
|
||||
cost_usd = _estimate_cost_usd(resp, search_requests=search_requests)
|
||||
try:
|
||||
usage = getattr(resp, "usage", None)
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=resp.usage.prompt_tokens if resp.usage else 0,
|
||||
completion_tokens=resp.usage.completion_tokens if resp.usage else 0,
|
||||
prompt_tokens=getattr(usage, "input_tokens", 0) or 0,
|
||||
completion_tokens=getattr(usage, "output_tokens", 0) or 0,
|
||||
log_prefix="[web_search]",
|
||||
cost_usd=cost_usd,
|
||||
model=model_used,
|
||||
provider="open_router",
|
||||
model=_WEB_SEARCH_DISPATCH_MODEL,
|
||||
provider="anthropic",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[web_search] usage tracking failed: %s", exc)
|
||||
@@ -204,92 +159,66 @@ class WebSearchTool(BaseTool):
|
||||
return WebSearchResponse(
|
||||
message=f"Found {len(results)} result(s) for {query!r}.",
|
||||
query=query,
|
||||
answer=answer,
|
||||
results=results,
|
||||
search_requests=1 if results else 0,
|
||||
search_requests=search_requests,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
def _extract_answer(resp: ChatCompletion) -> str:
|
||||
"""Return the synthesised answer text from Sonar's response.
|
||||
def _extract_results(resp: Any, *, limit: int) -> tuple[list[WebSearchResult], int]:
|
||||
"""Pull results + server-side request count from an Anthropic response."""
|
||||
results: list[WebSearchResult] = []
|
||||
search_requests = 0
|
||||
|
||||
Sonar reads every page it cites and writes a web-grounded synthesis
|
||||
into ``choices[0].message.content`` on the same call we pay for.
|
||||
Surfacing it saves the agent from re-fetching citation URLs — many
|
||||
are bot-protected and ``web_fetch`` can't reach them.
|
||||
"""
|
||||
if not resp.choices:
|
||||
return ""
|
||||
content = resp.choices[0].message.content
|
||||
return content or ""
|
||||
for block in getattr(resp, "content", []) or []:
|
||||
btype = getattr(block, "type", None)
|
||||
if btype == "web_search_tool_result":
|
||||
content = getattr(block, "content", []) or []
|
||||
for item in content:
|
||||
if getattr(item, "type", None) != "web_search_result":
|
||||
continue
|
||||
if len(results) >= limit:
|
||||
break
|
||||
# Anthropic's ``web_search_result`` exposes only
|
||||
# ``title``/``url``/``page_age`` plus an opaque
|
||||
# ``encrypted_content`` blob that is meant for citation
|
||||
# round-tripping, not for display — it is base64-ish
|
||||
# binary and would show as gibberish if surfaced to the
|
||||
# model or the frontend. There is no plain-text snippet
|
||||
# field in the current beta; callers get the readable
|
||||
# text via the model's ``text`` blocks with citations,
|
||||
# not via this list. Leave ``snippet`` empty.
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=getattr(item, "title", "") or "",
|
||||
url=getattr(item, "url", "") or "",
|
||||
snippet="",
|
||||
page_age=getattr(item, "page_age", None),
|
||||
)
|
||||
)
|
||||
|
||||
usage = getattr(resp, "usage", None)
|
||||
server_tool_use = getattr(usage, "server_tool_use", None) if usage else None
|
||||
if server_tool_use is not None:
|
||||
search_requests = getattr(server_tool_use, "web_search_requests", 0) or 0
|
||||
|
||||
return results, search_requests
|
||||
|
||||
|
||||
def _extract_results(resp: ChatCompletion, *, limit: int) -> list[WebSearchResult]:
|
||||
"""Pull ``url_citation`` annotations from the response.
|
||||
|
||||
Shared across both tiers — OpenRouter normalises the annotation
|
||||
schema across Perplexity's sonar models into
|
||||
``Annotation.url_citation`` (typed in ``openai.types.chat``). The
|
||||
``content`` snippet is an OpenRouter extension on the otherwise-
|
||||
typed ``AnnotationURLCitation``; pydantic stashes unknown fields in
|
||||
``model_extra``, which we read there rather than via ``getattr``.
|
||||
"""
|
||||
if not resp.choices:
|
||||
return []
|
||||
annotations = resp.choices[0].message.annotations or []
|
||||
out: list[WebSearchResult] = []
|
||||
for ann in annotations:
|
||||
if len(out) >= limit:
|
||||
break
|
||||
if ann.type != "url_citation":
|
||||
continue
|
||||
citation = ann.url_citation
|
||||
extras = citation.model_extra or {}
|
||||
snippet_raw = extras.get("content")
|
||||
snippet = (snippet_raw or "")[:_SNIPPET_MAX_CHARS] if snippet_raw else ""
|
||||
out.append(
|
||||
WebSearchResult(
|
||||
title=citation.title,
|
||||
url=citation.url,
|
||||
snippet=snippet,
|
||||
page_age=None,
|
||||
)
|
||||
)
|
||||
return out
|
||||
# Update when Anthropic revises pricing.
|
||||
_COST_PER_SEARCH_USD = 0.010 # $10 per 1,000 web_search requests
|
||||
_HAIKU_INPUT_USD_PER_MTOK = 1.0
|
||||
_HAIKU_OUTPUT_USD_PER_MTOK = 5.0
|
||||
|
||||
|
||||
def _extract_cost_usd(usage: CompletionUsage | None) -> float | None:
|
||||
"""Return the provider-reported USD cost off the response usage.
|
||||
def _estimate_cost_usd(resp: Any, *, search_requests: int) -> float:
|
||||
"""Per-search fee × count + Haiku dispatch tokens."""
|
||||
usage = getattr(resp, "usage", None)
|
||||
input_tokens = getattr(usage, "input_tokens", 0) if usage else 0
|
||||
output_tokens = getattr(usage, "output_tokens", 0) if usage else 0
|
||||
|
||||
OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible
|
||||
usage object when the request body includes
|
||||
``usage: {"include": True}``. The OpenAI SDK's typed
|
||||
``CompletionUsage`` does not declare it, so we read it off
|
||||
``model_extra`` (the pydantic v2 container for extras) to keep
|
||||
access fully typed — no ``getattr``. Mirrors the baseline service
|
||||
``_extract_usage_cost``; keep the two in sync.
|
||||
|
||||
Returns ``None`` when the field is absent, null, non-numeric,
|
||||
non-finite, or negative. Invalid values log at error level because
|
||||
they indicate a provider bug worth chasing; plain absences are
|
||||
silent so the caller can dedupe the "missing cost" warning.
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
extras = usage.model_extra or {}
|
||||
if "cost" not in extras:
|
||||
return None
|
||||
raw = extras["cost"]
|
||||
if raw is None:
|
||||
logger.error("[web_search] usage.cost is present but null")
|
||||
return None
|
||||
try:
|
||||
val = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.error("[web_search] usage.cost is not numeric: %r", raw)
|
||||
return None
|
||||
if not math.isfinite(val) or val < 0:
|
||||
logger.error("[web_search] usage.cost is non-finite or negative: %r", val)
|
||||
return None
|
||||
return val
|
||||
search_cost = search_requests * _COST_PER_SEARCH_USD
|
||||
inference_cost = (input_tokens / 1_000_000) * _HAIKU_INPUT_USD_PER_MTOK + (
|
||||
output_tokens / 1_000_000
|
||||
) * _HAIKU_OUTPUT_USD_PER_MTOK
|
||||
return round(search_cost + inference_cost, 6)
|
||||
|
||||
@@ -1,289 +1,212 @@
|
||||
"""Tests for the ``web_search`` copilot tool.
|
||||
|
||||
Covers the annotation extractor + cost extractor as pure units (fed
|
||||
with real ``openai`` SDK types — no duck-typed ``SimpleNamespace``
|
||||
stand-ins), plus integration tests exercising both the quick
|
||||
(``perplexity/sonar``) and deep (``perplexity/sonar-deep-research``)
|
||||
paths — mocking ``AsyncOpenAI.chat.completions.create`` and confirming
|
||||
the handler plumbs through to ``persist_and_record_usage`` with
|
||||
``provider='open_router'`` and the real ``usage.cost`` value.
|
||||
Covers the result extractor + cost estimator as pure units (fed with
|
||||
synthetic Anthropic response objects), plus light integration tests that
|
||||
mock ``AsyncAnthropic.messages.create`` and confirm the handler plumbs
|
||||
through to ``persist_and_record_usage`` with the right provider tag.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import (
|
||||
Annotation,
|
||||
AnnotationURLCitation,
|
||||
ChatCompletionMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .models import ErrorResponse, WebSearchResponse
|
||||
from .models import ErrorResponse, WebSearchResponse, WebSearchResult
|
||||
from .web_search import (
|
||||
_COST_PER_SEARCH_USD,
|
||||
WebSearchTool,
|
||||
_extract_answer,
|
||||
_extract_cost_usd,
|
||||
_estimate_cost_usd,
|
||||
_extract_results,
|
||||
)
|
||||
|
||||
|
||||
def _usage(
|
||||
def _fake_anthropic_response(
|
||||
*,
|
||||
prompt_tokens: int = 120,
|
||||
completion_tokens: int = 40,
|
||||
cost: object = 0.01,
|
||||
) -> CompletionUsage:
|
||||
"""Typed ``CompletionUsage`` with OpenRouter's ``cost`` extension
|
||||
parked in ``model_extra`` — the same channel the production code
|
||||
reads it from. ``model_construct`` preserves unknown fields;
|
||||
``model_validate`` would drop them because ``CompletionUsage``
|
||||
treats the schema as strict."""
|
||||
payload: dict[str, Any] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
return CompletionUsage.model_construct(None, **payload)
|
||||
results: list[dict] | None = None,
|
||||
search_requests: int = 1,
|
||||
input_tokens: int = 120,
|
||||
output_tokens: int = 40,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a synthetic Anthropic Messages response.
|
||||
|
||||
|
||||
def _citation(*, url: str, title: str, content: str | None = None) -> Annotation:
|
||||
"""Typed ``Annotation`` for a URL citation. ``content`` is an
|
||||
OpenRouter extension on the otherwise-typed schema — goes into
|
||||
``url_citation.model_extra`` when model_construct preserves it."""
|
||||
payload: dict[str, Any] = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"start_index": 0,
|
||||
"end_index": len(title),
|
||||
}
|
||||
if content is not None:
|
||||
payload["content"] = content
|
||||
url_citation = AnnotationURLCitation.model_construct(None, **payload)
|
||||
return Annotation(type="url_citation", url_citation=url_citation)
|
||||
|
||||
|
||||
def _fake_response(
|
||||
*,
|
||||
citations: list[dict] | None = None,
|
||||
answer: str = "ok",
|
||||
prompt_tokens: int = 120,
|
||||
completion_tokens: int = 40,
|
||||
cost: object = 0.01,
|
||||
) -> ChatCompletion:
|
||||
"""Build a typed ``ChatCompletion`` shaped like an OpenRouter
|
||||
response — typed end-to-end so the production code's attribute
|
||||
access runs under the real SDK types in tests."""
|
||||
annotations = [
|
||||
_citation(
|
||||
url=c.get("url", ""),
|
||||
title=c.get("title", "untitled"),
|
||||
content=c.get("content"),
|
||||
Matches the shape produced by ``client.messages.create`` when the
|
||||
response includes a ``web_search_tool_result`` content block and
|
||||
``usage.server_tool_use.web_search_requests`` on the turn meter.
|
||||
"""
|
||||
content = []
|
||||
if results is not None:
|
||||
content.append(
|
||||
SimpleNamespace(
|
||||
type="web_search_tool_result",
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="web_search_result",
|
||||
title=r.get("title", "untitled"),
|
||||
url=r.get("url", ""),
|
||||
encrypted_content=r.get("snippet", ""),
|
||||
page_age=r.get("page_age"),
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
)
|
||||
)
|
||||
for c in citations or []
|
||||
]
|
||||
message = ChatCompletionMessage.model_construct(
|
||||
None,
|
||||
role="assistant",
|
||||
content=answer,
|
||||
annotations=annotations,
|
||||
)
|
||||
choice = Choice.model_construct(
|
||||
None,
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=message,
|
||||
)
|
||||
return ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-test",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model="perplexity/sonar",
|
||||
choices=[choice],
|
||||
usage=_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
),
|
||||
usage = SimpleNamespace(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
server_tool_use=SimpleNamespace(web_search_requests=search_requests),
|
||||
)
|
||||
return SimpleNamespace(content=content, usage=usage)
|
||||
|
||||
|
||||
class TestExtractResults:
|
||||
"""Pin the annotation shape — a schema bump in the OpenAI SDK or
|
||||
OpenRouter surfaces here first. Same extractor serves both tiers
|
||||
because OpenRouter normalises annotations across models."""
|
||||
"""The extractor is the only Anthropic-response-shape contact point;
|
||||
pin its behaviour so an API shape change surfaces here first."""
|
||||
|
||||
def test_extracts_title_url_and_content_snippet(self):
|
||||
resp = _fake_response(
|
||||
citations=[
|
||||
def test_extracts_title_url_page_age_and_drops_encrypted_snippet(self):
|
||||
# Anthropic's ``web_search_result`` ships an opaque
|
||||
# ``encrypted_content`` blob that is not safe to surface —
|
||||
# the extractor must drop it (snippet=="") regardless of
|
||||
# whether the blob is non-empty.
|
||||
resp = _fake_anthropic_response(
|
||||
results=[
|
||||
{
|
||||
"title": "Kimi K2.6 launch",
|
||||
"url": "https://example.com/kimi",
|
||||
"content": "Moonshot released K2.6 on 2026-04-20.",
|
||||
"snippet": "EiJjbGF1ZGUtZW5jcnlwdGVkLWJsb2I=",
|
||||
"page_age": "1 day",
|
||||
},
|
||||
{
|
||||
"title": "OpenRouter pricing",
|
||||
"url": "https://openrouter.ai/moonshotai/kimi-k2.6",
|
||||
"snippet": "",
|
||||
},
|
||||
]
|
||||
)
|
||||
out = _extract_results(resp, limit=10)
|
||||
out, requests = _extract_results(resp, limit=10)
|
||||
assert requests == 1
|
||||
assert len(out) == 2
|
||||
assert out[0].title == "Kimi K2.6 launch"
|
||||
assert out[0].url == "https://example.com/kimi"
|
||||
assert out[0].snippet.startswith("Moonshot released")
|
||||
# Missing ``content`` extension → empty snippet rather than crash.
|
||||
assert out[0].snippet == ""
|
||||
assert out[0].page_age == "1 day"
|
||||
assert out[1].snippet == ""
|
||||
|
||||
def test_limit_caps_returned_results(self):
|
||||
resp = _fake_response(
|
||||
citations=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
|
||||
resp = _fake_anthropic_response(
|
||||
results=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
|
||||
)
|
||||
out = _extract_results(resp, limit=3)
|
||||
out, _ = _extract_results(resp, limit=3)
|
||||
assert len(out) == 3
|
||||
assert [r.title for r in out] == ["r0", "r1", "r2"]
|
||||
|
||||
def test_missing_choices_returns_empty(self):
|
||||
resp = ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-test",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model="perplexity/sonar",
|
||||
choices=[],
|
||||
usage=_usage(),
|
||||
def test_missing_content_returns_empty(self):
|
||||
resp = SimpleNamespace(content=[], usage=None)
|
||||
out, requests = _extract_results(resp, limit=10)
|
||||
assert out == []
|
||||
assert requests == 0
|
||||
|
||||
def test_non_search_blocks_are_ignored(self):
|
||||
resp = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="text", text="Here's what I found..."),
|
||||
SimpleNamespace(
|
||||
type="web_search_tool_result",
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="web_search_result",
|
||||
title="real",
|
||||
url="https://real.example",
|
||||
encrypted_content="body",
|
||||
page_age=None,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
assert _extract_results(resp, limit=10) == []
|
||||
out, _ = _extract_results(resp, limit=10)
|
||||
assert len(out) == 1 and out[0].title == "real"
|
||||
|
||||
def test_extract_answer_returns_message_content(self):
|
||||
resp = _fake_response(
|
||||
answer="Sonar's synthesised, web-grounded answer text.",
|
||||
citations=[{"title": "t", "url": "https://e"}],
|
||||
|
||||
class TestEstimateCostUsd:
|
||||
"""Pin the per-search fee + Haiku inference math — the pricing
|
||||
constants in ``web_search.py`` are hard-coded (no live lookup) so a
|
||||
drift between Anthropic's schedule and our constants must surface
|
||||
in this test for the next reader to notice."""
|
||||
|
||||
def test_zero_searches_still_charges_inference(self):
|
||||
resp = _fake_anthropic_response(results=[], search_requests=0)
|
||||
cost = _estimate_cost_usd(resp, search_requests=0)
|
||||
# Haiku at 1000 input / 5000 output tokens = tiny but non-zero.
|
||||
assert 0 < cost < 0.001
|
||||
|
||||
def test_single_search_fee_dominates(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[{"title": "x", "url": "https://e"}],
|
||||
search_requests=1,
|
||||
input_tokens=100,
|
||||
output_tokens=20,
|
||||
)
|
||||
assert _extract_answer(resp) == "Sonar's synthesised, web-grounded answer text."
|
||||
cost = _estimate_cost_usd(resp, search_requests=1)
|
||||
# ~$0.010 search + trivial inference — total still ~1 cent.
|
||||
assert cost >= _COST_PER_SEARCH_USD
|
||||
assert cost < _COST_PER_SEARCH_USD + 0.001
|
||||
|
||||
def test_extract_answer_returns_empty_when_no_choices(self):
|
||||
resp = ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-test",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model="perplexity/sonar",
|
||||
choices=[],
|
||||
usage=_usage(),
|
||||
def test_three_searches_linear_in_count(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[], search_requests=3, input_tokens=0, output_tokens=0
|
||||
)
|
||||
assert _extract_answer(resp) == ""
|
||||
|
||||
def test_snippet_clamped_to_max_chars(self):
|
||||
long_body = "x" * 5000
|
||||
resp = _fake_response(
|
||||
citations=[{"title": "t", "url": "https://e", "content": long_body}]
|
||||
)
|
||||
out = _extract_results(resp, limit=1)
|
||||
assert len(out) == 1
|
||||
assert len(out[0].snippet) == 500
|
||||
|
||||
|
||||
class TestExtractCostUsd:
|
||||
"""Read real ``usage.cost`` via typed ``model_extra`` — no
|
||||
hard-coded rates, so a future provider price change is reflected
|
||||
automatically. Error handling mirrors the baseline service's
|
||||
``_extract_usage_cost``."""
|
||||
|
||||
def test_returns_cost_value(self):
|
||||
assert _extract_cost_usd(_usage(cost=0.023456)) == pytest.approx(0.023456)
|
||||
|
||||
def test_returns_none_when_usage_missing(self):
|
||||
assert _extract_cost_usd(None) is None
|
||||
|
||||
def test_returns_none_when_cost_field_missing(self):
|
||||
assert _extract_cost_usd(_usage(cost=None)) is None
|
||||
|
||||
def test_returns_none_when_cost_is_explicit_null(self):
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_returns_none_when_cost_is_negative(self):
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-1.0
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_accepts_numeric_string(self):
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017"
|
||||
)
|
||||
assert _extract_cost_usd(usage) == pytest.approx(0.017)
|
||||
cost = _estimate_cost_usd(resp, search_requests=3)
|
||||
assert cost == pytest.approx(3 * _COST_PER_SEARCH_USD)
|
||||
|
||||
|
||||
class TestWebSearchToolDispatch:
|
||||
"""Integration test: mock the OpenAI client, confirm both paths
|
||||
dispatch the right Sonar model + track cost."""
|
||||
"""Lightweight integration test: mock the Anthropic client, confirm
|
||||
the handler returns a ``WebSearchResponse`` and the usage tracker is
|
||||
called with ``provider='anthropic'`` (not 'open_router', even on the
|
||||
baseline path — server-side web_search bills Anthropic regardless of
|
||||
the calling LLM's route)."""
|
||||
|
||||
def _session(self) -> ChatSession:
|
||||
s = ChatSession.new("test-user", dry_run=False)
|
||||
s.session_id = "sess-1"
|
||||
return s
|
||||
|
||||
def _mock_client(self, fake_resp: ChatCompletion) -> Any:
|
||||
return type(
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_response_with_results_and_tracks_cost(self, monkeypatch):
|
||||
fake_resp = _fake_anthropic_response(
|
||||
results=[
|
||||
{
|
||||
"title": "hello",
|
||||
"url": "https://example.com",
|
||||
"snippet": "greeting",
|
||||
}
|
||||
],
|
||||
search_requests=1,
|
||||
)
|
||||
mock_client = type(
|
||||
"MC",
|
||||
(),
|
||||
{
|
||||
"chat": type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"completions": type(
|
||||
"CC",
|
||||
(),
|
||||
{"create": AsyncMock(return_value=fake_resp)},
|
||||
)()
|
||||
},
|
||||
"messages": type(
|
||||
"M", (), {"create": AsyncMock(return_value=fake_resp)}
|
||||
)()
|
||||
},
|
||||
)()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_path_uses_sonar_base(self, monkeypatch):
|
||||
fake_resp = _fake_response(
|
||||
citations=[
|
||||
{
|
||||
"title": "hello",
|
||||
"url": "https://example.com",
|
||||
"content": "greeting",
|
||||
}
|
||||
],
|
||||
answer="Kimi K2.6 launched 2026-04-20 [1].",
|
||||
cost=0.01,
|
||||
)
|
||||
mock_client = self._mock_client(fake_resp)
|
||||
|
||||
# Stub the Anthropic API key so ``is_available`` is True.
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)(),
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(
|
||||
secrets=SimpleNamespace(anthropic_api_key="sk-test")
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
@@ -297,88 +220,35 @@ class TestWebSearchToolDispatch:
|
||||
session=self._session(),
|
||||
query="kimi k2.6 launch",
|
||||
max_results=5,
|
||||
deep=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, WebSearchResponse)
|
||||
assert result.answer == "Kimi K2.6 launched 2026-04-20 [1]."
|
||||
assert result.query == "kimi k2.6 launch"
|
||||
assert len(result.results) == 1
|
||||
assert result.results[0].snippet == "greeting"
|
||||
|
||||
create_call = mock_client.chat.completions.create.call_args
|
||||
assert create_call.kwargs["model"] == "perplexity/sonar"
|
||||
# Sonar searches natively — no server-tool extras.
|
||||
assert create_call.kwargs["extra_body"] == {"usage": {"include": True}}
|
||||
assert isinstance(result.results[0], WebSearchResult)
|
||||
assert result.search_requests == 1
|
||||
|
||||
# Cost tracker must have been called with provider="anthropic".
|
||||
assert mock_track.await_count == 1
|
||||
kwargs = mock_track.await_args.kwargs
|
||||
assert kwargs["provider"] == "open_router"
|
||||
assert kwargs["model"] == "perplexity/sonar"
|
||||
assert kwargs["cost_usd"] == pytest.approx(0.01)
|
||||
assert kwargs["provider"] == "anthropic"
|
||||
assert kwargs["model"] == "claude-haiku-4-5"
|
||||
assert kwargs["user_id"] == "u1"
|
||||
assert kwargs["cost_usd"] >= _COST_PER_SEARCH_USD
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deep_path_uses_sonar_deep_research(self, monkeypatch):
|
||||
fake_resp = _fake_response(
|
||||
citations=[
|
||||
{
|
||||
"title": "deep find",
|
||||
"url": "https://example.com/deep",
|
||||
"content": "research body",
|
||||
}
|
||||
],
|
||||
cost=0.087,
|
||||
)
|
||||
mock_client = self._mock_client(fake_resp)
|
||||
|
||||
async def test_missing_api_key_returns_error_without_calling_anthropic(
|
||||
self, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)(),
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(secrets=SimpleNamespace(anthropic_api_key="")),
|
||||
)
|
||||
|
||||
anthropic_stub = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=160),
|
||||
) as mock_track,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
result = await tool._execute(
|
||||
user_id="u1",
|
||||
session=self._session(),
|
||||
query="research question",
|
||||
deep=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, WebSearchResponse)
|
||||
create_call = mock_client.chat.completions.create.call_args
|
||||
assert create_call.kwargs["model"] == "perplexity/sonar-deep-research"
|
||||
|
||||
kwargs = mock_track.await_args.kwargs
|
||||
assert kwargs["provider"] == "open_router"
|
||||
assert kwargs["model"] == "perplexity/sonar-deep-research"
|
||||
assert kwargs["cost_usd"] == pytest.approx(0.087)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_credentials_returns_error(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type("C", (), {"api_key": "", "base_url": ""})(),
|
||||
)
|
||||
openai_stub = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=openai_stub,
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=anthropic_stub,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.persist_and_record_usage",
|
||||
@@ -394,26 +264,21 @@ class TestWebSearchToolDispatch:
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "web_search_not_configured"
|
||||
openai_stub.chat.completions.create.assert_not_called()
|
||||
anthropic_stub.messages.create.assert_not_called()
|
||||
mock_track.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_query_rejected_without_api_call(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)(),
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(
|
||||
secrets=SimpleNamespace(anthropic_api_key="sk-test")
|
||||
),
|
||||
)
|
||||
openai_stub = AsyncMock()
|
||||
anthropic_stub = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=openai_stub,
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=anthropic_stub,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
result = await tool._execute(
|
||||
@@ -421,13 +286,13 @@ class TestWebSearchToolDispatch:
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_query"
|
||||
openai_stub.chat.completions.create.assert_not_called()
|
||||
anthropic_stub.messages.create.assert_not_called()
|
||||
|
||||
|
||||
class TestToolRegistryIntegration:
|
||||
"""The tool must be registered under the ``web_search`` name so the
|
||||
MCP layer exposes it as ``mcp__copilot__web_search`` — which is
|
||||
what the SDK path dispatches to (see
|
||||
what the SDK path now dispatches to (see
|
||||
``sdk/tool_adapter.py::SDK_DISALLOWED_TOOLS`` which blocks the CLI's
|
||||
native ``WebSearch`` in favour of the MCP route)."""
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.blocks.ai_shortform_video_block import (
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
from backend.blocks.claude_code import ClaudeCodeBlock
|
||||
from backend.blocks.codex import CodeGenerationBlock, CodexModel
|
||||
from backend.blocks.enrichlayer.linkedin import (
|
||||
GetLinkedinProfileBlock,
|
||||
@@ -23,7 +22,6 @@ from backend.blocks.enrichlayer.linkedin import (
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.fact_checker import FactCheckerBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
@@ -34,50 +32,29 @@ from backend.blocks.llm import (
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.mem0 import (
|
||||
AddMemoryBlock,
|
||||
GetAllMemoriesBlock,
|
||||
GetLatestMemoryBlock,
|
||||
SearchMemoryBlock,
|
||||
)
|
||||
from backend.blocks.nvidia.deepfake import NvidiaDeepfakeDetectBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.blocks.perplexity import PerplexityBlock, PerplexityModel
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.screenshotone import ScreenshotWebPageBlock
|
||||
from backend.blocks.smartlead.campaign import (
|
||||
AddLeadToCampaignBlock,
|
||||
CreateCampaignBlock,
|
||||
SaveCampaignSequencesBlock,
|
||||
)
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.blocks.zerobounce.validate_emails import ValidateEmailsBlock
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
e2b_credentials,
|
||||
elevenlabs_credentials,
|
||||
enrichlayer_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
jina_credentials,
|
||||
llama_api_credentials,
|
||||
mem0_credentials,
|
||||
nvidia_credentials,
|
||||
open_router_credentials,
|
||||
openai_credentials,
|
||||
replicate_credentials,
|
||||
revid_credentials,
|
||||
screenshotone_credentials,
|
||||
smartlead_credentials,
|
||||
unreal_credentials,
|
||||
v0_credentials,
|
||||
zerobounce_credentials,
|
||||
)
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
@@ -315,23 +292,6 @@ LLM_COST = (
|
||||
)
|
||||
|
||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||
#
|
||||
# BLOCK_COSTS drives the **credit wallet** — the user-facing balance that funds
|
||||
# block executions regardless of where they run (builder, graph execution,
|
||||
# copilot ``run_block`` tool). A missing entry here makes the block run for
|
||||
# free from the wallet's perspective, even when the upstream provider charges
|
||||
# real USD. See ``backend.executor.utils::block_usage_cost`` for the lookup
|
||||
# and ``backend.copilot.tools.helpers::execute_block`` for the copilot-side
|
||||
# charge path.
|
||||
#
|
||||
# Credits are **not** the same as copilot microdollar rate-limit counters
|
||||
# (``backend.copilot.rate_limit``). Microdollars track AutoGPT's infra cost
|
||||
# (OpenRouter / Anthropic inference spend) and gate the chat loop; credits
|
||||
# track the user's prepaid balance. A block running inside copilot ``run_block``
|
||||
# decrements only the credit wallet via this table — microdollars stay scoped
|
||||
# to copilot LLM turns and are not double-charged from block execution.
|
||||
# See the module docstring on ``backend.copilot.rate_limit`` for the full
|
||||
# boundary.
|
||||
|
||||
BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
AIConversationBlock: LLM_COST,
|
||||
@@ -754,62 +714,6 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
),
|
||||
],
|
||||
PerplexityBlock: [
|
||||
# Sonar Deep Research: up to $5/1K searches + $8/1M reasoning tokens.
|
||||
# Flat-charge 10 credits mirrors the LLM table's SONAR_DEEP_RESEARCH
|
||||
# entry. Block execution decrements only the user credit wallet via
|
||||
# spend_credits(); the microdollar rate-limit counter is not touched
|
||||
# for run_block invocations. The actual per-run provider spend is
|
||||
# recorded separately as provider_cost on PlatformCostLog when
|
||||
# OpenRouter reports usage.
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"model": PerplexityModel.SONAR_DEEP_RESEARCH,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
# Sonar Pro: $1/1M input + $1/1M output + $0.005/search.
|
||||
BlockCost(
|
||||
cost_amount=5,
|
||||
cost_filter={
|
||||
"model": PerplexityModel.SONAR_PRO,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
# Sonar (default): $0.2/1M input + $0.2/1M output + $0.005/search.
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"model": PerplexityModel.SONAR,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
FactCheckerBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": jina_credentials.id,
|
||||
"provider": jina_credentials.provider,
|
||||
"type": jina_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
OrchestratorBlock: LLM_COST,
|
||||
VideoNarrationBlock: [
|
||||
BlockCost(
|
||||
@@ -823,151 +727,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
# Mem0: Starter $19/mo for 50K adds + 5K retrievals → $0.0004/add,
|
||||
# $0.004/retrieval. Floor at 1 credit covers raw cost with margin.
|
||||
AddMemoryBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": mem0_credentials.id,
|
||||
"provider": mem0_credentials.provider,
|
||||
"type": mem0_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
SearchMemoryBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": mem0_credentials.id,
|
||||
"provider": mem0_credentials.provider,
|
||||
"type": mem0_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
GetAllMemoriesBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": mem0_credentials.id,
|
||||
"provider": mem0_credentials.provider,
|
||||
"type": mem0_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
GetLatestMemoryBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": mem0_credentials.id,
|
||||
"provider": mem0_credentials.provider,
|
||||
"type": mem0_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
# ScreenshotOne: $17 / 2K screenshots = $0.0085/call (Basic tier).
|
||||
ScreenshotWebPageBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": screenshotone_credentials.id,
|
||||
"provider": screenshotone_credentials.provider,
|
||||
"type": screenshotone_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
# NVIDIA NIM hosted endpoints: no public per-call SKU; estimate based on
|
||||
# peer deepfake APIs (Hive/Sightengine ~$0.005-0.01/call).
|
||||
NvidiaDeepfakeDetectBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": nvidia_credentials.id,
|
||||
"provider": nvidia_credentials.provider,
|
||||
"type": nvidia_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
# Smartlead: $39/mo Basic = $0.0065 per email-equivalent. Campaign
|
||||
# creation touches multiple records → 2 credits; per-lead and config
|
||||
# writes are lighter → 1 credit.
|
||||
CreateCampaignBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": smartlead_credentials.id,
|
||||
"provider": smartlead_credentials.provider,
|
||||
"type": smartlead_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AddLeadToCampaignBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": smartlead_credentials.id,
|
||||
"provider": smartlead_credentials.provider,
|
||||
"type": smartlead_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
SaveCampaignSequencesBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": smartlead_credentials.id,
|
||||
"provider": smartlead_credentials.provider,
|
||||
"type": smartlead_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
# ZeroBounce: $16 / 2K validations = $0.008 per email. One email per call.
|
||||
ValidateEmailsBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": zerobounce_credentials.id,
|
||||
"provider": zerobounce_credentials.provider,
|
||||
"type": zerobounce_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
# ClaudeCodeBlock runs an E2B sandbox (~$0.00003/sec compute) AND
|
||||
# executes Claude Sonnet inside it. Real session cost is dominated by
|
||||
# the LLM and varies $0.50–$2 per typical run. Flat 100 credits ($1.00)
|
||||
# is a conservative-but-fair estimate; revisit once we expose the
|
||||
# x-total-cost header from the in-sandbox Claude calls back to
|
||||
# NodeExecutionStats.provider_cost.
|
||||
ClaudeCodeBlock: [
|
||||
BlockCost(
|
||||
cost_amount=100,
|
||||
cost_filter={
|
||||
"e2b_credentials": {
|
||||
"id": e2b_credentials.id,
|
||||
"provider": e2b_credentials.provider,
|
||||
"type": e2b_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -366,7 +366,7 @@ async def execute_node(
|
||||
|
||||
try:
|
||||
if execution_context.dry_run and _dry_run_input is None:
|
||||
block_iter = simulate_block(node_block, input_data, user_id=user_id)
|
||||
block_iter = simulate_block(node_block, input_data)
|
||||
else:
|
||||
block_iter = node_block.execute(input_data, **extra_exec_kwargs)
|
||||
|
||||
|
||||
@@ -31,31 +31,21 @@ Inspired by https://github.com/Significant-Gravitas/agent-simulator
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.util.clients import get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Default simulator model — Gemini 2.5 Flash-Lite via OpenRouter. Same provider
|
||||
# as Flash ($0.10 / $0.40 per MTok vs $0.30 / $1.20 — ~3× cheaper) with JSON-mode
|
||||
# reliability that's more than enough for dry-run shape-matching. Configurable
|
||||
# via ChatConfig.simulation_model (CHAT_SIMULATION_MODEL env var).
|
||||
_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash-lite"
|
||||
|
||||
# OpenRouter-specific extra_body flag that embeds the real generation cost on
|
||||
# the response usage object. Same shape used by the baseline copilot service
|
||||
# and web_search tool — keep the three aligned.
|
||||
_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}}
|
||||
# Default simulator model — Gemini 2.5 Flash via OpenRouter (fast, cheap, good at
|
||||
# JSON generation). Configurable via ChatConfig.simulation_model
|
||||
# (CHAT_SIMULATION_MODEL env var).
|
||||
_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash"
|
||||
|
||||
|
||||
def _simulator_model() -> str:
|
||||
@@ -115,15 +105,10 @@ async def _call_llm_for_simulation(
|
||||
user_prompt: str,
|
||||
*,
|
||||
label: str = "simulate",
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a simulation prompt to the LLM and return the parsed JSON dict.
|
||||
|
||||
Handles client acquisition, retries on invalid JSON, logging, and platform
|
||||
cost tracking. The dry-run simulator calls OpenRouter on the platform's
|
||||
key rather than a user's own API credentials, so every successful call is
|
||||
recorded against the triggering ``user_id``'s rate-limit counter via
|
||||
``persist_and_record_usage`` (same rails as every copilot turn).
|
||||
Handles client acquisition, retries on invalid JSON, and logging.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no LLM client is available.
|
||||
@@ -148,7 +133,6 @@ async def _call_llm_for_simulation(
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
extra_body=_OPENROUTER_INCLUDE_USAGE_COST,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("LLM returned empty choices array")
|
||||
@@ -157,21 +141,13 @@ async def _call_llm_for_simulation(
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
|
||||
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
logger.debug(
|
||||
"simulate(%s): attempt=%d tokens=%d/%d",
|
||||
label,
|
||||
attempt + 1,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"simulate(%s): attempt=%d usage unavailable", label, attempt + 1
|
||||
)
|
||||
|
||||
await _track_simulator_cost(usage=usage, user_id=user_id, model=model)
|
||||
logger.debug(
|
||||
"simulate(%s): attempt=%d tokens=%s/%s",
|
||||
label,
|
||||
attempt + 1,
|
||||
getattr(getattr(response, "usage", None), "prompt_tokens", "?"),
|
||||
getattr(getattr(response, "usage", None), "completion_tokens", "?"),
|
||||
)
|
||||
return parsed
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
@@ -198,69 +174,6 @@ async def _call_llm_for_simulation(
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _extract_cost_usd(usage: CompletionUsage | None) -> float | None:
|
||||
"""Return the provider-reported USD cost on the response usage object.
|
||||
|
||||
OpenRouter attaches a ``cost`` field to the OpenAI-compatible usage object
|
||||
when the request body includes ``usage: {"include": True}``. The typed
|
||||
``CompletionUsage`` does not declare it, so we read it off ``model_extra``
|
||||
(pydantic v2's container for extras) to keep access fully typed — no
|
||||
``getattr``. Mirrors ``backend.copilot.tools.web_search._extract_cost_usd``
|
||||
and ``backend.copilot.baseline.service._extract_usage_cost``; keep the
|
||||
three in sync.
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
extras = usage.model_extra or {}
|
||||
if "cost" not in extras:
|
||||
return None
|
||||
raw = extras["cost"]
|
||||
if raw is None:
|
||||
logger.error("[simulator] usage.cost is present but null")
|
||||
return None
|
||||
try:
|
||||
val = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.error("[simulator] usage.cost is not numeric: %r", raw)
|
||||
return None
|
||||
if not math.isfinite(val) or val < 0:
|
||||
logger.error("[simulator] usage.cost is non-finite or negative: %r", val)
|
||||
return None
|
||||
return val
|
||||
|
||||
|
||||
async def _track_simulator_cost(
|
||||
*,
|
||||
usage: CompletionUsage | None,
|
||||
user_id: str | None,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""Record platform cost for a single simulator LLM call.
|
||||
|
||||
The simulator runs outside a copilot ``ChatSession`` — pass ``session=None``
|
||||
so ``persist_and_record_usage`` skips the session append but still charges
|
||||
the user's rate-limit counter and writes a ``PlatformCostLog`` entry. No
|
||||
user_id means no tracking (e.g. in-process tests that don't plumb one
|
||||
through); rate-limit accounting silently no-ops in that case.
|
||||
"""
|
||||
if usage is None:
|
||||
return
|
||||
cost_usd = _extract_cost_usd(usage)
|
||||
try:
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=user_id,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
log_prefix="[simulator]",
|
||||
cost_usd=cost_usd,
|
||||
model=model,
|
||||
provider="open_router",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[simulator] usage tracking failed: %s", exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builders
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -480,18 +393,12 @@ def _default_for_input_result(result_schema: dict[str, Any], name: str | None) -
|
||||
async def simulate_block(
|
||||
block: Any,
|
||||
input_data: dict[str, Any],
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Simulate block execution using an LLM.
|
||||
|
||||
All block types (including MCPToolBlock) use the same generic LLM prompt
|
||||
which includes the block's run() source code for accurate simulation.
|
||||
|
||||
``user_id`` is threaded through to platform cost tracking — every dry-run
|
||||
LLM call hits the platform's OpenRouter key and is charged against the
|
||||
triggering user's rate-limit counter, same rails as copilot turns.
|
||||
|
||||
Note: callers should check ``prepare_dry_run(block, input_data)`` first.
|
||||
OrchestratorBlock and AgentExecutorBlock execute for real in dry-run mode
|
||||
(see manager.py).
|
||||
@@ -555,9 +462,7 @@ async def simulate_block(
|
||||
label = getattr(block, "name", "?")
|
||||
|
||||
try:
|
||||
parsed = await _call_llm_for_simulation(
|
||||
system_prompt, user_prompt, label=label, user_id=user_id
|
||||
)
|
||||
parsed = await _call_llm_for_simulation(system_prompt, user_prompt, label=label)
|
||||
|
||||
# Track which pins were yielded so we can fill in missing required
|
||||
# ones afterwards — downstream nodes connected to unyielded pins
|
||||
|
||||
@@ -5,7 +5,6 @@ Covers:
|
||||
- Input/output block passthrough
|
||||
- prepare_dry_run routing
|
||||
- simulate_block output-pin filling
|
||||
- Default simulator model + OpenRouter cost tracking
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,14 +13,8 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from backend.executor.simulator import (
|
||||
_DEFAULT_SIMULATOR_MODEL,
|
||||
_extract_cost_usd,
|
||||
_truncate_input_values,
|
||||
_truncate_value,
|
||||
build_simulation_prompt,
|
||||
@@ -518,217 +511,3 @@ class TestSimulateBlockPassthrough:
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "error"
|
||||
assert "No client" in outputs[0][1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default model + OpenRouter cost tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sim_usage(
|
||||
*,
|
||||
prompt_tokens: int = 1200,
|
||||
completion_tokens: int = 300,
|
||||
cost: object = 0.000157,
|
||||
) -> CompletionUsage:
|
||||
"""Typed ``CompletionUsage`` carrying OpenRouter's ``cost`` extension
|
||||
via ``model_extra`` — same pattern as
|
||||
``copilot/tools/web_search_test.py::_usage``. ``model_construct``
|
||||
preserves unknown fields; ``model_validate`` would drop them."""
|
||||
payload: dict[str, Any] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
return CompletionUsage.model_construct(None, **payload)
|
||||
|
||||
|
||||
def _sim_completion(*, content: str, usage: CompletionUsage) -> ChatCompletion:
|
||||
"""Typed ``ChatCompletion`` shaped like an OpenRouter simulator
|
||||
response so the production code runs under real SDK types."""
|
||||
message = ChatCompletionMessage.model_construct(
|
||||
None, role="assistant", content=content
|
||||
)
|
||||
choice = Choice.model_construct(
|
||||
None, index=0, finish_reason="stop", message=message
|
||||
)
|
||||
return ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-sim",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model=_DEFAULT_SIMULATOR_MODEL,
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultSimulatorModel:
|
||||
"""Pin the default model — anyone flipping this without a cost review
|
||||
trips the test."""
|
||||
|
||||
def test_default_is_flash_lite(self) -> None:
|
||||
assert _DEFAULT_SIMULATOR_MODEL == "google/gemini-2.5-flash-lite"
|
||||
|
||||
|
||||
class TestExtractCostUsd:
|
||||
"""Provider-reported USD cost via typed ``model_extra`` — mirrors
|
||||
``copilot.tools.web_search._extract_cost_usd`` and
|
||||
``copilot.baseline.service._extract_usage_cost``."""
|
||||
|
||||
def test_returns_cost_value(self) -> None:
|
||||
assert _extract_cost_usd(_sim_usage(cost=0.000157)) == pytest.approx(0.000157)
|
||||
|
||||
def test_returns_none_when_usage_missing(self) -> None:
|
||||
assert _extract_cost_usd(None) is None
|
||||
|
||||
def test_returns_none_when_cost_field_missing(self) -> None:
|
||||
assert _extract_cost_usd(_sim_usage(cost=None)) is None
|
||||
|
||||
def test_returns_none_when_cost_is_explicit_null(self) -> None:
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_returns_none_when_cost_is_negative(self) -> None:
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-0.5
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_accepts_numeric_string(self) -> None:
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017"
|
||||
)
|
||||
assert _extract_cost_usd(usage) == pytest.approx(0.017)
|
||||
|
||||
|
||||
class TestSimulatorCostTracking:
|
||||
"""Integration: mock the OpenAI client, confirm the simulator sends
|
||||
the flash-lite default + extra_body, then plumbs through to
|
||||
``persist_and_record_usage`` with ``provider='open_router'`` and the
|
||||
real ``usage.cost`` pulled off ``model_extra``."""
|
||||
|
||||
def _mock_client(self, fake_resp: ChatCompletion) -> tuple[Any, AsyncMock]:
|
||||
"""Build a fake ``AsyncOpenAI`` client. Same nested-type pattern as
|
||||
``copilot/tools/web_search_test.py::_mock_client`` — avoids
|
||||
MagicMock's auto-child-attr behaviour so the exact ``create`` call
|
||||
surface is what gets invoked."""
|
||||
create_mock = AsyncMock(return_value=fake_resp)
|
||||
client = type(
|
||||
"MC",
|
||||
(),
|
||||
{
|
||||
"chat": type(
|
||||
"C",
|
||||
(),
|
||||
{"completions": type("CC", (), {"create": create_mock})()},
|
||||
)()
|
||||
},
|
||||
)()
|
||||
return client, create_mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_default_model_and_tracks_cost(self) -> None:
|
||||
block = _make_block()
|
||||
fake_resp = _sim_completion(
|
||||
content='{"result": "simulated"}',
|
||||
usage=_sim_usage(prompt_tokens=1100, completion_tokens=220, cost=0.000189),
|
||||
)
|
||||
client, create_mock = self._mock_client(fake_resp)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.simulator.get_openai_client",
|
||||
return_value=client,
|
||||
),
|
||||
patch(
|
||||
"backend.executor.simulator.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=1320),
|
||||
) as mock_track,
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block, {"query": "hello"}, user_id="user-42"
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated") in outputs
|
||||
|
||||
create_kwargs = create_mock.await_args.kwargs
|
||||
assert create_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL
|
||||
assert create_kwargs["extra_body"] == {"usage": {"include": True}}
|
||||
|
||||
track_kwargs = mock_track.await_args.kwargs
|
||||
assert track_kwargs["provider"] == "open_router"
|
||||
assert track_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL
|
||||
assert track_kwargs["user_id"] == "user-42"
|
||||
assert track_kwargs["prompt_tokens"] == 1100
|
||||
assert track_kwargs["completion_tokens"] == 220
|
||||
assert track_kwargs["cost_usd"] == pytest.approx(0.000189)
|
||||
assert track_kwargs["session"] is None
|
||||
assert track_kwargs["log_prefix"] == "[simulator]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_even_when_cost_absent(self) -> None:
|
||||
"""Provider may omit ``cost`` (e.g. non-OpenRouter proxies). We
|
||||
still record token counts — ``persist_and_record_usage`` logs the
|
||||
turn and skips the rate-limit ledger when cost is ``None``."""
|
||||
block = _make_block()
|
||||
fake_resp = _sim_completion(
|
||||
content='{"result": "ok"}',
|
||||
usage=_sim_usage(prompt_tokens=100, completion_tokens=20, cost=None),
|
||||
)
|
||||
client, _ = self._mock_client(fake_resp)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.simulator.get_openai_client",
|
||||
return_value=client,
|
||||
),
|
||||
patch(
|
||||
"backend.executor.simulator.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=120),
|
||||
) as mock_track,
|
||||
):
|
||||
async for _name, _data in simulate_block(
|
||||
block, {"query": "x"}, user_id="user-7"
|
||||
):
|
||||
pass
|
||||
|
||||
track_kwargs = mock_track.await_args.kwargs
|
||||
assert track_kwargs["cost_usd"] is None
|
||||
assert track_kwargs["user_id"] == "user-7"
|
||||
assert track_kwargs["provider"] == "open_router"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracking_failure_does_not_break_simulation(self) -> None:
|
||||
"""Cost-tracking failures are warnings, not simulation failures —
|
||||
the block output must still flow to the caller."""
|
||||
block = _make_block()
|
||||
fake_resp = _sim_completion(
|
||||
content='{"result": "simulated"}',
|
||||
usage=_sim_usage(),
|
||||
)
|
||||
client, _ = self._mock_client(fake_resp)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.simulator.get_openai_client",
|
||||
return_value=client,
|
||||
),
|
||||
patch(
|
||||
"backend.executor.simulator.persist_and_record_usage",
|
||||
new=AsyncMock(side_effect=RuntimeError("redis down")),
|
||||
),
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block, {"query": "hello"}, user_id="user-42"
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated") in outputs
|
||||
|
||||
@@ -313,19 +313,11 @@ function getWebAccordionData(
|
||||
: null;
|
||||
|
||||
if (results) {
|
||||
const deep = inp.deep === true;
|
||||
const noun = deep ? "research source" : "search result";
|
||||
const answer = getStringField(output, "answer");
|
||||
return {
|
||||
title: `${results.length} ${noun}${results.length === 1 ? "" : "s"}`,
|
||||
title: `${results.length} search result${results.length === 1 ? "" : "s"}`,
|
||||
description: query ? truncate(query, 80) : undefined,
|
||||
content: (
|
||||
<div className="space-y-3">
|
||||
{answer && (
|
||||
<div className="whitespace-pre-wrap rounded-md bg-slate-50 p-3 text-sm text-slate-800">
|
||||
{answer}
|
||||
</div>
|
||||
)}
|
||||
{results.map((r, i) => {
|
||||
const title = getStringField(r, "title") ?? "(untitled)";
|
||||
const href = getStringField(r, "url") ?? "";
|
||||
|
||||
@@ -141,7 +141,6 @@ describe("GenericTool", () => {
|
||||
function makeWebSearchPart(
|
||||
results: Array<Record<string, unknown>>,
|
||||
query = "kimi k2.6",
|
||||
answer = "",
|
||||
): ToolUIPart {
|
||||
return {
|
||||
type: "tool-web_search",
|
||||
@@ -150,7 +149,6 @@ describe("GenericTool", () => {
|
||||
input: { query },
|
||||
output: {
|
||||
type: "web_search_response",
|
||||
answer,
|
||||
results,
|
||||
query,
|
||||
search_requests: 1,
|
||||
@@ -256,25 +254,6 @@ describe("GenericTool", () => {
|
||||
expect(normalized).toContain('Searched "kimi k2.6"');
|
||||
});
|
||||
|
||||
it("renders the synthesised answer above the citations when present", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart(
|
||||
[
|
||||
{ title: "Citation 1", url: "https://example.com/one" },
|
||||
{ title: "Citation 2", url: "https://example.com/two" },
|
||||
],
|
||||
"kimi k2.6 launch",
|
||||
"Kimi K2.6 launched on 2026-04-20 with SWE-Bench parity to Opus.",
|
||||
)}
|
||||
/>,
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { expanded: false }));
|
||||
expect(
|
||||
screen.getByText(/Kimi K2\.6 launched on 2026-04-20/),
|
||||
).not.toBeNull();
|
||||
});
|
||||
|
||||
it("uses '(untitled)' when a search result has no title", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
|
||||
@@ -205,14 +205,6 @@ export function humanizeFileName(filePath: string): string {
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
// web_search accepts a ``deep`` arg that dispatches to a multi-step
|
||||
// research model; render a distinct verb ("Researching"/"Researched"/
|
||||
// "Research failed") so users know the call takes longer.
|
||||
function _isDeepWebSearch(part: ToolUIPart): boolean {
|
||||
const input = part.input as Record<string, unknown> | undefined;
|
||||
return input?.deep === true;
|
||||
}
|
||||
|
||||
export function getAnimationText(
|
||||
part: ToolUIPart,
|
||||
category: ToolCategory,
|
||||
@@ -231,11 +223,9 @@ export function getAnimationText(
|
||||
: "Running command\u2026";
|
||||
case "web":
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
const deep = _isDeepWebSearch(part);
|
||||
const verb = deep ? "Researching" : "Searching";
|
||||
return shortSummary
|
||||
? `${verb} "${shortSummary}"`
|
||||
: `${verb} the web\u2026`;
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web\u2026";
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
@@ -295,12 +285,9 @@ export function getAnimationText(
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
case "web":
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
const deep = _isDeepWebSearch(part);
|
||||
const verb = deep ? "Researched" : "Searched";
|
||||
const completed = deep
|
||||
? "Web research completed"
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
: "Web search completed";
|
||||
return shortSummary ? `${verb} "${shortSummary}"` : completed;
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
@@ -367,10 +354,9 @@ export function getAnimationText(
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
return _isDeepWebSearch(part) ? "Research failed" : "Search failed";
|
||||
}
|
||||
return "Fetch failed";
|
||||
return toolName === "WebSearch" || toolName === "web_search"
|
||||
? "Search failed"
|
||||
: "Fetch failed";
|
||||
case "browser":
|
||||
return "Browser action failed";
|
||||
default:
|
||||
|
||||
@@ -2116,9 +2116,7 @@
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/MemoryForgetConfirmResponse"
|
||||
},
|
||||
{ "$ref": "#/components/schemas/TodoWriteResponse" },
|
||||
{ "$ref": "#/components/schemas/TaskResponse" }
|
||||
}
|
||||
],
|
||||
"title": "Response Getv2[Dummy] Tool Response Type Export For Codegen"
|
||||
}
|
||||
@@ -14572,9 +14570,7 @@
|
||||
"memory_store",
|
||||
"memory_search",
|
||||
"memory_forget_candidates",
|
||||
"memory_forget_confirm",
|
||||
"todo_write",
|
||||
"task"
|
||||
"memory_forget_confirm"
|
||||
],
|
||||
"title": "ResponseType",
|
||||
"description": "Types of tool responses."
|
||||
@@ -16070,55 +16066,6 @@
|
||||
"required": ["recent_searches", "providers", "top_blocks"],
|
||||
"title": "SuggestionsResponse"
|
||||
},
|
||||
"TaskResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "task"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"title": "Description",
|
||||
"default": ""
|
||||
},
|
||||
"response": {
|
||||
"type": "string",
|
||||
"title": "Response",
|
||||
"description": "Final assistant text the sub-agent produced.",
|
||||
"default": ""
|
||||
},
|
||||
"iterations": {
|
||||
"type": "integer",
|
||||
"title": "Iterations",
|
||||
"default": 0
|
||||
},
|
||||
"tool_calls": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Tool Calls",
|
||||
"description": "Names of tools the sub-agent invoked (for observability)."
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["completed", "max_iterations", "error"],
|
||||
"title": "Status",
|
||||
"default": "completed"
|
||||
},
|
||||
"error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Error"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message"],
|
||||
"title": "TaskResponse",
|
||||
"description": "Result of a delegated ``Task`` in-process sub-agent run.\n\nThe sub-agent runs a fresh tool-call loop with an isolated message\nhistory, then returns only its final assistant text. Intermediate tool\ncalls and thinking stay inside the sub-agent's loop so the parent\ncontext is not polluted."
|
||||
},
|
||||
"TimezoneResponse": {
|
||||
"properties": {
|
||||
"timezone": {
|
||||
@@ -16735,52 +16682,6 @@
|
||||
"required": ["timezone"],
|
||||
"title": "TimezoneResponse"
|
||||
},
|
||||
"TodoItem": {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"title": "Content",
|
||||
"description": "Imperative description of the task."
|
||||
},
|
||||
"activeForm": {
|
||||
"type": "string",
|
||||
"title": "Activeform",
|
||||
"description": "Present-continuous form shown while the task is running."
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"],
|
||||
"title": "Status",
|
||||
"default": "pending"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["content", "activeForm"],
|
||||
"title": "TodoItem",
|
||||
"description": "One entry in a ``TodoWrite`` checklist.\n\nMirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so\nthe frontend's ``GenericTool`` accordion renders baseline-emitted todos\nidentically to SDK-emitted ones."
|
||||
},
|
||||
"TodoWriteResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "todo_write"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"todos": {
|
||||
"items": { "$ref": "#/components/schemas/TodoItem" },
|
||||
"type": "array",
|
||||
"title": "Todos"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message"],
|
||||
"title": "TodoWriteResponse",
|
||||
"description": "Ack returned by ``TodoWrite``.\n\nThe tool is effectively stateless — the authoritative task list lives in\nthe assistant's latest tool-call arguments, which are replayed from the\ntranscript on each turn. The tool output only needs to confirm that the\nupdate was accepted so the model can proceed."
|
||||
},
|
||||
"TokenIntrospectionResult": {
|
||||
"properties": {
|
||||
"active": { "type": "boolean", "title": "Active" },
|
||||
|
||||
@@ -58,7 +58,7 @@ Tool and block identifiers provided in `tools` and `blocks` are validated at run
|
||||
| system_context | Optional additional context prepended to the prompt. Use this to constrain autopilot behavior, provide domain context, or set output format requirements. | str | No |
|
||||
| session_id | Session ID to continue an existing autopilot conversation. Leave empty to start a new session. Use the session_id output from a previous run to continue. | str | No |
|
||||
| max_recursion_depth | Maximum nesting depth when the autopilot calls this block recursively (sub-agent pattern). Prevents infinite loops. | int | No |
|
||||
| tools | Tool names to filter. Works with tools_exclude to form an allow-list or deny-list. Leave empty to apply no tool filter. | List["add_understanding" \| "ask_question" \| "bash_exec" \| "browser_act" \| "browser_navigate" \| "browser_screenshot" \| "connect_integration" \| "continue_run_block" \| "create_agent" \| "create_feature_request" \| "create_folder" \| "customize_agent" \| "delete_folder" \| "delete_workspace_file" \| "edit_agent" \| "find_agent" \| "find_block" \| "find_library_agent" \| "fix_agent_graph" \| "get_agent_building_guide" \| "get_doc_page" \| "get_mcp_guide" \| "get_sub_session_result" \| "list_folders" \| "list_workspace_files" \| "memory_forget_confirm" \| "memory_forget_search" \| "memory_search" \| "memory_store" \| "move_agents_to_folder" \| "move_folder" \| "read_workspace_file" \| "run_agent" \| "run_block" \| "run_mcp_tool" \| "run_sub_session" \| "search_docs" \| "search_feature_requests" \| "update_folder" \| "validate_agent_graph" \| "view_agent_output" \| "web_fetch" \| "web_search" \| "write_workspace_file" \| "Agent" \| "Edit" \| "Glob" \| "Grep" \| "Read" \| "Task" \| "TodoWrite" \| "WebSearch" \| "Write"] | No |
|
||||
| tools | Tool names to filter. Works with tools_exclude to form an allow-list or deny-list. Leave empty to apply no tool filter. | List["add_understanding" \| "ask_question" \| "bash_exec" \| "browser_act" \| "browser_navigate" \| "browser_screenshot" \| "connect_integration" \| "continue_run_block" \| "create_agent" \| "create_feature_request" \| "create_folder" \| "customize_agent" \| "delete_folder" \| "delete_workspace_file" \| "edit_agent" \| "find_agent" \| "find_block" \| "find_library_agent" \| "fix_agent_graph" \| "get_agent_building_guide" \| "get_doc_page" \| "get_mcp_guide" \| "get_sub_session_result" \| "list_folders" \| "list_workspace_files" \| "memory_forget_confirm" \| "memory_forget_search" \| "memory_search" \| "memory_store" \| "move_agents_to_folder" \| "move_folder" \| "read_workspace_file" \| "run_agent" \| "run_block" \| "run_mcp_tool" \| "run_sub_session" \| "search_docs" \| "search_feature_requests" \| "update_folder" \| "validate_agent_graph" \| "view_agent_output" \| "web_fetch" \| "write_workspace_file" \| "Agent" \| "Edit" \| "Glob" \| "Grep" \| "Read" \| "Task" \| "TodoWrite" \| "WebSearch" \| "Write"] | No |
|
||||
| tools_exclude | Controls how the 'tools' list is interpreted. True (default): 'tools' is a deny-list — listed tools are blocked, all others are allowed. An empty 'tools' list means allow everything. False: 'tools' is an allow-list — only listed tools are permitted. | bool | No |
|
||||
| blocks | Block identifiers to filter when the copilot uses run_block. Each entry can be: a block name (e.g. 'HTTP Request'), a full block UUID, or the first 8 hex characters of the UUID (e.g. 'c069dc6b'). Works with blocks_exclude. Leave empty to apply no block filter. | List[str] | No |
|
||||
| blocks_exclude | Controls how the 'blocks' list is interpreted. True (default): 'blocks' is a deny-list — listed blocks are blocked, all others are allowed. An empty 'blocks' list means allow everything. False: 'blocks' is an allow-list — only listed blocks are permitted. | bool | No |
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 110 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 116 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 100 KiB |
Reference in New Issue
Block a user