mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge remote-tracking branch 'origin/dev' into feat/task-decomposition-copilot
# Conflicts: # autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Extended-thinking wire support for the baseline (OpenRouter) path.
|
||||
|
||||
Anthropic routes on OpenRouter expose extended thinking through
|
||||
non-OpenAI extension fields that the OpenAI Python SDK doesn't model:
|
||||
OpenRouter routes that support extended thinking (Anthropic Claude and
|
||||
Moonshot Kimi today) expose reasoning through non-OpenAI extension fields
|
||||
that the OpenAI Python SDK doesn't model:
|
||||
|
||||
* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``.
|
||||
* ``reasoning_content`` — DeepSeek / some OpenRouter routes.
|
||||
@@ -17,12 +18,14 @@ This module keeps the wire-level concerns in one place:
|
||||
one streaming round and emits ``StreamReasoning*`` events so the caller
|
||||
only has to plumb the events into its pending queue.
|
||||
* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the
|
||||
OpenAI client call. Returns ``None`` on non-Anthropic routes.
|
||||
OpenAI client call. Returns ``None`` for routes without reasoning
|
||||
support (see :func:`_is_reasoning_route`).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@@ -42,6 +45,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
|
||||
|
||||
# Coalescing thresholds for ``StreamReasoningDelta`` emission. OpenRouter's
|
||||
# Kimi K2.6 endpoint tokenises reasoning at a much finer grain than Anthropic
|
||||
# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without
|
||||
# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React
|
||||
# re-render of the non-virtualised chat list, which paint-storms the browser
|
||||
# main thread and freezes the UI. Batching into ~32-char / ~40 ms windows
|
||||
# cuts the event rate ~100x while staying snappy enough that the Reasoning
|
||||
# collapse still feels live (well under the ~100 ms perceptual threshold).
|
||||
# Per-delta persistence to ``session.messages`` stays granular — we only
|
||||
# coalesce the *wire* emission.
|
||||
_COALESCE_MIN_CHARS = 32
|
||||
_COALESCE_MAX_INTERVAL_MS = 40.0
|
||||
|
||||
|
||||
class ReasoningDetail(BaseModel):
|
||||
"""One entry in OpenRouter's ``reasoning_details`` list.
|
||||
@@ -132,18 +148,72 @@ class OpenRouterDeltaExtension(BaseModel):
|
||||
return "".join(d.visible_text for d in self.reasoning_details)
|
||||
|
||||
|
||||
def _is_reasoning_route(model: str) -> bool:
|
||||
"""Return True when the route supports OpenRouter's ``reasoning`` extension.
|
||||
|
||||
OpenRouter exposes reasoning tokens via a unified ``reasoning`` request
|
||||
param that works on any provider that supports extended thinking —
|
||||
currently Anthropic (Claude Opus / Sonnet) and Moonshot (Kimi K2.6 +
|
||||
kimi-k2-thinking) advertise it in their ``supported_parameters``.
|
||||
Other providers silently drop the field, but we skip it anyway to keep
|
||||
the payload tight and avoid confusing cache diagnostics.
|
||||
|
||||
Kept separate from :func:`backend.copilot.baseline.service._is_anthropic_model`
|
||||
because ``cache_control`` is strictly Anthropic-specific (Moonshot does
|
||||
its own auto-caching), so the two gates must not conflate.
|
||||
|
||||
Both the Claude and Kimi matches are anchored to the provider
|
||||
prefix (or to a bare model id with no prefix at all) to avoid
|
||||
substring false positives — a custom ``some-other-provider/claude-mock``
|
||||
or ``provider/hakimi-large`` configured via
|
||||
``CHAT_FAST_STANDARD_MODEL`` must NOT inherit the reasoning
|
||||
extra_body and take a 400 from its upstream. Recognised shapes:
|
||||
|
||||
* Claude — ``anthropic/`` or ``anthropic.`` provider prefix, or a
|
||||
bare ``claude-`` model id with no provider prefix
|
||||
(``claude-opus-4.7``, ``anthropic/claude-sonnet-4-6``,
|
||||
``anthropic.claude-3-5-sonnet``). A non-Anthropic prefix like
|
||||
``someprovider/claude-mock`` is rejected on purpose.
|
||||
* Kimi — ``moonshotai/`` provider prefix, or a ``kimi-`` model id
|
||||
with no provider prefix (``kimi-k2.6``,
|
||||
``moonshotai/kimi-k2-thinking``). Like Claude, a non-Moonshot
|
||||
prefix is rejected — exception: ``openrouter/kimi-k2.6`` stays
|
||||
recognised because ``openrouter/`` is how we route to Moonshot
|
||||
today and changing that would be a behaviour regression for
|
||||
existing deployments.
|
||||
"""
|
||||
lowered = model.lower()
|
||||
if lowered.startswith(("anthropic/", "anthropic.")):
|
||||
return True
|
||||
if lowered.startswith("moonshotai/"):
|
||||
return True
|
||||
# ``openrouter/`` historically routes to whatever the default
|
||||
# upstream for the model is — for kimi that's Moonshot, so accept
|
||||
# ``openrouter/kimi-...`` here. Other ``openrouter/`` models
|
||||
# (e.g. ``openrouter/auto``) fall through to the no-prefix check
|
||||
# below and are rejected unless they start with ``claude-`` /
|
||||
# ``kimi-`` after the slash, which no real OpenRouter route does.
|
||||
if lowered.startswith("openrouter/kimi-"):
|
||||
return True
|
||||
if "/" in lowered:
|
||||
# Any other provider prefix is a custom / non-Anthropic /
|
||||
# non-Moonshot route and must not opt into reasoning. This
|
||||
# blocks substring false positives like
|
||||
# ``some-provider/claude-mock-v1`` or ``other/kimi-pro``.
|
||||
return False
|
||||
# No provider prefix — accept bare ``claude-*`` and ``kimi-*`` ids
|
||||
# so direct CLI configs (``claude-3-5-sonnet-20241022``,
|
||||
# ``kimi-k2-instruct``) keep working.
|
||||
return lowered.startswith("claude-") or lowered.startswith("kimi-")
|
||||
|
||||
|
||||
def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None:
|
||||
"""Build the ``extra_body["reasoning"]`` fragment for the OpenAI client.
|
||||
|
||||
Returns ``None`` for non-Anthropic routes (other OpenRouter providers
|
||||
ignore the field but we skip it anyway to keep the payload minimal)
|
||||
and for ``max_thinking_tokens <= 0`` (operator kill switch).
|
||||
Returns ``None`` for non-reasoning routes and for
|
||||
``max_thinking_tokens <= 0`` (operator kill switch).
|
||||
"""
|
||||
# Imported lazily to avoid pulling service.py at module load — service.py
|
||||
# imports this module, and the lazy import keeps the dependency one-way.
|
||||
from backend.copilot.baseline.service import _is_anthropic_model
|
||||
|
||||
if not _is_anthropic_model(model) or max_thinking_tokens <= 0:
|
||||
if not _is_reasoning_route(model) or max_thinking_tokens <= 0:
|
||||
return None
|
||||
return {"reasoning": {"max_tokens": max_thinking_tokens}}
|
||||
|
||||
@@ -177,11 +247,24 @@ class BaselineReasoningEmitter:
|
||||
def __init__(
|
||||
self,
|
||||
session_messages: list[ChatMessage] | None = None,
|
||||
*,
|
||||
coalesce_min_chars: int = _COALESCE_MIN_CHARS,
|
||||
coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS,
|
||||
) -> None:
|
||||
self._block_id: str = str(uuid.uuid4())
|
||||
self._open: bool = False
|
||||
self._session_messages = session_messages
|
||||
self._current_row: ChatMessage | None = None
|
||||
# Coalescing state — ``_pending_delta`` accumulates reasoning text
|
||||
# between wire flushes. Providers like Kimi K2.6 emit very fine-
|
||||
# grained chunks; batching them reduces Redis ``xadd`` + SSE + React
|
||||
# re-render load by ~100x for equivalent text output. Tuning knobs
|
||||
# are kwargs so tests can disable coalescing (``=0``) for
|
||||
# deterministic event assertions.
|
||||
self._coalesce_min_chars = coalesce_min_chars
|
||||
self._coalesce_max_interval_ms = coalesce_max_interval_ms
|
||||
self._pending_delta: str = ""
|
||||
self._last_flush_monotonic: float = 0.0
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
@@ -192,39 +275,86 @@ class BaselineReasoningEmitter:
|
||||
|
||||
Empty list when the chunk carries no reasoning payload, so this is
|
||||
safe to call on every chunk without guarding at the call site.
|
||||
Persistence (when a session message list is attached) happens in
|
||||
lockstep with emission so the row's content stays equal to the
|
||||
concatenated deltas at every delta boundary.
|
||||
|
||||
Persistence (when a session message list is attached) stays
|
||||
per-delta so the DB row's content always equals the concatenation
|
||||
of wire deltas at every chunk boundary, independent of the
|
||||
coalescing window. Only the wire emission is batched.
|
||||
"""
|
||||
ext = OpenRouterDeltaExtension.from_delta(delta)
|
||||
text = ext.visible_text()
|
||||
if not text:
|
||||
return []
|
||||
events: list[StreamBaseResponse] = []
|
||||
# First reasoning text in this block — emit Start + the first Delta
|
||||
# atomically so the frontend Reasoning collapse renders immediately
|
||||
# rather than waiting for the coalesce window to elapse. Subsequent
|
||||
# chunks buffer into ``_pending_delta`` and only flush when the
|
||||
# char/time thresholds trip.
|
||||
# Sample the monotonic clock exactly once per chunk — at ~4,700
|
||||
# chunks per turn, folding the two calls into one cuts ~4,700
|
||||
# syscalls off the hot path without changing semantics.
|
||||
now = time.monotonic()
|
||||
if not self._open:
|
||||
events.append(StreamReasoningStart(id=self._block_id))
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
self._open = True
|
||||
self._last_flush_monotonic = now
|
||||
if self._session_messages is not None:
|
||||
self._current_row = ChatMessage(role="reasoning", content="")
|
||||
self._current_row = ChatMessage(role="reasoning", content=text)
|
||||
self._session_messages.append(self._current_row)
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
return events
|
||||
|
||||
# Persist per-delta (no coalescing here — the session snapshot stays
|
||||
# consistent at every chunk boundary, independent of the wire
|
||||
# coalesce window).
|
||||
if self._current_row is not None:
|
||||
self._current_row.content = (self._current_row.content or "") + text
|
||||
|
||||
self._pending_delta += text
|
||||
if self._should_flush_pending(now):
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
self._pending_delta = ""
|
||||
self._last_flush_monotonic = now
|
||||
return events
|
||||
|
||||
def _should_flush_pending(self, now: float) -> bool:
|
||||
"""Return True when the accumulated delta should be emitted now.
|
||||
|
||||
*now* is the monotonic timestamp sampled by the caller so the
|
||||
clock is read at most once per chunk (the flush-timestamp update
|
||||
reuses the same value).
|
||||
"""
|
||||
if not self._pending_delta:
|
||||
return False
|
||||
if len(self._pending_delta) >= self._coalesce_min_chars:
|
||||
return True
|
||||
elapsed_ms = (now - self._last_flush_monotonic) * 1000.0
|
||||
return elapsed_ms >= self._coalesce_max_interval_ms
|
||||
|
||||
def close(self) -> list[StreamBaseResponse]:
|
||||
"""Emit ``StreamReasoningEnd`` for the open block (if any) and rotate.
|
||||
|
||||
Idempotent — returns ``[]`` when no block is open. The id rotation
|
||||
guarantees the next reasoning block starts with a fresh id rather
|
||||
than reusing one already closed on the wire. The persisted row is
|
||||
not removed — it stays in ``session_messages`` as the durable
|
||||
record of what was reasoned.
|
||||
Idempotent — returns ``[]`` when no block is open. Drains any
|
||||
still-buffered delta first so the frontend never loses tail text
|
||||
from the coalesce window. The id rotation guarantees the next
|
||||
reasoning block starts with a fresh id rather than reusing one
|
||||
already closed on the wire. The persisted row is not removed —
|
||||
it stays in ``session_messages`` as the durable record of what
|
||||
was reasoned.
|
||||
"""
|
||||
if not self._open:
|
||||
return []
|
||||
event = StreamReasoningEnd(id=self._block_id)
|
||||
events: list[StreamBaseResponse] = []
|
||||
if self._pending_delta:
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
self._pending_delta = ""
|
||||
events.append(StreamReasoningEnd(id=self._block_id))
|
||||
self._open = False
|
||||
self._block_id = str(uuid.uuid4())
|
||||
self._current_row = None
|
||||
return [event]
|
||||
return events
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.copilot.baseline.reasoning import (
|
||||
BaselineReasoningEmitter,
|
||||
OpenRouterDeltaExtension,
|
||||
ReasoningDetail,
|
||||
_is_reasoning_route,
|
||||
reasoning_extra_body,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
@@ -135,6 +136,59 @@ class TestOpenRouterDeltaExtension:
|
||||
assert ext.visible_text() == "real"
|
||||
|
||||
|
||||
class TestIsReasoningRoute:
|
||||
def test_anthropic_routes(self):
|
||||
assert _is_reasoning_route("anthropic/claude-sonnet-4-6")
|
||||
assert _is_reasoning_route("claude-3-5-sonnet-20241022")
|
||||
assert _is_reasoning_route("anthropic.claude-3-5-sonnet")
|
||||
assert _is_reasoning_route("ANTHROPIC/Claude-Opus") # case-insensitive
|
||||
|
||||
def test_moonshot_kimi_routes(self):
|
||||
# OpenRouter advertises the ``reasoning`` extension on Moonshot
|
||||
# endpoints — both K2.6 (the new baseline default) and the
|
||||
# reasoning-native kimi-k2-thinking variant.
|
||||
assert _is_reasoning_route("moonshotai/kimi-k2.6")
|
||||
assert _is_reasoning_route("moonshotai/kimi-k2-thinking")
|
||||
assert _is_reasoning_route("moonshotai/kimi-k2.5")
|
||||
# Direct (non-OpenRouter) model ids also resolve via the ``kimi-``
|
||||
# prefix so a future bare ``kimi-k3`` id would still match.
|
||||
assert _is_reasoning_route("kimi-k2-instruct")
|
||||
# Provider-prefixed bare kimi ids (without the ``moonshotai/``
|
||||
# prefix) are also recognised — the match anchors on the final
|
||||
# path segment.
|
||||
assert _is_reasoning_route("openrouter/kimi-k2.6")
|
||||
|
||||
def test_other_providers_rejected(self):
|
||||
assert not _is_reasoning_route("openai/gpt-4o")
|
||||
assert not _is_reasoning_route("google/gemini-2.5-pro")
|
||||
assert not _is_reasoning_route("xai/grok-4")
|
||||
assert not _is_reasoning_route("meta-llama/llama-3.3-70b-instruct")
|
||||
assert not _is_reasoning_route("deepseek/deepseek-r1")
|
||||
|
||||
def test_kimi_substring_false_positives_rejected(self):
|
||||
# Regression: the previous implementation matched any model whose
|
||||
# name contained the substring ``kimi`` — including unrelated model
|
||||
# ids like ``hakimi``. The anchored match below rejects them.
|
||||
assert not _is_reasoning_route("some-provider/hakimi-large")
|
||||
assert not _is_reasoning_route("hakimi")
|
||||
assert not _is_reasoning_route("akimi-7b")
|
||||
|
||||
def test_claude_substring_false_positives_rejected(self):
|
||||
# Regression (Sentry review on #12871): ``'claude' in lowered``
|
||||
# matched any substring — a custom
|
||||
# ``someprovider/claude-mock-v1`` set via
|
||||
# ``CHAT_FAST_STANDARD_MODEL`` would inherit the reasoning
|
||||
# extra_body and take a 400 from its upstream. The anchored
|
||||
# match requires either an ``anthropic`` / ``anthropic.`` /
|
||||
# ``anthropic/`` prefix, or a bare ``claude-`` id with no
|
||||
# provider prefix.
|
||||
assert not _is_reasoning_route("someprovider/claude-mock-v1")
|
||||
assert not _is_reasoning_route("custom/claude-like-model")
|
||||
# Same principle for Kimi — a non-Moonshot provider prefix is
|
||||
# rejected even when the model id starts with ``kimi-``.
|
||||
assert not _is_reasoning_route("other/kimi-pro")
|
||||
|
||||
|
||||
class TestReasoningExtraBody:
|
||||
def test_anthropic_route_returns_fragment(self):
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == {
|
||||
@@ -146,16 +200,30 @@ class TestReasoningExtraBody:
|
||||
"reasoning": {"max_tokens": 2048}
|
||||
}
|
||||
|
||||
def test_non_anthropic_route_returns_none(self):
|
||||
def test_kimi_routes_return_fragment(self):
|
||||
# Kimi K2.6 ships the same OpenRouter ``reasoning`` extension as
|
||||
# Anthropic, so the gate widened with this PR and the fragment
|
||||
# must now materialise on Moonshot routes too.
|
||||
assert reasoning_extra_body("moonshotai/kimi-k2.6", 8192) == {
|
||||
"reasoning": {"max_tokens": 8192}
|
||||
}
|
||||
assert reasoning_extra_body("moonshotai/kimi-k2-thinking", 4096) == {
|
||||
"reasoning": {"max_tokens": 4096}
|
||||
}
|
||||
|
||||
def test_non_reasoning_route_returns_none(self):
|
||||
assert reasoning_extra_body("openai/gpt-4o", 4096) is None
|
||||
assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None
|
||||
assert reasoning_extra_body("xai/grok-4", 4096) is None
|
||||
|
||||
def test_zero_max_tokens_kill_switch(self):
|
||||
# Operator kill switch: ``max_thinking_tokens <= 0`` disables the
|
||||
# ``reasoning`` extra_body fragment even on an Anthropic route.
|
||||
# Lets us silence reasoning without dropping the SDK path's budget.
|
||||
# ``reasoning`` extra_body fragment on ANY reasoning route (Anthropic
|
||||
# or Kimi). Lets us silence reasoning without dropping the SDK
|
||||
# path's budget.
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None
|
||||
assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None
|
||||
assert reasoning_extra_body("moonshotai/kimi-k2.6", 0) is None
|
||||
|
||||
|
||||
class TestBaselineReasoningEmitter:
|
||||
@@ -171,7 +239,12 @@ class TestBaselineReasoningEmitter:
|
||||
assert emitter.is_open is True
|
||||
|
||||
def test_subsequent_deltas_reuse_block_id_without_new_start(self):
|
||||
emitter = BaselineReasoningEmitter()
|
||||
# Disable coalescing so each chunk flushes immediately — this test
|
||||
# is about the Start/Delta/block-id state machine, not the coalesce
|
||||
# window. Coalescing behaviour is covered below.
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=0, coalesce_max_interval_ms=0
|
||||
)
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
second = emitter.on_delta(_delta(reasoning="b"))
|
||||
|
||||
@@ -226,6 +299,106 @@ class TestBaselineReasoningEmitter:
|
||||
assert deltas[0].delta == "plan: do the thing"
|
||||
|
||||
|
||||
class TestReasoningDeltaCoalescing:
|
||||
"""Coalescing batches fine-grained provider chunks into bigger wire
|
||||
frames. OpenRouter's Kimi K2.6 emits ~4,700 reasoning-delta chunks
|
||||
per turn vs ~28 for Sonnet; without batching, every chunk becomes one
|
||||
Redis ``xadd`` + one SSE event + one React re-render of the
|
||||
non-virtualised chat list, which paint-storms the browser. These
|
||||
tests pin the batching contract: small chunks buffer until the
|
||||
char-size or time threshold trips, large chunks still flush
|
||||
immediately, and ``close()`` never drops tail text."""
|
||||
|
||||
def test_small_chunks_after_first_buffer_until_threshold(self):
|
||||
# Generous time threshold so size alone controls flush timing.
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=32, coalesce_max_interval_ms=60_000
|
||||
)
|
||||
# First chunk always flushes immediately (so UI renders without
|
||||
# waiting).
|
||||
first = emitter.on_delta(_delta(reasoning="hi "))
|
||||
assert any(isinstance(e, StreamReasoningStart) for e in first)
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
|
||||
|
||||
# Subsequent small chunks buffer silently — 5 × 4 chars = 20 chars,
|
||||
# still under the 32-char threshold.
|
||||
for _ in range(5):
|
||||
assert emitter.on_delta(_delta(reasoning="abcd")) == []
|
||||
|
||||
# Once the threshold is crossed, the accumulated buffer flushes
|
||||
# as a single StreamReasoningDelta carrying every buffered chunk.
|
||||
flush = emitter.on_delta(_delta(reasoning="efghijklmnop"))
|
||||
assert len(flush) == 1
|
||||
assert isinstance(flush[0], StreamReasoningDelta)
|
||||
assert flush[0].delta == "abcd" * 5 + "efghijklmnop"
|
||||
|
||||
def test_time_based_flush_when_chars_stay_below_threshold(self, monkeypatch):
|
||||
# Fake ``time.monotonic`` so we can drive the time-based branch
|
||||
# deterministically without real sleeps.
|
||||
from backend.copilot.baseline import reasoning as rmod
|
||||
|
||||
fake_now = [0.0]
|
||||
monkeypatch.setattr(rmod.time, "monotonic", lambda: fake_now[0])
|
||||
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=1000, coalesce_max_interval_ms=40
|
||||
)
|
||||
# t=0: first chunk flushes immediately.
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
|
||||
|
||||
# t=10 ms: still under 40 ms → buffer.
|
||||
fake_now[0] = 0.010
|
||||
assert emitter.on_delta(_delta(reasoning="b")) == []
|
||||
|
||||
# t=50 ms since last flush → time threshold trips, flush fires.
|
||||
fake_now[0] = 0.060
|
||||
flushed = emitter.on_delta(_delta(reasoning="c"))
|
||||
assert len(flushed) == 1
|
||||
assert isinstance(flushed[0], StreamReasoningDelta)
|
||||
assert flushed[0].delta == "bc"
|
||||
|
||||
def test_close_flushes_tail_buffer_before_end(self):
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=1000, coalesce_max_interval_ms=60_000
|
||||
)
|
||||
emitter.on_delta(_delta(reasoning="first")) # flushes (first chunk)
|
||||
emitter.on_delta(_delta(reasoning=" middle ")) # buffered
|
||||
emitter.on_delta(_delta(reasoning="tail")) # buffered
|
||||
|
||||
events = emitter.close()
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], StreamReasoningDelta)
|
||||
assert events[0].delta == " middle tail"
|
||||
assert isinstance(events[1], StreamReasoningEnd)
|
||||
|
||||
def test_coalesce_disabled_flushes_every_chunk(self):
|
||||
emitter = BaselineReasoningEmitter(
|
||||
coalesce_min_chars=0, coalesce_max_interval_ms=0
|
||||
)
|
||||
first = emitter.on_delta(_delta(reasoning="a"))
|
||||
second = emitter.on_delta(_delta(reasoning="b"))
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in first) == 1
|
||||
assert sum(isinstance(e, StreamReasoningDelta) for e in second) == 1
|
||||
|
||||
def test_persistence_stays_per_delta_even_when_wire_coalesces(self):
|
||||
"""DB row content must track every chunk so a crash mid-turn
|
||||
persists the full reasoning-so-far, even if the coalesce window
|
||||
never flushed those chunks to the wire."""
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(
|
||||
session,
|
||||
coalesce_min_chars=1000,
|
||||
coalesce_max_interval_ms=60_000,
|
||||
)
|
||||
emitter.on_delta(_delta(reasoning="first "))
|
||||
emitter.on_delta(_delta(reasoning="chunk "))
|
||||
emitter.on_delta(_delta(reasoning="three"))
|
||||
# No close; verify the persisted row already has everything.
|
||||
assert len(session) == 1
|
||||
assert session[0].content == "first chunk three"
|
||||
|
||||
|
||||
class TestReasoningPersistence:
|
||||
"""The persistence contract: without ``role="reasoning"`` rows in
|
||||
session.messages, useHydrateOnStreamEnd overwrites the live-streamed
|
||||
|
||||
@@ -321,14 +321,17 @@ def _filter_tools_by_permissions(
|
||||
def _resolve_baseline_model(tier: CopilotLlmModel | None) -> str:
|
||||
"""Pick the model for the baseline path based on the per-request tier.
|
||||
|
||||
The baseline (fast) and SDK (extended thinking) paths now share the
|
||||
same tier-based model resolution — only the *path* differs between
|
||||
"fast" and "extended_thinking". ``'advanced'`` → Opus;
|
||||
``'standard'`` / ``None`` → the config default (Sonnet).
|
||||
Baseline resolves independently of SDK via the ``fast_*_model`` cells
|
||||
of the (path, tier) matrix. ``'standard'`` / ``None`` picks Kimi
|
||||
K2.6 by default (cheap + OpenRouter ``reasoning`` support);
|
||||
``'advanced'`` picks Opus by default so the advanced tier is a clean
|
||||
A/B against the SDK advanced tier — same model, different path —
|
||||
isolating reasoning-wire + cache differences from model capability.
|
||||
Both defaults are overridable per ``CHAT_FAST_*_MODEL`` env vars.
|
||||
"""
|
||||
from backend.copilot.service import resolve_chat_model
|
||||
|
||||
return resolve_chat_model(tier)
|
||||
if tier == "advanced":
|
||||
return config.fast_advanced_model
|
||||
return config.fast_standard_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -761,6 +764,19 @@ async def _baseline_tool_executor(
|
||||
)
|
||||
)
|
||||
|
||||
# Announce the tool call to the session so in-turn guards like
|
||||
# ``require_guide_read`` can see it *right now*, before the tool
|
||||
# actually runs. Without this, the tool_call row lives only in
|
||||
# ``state.session_messages`` until the ``finally`` block flushes it
|
||||
# into ``session.messages`` at turn end — so a second tool in the
|
||||
# same turn (e.g. ``create_agent`` after ``get_agent_building_guide``)
|
||||
# scans a stale ``session.messages`` and the guard re-fires despite
|
||||
# the guide having been called. The announce-set is cleared at turn
|
||||
# end; we deliberately don't touch ``session.messages`` here to avoid
|
||||
# duplicating the assistant row that ``_baseline_conversation_updater``
|
||||
# will append at round end.
|
||||
session.announce_inflight_tool_call(tool_name)
|
||||
|
||||
try:
|
||||
result: StreamToolOutputAvailable = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
@@ -1806,6 +1822,16 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# In-flight tool-call announcements are only meaningful for the
|
||||
# current turn; clear at the top of the outer finally so the next
|
||||
# turn starts with a clean scratch buffer even if one of the
|
||||
# awaited cleanup steps below (usage persistence, session upsert,
|
||||
# transcript upload) raises. The buffer is a process-local scratch
|
||||
# set — if we leak it into the next turn the guide-read guard would
|
||||
# observe a phantom in-flight call and skip its gate, so this must
|
||||
# run unconditionally.
|
||||
session.clear_inflight_tool_calls()
|
||||
|
||||
# Pending messages are drained atomically at turn start and
|
||||
# between tool rounds, so there's nothing to clear in finally.
|
||||
# Any message pushed after the final drain window stays in the
|
||||
|
||||
@@ -1404,6 +1404,16 @@ class TestApplyPromptCacheMarkers:
|
||||
assert not _is_anthropic_model("xai/grok-4")
|
||||
assert not _is_anthropic_model("meta-llama/llama-3.3-70b-instruct")
|
||||
|
||||
def test_is_anthropic_model_rejects_kimi_routes(self):
|
||||
"""Regression guard: Kimi K2.6 is a reasoning route (reasoning
|
||||
extra_body is sent) but NOT an Anthropic route — Moonshot does
|
||||
its own auto prompt caching, so ``cache_control`` markers must
|
||||
NOT be applied. OpenRouter silently drops them today, but if
|
||||
they ever start failing fast we'd want the gate tight."""
|
||||
assert not _is_anthropic_model("moonshotai/kimi-k2.6")
|
||||
assert not _is_anthropic_model("moonshotai/kimi-k2-thinking")
|
||||
assert not _is_anthropic_model("kimi-k2-instruct")
|
||||
|
||||
def test_cache_control_uses_configured_ttl(self, monkeypatch):
|
||||
"""TTL comes from ChatConfig.baseline_prompt_cache_ttl — defaults
|
||||
to 1h so the static prefix (system + tools) stays warm across
|
||||
@@ -1829,7 +1839,7 @@ class TestBaselineReasoningStreaming:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_param_absent_on_non_anthropic_routes(self):
|
||||
"""Non-Anthropic routes (e.g. OpenAI) must not receive ``reasoning``."""
|
||||
"""Non-reasoning routes (e.g. OpenAI) must not receive ``reasoning``."""
|
||||
state = _BaselineStreamState(model="openai/gpt-4o")
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -1850,6 +1860,54 @@ class TestBaselineReasoningStreaming:
|
||||
extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"]
|
||||
assert "reasoning" not in extra_body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kimi_route_sends_reasoning_but_no_cache_control(self):
|
||||
"""Kimi K2.6 is the default fast_model and sends ``reasoning`` via
|
||||
OpenRouter's unified extension. It must NOT receive ``cache_control``
|
||||
markers or the ``anthropic-beta`` header — Moonshot uses its own
|
||||
auto-caching and those Anthropic-only fields would either get
|
||||
silently dropped or (worst case) 400 on a future provider change."""
|
||||
state = _BaselineStreamState(model="moonshotai/kimi-k2.6")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_stream_mock()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service._get_openai_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
await _baseline_llm_caller(
|
||||
messages=[
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "echo", "parameters": {}},
|
||||
}
|
||||
],
|
||||
state=state,
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
extra_body = call_kwargs["extra_body"]
|
||||
# Reasoning param on — the whole point of picking Kimi is the
|
||||
# cheap-but-still-reasoning-capable path.
|
||||
assert "reasoning" in extra_body
|
||||
assert extra_body["reasoning"]["max_tokens"] > 0
|
||||
# Anthropic-only fields stay off.
|
||||
assert "extra_headers" not in call_kwargs
|
||||
sys_msg = call_kwargs["messages"][0]
|
||||
sys_content = sys_msg.get("content")
|
||||
if isinstance(sys_content, list):
|
||||
assert all("cache_control" not in block for block in sys_content)
|
||||
tools = call_kwargs.get("tools", [])
|
||||
for t in tools:
|
||||
assert "cache_control" not in t
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_only_stream_still_closes_block(self):
|
||||
"""Regression: a stream with only reasoning (no text, no tool_call)
|
||||
|
||||
@@ -63,21 +63,123 @@ def _make_session_messages(*roles: str) -> list[ChatMessage]:
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Baseline model resolution honours the per-request tier toggle."""
|
||||
"""Baseline model resolution honours the per-request tier toggle.
|
||||
|
||||
def test_advanced_tier_selects_advanced_model(self):
|
||||
assert _resolve_baseline_model("advanced") == config.advanced_model
|
||||
Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix
|
||||
and never falls through to the SDK-side ``thinking_*_model`` cells.
|
||||
Default routing:
|
||||
- ``standard`` / ``None`` → ``config.fast_standard_model`` (Kimi K2.6)
|
||||
- ``advanced`` → ``config.fast_advanced_model`` (Opus — same as SDK's
|
||||
advanced tier, so the advanced A/B isolates path differences)
|
||||
"""
|
||||
|
||||
def test_standard_tier_selects_default_model(self):
|
||||
assert _resolve_baseline_model("standard") == config.model
|
||||
def test_advanced_tier_selects_fast_advanced_model(self):
|
||||
assert _resolve_baseline_model("advanced") == config.fast_advanced_model
|
||||
|
||||
def test_none_tier_selects_default_model(self):
|
||||
"""Baseline users without a tier MUST keep the default (standard)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
def test_standard_tier_selects_fast_standard_model(self):
|
||||
assert _resolve_baseline_model("standard") == config.fast_standard_model
|
||||
|
||||
def test_standard_and_advanced_models_differ(self):
|
||||
"""Advanced tier defaults to a different (Opus) model than standard."""
|
||||
assert config.model != config.advanced_model
|
||||
def test_none_tier_selects_fast_standard_model(self):
|
||||
"""Baseline users without a tier get the cheap fast-standard default."""
|
||||
assert _resolve_baseline_model(None) == config.fast_standard_model
|
||||
|
||||
def test_fast_standard_default_is_kimi(self):
|
||||
"""Shipped default: Kimi K2.6 on the baseline standard cell.
|
||||
|
||||
Asserts the declared ``Field`` default — env-independent — so a
|
||||
deploy-time ``CHAT_FAST_STANDARD_MODEL`` rollback override
|
||||
doesn't fail CI while still pinning the shipped default.
|
||||
"""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_standard_model"].default
|
||||
== "moonshotai/kimi-k2.6"
|
||||
)
|
||||
|
||||
def test_fast_advanced_default_is_opus(self):
|
||||
"""Shipped default: Opus on the baseline advanced cell — mirrors
|
||||
the SDK advanced cell so the advanced-tier A/B stays clean
|
||||
(same model, different path)."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_advanced_model"].default
|
||||
== "anthropic/claude-opus-4.7"
|
||||
)
|
||||
|
||||
def test_standard_cells_diverge_across_paths(self):
|
||||
"""The whole point of the split: baseline cheap (Kimi) vs SDK
|
||||
Anthropic-only (Sonnet). If the shipped standard defaults ever
|
||||
collapse to the same value someone lost the cost savings.
|
||||
Checked against ``Field`` defaults, not the env-backed singleton."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["thinking_standard_model"].default
|
||||
!= ChatConfig.model_fields["fast_standard_model"].default
|
||||
)
|
||||
|
||||
def test_standard_and_advanced_cells_differ_on_fast(self):
|
||||
"""Advanced tier defaults to a different model than standard on
|
||||
the baseline path. Checked against declared ``Field`` defaults
|
||||
so operator env overrides don't flake the test."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_standard_model"].default
|
||||
!= ChatConfig.model_fields["fast_advanced_model"].default
|
||||
)
|
||||
|
||||
def test_legacy_env_aliases_route_to_new_fields(self, monkeypatch):
|
||||
"""Backward compat: the pre-split env var names must still bind.
|
||||
|
||||
The four-field matrix was introduced with ``validation_alias``
|
||||
entries so that existing deployments setting ``CHAT_MODEL`` /
|
||||
``CHAT_ADVANCED_MODEL`` / ``CHAT_FAST_MODEL`` continue to override
|
||||
the same effective cell without a rename. Construct a fresh
|
||||
``ChatConfig`` with each legacy name set and confirm it lands on
|
||||
the new field.
|
||||
"""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
monkeypatch.setenv("CHAT_MODEL", "legacy/sonnet-via-chat-model")
|
||||
monkeypatch.setenv("CHAT_ADVANCED_MODEL", "legacy/opus-via-advanced")
|
||||
monkeypatch.setenv("CHAT_FAST_MODEL", "legacy/fast-via-fast-model")
|
||||
|
||||
cfg = ChatConfig()
|
||||
|
||||
assert cfg.thinking_standard_model == "legacy/sonnet-via-chat-model"
|
||||
assert cfg.thinking_advanced_model == "legacy/opus-via-advanced"
|
||||
assert cfg.fast_standard_model == "legacy/fast-via-fast-model"
|
||||
|
||||
def test_all_four_new_env_vars_bind_to_their_cells(self, monkeypatch):
|
||||
"""Each of the four (path, tier) cells must be overridable via
|
||||
its documented ``CHAT_*_*_MODEL`` env var — including
|
||||
``CHAT_FAST_ADVANCED_MODEL`` which was missing a
|
||||
``validation_alias`` in the original split and only bound
|
||||
implicitly through ``env_prefix``. Pinning all four here so
|
||||
that whenever someone touches the config shape, an accidental
|
||||
unbinding fails CI instead of silently ignoring operator
|
||||
overrides.
|
||||
"""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
monkeypatch.setenv("CHAT_FAST_STANDARD_MODEL", "explicit/fast-std")
|
||||
monkeypatch.setenv("CHAT_FAST_ADVANCED_MODEL", "explicit/fast-adv")
|
||||
monkeypatch.setenv("CHAT_THINKING_STANDARD_MODEL", "explicit/think-std")
|
||||
monkeypatch.setenv("CHAT_THINKING_ADVANCED_MODEL", "explicit/think-adv")
|
||||
# Clear the legacy aliases so they don't win priority in
|
||||
# ``AliasChoices`` (first match wins).
|
||||
for legacy in ("CHAT_MODEL", "CHAT_ADVANCED_MODEL", "CHAT_FAST_MODEL"):
|
||||
monkeypatch.delenv(legacy, raising=False)
|
||||
|
||||
cfg = ChatConfig()
|
||||
|
||||
assert cfg.fast_standard_model == "explicit/fast-std"
|
||||
assert cfg.fast_advanced_model == "explicit/fast-adv"
|
||||
assert cfg.thinking_standard_model == "explicit/think-std"
|
||||
assert cfg.thinking_advanced_model == "explicit/think-adv"
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import AliasChoices, Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
@@ -17,8 +17,12 @@ from backend.util.clients import OPENROUTER_BASE_URL
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
# Per-request model tier set by the frontend model toggle.
|
||||
# 'standard' uses ``ChatConfig.model`` (Sonnet by default).
|
||||
# 'advanced' uses ``ChatConfig.advanced_model`` (Opus by default).
|
||||
# 'standard' picks the cheaper everyday model for the active path —
|
||||
# ``fast_standard_model`` on the baseline path, ``thinking_standard_model``
|
||||
# on the SDK path.
|
||||
# 'advanced' picks the premium model for the active path — ``fast_advanced_model``
|
||||
# on the baseline path, ``thinking_advanced_model`` on the SDK path (both
|
||||
# default to Opus today).
|
||||
# None means no preference — falls through to LD per-user targeting, then config.
|
||||
# Using tier names instead of model names keeps the contract model-agnostic.
|
||||
CopilotLlmModel = Literal["standard", "advanced"]
|
||||
@@ -27,21 +31,61 @@ CopilotLlmModel = Literal["standard", "advanced"]
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# Chat model tiers — applied orthogonally to the path (fast=baseline vs
|
||||
# extended_thinking=SDK). The "fast" vs "extended_thinking" toggle picks
|
||||
# which code path runs (no reasoning / heavy SDK); "standard" vs
|
||||
# "advanced" picks the model inside that path.
|
||||
model: str = Field(
|
||||
default="anthropic/claude-sonnet-4-6",
|
||||
description="Model used for the 'standard' tier (Sonnet by default). "
|
||||
"Applies to both baseline (fast) and SDK (extended thinking) paths. "
|
||||
"Override via CHAT_MODEL env var.",
|
||||
# Chat model tiers — a 2×2 of (path, tier). ``path`` = ``CopilotMode``
|
||||
# (``"fast"`` → baseline OpenAI-compat / any OpenRouter model;
|
||||
# ``"extended_thinking"`` → Claude Agent SDK, Anthropic-only CLI).
|
||||
# ``tier`` = ``CopilotLlmModel`` (``"standard"`` / ``"advanced"``).
|
||||
# Each cell has its own config so the two paths can evolve
|
||||
# independently (cheap provider on baseline, Anthropic on SDK) at each
|
||||
# tier without conflating one path's needs with the other's constraint.
|
||||
#
|
||||
# Historical env var names (``CHAT_MODEL`` / ``CHAT_ADVANCED_MODEL`` /
|
||||
# ``CHAT_FAST_MODEL``) are preserved via ``validation_alias`` so
|
||||
# existing deployments continue to override the same effective cell.
|
||||
fast_standard_model: str = Field(
|
||||
default="moonshotai/kimi-k2.6",
|
||||
validation_alias=AliasChoices(
|
||||
"CHAT_FAST_STANDARD_MODEL",
|
||||
"CHAT_FAST_MODEL",
|
||||
),
|
||||
description="Baseline path, 'standard' / ``None`` tier. Kimi K2.6 "
|
||||
"by default: ~5x cheaper input and ~5.4x cheaper output than Sonnet, "
|
||||
"SWE-Bench Verified parity with Opus, and OpenRouter advertises the "
|
||||
"``reasoning`` + ``include_reasoning`` extension params on the "
|
||||
"Moonshot endpoints — so the baseline reasoning plumbing lights up "
|
||||
"without provider-specific code. Roll back to the Anthropic route "
|
||||
"via ``CHAT_FAST_STANDARD_MODEL=anthropic/claude-sonnet-4-6`` (then "
|
||||
"``cache_control`` breakpoints reactivate via "
|
||||
"``_is_anthropic_model``).",
|
||||
)
|
||||
advanced_model: str = Field(
|
||||
default="anthropic/claude-opus-4-7",
|
||||
description="Model used for the 'advanced' tier (Opus by default). "
|
||||
"Applies to both baseline (fast) and SDK (extended thinking) paths. "
|
||||
"Override via CHAT_ADVANCED_MODEL env var.",
|
||||
fast_advanced_model: str = Field(
|
||||
default="anthropic/claude-opus-4.7",
|
||||
validation_alias=AliasChoices("CHAT_FAST_ADVANCED_MODEL"),
|
||||
description="Baseline path, 'advanced' tier. Opus by default. "
|
||||
"Override via ``CHAT_FAST_ADVANCED_MODEL``.",
|
||||
)
|
||||
thinking_standard_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4-6",
|
||||
validation_alias=AliasChoices(
|
||||
"CHAT_THINKING_STANDARD_MODEL",
|
||||
"CHAT_MODEL",
|
||||
),
|
||||
description="SDK (extended-thinking) path, 'standard' / ``None`` "
|
||||
"tier. Sonnet by default: the Claude Agent SDK CLI only speaks to "
|
||||
"Anthropic endpoints, so the standard SDK tier has to stay on an "
|
||||
"Anthropic model regardless of what the baseline path runs. "
|
||||
"Override via ``CHAT_THINKING_STANDARD_MODEL`` (legacy "
|
||||
"``CHAT_MODEL`` still honored).",
|
||||
)
|
||||
thinking_advanced_model: str = Field(
|
||||
default="anthropic/claude-opus-4.7",
|
||||
validation_alias=AliasChoices(
|
||||
"CHAT_THINKING_ADVANCED_MODEL",
|
||||
"CHAT_ADVANCED_MODEL",
|
||||
),
|
||||
description="SDK (extended-thinking) path, 'advanced' tier. Opus "
|
||||
"by default. Override via ``CHAT_THINKING_ADVANCED_MODEL`` "
|
||||
"(legacy ``CHAT_ADVANCED_MODEL`` still honored).",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
@@ -150,7 +194,7 @@ class ChatConfig(BaseSettings):
|
||||
claude_agent_model: str | None = Field(
|
||||
default=None,
|
||||
description="Model for the Claude Agent SDK path. If None, derives from "
|
||||
"the `model` field by stripping the OpenRouter provider prefix.",
|
||||
"`thinking_standard_model` by stripping the OpenRouter provider prefix.",
|
||||
)
|
||||
claude_agent_max_buffer_size: int = Field(
|
||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
||||
@@ -426,3 +470,10 @@ class ChatConfig(BaseSettings):
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
# Accept both the Python attribute name and the validation_alias when
|
||||
# constructing a ``ChatConfig`` directly (e.g. in tests passing
|
||||
# ``thinking_standard_model=...``). Without this, pydantic only
|
||||
# accepts the alias names (``CHAT_THINKING_STANDARD_MODEL`` env) and
|
||||
# rejects field-name kwargs — breaking ``ChatConfig(field=...)`` in
|
||||
# every test that constructs a config.
|
||||
populate_by_name = True
|
||||
|
||||
@@ -20,7 +20,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
)
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from backend.data.db_accessors import chat_db, library_db
|
||||
from backend.data.graph import GraphSettings
|
||||
@@ -205,6 +205,15 @@ class ChatSessionInfo(BaseModel):
|
||||
|
||||
class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
# In-flight tool-call names for the CURRENT turn. Not persisted to
|
||||
# DB and not serialised on the wire — ``PrivateAttr`` keeps this a
|
||||
# process-local scratch buffer that's invisible to ``model_dump`` /
|
||||
# ``model_dump_json`` / the redis cache path. Populated by the
|
||||
# baseline tool executor the moment a tool is dispatched so in-turn
|
||||
# guards (e.g. ``require_guide_read``) can see the call before it
|
||||
# lands in ``messages`` at turn-end. Cleared when the turn
|
||||
# completes.
|
||||
_inflight_tool_calls: set[str] = PrivateAttr(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
@@ -242,6 +251,56 @@ class ChatSession(ChatSessionInfo):
|
||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||
)
|
||||
|
||||
def announce_inflight_tool_call(self, tool_name: str) -> None:
|
||||
"""Record that *tool_name* is being dispatched in the current turn.
|
||||
|
||||
Called by the baseline tool executor **before** the tool actually
|
||||
runs (the announcement is about dispatch, not success). If the
|
||||
tool raises, the name stays in the buffer for the rest of the
|
||||
turn — that matches the guide-read gate's contract ("was the tool
|
||||
called?") but means any future gate wanting *successful*
|
||||
dispatches would need its own tracking.
|
||||
|
||||
Lets in-turn guards (see
|
||||
``copilot/tools/helpers.py::require_guide_read``) see a tool
|
||||
call the moment it's issued, instead of waiting for the
|
||||
``session.messages`` flush at turn end — fixing a loop where a
|
||||
second tool in the same turn re-fires a guard despite the
|
||||
guarding tool having already been called (seen on Kimi K2.6 in
|
||||
particular because its aggressive tool-call chaining exercises
|
||||
this path much more than Sonnet does). The buffer is cleared by
|
||||
:meth:`clear_inflight_tool_calls` at turn end.
|
||||
"""
|
||||
self._inflight_tool_calls.add(tool_name)
|
||||
|
||||
def clear_inflight_tool_calls(self) -> None:
|
||||
"""Reset the in-flight tool-call announcement buffer."""
|
||||
self._inflight_tool_calls.clear()
|
||||
|
||||
def has_tool_been_called(self, tool_name: str) -> bool:
|
||||
"""True when *tool_name* has been called in this session.
|
||||
|
||||
Checks the in-flight announcement buffer (for calls dispatched
|
||||
in the *current* turn but not yet flushed into ``messages``) and
|
||||
the durable ``messages`` history (for past turns + prior rounds
|
||||
within this turn whose writes already landed). The durable
|
||||
scan is session-wide, not turn-scoped: a matching tool call
|
||||
anywhere in ``messages`` counts. This matches the guide-read
|
||||
contract — once the guide has been read in the session, the
|
||||
agent doesn't need to re-read it for later create/edit/fix
|
||||
tools.
|
||||
"""
|
||||
if tool_name in self._inflight_tool_calls:
|
||||
return True
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role != "assistant" or not msg.tool_calls:
|
||||
continue
|
||||
for tc in msg.tool_calls:
|
||||
name = tc.get("function", {}).get("name") or tc.get("name")
|
||||
if name == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ ToolName = Literal[
|
||||
"validate_agent_graph",
|
||||
"view_agent_output",
|
||||
"web_fetch",
|
||||
"web_search",
|
||||
"write_workspace_file",
|
||||
# SDK built-ins
|
||||
"Agent",
|
||||
|
||||
@@ -9,8 +9,8 @@ persistence, and the ``CompactionTracker`` state machine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -25,8 +25,6 @@ from ..response_model import (
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompactionResult:
|
||||
@@ -73,6 +71,14 @@ def _new_tool_call_id() -> str:
|
||||
return f"compaction-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def _summarize_sources(sources: list[str]) -> str:
|
||||
counts = Counter(sources)
|
||||
parts: list[str] = []
|
||||
for source, count in counts.items():
|
||||
parts.append(f"{source}:{count}" if count > 1 else source)
|
||||
return ",".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public event builder
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -185,26 +191,54 @@ class CompactionTracker:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._compact_start = asyncio.Event()
|
||||
self._start_emitted = False
|
||||
self._done = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path: str = ""
|
||||
self._active_transcript_path: str = ""
|
||||
self._pending_transcript_paths: deque[str] = deque()
|
||||
self._attempted_sources: list[str] = []
|
||||
self._completed_sources: list[str] = []
|
||||
|
||||
@property
|
||||
def attempt_count(self) -> int:
|
||||
return len(self._attempted_sources)
|
||||
|
||||
@property
|
||||
def attempt_sources(self) -> tuple[str, ...]:
|
||||
return tuple(self._attempted_sources)
|
||||
|
||||
@property
|
||||
def completed_count(self) -> int:
|
||||
return len(self._completed_sources)
|
||||
|
||||
@property
|
||||
def completed_sources(self) -> tuple[str, ...]:
|
||||
return tuple(self._completed_sources)
|
||||
|
||||
def get_observability_metadata(self) -> dict[str, Any]:
|
||||
if not self._attempted_sources and not self._completed_sources:
|
||||
return {}
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"compaction_attempt_count": self.attempt_count,
|
||||
"compaction_attempt_sources": _summarize_sources(self._attempted_sources),
|
||||
}
|
||||
if self._completed_sources:
|
||||
metadata["compaction_count"] = self.completed_count
|
||||
metadata["compaction_sources"] = _summarize_sources(self._completed_sources)
|
||||
return metadata
|
||||
|
||||
def get_log_summary(self) -> dict[str, Any]:
|
||||
return {
|
||||
"attempt_count": self.attempt_count,
|
||||
"attempt_sources": _summarize_sources(self._attempted_sources),
|
||||
"completed_count": self.completed_count,
|
||||
"completed_sources": _summarize_sources(self._completed_sources),
|
||||
}
|
||||
|
||||
def on_compact(self, transcript_path: str = "") -> None:
|
||||
"""Callback for the PreCompact hook. Stores transcript_path."""
|
||||
if (
|
||||
self._transcript_path
|
||||
and transcript_path
|
||||
and self._transcript_path != transcript_path
|
||||
):
|
||||
logger.warning(
|
||||
"[Compaction] Overwriting transcript_path %s -> %s",
|
||||
self._transcript_path,
|
||||
transcript_path,
|
||||
)
|
||||
self._transcript_path = transcript_path
|
||||
self._compact_start.set()
|
||||
"""Callback for the PreCompact hook. Queues an SDK compaction attempt."""
|
||||
self._attempted_sources.append("sdk_internal")
|
||||
self._pending_transcript_paths.append(transcript_path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pre-query compaction
|
||||
@@ -212,7 +246,8 @@ class CompactionTracker:
|
||||
|
||||
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
|
||||
"""Emit + persist a self-contained compaction tool call."""
|
||||
self._done = True
|
||||
self._attempted_sources.append("pre_query")
|
||||
self._completed_sources.append("pre_query")
|
||||
return emit_compaction(session)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -221,18 +256,17 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
self._transcript_path = ""
|
||||
self._active_transcript_path = ""
|
||||
self._pending_transcript_paths.clear()
|
||||
|
||||
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
|
||||
"""If the PreCompact hook fired, emit start events (spinning tool)."""
|
||||
if self._compact_start.is_set() and not self._start_emitted and not self._done:
|
||||
self._compact_start.clear()
|
||||
if self._pending_transcript_paths and not self._start_emitted:
|
||||
self._start_emitted = True
|
||||
self._tool_call_id = _new_tool_call_id()
|
||||
self._active_transcript_path = self._pending_transcript_paths.popleft()
|
||||
return _start_events(self._tool_call_id)
|
||||
return []
|
||||
|
||||
@@ -246,27 +280,30 @@ class CompactionTracker:
|
||||
# Yield so pending hook tasks can set compact_start
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._done:
|
||||
return CompactionResult()
|
||||
if not self._start_emitted and not self._compact_start.is_set():
|
||||
if not self._start_emitted and not self._pending_transcript_paths:
|
||||
return CompactionResult()
|
||||
|
||||
if self._start_emitted:
|
||||
# Close the open spinner
|
||||
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
|
||||
persist_id = self._tool_call_id
|
||||
transcript_path = self._active_transcript_path
|
||||
else:
|
||||
# PreCompact fired but start never emitted — self-contained
|
||||
persist_id = _new_tool_call_id()
|
||||
done_events = compaction_events(
|
||||
COMPACTION_DONE_MSG, tool_call_id=persist_id
|
||||
)
|
||||
transcript_path = (
|
||||
self._pending_transcript_paths.popleft()
|
||||
if self._pending_transcript_paths
|
||||
else ""
|
||||
)
|
||||
|
||||
transcript_path = self._transcript_path
|
||||
self._compact_start.clear()
|
||||
self._start_emitted = False
|
||||
self._done = True
|
||||
self._transcript_path = ""
|
||||
self._tool_call_id = ""
|
||||
self._active_transcript_path = ""
|
||||
self._completed_sources.append("sdk_internal")
|
||||
_persist(session, persist_id, COMPACTION_DONE_MSG)
|
||||
return CompactionResult(
|
||||
events=done_events, just_ended=True, transcript_path=transcript_path
|
||||
|
||||
@@ -162,10 +162,11 @@ class TestFilterCompactionMessages:
|
||||
|
||||
|
||||
class TestCompactionTracker:
|
||||
def test_on_compact_sets_event(self):
|
||||
def test_on_compact_registers_pending_attempt(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact()
|
||||
assert tracker._compact_start.is_set()
|
||||
assert tracker.attempt_count == 1
|
||||
assert list(tracker._pending_transcript_paths) == [""]
|
||||
|
||||
def test_emit_start_if_ready_no_event(self):
|
||||
tracker = CompactionTracker()
|
||||
@@ -244,36 +245,39 @@ class TestCompactionTracker:
|
||||
evts = tracker.emit_pre_query(session)
|
||||
assert len(evts) == 5
|
||||
assert len(session.messages) == 2
|
||||
assert tracker._done is True
|
||||
assert tracker.attempt_count == 1
|
||||
assert tracker.completed_count == 1
|
||||
assert tracker.get_observability_metadata() == {
|
||||
"compaction_attempt_count": 1,
|
||||
"compaction_attempt_sources": "pre_query",
|
||||
"compaction_count": 1,
|
||||
"compaction_sources": "pre_query",
|
||||
}
|
||||
|
||||
def test_reset_for_query(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker._done = True
|
||||
tracker.on_compact("/some/path")
|
||||
tracker._start_emitted = True
|
||||
tracker._tool_call_id = "old"
|
||||
tracker._transcript_path = "/some/path"
|
||||
tracker._active_transcript_path = "/active/path"
|
||||
tracker.reset_for_query()
|
||||
assert tracker._done is False
|
||||
assert tracker._start_emitted is False
|
||||
assert tracker._tool_call_id == ""
|
||||
assert tracker._transcript_path == ""
|
||||
assert tracker._active_transcript_path == ""
|
||||
assert list(tracker._pending_transcript_paths) == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_query_blocks_sdk_compaction_until_reset(self):
|
||||
"""After pre-query compaction, SDK compaction is blocked until
|
||||
reset_for_query is called."""
|
||||
async def test_pre_query_does_not_block_sdk_compaction_within_query(self):
|
||||
"""SDK auto-compaction can still fire after a pre-query compaction."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact()
|
||||
# _done is True so emit_start_if_ready is blocked
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert evts == []
|
||||
# Reset clears _done, allowing subsequent compaction
|
||||
tracker.reset_for_query()
|
||||
tracker.on_compact()
|
||||
evts = tracker.emit_start_if_ready()
|
||||
assert len(evts) == 3
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert tracker.completed_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_allows_new_compaction(self):
|
||||
@@ -318,43 +322,18 @@ class TestCompactionTracker:
|
||||
assert len(result1.events) == 2
|
||||
assert result1.transcript_path == "/path/1"
|
||||
|
||||
# Second compaction cycle (should NOT be blocked — _done resets
|
||||
# because emit_end_if_ready sets it True, but the next on_compact
|
||||
# + emit_start_if_ready checks !_done which IS True now.
|
||||
# So we need reset_for_query between queries, but within a single
|
||||
# query multiple compactions work because _done blocks emit_start
|
||||
# until the next message arrives, at which point emit_end detects it)
|
||||
#
|
||||
# Actually: _done=True blocks emit_start_if_ready, so we need
|
||||
# the stream loop to reset. In practice service.py doesn't call
|
||||
# reset between compactions within the same query — let's verify
|
||||
# the actual behavior.
|
||||
# Second compaction cycle in the same query
|
||||
tracker.on_compact("/path/2")
|
||||
# _done is True from first compaction, so start is blocked
|
||||
start_evts = tracker.emit_start_if_ready()
|
||||
assert start_evts == []
|
||||
# But emit_end returns no-op because _done is True
|
||||
assert len(start_evts) == 3
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is False
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
assert tracker.completed_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compactions_with_intervening_message(self):
|
||||
"""Multiple compactions work when the stream loop processes messages between them.
|
||||
|
||||
In the real service.py flow:
|
||||
1. PreCompact fires → on_compact()
|
||||
2. emit_start shows spinner
|
||||
3. Next message arrives → emit_end completes compaction (_done=True)
|
||||
4. Stream continues processing messages...
|
||||
5. If a second PreCompact fires, _done=True blocks emit_start
|
||||
6. But the next message triggers emit_end, which sees _done=True → no-op
|
||||
7. The stream loop needs to detect this and handle accordingly
|
||||
|
||||
The actual flow for multiple compactions within a query requires
|
||||
_done to be cleared between them. The service.py code uses
|
||||
CompactionResult.just_ended to trigger replace_entries, and _done
|
||||
stays True until reset_for_query.
|
||||
"""
|
||||
"""Multiple compactions remain supported across query boundaries."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
@@ -376,10 +355,10 @@ class TestCompactionTracker:
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
def test_on_compact_stores_transcript_path(self):
|
||||
def test_on_compact_queues_transcript_path(self):
|
||||
tracker = CompactionTracker()
|
||||
tracker.on_compact("/some/path.jsonl")
|
||||
assert tracker._transcript_path == "/some/path.jsonl"
|
||||
assert list(tracker._pending_transcript_paths) == ["/some/path.jsonl"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_returns_transcript_path(self):
|
||||
@@ -391,17 +370,71 @@ class TestCompactionTracker:
|
||||
result = await tracker.emit_end_if_ready(session)
|
||||
assert result.just_ended is True
|
||||
assert result.transcript_path == "/my/session.jsonl"
|
||||
# transcript_path is cleared after emit_end
|
||||
assert tracker._transcript_path == ""
|
||||
assert tracker._active_transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_end_clears_transcript_path(self):
|
||||
"""After emit_end, _transcript_path is reset so it doesn't leak to
|
||||
subsequent non-compaction emit_end calls."""
|
||||
async def test_emit_end_clears_active_transcript_path(self):
|
||||
"""After emit_end, the active transcript path is reset."""
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
tracker.on_compact("/first/path.jsonl")
|
||||
tracker.emit_start_if_ready()
|
||||
await tracker.emit_end_if_ready(session)
|
||||
# After compaction, _transcript_path is cleared
|
||||
assert tracker._transcript_path == ""
|
||||
assert tracker._active_transcript_path == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_pending_hooks_are_counted_even_before_completion(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.emit_start_if_ready()
|
||||
tracker.on_compact("/path/2")
|
||||
tracker.on_compact("/path/3")
|
||||
|
||||
result1 = await tracker.emit_end_if_ready(session)
|
||||
assert result1.just_ended is True
|
||||
assert result1.transcript_path == "/path/1"
|
||||
assert tracker.attempt_count == 3
|
||||
assert tracker.completed_count == 1
|
||||
|
||||
tracker.emit_start_if_ready()
|
||||
result2 = await tracker.emit_end_if_ready(session)
|
||||
assert result2.just_ended is True
|
||||
assert result2.transcript_path == "/path/2"
|
||||
|
||||
tracker.emit_start_if_ready()
|
||||
result3 = await tracker.emit_end_if_ready(session)
|
||||
assert result3.just_ended is True
|
||||
assert result3.transcript_path == "/path/3"
|
||||
assert tracker.completed_count == 3
|
||||
|
||||
def test_get_observability_metadata_includes_attempts_and_completions(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.on_compact("/path/2")
|
||||
|
||||
assert tracker.get_observability_metadata() == {
|
||||
"compaction_attempt_count": 3,
|
||||
"compaction_attempt_sources": "pre_query,sdk_internal:2",
|
||||
"compaction_count": 1,
|
||||
"compaction_sources": "pre_query",
|
||||
}
|
||||
|
||||
def test_get_log_summary_includes_attempts_and_completions(self):
|
||||
tracker = CompactionTracker()
|
||||
session = _make_session()
|
||||
|
||||
tracker.emit_pre_query(session)
|
||||
tracker.on_compact("/path/1")
|
||||
tracker.on_compact("/path/2")
|
||||
|
||||
assert tracker.get_log_summary() == {
|
||||
"attempt_count": 3,
|
||||
"attempt_sources": "pre_query,sdk_internal:2",
|
||||
"completed_count": 1,
|
||||
"completed_sources": "pre_query",
|
||||
}
|
||||
|
||||
@@ -450,7 +450,9 @@ async def _reduce_context(
|
||||
# useful for the eventual upload_transcript call that seeds future turns.
|
||||
if transcript_content and not tried_compaction:
|
||||
compacted = await compact_transcript(
|
||||
transcript_content, model=config.model, log_prefix=log_prefix
|
||||
transcript_content,
|
||||
model=config.thinking_standard_model,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
if (
|
||||
compacted
|
||||
@@ -700,7 +702,7 @@ def _resolve_sdk_model() -> str | None:
|
||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
||||
|
||||
Uses `config.claude_agent_model` if set, otherwise derives from
|
||||
`config.model` via :func:`_normalize_model_name`.
|
||||
`config.thinking_standard_model` via :func:`_normalize_model_name`.
|
||||
|
||||
When `use_claude_code_subscription` is enabled and no explicit
|
||||
`claude_agent_model` is set, returns `None` so the CLI uses the
|
||||
@@ -710,7 +712,7 @@ def _resolve_sdk_model() -> str | None:
|
||||
return config.claude_agent_model
|
||||
if config.use_claude_code_subscription:
|
||||
return None
|
||||
return _normalize_model_name(config.model)
|
||||
return _normalize_model_name(config.thinking_standard_model)
|
||||
|
||||
|
||||
def _resolve_fallback_model() -> str | None:
|
||||
@@ -739,7 +741,7 @@ async def _resolve_sdk_model_for_request(
|
||||
cost (reported by the SDK) already reflects model-pricing differences.
|
||||
"""
|
||||
if model == "advanced":
|
||||
sdk_model = _normalize_model_name(config.advanced_model)
|
||||
sdk_model = _normalize_model_name(config.thinking_advanced_model)
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: advanced (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
@@ -1191,7 +1193,10 @@ async def _compress_messages(
|
||||
|
||||
try:
|
||||
result = await _run_compression(
|
||||
messages_dict, config.model, "[SDK]", target_tokens=target_tokens
|
||||
messages_dict,
|
||||
config.thinking_standard_model,
|
||||
"[SDK]",
|
||||
target_tokens=target_tokens,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Guard against timeouts or unexpected errors in compression —
|
||||
@@ -3745,15 +3750,17 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
if ended_with_stream_error:
|
||||
logger.warning(
|
||||
"%s Stream ended with SDK error after %d messages",
|
||||
"%s Stream ended with SDK error after %d messages (compaction=%s)",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
compaction.get_log_summary(),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Stream completed successfully with %d messages",
|
||||
"%s Stream completed successfully with %d messages (compaction=%s)",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
compaction.get_log_summary(),
|
||||
)
|
||||
except GeneratorExit:
|
||||
# GeneratorExit is raised when the async generator is closed by the
|
||||
@@ -3856,7 +3863,7 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=sdk_model or config.model,
|
||||
model=sdk_model or config.thinking_standard_model,
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
|
||||
@@ -364,9 +364,10 @@ class TestNormalizeModelName:
|
||||
"""Unit tests for the model-name normalisation helper.
|
||||
|
||||
The per-request model toggle calls _normalize_model_name with either
|
||||
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
|
||||
'standard'). These tests verify the OpenRouter/provider-prefix stripping
|
||||
that keeps the value compatible with the Claude CLI.
|
||||
``config.thinking_advanced_model`` (for 'advanced') or
|
||||
``config.thinking_standard_model`` (for 'standard'). These tests verify
|
||||
the OpenRouter/provider-prefix stripping that keeps the value compatible
|
||||
with the Claude CLI.
|
||||
"""
|
||||
|
||||
def test_strips_anthropic_prefix(self):
|
||||
|
||||
@@ -395,7 +395,7 @@ class TestResolveSdkModel:
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
thinking_standard_model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
@@ -412,7 +412,7 @@ class TestResolveSdkModel:
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
thinking_standard_model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
@@ -430,7 +430,7 @@ class TestResolveSdkModel:
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
thinking_standard_model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
@@ -447,7 +447,7 @@ class TestResolveSdkModel:
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
thinking_standard_model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model="claude-sonnet-4-5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
@@ -462,7 +462,7 @@ class TestResolveSdkModel:
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
thinking_standard_model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
@@ -477,7 +477,7 @@ class TestResolveSdkModel:
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="claude-opus-4.6",
|
||||
thinking_standard_model="claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
|
||||
@@ -779,7 +779,9 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
|
||||
# access. read_file also handles local tool-results and ephemeral reads.
|
||||
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "Agent", "WebSearch", "TodoWrite"]
|
||||
# WebSearch moved to ``SDK_DISALLOWED_TOOLS`` — routed through
|
||||
# ``mcp__copilot__web_search`` so cost tracking is unified across paths.
|
||||
_SDK_BUILTIN_ALWAYS = ["Task", "Agent", "TodoWrite"]
|
||||
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
@@ -805,6 +807,7 @@ _SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"WebSearch",
|
||||
"AskUserQuestion",
|
||||
"Write",
|
||||
"Edit",
|
||||
|
||||
@@ -42,17 +42,18 @@ settings = Settings()
|
||||
|
||||
|
||||
def resolve_chat_model(tier: CopilotLlmModel | None) -> str:
|
||||
"""Return the configured OpenRouter model string for the given tier.
|
||||
"""Return the configured SDK model for the given tier.
|
||||
|
||||
Shared by the baseline (fast) and SDK (extended thinking) paths so
|
||||
both honor the same standard/advanced env-var configuration. ``None``
|
||||
and ``'standard'`` fall through to ``config.model``; ``'advanced'``
|
||||
uses ``config.advanced_model``. Keep this flat — if a third tier
|
||||
shows up later, extend here and both paths pick it up for free.
|
||||
The SDK (extended-thinking) path is Anthropic-only — the Claude Agent
|
||||
SDK CLI refuses non-Anthropic endpoints — so both SDK tiers resolve
|
||||
to the ``thinking_*_model`` cells. Baseline has its own resolver
|
||||
(``_resolve_baseline_model``) that reads the ``fast_*_model`` cells;
|
||||
the two paths diverge deliberately at the config layer so a cheaper
|
||||
baseline provider can't break SDK, or vice versa.
|
||||
"""
|
||||
if tier == "advanced":
|
||||
return config.advanced_model
|
||||
return config.model
|
||||
return config.thinking_advanced_model
|
||||
return config.thinking_standard_model
|
||||
|
||||
|
||||
_client: LangfuseAsyncOpenAI | None = None
|
||||
|
||||
@@ -46,6 +46,7 @@ from .run_sub_session import RunSubSessionTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .web_search import WebSearchTool
|
||||
from .workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
@@ -95,6 +96,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"get_agent_building_guide": GetAgentBuildingGuideTool(),
|
||||
# Web fetch for safe URL retrieval
|
||||
"web_fetch": WebFetchTool(),
|
||||
"web_search": WebSearchTool(),
|
||||
# Agent-browser multi-step automation (navigate, act, screenshot)
|
||||
"browser_navigate": BrowserNavigateTool(),
|
||||
"browser_act": BrowserActTool(),
|
||||
|
||||
@@ -7,8 +7,6 @@ tokens and then produce JSON that fails validation — wasting turns on
|
||||
auto-fix loops.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
@@ -21,12 +19,21 @@ def _session_with_messages(
|
||||
messages: list[ChatMessage],
|
||||
builder_graph_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
"""Build a minimal ChatSession whose ``messages`` matches *messages*."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
"""Build a real ChatSession with the given messages.
|
||||
|
||||
Uses ``ChatSession.new`` + attribute reassignment rather than
|
||||
``MagicMock(spec=...)`` because the gate now calls
|
||||
``session.has_tool_been_called(...)`` and a ``spec`` mock
|
||||
returns a truthy ``MagicMock`` from that call, hiding real gate
|
||||
behaviour. A live ``ChatSession`` also correctly initialises the
|
||||
``_inflight_tool_calls`` PrivateAttr scratch buffer used by the
|
||||
in-turn announcement path.
|
||||
"""
|
||||
session = ChatSession.new(
|
||||
"test-user", dry_run=False, builder_graph_id=builder_graph_id
|
||||
)
|
||||
session.session_id = "test-session"
|
||||
session.messages = messages
|
||||
session.metadata = MagicMock()
|
||||
session.metadata.builder_graph_id = builder_graph_id
|
||||
return session
|
||||
|
||||
|
||||
@@ -124,6 +131,47 @@ def test_tool_name_surfaced_in_error(tool_name: str):
|
||||
assert tool_name in result.message
|
||||
|
||||
|
||||
def test_inflight_announcement_lets_gate_pass_within_same_turn():
|
||||
"""Regression for the Kimi baseline loop: the guide call is
|
||||
dispatched earlier in the SAME turn and buffered by the
|
||||
``_baseline_tool_executor`` into the in-flight announcement set,
|
||||
but hasn't been flushed into ``session.messages`` yet. The gate
|
||||
must see it anyway — otherwise a follow-up ``create_agent`` in the
|
||||
same turn re-fires the guard despite the guide call and the model
|
||||
loops retrying the guide."""
|
||||
session = _session_with_messages(
|
||||
[ChatMessage(role="user", content="build something")]
|
||||
)
|
||||
# Simulate _baseline_tool_executor's announce.
|
||||
session.announce_inflight_tool_call("get_agent_building_guide")
|
||||
assert require_guide_read(session, "create_agent") is None
|
||||
|
||||
|
||||
def test_inflight_clear_restores_gate_for_next_turn():
|
||||
"""End-of-turn cleanup must drop the in-flight buffer so it can't
|
||||
leak into the *next* turn's ``session.messages`` scan (e.g. a second
|
||||
session turn that should legitimately require a fresh guide call if
|
||||
``messages`` got compressed away)."""
|
||||
session = _session_with_messages([ChatMessage(role="user", content="build")])
|
||||
session.announce_inflight_tool_call("get_agent_building_guide")
|
||||
assert require_guide_read(session, "create_agent") is None
|
||||
session.clear_inflight_tool_calls()
|
||||
# With the buffer cleared and no guide row in messages, the guard
|
||||
# fires again.
|
||||
assert isinstance(require_guide_read(session, "create_agent"), ErrorResponse)
|
||||
|
||||
|
||||
def test_inflight_announcement_does_not_serialise_into_model_dump():
|
||||
"""PrivateAttr invariant: the scratch buffer must never leak into
|
||||
``model_dump()`` / the Redis cache payload / the DB — it's
|
||||
process-local turn state, not durable session state."""
|
||||
session = _session_with_messages([])
|
||||
session.announce_inflight_tool_call("get_agent_building_guide")
|
||||
dumped = session.model_dump()
|
||||
assert "_inflight_tool_calls" not in dumped
|
||||
assert "inflight_tool_calls" not in dumped
|
||||
|
||||
|
||||
def test_builder_bound_session_bypasses_gate():
|
||||
"""Builder-bound sessions receive the guide via <builder_context> on
|
||||
every turn, so the tool-call gate is unnecessary and only wastes a
|
||||
|
||||
@@ -787,22 +787,18 @@ def _resolve_discriminated_credentials(
|
||||
_AGENT_GUIDE_TOOL_NAME = "get_agent_building_guide"
|
||||
|
||||
|
||||
def _guide_read_in_session(session: ChatSession) -> bool:
|
||||
"""True if this session's assistant messages include a guide tool call."""
|
||||
for msg in reversed(session.messages):
|
||||
if msg.role != "assistant" or not msg.tool_calls:
|
||||
continue
|
||||
for tc in msg.tool_calls:
|
||||
name = tc.get("function", {}).get("name") or tc.get("name")
|
||||
if name == _AGENT_GUIDE_TOOL_NAME:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def require_guide_read(session: ChatSession, tool_name: str):
|
||||
"""Return an ErrorResponse if the guide hasn't been loaded this session.
|
||||
|
||||
Import inline to keep ``helpers.py`` free of tool-response imports.
|
||||
Uses :meth:`ChatSession.has_tool_been_called` which checks both the
|
||||
persisted ``messages`` list (session-wide) and the in-flight
|
||||
announcement buffer — so a guide call dispatched earlier in the
|
||||
*current* turn (before ``session.messages`` flushes at turn end) is
|
||||
recognised too. Otherwise a second tool in the same turn would
|
||||
re-fire this guard despite the guide having been called — seen on
|
||||
Kimi K2.6 in particular because its aggressive tool-call chaining
|
||||
exercises this path far more than Sonnet does.
|
||||
"""
|
||||
from .models import ErrorResponse # noqa: PLC0415 — avoid circular import
|
||||
|
||||
@@ -812,7 +808,7 @@ def require_guide_read(session: ChatSession, tool_name: str):
|
||||
# requiring one would waste a round-trip every turn.
|
||||
if session.metadata.builder_graph_id:
|
||||
return None
|
||||
if _guide_read_in_session(session):
|
||||
if session.has_tool_been_called(_AGENT_GUIDE_TOOL_NAME):
|
||||
return None
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
|
||||
@@ -79,6 +79,7 @@ class ResponseType(str, Enum):
|
||||
|
||||
# Web
|
||||
WEB_FETCH = "web_fetch"
|
||||
WEB_SEARCH = "web_search"
|
||||
|
||||
# Feature requests
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
@@ -588,6 +589,30 @@ class WebFetchResponse(ToolResponseBase):
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class WebSearchResult(BaseModel):
|
||||
"""One entry in a web_search tool response."""
|
||||
|
||||
title: str
|
||||
url: str
|
||||
snippet: str = ""
|
||||
page_age: str | None = None
|
||||
|
||||
|
||||
class WebSearchResponse(ToolResponseBase):
|
||||
"""Response for web_search tool — mirrors the shape of the SDK's
|
||||
native ``WebSearch`` tool so the LLM sees a consistent interface
|
||||
regardless of which path dispatched the call."""
|
||||
|
||||
type: ResponseType = ResponseType.WEB_SEARCH
|
||||
query: str
|
||||
results: list[WebSearchResult] = Field(default_factory=list)
|
||||
# Backend-reported usage for this call (copied from Anthropic's
|
||||
# ``usage.server_tool_use``). Surfaces as metadata for frontend
|
||||
# debug panels but is also what drives rate-limit / cost tracking
|
||||
# via ``persist_and_record_usage(provider="anthropic")``.
|
||||
search_requests: int = 0
|
||||
|
||||
|
||||
class BashExecResponse(ToolResponseBase):
|
||||
"""Response for bash_exec tool."""
|
||||
|
||||
|
||||
224
autogpt_platform/backend/backend/copilot/tools/web_search.py
Normal file
224
autogpt_platform/backend/backend/copilot/tools/web_search.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Web search tool — wraps Anthropic's server-side ``web_search`` beta.
|
||||
|
||||
Single entry point for web search on both SDK and baseline paths. The
|
||||
``web_search_20250305`` tool is server-side on Anthropic, so we call
|
||||
the Messages API directly regardless of which LLM invoked the copilot
|
||||
tool — OpenRouter can't proxy server-side tool execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, WebSearchResponse, WebSearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_WEB_SEARCH_DISPATCH_MODEL = "claude-haiku-4-5"
|
||||
_MAX_DISPATCH_TOKENS = 512
|
||||
_DEFAULT_MAX_RESULTS = 5
|
||||
_HARD_MAX_RESULTS = 20
|
||||
|
||||
|
||||
class WebSearchTool(BaseTool):
|
||||
"""Search the public web and return cited results."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "web_search"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the web for live info (news, recent docs). Returns "
|
||||
"{title, url, snippet}; use web_fetch to deep-dive a URL. "
|
||||
"Prefer one targeted query over many reformulations."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query.",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
f"Max results (default {_DEFAULT_MAX_RESULTS}, "
|
||||
f"cap {_HARD_MAX_RESULTS})."
|
||||
),
|
||||
"default": _DEFAULT_MAX_RESULTS,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return bool(Settings().secrets.anthropic_api_key)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
max_results: int = _DEFAULT_MAX_RESULTS,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
query = (query or "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a non-empty search query.",
|
||||
error="missing_query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
max_results = int(max_results)
|
||||
except (TypeError, ValueError):
|
||||
max_results = _DEFAULT_MAX_RESULTS
|
||||
max_results = max(1, min(max_results, _HARD_MAX_RESULTS))
|
||||
|
||||
api_key = Settings().secrets.anthropic_api_key
|
||||
if not api_key:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web search is unavailable — the deployment has no "
|
||||
"Anthropic API key configured."
|
||||
),
|
||||
error="web_search_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=_WEB_SEARCH_DISPATCH_MODEL,
|
||||
max_tokens=_MAX_DISPATCH_TOKENS,
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
"max_uses": 1,
|
||||
}
|
||||
],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Use the web_search tool exactly once with the "
|
||||
f"query {query!r} and then stop. Do not "
|
||||
f"summarise — the caller parses the raw "
|
||||
f"tool_result."
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[web_search] Anthropic call failed for query=%r: %s", query, exc
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Web search failed: {exc}",
|
||||
error="web_search_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
results, search_requests = _extract_results(resp, limit=max_results)
|
||||
|
||||
cost_usd = _estimate_cost_usd(resp, search_requests=search_requests)
|
||||
try:
|
||||
usage = getattr(resp, "usage", None)
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=getattr(usage, "input_tokens", 0) or 0,
|
||||
completion_tokens=getattr(usage, "output_tokens", 0) or 0,
|
||||
log_prefix="[web_search]",
|
||||
cost_usd=cost_usd,
|
||||
model=_WEB_SEARCH_DISPATCH_MODEL,
|
||||
provider="anthropic",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[web_search] usage tracking failed: %s", exc)
|
||||
|
||||
return WebSearchResponse(
|
||||
message=f"Found {len(results)} result(s) for {query!r}.",
|
||||
query=query,
|
||||
results=results,
|
||||
search_requests=search_requests,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
def _extract_results(resp: Any, *, limit: int) -> tuple[list[WebSearchResult], int]:
|
||||
"""Pull results + server-side request count from an Anthropic response."""
|
||||
results: list[WebSearchResult] = []
|
||||
search_requests = 0
|
||||
|
||||
for block in getattr(resp, "content", []) or []:
|
||||
btype = getattr(block, "type", None)
|
||||
if btype == "web_search_tool_result":
|
||||
content = getattr(block, "content", []) or []
|
||||
for item in content:
|
||||
if getattr(item, "type", None) != "web_search_result":
|
||||
continue
|
||||
if len(results) >= limit:
|
||||
break
|
||||
# Anthropic's ``web_search_result`` exposes only
|
||||
# ``title``/``url``/``page_age`` plus an opaque
|
||||
# ``encrypted_content`` blob that is meant for citation
|
||||
# round-tripping, not for display — it is base64-ish
|
||||
# binary and would show as gibberish if surfaced to the
|
||||
# model or the frontend. There is no plain-text snippet
|
||||
# field in the current beta; callers get the readable
|
||||
# text via the model's ``text`` blocks with citations,
|
||||
# not via this list. Leave ``snippet`` empty.
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=getattr(item, "title", "") or "",
|
||||
url=getattr(item, "url", "") or "",
|
||||
snippet="",
|
||||
page_age=getattr(item, "page_age", None),
|
||||
)
|
||||
)
|
||||
|
||||
usage = getattr(resp, "usage", None)
|
||||
server_tool_use = getattr(usage, "server_tool_use", None) if usage else None
|
||||
if server_tool_use is not None:
|
||||
search_requests = getattr(server_tool_use, "web_search_requests", 0) or 0
|
||||
|
||||
return results, search_requests
|
||||
|
||||
|
||||
# Update when Anthropic revises pricing.
|
||||
_COST_PER_SEARCH_USD = 0.010 # $10 per 1,000 web_search requests
|
||||
_HAIKU_INPUT_USD_PER_MTOK = 1.0
|
||||
_HAIKU_OUTPUT_USD_PER_MTOK = 5.0
|
||||
|
||||
|
||||
def _estimate_cost_usd(resp: Any, *, search_requests: int) -> float:
|
||||
"""Per-search fee × count + Haiku dispatch tokens."""
|
||||
usage = getattr(resp, "usage", None)
|
||||
input_tokens = getattr(usage, "input_tokens", 0) if usage else 0
|
||||
output_tokens = getattr(usage, "output_tokens", 0) if usage else 0
|
||||
|
||||
search_cost = search_requests * _COST_PER_SEARCH_USD
|
||||
inference_cost = (input_tokens / 1_000_000) * _HAIKU_INPUT_USD_PER_MTOK + (
|
||||
output_tokens / 1_000_000
|
||||
) * _HAIKU_OUTPUT_USD_PER_MTOK
|
||||
return round(search_cost + inference_cost, 6)
|
||||
@@ -0,0 +1,308 @@
|
||||
"""Tests for the ``web_search`` copilot tool.
|
||||
|
||||
Covers the result extractor + cost estimator as pure units (fed with
|
||||
synthetic Anthropic response objects), plus light integration tests that
|
||||
mock ``AsyncAnthropic.messages.create`` and confirm the handler plumbs
|
||||
through to ``persist_and_record_usage`` with the right provider tag.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .models import ErrorResponse, WebSearchResponse, WebSearchResult
|
||||
from .web_search import (
|
||||
_COST_PER_SEARCH_USD,
|
||||
WebSearchTool,
|
||||
_estimate_cost_usd,
|
||||
_extract_results,
|
||||
)
|
||||
|
||||
|
||||
def _fake_anthropic_response(
|
||||
*,
|
||||
results: list[dict] | None = None,
|
||||
search_requests: int = 1,
|
||||
input_tokens: int = 120,
|
||||
output_tokens: int = 40,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a synthetic Anthropic Messages response.
|
||||
|
||||
Matches the shape produced by ``client.messages.create`` when the
|
||||
response includes a ``web_search_tool_result`` content block and
|
||||
``usage.server_tool_use.web_search_requests`` on the turn meter.
|
||||
"""
|
||||
content = []
|
||||
if results is not None:
|
||||
content.append(
|
||||
SimpleNamespace(
|
||||
type="web_search_tool_result",
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="web_search_result",
|
||||
title=r.get("title", "untitled"),
|
||||
url=r.get("url", ""),
|
||||
encrypted_content=r.get("snippet", ""),
|
||||
page_age=r.get("page_age"),
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
)
|
||||
)
|
||||
usage = SimpleNamespace(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
server_tool_use=SimpleNamespace(web_search_requests=search_requests),
|
||||
)
|
||||
return SimpleNamespace(content=content, usage=usage)
|
||||
|
||||
|
||||
class TestExtractResults:
|
||||
"""The extractor is the only Anthropic-response-shape contact point;
|
||||
pin its behaviour so an API shape change surfaces here first."""
|
||||
|
||||
def test_extracts_title_url_page_age_and_drops_encrypted_snippet(self):
|
||||
# Anthropic's ``web_search_result`` ships an opaque
|
||||
# ``encrypted_content`` blob that is not safe to surface —
|
||||
# the extractor must drop it (snippet=="") regardless of
|
||||
# whether the blob is non-empty.
|
||||
resp = _fake_anthropic_response(
|
||||
results=[
|
||||
{
|
||||
"title": "Kimi K2.6 launch",
|
||||
"url": "https://example.com/kimi",
|
||||
"snippet": "EiJjbGF1ZGUtZW5jcnlwdGVkLWJsb2I=",
|
||||
"page_age": "1 day",
|
||||
},
|
||||
{
|
||||
"title": "OpenRouter pricing",
|
||||
"url": "https://openrouter.ai/moonshotai/kimi-k2.6",
|
||||
"snippet": "",
|
||||
},
|
||||
]
|
||||
)
|
||||
out, requests = _extract_results(resp, limit=10)
|
||||
assert requests == 1
|
||||
assert len(out) == 2
|
||||
assert out[0].title == "Kimi K2.6 launch"
|
||||
assert out[0].url == "https://example.com/kimi"
|
||||
assert out[0].snippet == ""
|
||||
assert out[0].page_age == "1 day"
|
||||
assert out[1].snippet == ""
|
||||
|
||||
def test_limit_caps_returned_results(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
|
||||
)
|
||||
out, _ = _extract_results(resp, limit=3)
|
||||
assert len(out) == 3
|
||||
assert [r.title for r in out] == ["r0", "r1", "r2"]
|
||||
|
||||
def test_missing_content_returns_empty(self):
|
||||
resp = SimpleNamespace(content=[], usage=None)
|
||||
out, requests = _extract_results(resp, limit=10)
|
||||
assert out == []
|
||||
assert requests == 0
|
||||
|
||||
def test_non_search_blocks_are_ignored(self):
|
||||
resp = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="text", text="Here's what I found..."),
|
||||
SimpleNamespace(
|
||||
type="web_search_tool_result",
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="web_search_result",
|
||||
title="real",
|
||||
url="https://real.example",
|
||||
encrypted_content="body",
|
||||
page_age=None,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
out, _ = _extract_results(resp, limit=10)
|
||||
assert len(out) == 1 and out[0].title == "real"
|
||||
|
||||
|
||||
class TestEstimateCostUsd:
|
||||
"""Pin the per-search fee + Haiku inference math — the pricing
|
||||
constants in ``web_search.py`` are hard-coded (no live lookup) so a
|
||||
drift between Anthropic's schedule and our constants must surface
|
||||
in this test for the next reader to notice."""
|
||||
|
||||
def test_zero_searches_still_charges_inference(self):
|
||||
resp = _fake_anthropic_response(results=[], search_requests=0)
|
||||
cost = _estimate_cost_usd(resp, search_requests=0)
|
||||
# Haiku at 1000 input / 5000 output tokens = tiny but non-zero.
|
||||
assert 0 < cost < 0.001
|
||||
|
||||
def test_single_search_fee_dominates(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[{"title": "x", "url": "https://e"}],
|
||||
search_requests=1,
|
||||
input_tokens=100,
|
||||
output_tokens=20,
|
||||
)
|
||||
cost = _estimate_cost_usd(resp, search_requests=1)
|
||||
# ~$0.010 search + trivial inference — total still ~1 cent.
|
||||
assert cost >= _COST_PER_SEARCH_USD
|
||||
assert cost < _COST_PER_SEARCH_USD + 0.001
|
||||
|
||||
def test_three_searches_linear_in_count(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[], search_requests=3, input_tokens=0, output_tokens=0
|
||||
)
|
||||
cost = _estimate_cost_usd(resp, search_requests=3)
|
||||
assert cost == pytest.approx(3 * _COST_PER_SEARCH_USD)
|
||||
|
||||
|
||||
class TestWebSearchToolDispatch:
|
||||
"""Lightweight integration test: mock the Anthropic client, confirm
|
||||
the handler returns a ``WebSearchResponse`` and the usage tracker is
|
||||
called with ``provider='anthropic'`` (not 'open_router', even on the
|
||||
baseline path — server-side web_search bills Anthropic regardless of
|
||||
the calling LLM's route)."""
|
||||
|
||||
def _session(self) -> ChatSession:
|
||||
s = ChatSession.new("test-user", dry_run=False)
|
||||
s.session_id = "sess-1"
|
||||
return s
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_response_with_results_and_tracks_cost(self, monkeypatch):
|
||||
fake_resp = _fake_anthropic_response(
|
||||
results=[
|
||||
{
|
||||
"title": "hello",
|
||||
"url": "https://example.com",
|
||||
"snippet": "greeting",
|
||||
}
|
||||
],
|
||||
search_requests=1,
|
||||
)
|
||||
mock_client = type(
|
||||
"MC",
|
||||
(),
|
||||
{
|
||||
"messages": type(
|
||||
"M", (), {"create": AsyncMock(return_value=fake_resp)}
|
||||
)()
|
||||
},
|
||||
)()
|
||||
|
||||
# Stub the Anthropic API key so ``is_available`` is True.
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(
|
||||
secrets=SimpleNamespace(anthropic_api_key="sk-test")
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=160),
|
||||
) as mock_track,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
result = await tool._execute(
|
||||
user_id="u1",
|
||||
session=self._session(),
|
||||
query="kimi k2.6 launch",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
assert isinstance(result, WebSearchResponse)
|
||||
assert result.query == "kimi k2.6 launch"
|
||||
assert len(result.results) == 1
|
||||
assert isinstance(result.results[0], WebSearchResult)
|
||||
assert result.search_requests == 1
|
||||
|
||||
# Cost tracker must have been called with provider="anthropic".
|
||||
assert mock_track.await_count == 1
|
||||
kwargs = mock_track.await_args.kwargs
|
||||
assert kwargs["provider"] == "anthropic"
|
||||
assert kwargs["model"] == "claude-haiku-4-5"
|
||||
assert kwargs["user_id"] == "u1"
|
||||
assert kwargs["cost_usd"] >= _COST_PER_SEARCH_USD
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_api_key_returns_error_without_calling_anthropic(
|
||||
self, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(secrets=SimpleNamespace(anthropic_api_key="")),
|
||||
)
|
||||
anthropic_stub = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=anthropic_stub,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.persist_and_record_usage",
|
||||
new=AsyncMock(),
|
||||
) as mock_track,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
assert tool.is_available is False
|
||||
result = await tool._execute(
|
||||
user_id="u1",
|
||||
session=self._session(),
|
||||
query="anything",
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "web_search_not_configured"
|
||||
anthropic_stub.messages.create.assert_not_called()
|
||||
mock_track.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_query_rejected_without_api_call(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(
|
||||
secrets=SimpleNamespace(anthropic_api_key="sk-test")
|
||||
),
|
||||
)
|
||||
anthropic_stub = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=anthropic_stub,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
result = await tool._execute(
|
||||
user_id="u1", session=self._session(), query=" "
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_query"
|
||||
anthropic_stub.messages.create.assert_not_called()
|
||||
|
||||
|
||||
class TestToolRegistryIntegration:
|
||||
"""The tool must be registered under the ``web_search`` name so the
|
||||
MCP layer exposes it as ``mcp__copilot__web_search`` — which is
|
||||
what the SDK path now dispatches to (see
|
||||
``sdk/tool_adapter.py::SDK_DISALLOWED_TOOLS`` which blocks the CLI's
|
||||
native ``WebSearch`` in favour of the MCP route)."""
|
||||
|
||||
def test_web_search_is_in_tool_registry(self):
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
assert "web_search" in TOOL_REGISTRY
|
||||
assert isinstance(TOOL_REGISTRY["web_search"], WebSearchTool)
|
||||
|
||||
def test_sdk_native_websearch_is_disallowed(self):
|
||||
from backend.copilot.sdk.tool_adapter import SDK_DISALLOWED_TOOLS
|
||||
|
||||
assert "WebSearch" in SDK_DISALLOWED_TOOLS
|
||||
@@ -305,15 +305,58 @@ function getWebAccordionData(
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const url =
|
||||
getStringField(inp as Record<string, unknown>, "url", "query") ??
|
||||
"Web content";
|
||||
const query = getStringField(inp, "query");
|
||||
const url = getStringField(inp, "url") ?? query ?? "Web content";
|
||||
|
||||
const results = Array.isArray(output.results)
|
||||
? (output.results as Array<Record<string, unknown>>)
|
||||
: null;
|
||||
|
||||
if (results) {
|
||||
return {
|
||||
title: `${results.length} search result${results.length === 1 ? "" : "s"}`,
|
||||
description: query ? truncate(query, 80) : undefined,
|
||||
content: (
|
||||
<div className="space-y-3">
|
||||
{results.map((r, i) => {
|
||||
const title = getStringField(r, "title") ?? "(untitled)";
|
||||
const href = getStringField(r, "url") ?? "";
|
||||
const snippet = getStringField(r, "snippet");
|
||||
const pageAge = getStringField(r, "page_age");
|
||||
return (
|
||||
<div key={i} className="text-sm">
|
||||
{href ? (
|
||||
<a
|
||||
href={href}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="font-medium text-blue-600 hover:underline"
|
||||
>
|
||||
{title}
|
||||
</a>
|
||||
) : (
|
||||
<span className="font-medium">{title}</span>
|
||||
)}
|
||||
{href && (
|
||||
<div className="text-xs text-slate-500">
|
||||
{truncate(href, 100)}
|
||||
</div>
|
||||
)}
|
||||
{snippet && <p className="mt-0.5 text-slate-700">{snippet}</p>}
|
||||
{pageAge && (
|
||||
<div className="mt-0.5 text-xs text-slate-400">{pageAge}</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
// Try direct string fields first, then MCP content blocks, then raw JSON
|
||||
let content = getStringField(output, "content", "text", "_raw");
|
||||
if (!content) content = extractMcpText(output);
|
||||
if (!content) {
|
||||
// Fallback: render the raw JSON so the accordion isn't empty
|
||||
try {
|
||||
const raw = JSON.stringify(output, null, 2);
|
||||
if (raw !== "{}") content = raw;
|
||||
@@ -327,11 +370,7 @@ function getWebAccordionData(
|
||||
const message = getStringField(output, "message");
|
||||
|
||||
return {
|
||||
title: statusCode
|
||||
? `Response (${statusCode})`
|
||||
: url
|
||||
? "Web fetch"
|
||||
: "Search results",
|
||||
title: statusCode ? `Response (${statusCode})` : "Web fetch",
|
||||
description: truncate(url, 80),
|
||||
content: content ? (
|
||||
<ContentCodeBlock>{content}</ContentCodeBlock>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { render, screen } from "@/tests/integrations/test-utils";
|
||||
import { fireEvent, render, screen } from "@/tests/integrations/test-utils";
|
||||
import { GenericTool } from "../GenericTool";
|
||||
|
||||
function makePart(overrides: Record<string, unknown> = {}): ToolUIPart {
|
||||
@@ -136,4 +136,181 @@ describe("GenericTool", () => {
|
||||
const trigger2 = screen.getByRole("button", { expanded: false });
|
||||
expect(trigger2.textContent).toContain("completed");
|
||||
});
|
||||
|
||||
describe("web_search results rendering", () => {
|
||||
function makeWebSearchPart(
|
||||
results: Array<Record<string, unknown>>,
|
||||
query = "kimi k2.6",
|
||||
): ToolUIPart {
|
||||
return {
|
||||
type: "tool-web_search",
|
||||
toolCallId: "call-web-1",
|
||||
state: "output-available",
|
||||
input: { query },
|
||||
output: {
|
||||
type: "web_search_response",
|
||||
results,
|
||||
query,
|
||||
search_requests: 1,
|
||||
},
|
||||
} as unknown as ToolUIPart;
|
||||
}
|
||||
|
||||
it("renders an 'N search results' title and shows the query in the description", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart([
|
||||
{
|
||||
title: "Kimi K2.6 release notes",
|
||||
url: "https://example.com/kimi",
|
||||
snippet: "A fast model",
|
||||
page_age: "2 days ago",
|
||||
},
|
||||
{
|
||||
title: "Second result",
|
||||
url: "https://example.com/two",
|
||||
snippet: "Another snippet",
|
||||
},
|
||||
])}
|
||||
/>,
|
||||
);
|
||||
const trigger = screen.getByRole("button", { expanded: false });
|
||||
expect(trigger.textContent).toContain("2 search results");
|
||||
expect(trigger.textContent).toContain("kimi k2.6");
|
||||
|
||||
fireEvent.click(trigger);
|
||||
|
||||
const firstLink = screen.getByRole("link", {
|
||||
name: "Kimi K2.6 release notes",
|
||||
}) as HTMLAnchorElement;
|
||||
expect(firstLink.getAttribute("href")).toBe("https://example.com/kimi");
|
||||
expect(firstLink.getAttribute("target")).toBe("_blank");
|
||||
expect(firstLink.getAttribute("rel")).toBe("noopener noreferrer");
|
||||
expect(screen.queryByText("A fast model")).not.toBeNull();
|
||||
expect(screen.queryByText("2 days ago")).not.toBeNull();
|
||||
|
||||
const secondLink = screen.getByRole("link", {
|
||||
name: "Second result",
|
||||
}) as HTMLAnchorElement;
|
||||
expect(secondLink.getAttribute("href")).toBe("https://example.com/two");
|
||||
});
|
||||
|
||||
it("uses singular 'search result' when there is exactly one result", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart([
|
||||
{
|
||||
title: "Only result",
|
||||
url: "https://example.com/only",
|
||||
snippet: "Lone snippet",
|
||||
},
|
||||
])}
|
||||
/>,
|
||||
);
|
||||
const trigger = screen.getByRole("button", { expanded: false });
|
||||
expect(trigger.textContent).toContain("1 search result");
|
||||
expect(trigger.textContent).not.toContain("1 search results");
|
||||
});
|
||||
|
||||
it("handles an empty results array (0 search results)", () => {
|
||||
render(<GenericTool part={makeWebSearchPart([])} />);
|
||||
const trigger = screen.getByRole("button", { expanded: false });
|
||||
expect(trigger.textContent).toContain("0 search results");
|
||||
});
|
||||
|
||||
it("renders an untitled non-link when a result has no url", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart([
|
||||
{ title: "No URL entry", snippet: "Just text" },
|
||||
])}
|
||||
/>,
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { expanded: false }));
|
||||
expect(screen.queryByRole("link")).toBeNull();
|
||||
expect(screen.queryByText("No URL entry")).not.toBeNull();
|
||||
expect(screen.queryByText("Just text")).not.toBeNull();
|
||||
});
|
||||
|
||||
it("shows subtitle 'Searched \"…\"' once web_search output is available", () => {
|
||||
const { container } = render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart(
|
||||
[
|
||||
{
|
||||
title: "Kimi K2.6 release notes",
|
||||
url: "https://example.com/kimi",
|
||||
snippet: "A fast model",
|
||||
},
|
||||
],
|
||||
"kimi k2.6",
|
||||
)}
|
||||
/>,
|
||||
);
|
||||
// MorphingTextAnimation splits each character into its own span and
|
||||
// substitutes spaces with , so assert on a normalized textContent
|
||||
// rather than the raw substring.
|
||||
const normalized = (container.textContent ?? "").replace(/ /g, " ");
|
||||
expect(normalized).toContain('Searched "kimi k2.6"');
|
||||
});
|
||||
|
||||
it("uses '(untitled)' when a search result has no title", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart([
|
||||
{ url: "https://example.com/x", snippet: "No title here" },
|
||||
])}
|
||||
/>,
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { expanded: false }));
|
||||
const link = screen.getByRole("link", {
|
||||
name: "(untitled)",
|
||||
}) as HTMLAnchorElement;
|
||||
expect(link.getAttribute("href")).toBe("https://example.com/x");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getWebAccordionData non-results fallback", () => {
|
||||
function makeWebFetchPart(output: Record<string, unknown>): ToolUIPart {
|
||||
return {
|
||||
type: "tool-web_fetch",
|
||||
toolCallId: "call-fetch-1",
|
||||
state: "output-available",
|
||||
input: { url: "https://example.com/page" },
|
||||
output,
|
||||
} as unknown as ToolUIPart;
|
||||
}
|
||||
|
||||
it("renders 'Web fetch' title when output has content instead of results", () => {
|
||||
render(
|
||||
<GenericTool part={makeWebFetchPart({ content: "fetched body" })} />,
|
||||
);
|
||||
const trigger = screen.getByRole("button", { expanded: false });
|
||||
expect(trigger.textContent).toContain("Web fetch");
|
||||
fireEvent.click(trigger);
|
||||
expect(screen.queryByText("fetched body")).not.toBeNull();
|
||||
});
|
||||
|
||||
it("renders 'Response (N)' title when output has a status_code", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebFetchPart({ status_code: 404, message: "not found" })}
|
||||
/>,
|
||||
);
|
||||
const trigger = screen.getByRole("button", { expanded: false });
|
||||
expect(trigger.textContent).toContain("Response (404)");
|
||||
});
|
||||
|
||||
it("falls back to MCP text blocks when direct content is absent", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebFetchPart({
|
||||
content: [{ type: "text", text: "mcp body" }],
|
||||
})}
|
||||
/>,
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { expanded: false }));
|
||||
expect(screen.queryByText("mcp body")).not.toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -22,6 +22,11 @@ describe("extractToolName", () => {
|
||||
const part = { type: "Read" } as unknown as ToolUIPart;
|
||||
expect(extractToolName(part)).toBe("Read");
|
||||
});
|
||||
|
||||
it("strips the tool- prefix for web_search", () => {
|
||||
const part = { type: "tool-web_search" } as unknown as ToolUIPart;
|
||||
expect(extractToolName(part)).toBe("web_search");
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatToolName", () => {
|
||||
@@ -60,8 +65,9 @@ describe("getToolCategory", () => {
|
||||
expect(getToolCategory("bash_exec")).toBe("bash");
|
||||
});
|
||||
|
||||
it("returns 'web' for web_fetch, WebSearch, WebFetch", () => {
|
||||
it("returns 'web' for web_fetch, web_search, WebSearch, WebFetch", () => {
|
||||
expect(getToolCategory("web_fetch")).toBe("web");
|
||||
expect(getToolCategory("web_search")).toBe("web");
|
||||
expect(getToolCategory("WebSearch")).toBe("web");
|
||||
expect(getToolCategory("WebFetch")).toBe("web");
|
||||
});
|
||||
@@ -229,6 +235,50 @@ describe("getAnimationText", () => {
|
||||
expect(getAnimationText(part, "web")).toBe('Searching "test query"');
|
||||
});
|
||||
|
||||
it("shows searching text for web_search with a query summary", () => {
|
||||
const part = makePart({
|
||||
type: "tool-web_search",
|
||||
state: "input-streaming",
|
||||
input: { query: "kimi k2.6" },
|
||||
});
|
||||
expect(getAnimationText(part, "web")).toBe('Searching "kimi k2.6"');
|
||||
});
|
||||
|
||||
it("falls back to generic searching text for web_search with no query", () => {
|
||||
const part = makePart({
|
||||
type: "tool-web_search",
|
||||
state: "input-streaming",
|
||||
});
|
||||
expect(getAnimationText(part, "web")).toBe("Searching the web…");
|
||||
});
|
||||
|
||||
it("shows completed text for web_search with a query summary", () => {
|
||||
const part = makePart({
|
||||
type: "tool-web_search",
|
||||
state: "output-available",
|
||||
input: { query: "kimi k2.6" },
|
||||
output: { results: [] },
|
||||
});
|
||||
expect(getAnimationText(part, "web")).toBe('Searched "kimi k2.6"');
|
||||
});
|
||||
|
||||
it("falls back to generic completed text for web_search with no query", () => {
|
||||
const part = makePart({
|
||||
type: "tool-web_search",
|
||||
state: "output-available",
|
||||
output: { results: [] },
|
||||
});
|
||||
expect(getAnimationText(part, "web")).toBe("Web search completed");
|
||||
});
|
||||
|
||||
it("shows error text for web_search failure", () => {
|
||||
const part = makePart({
|
||||
type: "tool-web_search",
|
||||
state: "output-error",
|
||||
});
|
||||
expect(getAnimationText(part, "web")).toBe("Search failed");
|
||||
});
|
||||
|
||||
it("shows fetching text for web_fetch", () => {
|
||||
const part = makePart({
|
||||
type: "tool-web_fetch",
|
||||
|
||||
@@ -60,6 +60,7 @@ export function getToolCategory(toolName: string): ToolCategory {
|
||||
case "bash_exec":
|
||||
return "bash";
|
||||
case "web_fetch":
|
||||
case "web_search":
|
||||
case "WebSearch":
|
||||
case "WebFetch":
|
||||
return "web";
|
||||
@@ -114,6 +115,7 @@ function getInputSummary(toolName: string, input: unknown): string | null {
|
||||
case "web_fetch":
|
||||
case "WebFetch":
|
||||
return typeof inp.url === "string" ? inp.url : null;
|
||||
case "web_search":
|
||||
case "WebSearch":
|
||||
return typeof inp.query === "string" ? inp.query : null;
|
||||
case "browser_navigate":
|
||||
@@ -220,7 +222,7 @@ export function getAnimationText(
|
||||
? `Running: ${shortSummary}`
|
||||
: "Running command\u2026";
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
return shortSummary
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web\u2026";
|
||||
@@ -282,7 +284,7 @@ export function getAnimationText(
|
||||
// exit status here would just double up.
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
case "web":
|
||||
if (toolName === "WebSearch") {
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
: "Web search completed";
|
||||
@@ -352,7 +354,9 @@ export function getAnimationText(
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" ? "Search failed" : "Fetch failed";
|
||||
return toolName === "WebSearch" || toolName === "web_search"
|
||||
? "Search failed"
|
||||
: "Fetch failed";
|
||||
case "browser":
|
||||
return "Browser action failed";
|
||||
default:
|
||||
|
||||
@@ -14642,6 +14642,7 @@
|
||||
"browser_screenshot",
|
||||
"bash_exec",
|
||||
"web_fetch",
|
||||
"web_search",
|
||||
"feature_request_search",
|
||||
"feature_request_created",
|
||||
"memory_store",
|
||||
|
||||
Reference in New Issue
Block a user