Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into fix/copilot-static-system-prompt

This commit is contained in:
Zamil Majdy
2026-04-15 19:01:57 +07:00
30 changed files with 1761 additions and 166 deletions

View File

@@ -15,9 +15,10 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
from backend.copilot.db import get_chat_messages_paginated
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.message_dedup import acquire_dedup_lock
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -140,6 +141,11 @@ class StreamChatRequest(BaseModel):
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",
)
model: CopilotLlmModel | None = Field(
default=None,
description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. "
"If None, the server applies per-user LD targeting then falls back to config.",
)
class CreateSessionRequest(BaseModel):
@@ -377,6 +383,31 @@ async def delete_session(
return Response(status_code=204)
@router.delete(
"/sessions/{session_id}/stream",
dependencies=[Security(auth.requires_user)],
status_code=204,
)
async def disconnect_session_stream(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""Disconnect all active SSE listeners for a session.
Called by the frontend when the user switches away from a chat so the
backend releases XREAD listeners immediately rather than waiting for
the 5-10 s timeout.
"""
session = await get_chat_session(session_id, user_id)
if not session:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
await stream_registry.disconnect_all_listeners(session_id)
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",
@@ -811,6 +842,9 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
# Capture the original message text BEFORE any mutation (attachment enrichment)
# so the idempotency hash is stable across retries.
original_message = request.message
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
@@ -839,60 +873,91 @@ async def stream_chat_post(
)
request.message += files_block
# ── Idempotency guard ────────────────────────────────────────────────────
# Blocks duplicate executor tasks from concurrent/retried POSTs.
# See backend/copilot/message_dedup.py for the full lifecycle description.
dedup_lock = None
if request.is_user_message:
dedup_lock = await acquire_dedup_lock(
session_id, original_message, sanitized_file_ids
)
if dedup_lock is None and (original_message or sanitized_file_ids):
async def _empty_sse() -> AsyncGenerator[str, None]:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return StreamingResponse(
_empty_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
"x-vercel-ai-ui-message-stream": "v1",
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
#
# If any of these operations raises, release the dedup lock before propagating
# so subsequent retries are not blocked for 30 s.
try:
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_name="chat",
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
)
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
model=request.model,
)
except Exception:
if dedup_lock:
await dedup_lock.release()
raise
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
@@ -900,6 +965,9 @@ async def stream_chat_post(
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
import time as time_module
@@ -913,6 +981,12 @@ async def stream_chat_post(
subscriber_queue = None
first_chunk_yielded = False
chunks_yielded = 0
# True for every exit path except GeneratorExit (client disconnect).
# On disconnect the backend turn is still running — releasing the lock
# there would reopen the infra-retry duplicate window. The 30 s TTL
# is the fallback. All other exits (normal finish, early return, error)
# should release so the user can re-send the same message.
release_dedup_lock_on_exit = True
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
@@ -924,8 +998,7 @@ async def stream_chat_post(
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
return # finally releases dedup_lock
# Read from the subscriber queue and yield to SSE
logger.info(
@@ -954,7 +1027,6 @@ async def stream_chat_post(
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
total_time = time_module.perf_counter() - event_gen_start
logger.info(
@@ -968,7 +1040,8 @@ async def stream_chat_post(
}
},
)
break
break # finally releases dedup_lock
except asyncio.TimeoutError:
yield StreamHeartbeat().to_sse()
@@ -983,7 +1056,7 @@ async def stream_chat_post(
}
},
)
pass # Client disconnected - background task continues
release_dedup_lock_on_exit = False
except Exception as e:
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
logger.error(
@@ -998,7 +1071,10 @@ async def stream_chat_post(
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
# finally releases dedup_lock
finally:
if dedup_lock and release_dedup_lock_on_exit:
await dedup_lock.release()
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:

View File

@@ -133,14 +133,30 @@ def test_stream_chat_rejects_too_many_file_ids():
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
def _mock_stream_internals(
mocker: pytest_mock.MockerFixture,
*,
redis_set_returns: object = True,
):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
validation and enrichment logic without needing Redis/RabbitMQ.
Args:
redis_set_returns: Value returned by the mocked Redis ``set`` call.
``True`` (default) simulates a fresh key (new message);
``None`` simulates a collision (duplicate blocked).
Returns:
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
callers can make additional assertions about side-effects.
"""
import types
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
mock_save = mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
@@ -150,7 +166,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
mock_enqueue = mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
@@ -158,9 +174,18 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=redis_set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
return ns
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
@@ -189,7 +214,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
# ─── UUID format filtering ─────────────────────────────────────────────
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
@@ -228,7 +253,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
# ─── Cross-workspace file_ids ─────────────────────────────────────────
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
@@ -257,7 +282,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
# ─── Rate limit → 429 ─────────────────────────────────────────────────
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture):
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
@@ -278,7 +303,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix
assert "daily" in response.json()["detail"].lower()
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
def test_stream_chat_returns_429_on_weekly_rate_limit(
mocker: pytest_mock.MockerFixture,
):
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
@@ -301,7 +328,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi
assert "resets in" in detail
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
"""The 429 response detail should include the human-readable reset time."""
from backend.copilot.rate_limit import RateLimitExceeded
@@ -677,3 +704,279 @@ class TestStripInjectedContext:
result = _strip_injected_context(msg)
# Without a role, the helper short-circuits without touching content.
assert result["content"] == "hello"
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
mocker: pytest_mock.MockerFixture,
) -> None:
"""A second POST with the same message within the 30-s window must return
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
turn complete without creating a ghost response."""
# redis_set_returns=None simulates a collision: the NX key already exists.
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-dup/stream",
json={"message": "duplicate message", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
assert '"finish"' in body
assert "[DONE]" in body
# The empty SSE response must include the AI SDK protocol header so the
# frontend treats it as a valid stream and marks the turn complete.
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
# The duplicate guard must prevent save/enqueue side effects.
ns.save.assert_not_called()
ns.enqueue.assert_not_called()
def test_stream_chat_first_post_proceeds_normally(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The first POST (Redis NX key set successfully) must proceed through the
normal streaming path — no early return."""
ns = _mock_stream_internals(mocker, redis_set_returns=True)
response = client.post(
"/sessions/sess-new/stream",
json={"message": "first message", "is_user_message": True},
)
assert response.status_code == 200
# Redis set must have been called once with the NX flag.
ns.redis.set.assert_called_once()
call_kwargs = ns.redis.set.call_args
assert call_kwargs.kwargs.get("nx") is True
def test_stream_chat_dedup_skipped_for_non_user_messages(
mocker: pytest_mock.MockerFixture,
) -> None:
"""System/assistant messages (is_user_message=False) bypass the dedup
guard — they are injected programmatically and must always be processed."""
ns = _mock_stream_internals(mocker, redis_set_returns=None)
response = client.post(
"/sessions/sess-sys/stream",
json={"message": "system context", "is_user_message": False},
)
# Even though redis_set_returns=None (would block a user message),
# the endpoint must proceed because is_user_message=False.
assert response.status_code == 200
ns.redis.set.assert_not_called()
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup hash must be computed from the original request message,
not the mutated version that has the [Attached files] block appended.
A file_id is sent so the route actually appends the [Attached files] block,
exercising the mutation path — the hash must still match the original text."""
import hashlib
ns = _mock_stream_internals(mocker, redis_set_returns=True)
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
# Mock workspace + prisma so the attachment block is actually appended.
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
fake_file = type(
"F",
(),
{
"id": file_id,
"name": "doc.pdf",
"mimeType": "application/pdf",
"sizeBytes": 1024,
},
)()
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-hash/stream",
json={
"message": "plain message",
"is_user_message": True,
"file_ids": [file_id],
},
)
assert response.status_code == 200
ns.redis.set.assert_called_once()
call_args = ns.redis.set.call_args
dedup_key = call_args.args[0]
# Hash must use the original message + sorted file IDs, not the mutated text.
expected_hash = hashlib.sha256(
f"sess-hash:plain message:{file_id}".encode()
).hexdigest()[:16]
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
assert dedup_key == expected_key, (
f"Dedup key {dedup_key!r} does not match expected {expected_key!r}"
"hash may be using mutated message or wrong inputs"
)
def test_stream_chat_dedup_key_released_after_stream_finish(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The dedup Redis key must be deleted after the turn completes (when
subscriber_queue is None the route yields StreamFinish immediately and
should release the key so the user can re-send the same message)."""
from unittest.mock import AsyncMock as _AsyncMock
# Set up all internals manually so we can control subscribe_to_session.
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
response = client.post(
"/sessions/sess-finish/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
body = response.text
assert '"finish"' in body
# The dedup key must be released so intentional re-sends are allowed.
mock_redis.delete.assert_called_once()
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
mocker: pytest_mock.MockerFixture,
) -> None:
"""The route must not crash when the dedup Redis delete fails on the
subscriber_queue-is-None early-finish path (except Exception: pass)."""
from unittest.mock import AsyncMock as _AsyncMock
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = _AsyncMock(return_value=None)
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mock_redis = mocker.AsyncMock()
mock_redis.set = _AsyncMock(return_value=True)
# Make the delete raise so the except-pass branch is exercised.
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=_AsyncMock,
return_value=mock_redis,
)
# Should not raise even though delete fails.
response = client.post(
"/sessions/sess-finish-err/stream",
json={"message": "hello", "is_user_message": True},
)
assert response.status_code == 200
assert '"finish"' in response.text
# delete must have been attempted — the except-pass branch silenced the error.
mock_redis.delete.assert_called_once()
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
def test_disconnect_stream_returns_204_and_awaits_registry(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_session = MagicMock()
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=mock_session,
)
mock_disconnect = mocker.patch(
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
new_callable=AsyncMock,
return_value=2,
)
response = client.delete("/sessions/sess-1/stream")
assert response.status_code == 204
mock_disconnect.assert_awaited_once_with("sess-1")
def test_disconnect_stream_returns_404_when_session_missing(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=None,
)
mock_disconnect = mocker.patch(
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
new_callable=AsyncMock,
)
response = client.delete("/sessions/unknown-session/stream")
assert response.status_code == 404
mock_disconnect.assert_not_awaited()

View File

@@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
CopilotMode = Literal["fast", "extended_thinking"]
# Per-request model tier set by the frontend model toggle.
# 'standard' uses the global config default (currently Sonnet).
# 'advanced' forces the highest-capability model (currently Opus).
# None means no preference — falls through to LD per-user targeting, then config.
# Using tier names instead of model names keeps the contract model-agnostic.
CopilotLlmModel = Literal["standard", "advanced"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
@@ -163,12 +170,12 @@ class ChatConfig(BaseSettings):
"CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.",
)
claude_agent_max_budget_usd: float = Field(
default=15.0,
default=10.0,
ge=0.01,
le=1000.0,
description="Maximum spend in USD per SDK query. The CLI attempts "
"to wrap up gracefully when this budget is reached. "
"Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
"Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
"Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.",
)
claude_agent_max_thinking_tokens: int = Field(

View File

@@ -351,6 +351,7 @@ class CoPilotProcessor:
context=entry.context,
file_ids=entry.file_ids,
mode=effective_mode,
model=entry.model,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,

View File

@@ -9,7 +9,7 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotMode
from backend.copilot.config import CopilotLlmModel, CopilotMode
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel):
mode: CopilotMode | None = None
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
model: CopilotLlmModel | None = None
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -180,6 +183,7 @@ async def enqueue_copilot_turn(
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
model: CopilotLlmModel | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -192,6 +196,7 @@ async def enqueue_copilot_turn(
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
model: Per-request model tier ('standard' or 'advanced'). None = server default.
"""
from backend.util.clients import get_async_copilot_queue
@@ -204,6 +209,7 @@ async def enqueue_copilot_turn(
context=context,
file_ids=file_ids,
mode=mode,
model=model,
)
queue_client = await get_async_copilot_queue()

View File

@@ -0,0 +1,71 @@
"""Per-request idempotency lock for the /stream endpoint.
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
Lifecycle
---------
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
``None`` when the key already exists (duplicate request).
2. ``release()`` — deletes the key. Must be called on turn completion or turn
error so the next legitimate send is never blocked.
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
the backend turn is still running, and releasing would reopen the duplicate
window for infra-level retries. The 30 s TTL is the safety net.
"""
import hashlib
import logging
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_KEY_PREFIX = "chat:msg_dedup"
_TTL_SECONDS = 30
class _DedupLock:
def __init__(self, key: str, redis) -> None:
self._key = key
self._redis = redis
async def release(self) -> None:
"""Best-effort key deletion. The TTL handles failures silently."""
try:
await self._redis.delete(self._key)
except Exception:
pass
async def acquire_dedup_lock(
session_id: str,
message: str | None,
file_ids: list[str] | None,
) -> _DedupLock | None:
"""Acquire the idempotency lock for this (session, message, files) tuple.
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
Returns ``None`` when a duplicate is detected (lock already held).
Returns ``None`` when there is nothing to deduplicate (no message, no files).
"""
if not message and not file_ids:
return None
sorted_ids = ":".join(sorted(file_ids or []))
content_hash = hashlib.sha256(
f"{session_id}:{message or ''}:{sorted_ids}".encode()
).hexdigest()[:16]
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
redis = await get_redis_async()
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
if not acquired:
logger.warning(
f"[STREAM] Duplicate user message blocked for session {session_id}, "
f"hash={content_hash} — returning empty SSE",
)
return None
return _DedupLock(key, redis)

View File

@@ -0,0 +1,94 @@
"""Unit tests for backend.copilot.message_dedup."""
from unittest.mock import AsyncMock
import pytest
import pytest_mock
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=set_returns)
mocker.patch(
"backend.copilot.message_dedup.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
return mock_redis
@pytest.mark.asyncio
async def test_acquire_returns_none_when_no_message_no_files(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Nothing to deduplicate — no Redis call made, None returned."""
mock_redis = _patch_redis(mocker, set_returns=True)
result = await acquire_dedup_lock("sess-1", None, None)
assert result is None
mock_redis.set.assert_not_called()
@pytest.mark.asyncio
async def test_acquire_returns_lock_on_first_request(
mocker: pytest_mock.MockerFixture,
) -> None:
"""First request acquires the lock and returns a _DedupLock."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
mock_redis.set.assert_called_once()
key_arg = mock_redis.set.call_args.args[0]
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
@pytest.mark.asyncio
async def test_acquire_returns_none_on_duplicate(
mocker: pytest_mock.MockerFixture,
) -> None:
"""Duplicate request (NX fails) returns None to signal the caller."""
_patch_redis(mocker, set_returns=None)
result = await acquire_dedup_lock("sess-1", "hello", None)
assert result is None
@pytest.mark.asyncio
async def test_acquire_key_stable_across_file_order(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs are sorted before hashing so order doesn't affect the key."""
mock_redis_1 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
key_ab = mock_redis_1.set.call_args.args[0]
mock_redis_2 = _patch_redis(mocker, set_returns=True)
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
key_ba = mock_redis_2.set.call_args.args[0]
assert key_ab == key_ba
@pytest.mark.asyncio
async def test_release_deletes_key(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() calls Redis delete exactly once."""
mock_redis = _patch_redis(mocker, set_returns=True)
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release()
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_swallows_redis_error(
mocker: pytest_mock.MockerFixture,
) -> None:
"""release() must not raise even when Redis delete fails."""
mock_redis = _patch_redis(mocker, set_returns=True)
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
lock = await acquire_dedup_lock("sess-1", "hello", None)
assert lock is not None
await lock.release() # must not raise
mock_redis.delete.assert_called_once()

View File

@@ -302,6 +302,7 @@ async def record_token_usage(
*,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
model_cost_multiplier: float = 1.0,
) -> None:
"""Record token usage for a user across all windows.
@@ -315,12 +316,17 @@ async def record_token_usage(
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
from the API response). Cache counts are passed separately.
``model_cost_multiplier`` scales the final weighted total to reflect
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
so that Opus turns deplete the rate limit faster, proportional to cost.
Args:
user_id: The user's ID.
prompt_tokens: Uncached input tokens.
completion_tokens: Output tokens.
cache_read_tokens: Tokens served from prompt cache (10% cost).
cache_creation_tokens: Tokens written to prompt cache (25% cost).
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
"""
prompt_tokens = max(0, prompt_tokens)
completion_tokens = max(0, completion_tokens)
@@ -332,7 +338,9 @@ async def record_token_usage(
+ round(cache_creation_tokens * 0.25)
+ round(cache_read_tokens * 0.1)
)
total = weighted_input + completion_tokens
total = round(
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
)
if total <= 0:
return
@@ -340,11 +348,12 @@ async def record_token_usage(
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
)
logger.info(
"Recording token usage for %s: raw=%d, weighted=%d "
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
user_id[:8],
raw_total,
total,
model_cost_multiplier,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,

View File

@@ -34,9 +34,13 @@ Steps:
always inspect the current graph first so you know exactly what to change.
Avoid using `include_graph=true` with broad keyword searches, as fetching
multiple graphs at once is expensive and consumes LLM context budget.
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
2. **Discover blocks**: Call `find_block(query, include_schemas=true, for_agent_generation=true)` to
search for relevant blocks. This returns block IDs, names, descriptions,
and full input/output schemas.
and full input/output schemas. The `for_agent_generation=true` flag is
required to surface graph-only blocks such as AgentInputBlock,
AgentDropdownInputBlock, AgentOutputBlock, OrchestratorBlock,
and WebhookBlock and MCPToolBlock. (When running MCP tools interactively
in CoPilot outside agent generation, use `run_mcp_tool` instead.)
3. **Find library agents**: Call `find_library_agent` to discover reusable
agents that can be composed as sub-agents via `AgentExecutorBlock`.
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
@@ -177,6 +181,12 @@ To compose agents using other agents as sub-agents:
### Using MCP Tools (MCPToolBlock)
> **Agent graph vs CoPilot direct execution**: This section covers embedding MCP
> tools as persistent nodes in an agent graph. When running MCP tools directly in
> CoPilot (outside agent generation), use `run_mcp_tool` instead — it handles
> server discovery and authentication interactively. Use `MCPToolBlock` here only
> when the user wants the MCP call baked into a reusable agent graph.
To use an MCP (Model Context Protocol) tool as a node in the agent:
1. The user must specify which MCP server URL and tool name they want
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)

View File

@@ -207,7 +207,7 @@ class TestConfigDefaults:
def test_max_budget_usd_default(self):
cfg = _make_config()
assert cfg.claude_agent_max_budget_usd == 15.0
assert cfg.claude_agent_max_budget_usd == 10.0
def test_max_thinking_tokens_default(self):
cfg = _make_config()

View File

@@ -56,7 +56,7 @@ from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from ..config import ChatConfig, CopilotMode
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
from ..constants import (
COPILOT_ERROR_PREFIX,
COPILOT_RETRYABLE_ERROR_PREFIX,
@@ -132,6 +132,11 @@ _MAX_STREAM_ATTEMPTS = 3
# self-correct. The limit is generous to allow recovery attempts.
_EMPTY_TOOL_CALL_LIMIT = 5
# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet
# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus
# turns deplete quota proportionally faster.
_OPUS_COST_MULTIPLIER = 5.0
# User-facing error shown when the empty-tool-call circuit breaker trips.
_CIRCUIT_BREAKER_ERROR_MSG = (
"AutoPilot was unable to complete the tool call "
@@ -674,6 +679,48 @@ def _resolve_fallback_model() -> str | None:
return _normalize_model_name(raw)
async def _resolve_model_and_multiplier(
model: "CopilotLlmModel | None",
session_id: str,
) -> tuple[str | None, float]:
"""Resolve the SDK model string and rate-limit cost multiplier for a turn.
Priority (highest first):
1. Explicit per-request ``model`` tier from the frontend toggle.
2. Global config default (``_resolve_sdk_model()``).
Returns a ``(sdk_model, cost_multiplier)`` pair.
``sdk_model`` is ``None`` when the Claude Code subscription default applies.
``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise.
"""
sdk_model = _resolve_sdk_model()
if model == "advanced":
sdk_model = _normalize_model_name("anthropic/claude-opus-4-6")
logger.info(
"[SDK] [%s] Per-request model override: advanced (%s)",
session_id[:12] if session_id else "?",
sdk_model,
)
return sdk_model, _OPUS_COST_MULTIPLIER
if model == "standard":
# Reset to config default — respects subscription mode (None = CLI default).
sdk_model = _resolve_sdk_model()
logger.info(
"[SDK] [%s] Per-request model override: standard (%s)",
session_id[:12] if session_id else "?",
sdk_model or "subscription-default",
)
return sdk_model, 1.0
# No per-request override; derive multiplier from final resolved model.
cost_multiplier = (
_OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0
)
return sdk_model, cost_multiplier
_MAX_TRANSIENT_BACKOFF_SECONDS = 30
@@ -1865,15 +1912,20 @@ async def _run_stream_attempt(
# cache_read_input_tokens = served from cache
# cache_creation_input_tokens = written to cache
if sdk_msg.usage:
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
state.usage.cache_read_tokens += sdk_msg.usage.get(
"cache_read_input_tokens", 0
# Use `or 0` instead of a default in .get() because
# OpenRouter may include the key with a null value (e.g.
# {"cache_read_input_tokens": null}) for models that don't
# yet report cache tokens, making .get("key", 0) return
# None rather than the fallback 0.
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0
state.usage.cache_read_tokens += (
sdk_msg.usage.get("cache_read_input_tokens") or 0
)
state.usage.cache_creation_tokens += sdk_msg.usage.get(
"cache_creation_input_tokens", 0
state.usage.cache_creation_tokens += (
sdk_msg.usage.get("cache_creation_input_tokens") or 0
)
state.usage.completion_tokens += sdk_msg.usage.get(
"output_tokens", 0
state.usage.completion_tokens += (
sdk_msg.usage.get("output_tokens") or 0
)
logger.info(
"%s Token usage: uncached=%d, cache_read=%d, "
@@ -2150,6 +2202,7 @@ async def stream_chat_completion_sdk(
file_ids: list[str] | None = None,
permissions: "CopilotPermissions | None" = None,
mode: CopilotMode | None = None,
model: CopilotLlmModel | None = None,
**_kwargs: Any,
) -> AsyncIterator[StreamBaseResponse]:
"""Stream chat completion using Claude Agent SDK.
@@ -2160,6 +2213,9 @@ async def stream_chat_completion_sdk(
saved to the SDK working directory for the Read tool.
mode: Accepted for signature compatibility with the baseline path.
The SDK path does not currently branch on this value.
model: Per-request model preference from the frontend toggle.
'advanced' → Claude Opus; 'standard' → global config default.
Takes priority over per-user LaunchDarkly targeting.
"""
_ = mode # SDK path ignores the requested mode.
@@ -2274,6 +2330,10 @@ async def stream_chat_completion_sdk(
turn_cache_creation_tokens = 0
turn_cost_usd: float | None = None
graphiti_enabled = False
# Defaults ensure the finally block can always reference these safely even when
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
sdk_model: str | None = None
model_cost_multiplier: float = 1.0
# Make sure there is no more code between the lock acquisition and try-block.
try:
@@ -2487,7 +2547,10 @@ async def stream_chat_completion_sdk(
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
sdk_model = _resolve_sdk_model()
# Resolve model and cost multiplier (request tier → config default).
sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier(
model, session_id
)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
@@ -3188,8 +3251,9 @@ async def stream_chat_completion_sdk(
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
model=config.model,
model=sdk_model or config.model,
provider="anthropic",
model_cost_multiplier=model_cost_multiplier,
)
# --- Persist session messages ---

View File

@@ -20,7 +20,9 @@ from .service import (
_is_prompt_too_long,
_is_tool_only_message,
_iter_sdk_messages,
_normalize_model_name,
_reduce_context,
_TokenUsage,
)
# ---------------------------------------------------------------------------
@@ -350,3 +352,128 @@ class TestIsParallelContinuation:
msg = MagicMock(spec=AssistantMessage)
msg.content = [self._make_tool_block()]
assert _is_tool_only_message(msg) is True
# ---------------------------------------------------------------------------
# _normalize_model_name — used by per-request model override
# ---------------------------------------------------------------------------
class TestNormalizeModelName:
"""Unit tests for the model-name normalisation helper.
The per-request model toggle calls _normalize_model_name with either
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
'standard'). These tests verify the OpenRouter/provider-prefix stripping
that keeps the value compatible with the Claude CLI.
"""
def test_strips_anthropic_prefix(self):
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_strips_openai_prefix(self):
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
def test_strips_google_prefix(self):
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
def test_already_normalized_unchanged(self):
assert (
_normalize_model_name("claude-sonnet-4-20250514")
== "claude-sonnet-4-20250514"
)
def test_empty_string_unchanged(self):
assert _normalize_model_name("") == ""
def test_opus_model_roundtrip(self):
"""The exact string used for the 'opus' toggle strips correctly."""
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_sonnet_openrouter_model(self):
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
# ---------------------------------------------------------------------------
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
# ---------------------------------------------------------------------------
class TestTokenUsageNullSafety:
"""Verify that ResultMessage.usage dicts with null-valued cache fields
(as emitted by OpenRouter for the initial streaming event before real
token counts are available) do not crash the accumulator.
Before the fix, dict.get("cache_read_input_tokens", 0) returned None
when the key existed with a null value, causing 'int += None' TypeError.
"""
def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None:
"""Mirror the production accumulation in sdk/service.py."""
acc.prompt_tokens += usage.get("input_tokens") or 0
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
acc.completion_tokens += usage.get("output_tokens") or 0
def test_null_cache_tokens_do_not_crash(self):
"""OpenRouter initial event: cache keys present with null value."""
usage = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_input_tokens": None,
"cache_creation_input_tokens": None,
}
acc = _TokenUsage()
self._apply_usage(usage, acc) # must not raise TypeError
assert acc.prompt_tokens == 0
assert acc.cache_read_tokens == 0
assert acc.cache_creation_tokens == 0
assert acc.completion_tokens == 0
def test_real_cache_tokens_are_accumulated(self):
"""OpenRouter final event: real cache token counts are captured."""
usage = {
"input_tokens": 10,
"output_tokens": 349,
"cache_read_input_tokens": 16600,
"cache_creation_input_tokens": 512,
}
acc = _TokenUsage()
self._apply_usage(usage, acc)
assert acc.prompt_tokens == 10
assert acc.cache_read_tokens == 16600
assert acc.cache_creation_tokens == 512
assert acc.completion_tokens == 349
def test_absent_cache_keys_default_to_zero(self):
"""Minimal usage dict without cache keys defaults correctly."""
usage = {"input_tokens": 5, "output_tokens": 20}
acc = _TokenUsage()
self._apply_usage(usage, acc)
assert acc.prompt_tokens == 5
assert acc.cache_read_tokens == 0
assert acc.cache_creation_tokens == 0
assert acc.completion_tokens == 20
def test_multi_turn_accumulation(self):
"""Null event followed by real event: only real tokens counted."""
null_event = {
"input_tokens": 0,
"output_tokens": 0,
"cache_read_input_tokens": None,
"cache_creation_input_tokens": None,
}
real_event = {
"input_tokens": 10,
"output_tokens": 349,
"cache_read_input_tokens": 16600,
"cache_creation_input_tokens": 512,
}
acc = _TokenUsage()
self._apply_usage(null_event, acc)
self._apply_usage(real_event, acc)
assert acc.prompt_tokens == 10
assert acc.cache_read_tokens == 16600
assert acc.cache_creation_tokens == 512
assert acc.completion_tokens == 349

View File

@@ -1149,3 +1149,50 @@ async def unsubscribe_from_session(
)
logger.debug(f"Successfully unsubscribed from session {session_id}")
async def disconnect_all_listeners(session_id: str) -> int:
"""Cancel every active listener task for *session_id*.
Called when the frontend switches away from a session and wants the
backend to release resources immediately rather than waiting for the
XREAD timeout.
Scope / limitations (best-effort optimisation, not a correctness primitive):
- Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request
lands on a different worker than the one serving the SSE, no listener
is cancelled here — the SSE worker still releases on its XREAD timeout.
- Session-scoped (not subscriber-scoped): cancels every active listener
for the session on this pod. In the rare case a single user opens two
SSE connections to the same session on the same pod (e.g. two tabs),
both would be torn down. Cross-pod, subscriber-scoped cancellation
would require a Redis pub/sub fan-out with per-listener tokens; that
is not implemented here because the XREAD timeout already bounds the
worst case.
Returns the number of listener tasks that were cancelled.
"""
to_cancel: list[tuple[int, asyncio.Task]] = [
(qid, task)
for qid, (sid, task) in list(_listener_sessions.items())
if sid == session_id and not task.done()
]
for qid, task in to_cancel:
_listener_sessions.pop(qid, None)
task.cancel()
cancelled = 0
for _qid, task in to_cancel:
try:
await asyncio.wait_for(task, timeout=5.0)
except asyncio.CancelledError:
cancelled += 1
except asyncio.TimeoutError:
pass
except Exception as e:
logger.error(f"Error cancelling listener for session {session_id}: {e}")
if cancelled:
logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}")
return cancelled

View File

@@ -0,0 +1,110 @@
"""Tests for disconnect_all_listeners in stream_registry."""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot import stream_registry
@pytest.fixture(autouse=True)
def _clear_listener_sessions():
stream_registry._listener_sessions.clear()
yield
stream_registry._listener_sessions.clear()
async def _sleep_forever():
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
raise
@pytest.mark.asyncio
async def test_disconnect_all_listeners_cancels_matching_session():
task_a = asyncio.create_task(_sleep_forever())
task_b = asyncio.create_task(_sleep_forever())
task_other = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-1", task_a)
stream_registry._listener_sessions[2] = ("sess-1", task_b)
stream_registry._listener_sessions[3] = ("sess-other", task_other)
try:
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 2
assert task_a.cancelled()
assert task_b.cancelled()
assert not task_other.done()
# Matching entries are removed, non-matching entries remain.
assert 1 not in stream_registry._listener_sessions
assert 2 not in stream_registry._listener_sessions
assert 3 in stream_registry._listener_sessions
finally:
task_other.cancel()
try:
await task_other
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_disconnect_all_listeners_no_match_returns_zero():
task = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-other", task)
try:
cancelled = await stream_registry.disconnect_all_listeners("sess-missing")
assert cancelled == 0
assert not task.done()
assert 1 in stream_registry._listener_sessions
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_disconnect_all_listeners_skips_already_done_tasks():
async def _noop():
return None
done_task = asyncio.create_task(_noop())
await done_task
stream_registry._listener_sessions[1] = ("sess-1", done_task)
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
# Done tasks are filtered out before cancellation.
assert cancelled == 0
@pytest.mark.asyncio
async def test_disconnect_all_listeners_empty_registry():
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 0
@pytest.mark.asyncio
async def test_disconnect_all_listeners_timeout_not_counted():
"""Tasks that don't respond to cancellation (timeout) are not counted."""
task = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-1", task)
with patch.object(
asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError)
):
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 0
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@@ -96,6 +96,7 @@ async def persist_and_record_usage(
cost_usd: float | str | None = None,
model: str | None = None,
provider: str = "open_router",
model_cost_multiplier: float = 1.0,
) -> int:
"""Persist token usage to session and record for rate limiting.
@@ -109,6 +110,9 @@ async def persist_and_record_usage(
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
cost_usd: Optional cost for logging (float from SDK, str otherwise).
provider: Cost provider name (e.g. "anthropic", "open_router").
model_cost_multiplier: Relative model cost factor for rate limiting
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
more expensive models deplete the rate limit proportionally faster.
Returns:
The computed total_tokens (prompt + completion; cache excluded).
@@ -163,6 +167,7 @@ async def persist_and_record_usage(
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
model_cost_multiplier=model_cost_multiplier,
)
except Exception as usage_err:
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)

View File

@@ -230,6 +230,7 @@ class TestRateLimitRecording:
completion_tokens=50,
cache_read_tokens=1000,
cache_creation_tokens=200,
model_cost_multiplier=1.0,
)
@pytest.mark.asyncio

View File

@@ -74,6 +74,15 @@ class FindBlockTool(BaseTool):
"description": "Include full input/output schemas (for agent JSON generation).",
"default": False,
},
"for_agent_generation": {
"type": "boolean",
"description": (
"Set to true when searching for blocks to use inside an agent graph "
"(e.g. AgentInputBlock, AgentOutputBlock, OrchestratorBlock). "
"Bypasses the CoPilot-only filter so graph-only blocks are visible."
),
"default": False,
},
},
"required": ["query"],
}
@@ -88,6 +97,7 @@ class FindBlockTool(BaseTool):
session: ChatSession,
query: str = "",
include_schemas: bool = False,
for_agent_generation: bool = False,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
@@ -97,6 +107,8 @@ class FindBlockTool(BaseTool):
session: Chat session
query: Search query
include_schemas: Whether to include block schemas in results
for_agent_generation: When True, bypasses the CoPilot exclusion filter
so graph-only blocks (INPUT, OUTPUT, ORCHESTRATOR, etc.) are visible.
Returns:
BlockListResponse: List of matching blocks
@@ -123,34 +135,36 @@ class FindBlockTool(BaseTool):
suggestions=["Search for an alternative block by name"],
session_id=session_id,
)
if (
is_excluded = (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):
if block.block_type == BlockType.MCP_TOOL:
)
if is_excluded:
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
# exposed when building an agent graph so the LLM can inspect
# their schemas and wire them as nodes. In CoPilot direct use
# they are not executable — guide the LLM to the right tool.
if not for_agent_generation:
if block.block_type == BlockType.MCP_TOOL:
message = (
f"Block '{block.name}' (ID: {block.id}) cannot be "
"run directly in CoPilot. Use run_mcp_tool for "
"interactive MCP execution, or call find_block with "
"for_agent_generation=true to embed it in an agent graph."
)
else:
message = (
f"Block '{block.name}' (ID: {block.id}) is not available "
"in CoPilot. It can only be used within agent graphs."
)
return NoResultsResponse(
message=(
f"Block '{block.name}' (ID: {block.id}) is not "
"runnable through find_block/run_block. Use "
"run_mcp_tool instead."
),
message=message,
suggestions=[
"Use run_mcp_tool to discover and run this MCP tool",
"Search for an alternative block by name",
"Use this block in an agent graph instead",
],
session_id=session_id,
)
return NoResultsResponse(
message=(
f"Block '{block.name}' (ID: {block.id}) is not available "
"in CoPilot. It can only be used within agent graphs."
),
suggestions=[
"Search for an alternative block by name",
"Use this block in an agent graph instead",
],
session_id=session_id,
)
# Check block-level permissions — hide denied blocks entirely
perms = get_current_permissions()
@@ -221,8 +235,9 @@ class FindBlockTool(BaseTool):
if not block or block.disabled:
continue
# Skip blocks excluded from CoPilot (graph-only blocks)
if (
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
# skipped in CoPilot direct use but surfaced for agent graph building.
if not for_agent_generation and (
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
):

View File

@@ -12,7 +12,7 @@ from .find_block import (
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from .models import BlockListResponse
from .models import BlockListResponse, NoResultsResponse
_TEST_USER_ID = "test-user-find-block"
@@ -166,6 +166,194 @@ class TestFindBlockFiltering:
assert len(response.blocks) == 1
assert response.blocks[0].id == "normal-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_for_agent_generation_exposes_excluded_blocks_in_search(self):
"""With for_agent_generation=True, excluded block types appear in search results."""
session = make_session(user_id=_TEST_USER_ID)
search_results = [
{"content_id": "input-block-id", "score": 0.9},
{"content_id": "output-block-id", "score": 0.8},
]
input_block = make_mock_block("input-block-id", "Agent Input", BlockType.INPUT)
output_block = make_mock_block(
"output-block-id", "Agent Output", BlockType.OUTPUT
)
def mock_get_block(block_id):
return {
"input-block-id": input_block,
"output-block-id": output_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="agent input",
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 2
block_ids = {b.id for b in response.blocks}
assert "input-block-id" in block_ids
assert "output-block-id" in block_ids
@pytest.mark.asyncio(loop_scope="session")
async def test_mcp_tool_exposed_with_for_agent_generation_in_search(self):
"""MCP_TOOL blocks appear in search results when for_agent_generation=True."""
session = make_session(user_id=_TEST_USER_ID)
search_results = [
{"content_id": "mcp-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
standard_block = make_mock_block(
"standard-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"mcp-block-id": mcp_block,
"standard-block-id": standard_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="mcp tool",
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 2
assert any(b.id == "mcp-block-id" for b in response.blocks)
assert any(b.id == "standard-block-id" for b in response.blocks)
@pytest.mark.asyncio(loop_scope="session")
async def test_mcp_tool_excluded_without_for_agent_generation_in_search(self):
"""MCP_TOOL blocks are excluded from search in normal CoPilot mode."""
session = make_session(user_id=_TEST_USER_ID)
search_results = [
{"content_id": "mcp-block-id", "score": 0.9},
{"content_id": "standard-block-id", "score": 0.8},
]
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
standard_block = make_mock_block(
"standard-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
"mcp-block-id": mcp_block,
"standard-block-id": standard_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="mcp tool",
for_agent_generation=False,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 1
assert response.blocks[0].id == "standard-block-id"
@pytest.mark.asyncio(loop_scope="session")
async def test_for_agent_generation_exposes_excluded_ids_in_search(self):
"""With for_agent_generation=True, excluded block IDs appear in search results."""
session = make_session(user_id=_TEST_USER_ID)
orchestrator_id = next(iter(COPILOT_EXCLUDED_BLOCK_IDS))
search_results = [
{"content_id": orchestrator_id, "score": 0.9},
{"content_id": "normal-block-id", "score": 0.8},
]
orchestrator_block = make_mock_block(
orchestrator_id, "Orchestrator", BlockType.STANDARD
)
normal_block = make_mock_block(
"normal-block-id", "Normal Block", BlockType.STANDARD
)
def mock_get_block(block_id):
return {
orchestrator_id: orchestrator_block,
"normal-block-id": normal_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query="orchestrator",
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert len(response.blocks) == 2
block_ids = {b.id for b in response.blocks}
assert orchestrator_id in block_ids
assert "normal-block-id" in block_ids
@pytest.mark.asyncio(loop_scope="session")
async def test_response_size_average_chars_per_block(self):
"""Measure average chars per block in the serialized response."""
@@ -549,8 +737,6 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=block_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
@pytest.mark.asyncio(loop_scope="session")
@@ -571,8 +757,6 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=block_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
assert "disabled" in response.message.lower()
@@ -592,8 +776,6 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=block_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
assert "not available" in response.message.lower()
@@ -613,7 +795,74 @@ class TestFindBlockDirectLookup:
user_id=_TEST_USER_ID, session=session, query=orchestrator_id
)
from .models import NoResultsResponse
assert isinstance(response, NoResultsResponse)
assert "not available" in response.message.lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_excluded_block_type_allowed_with_for_agent_generation(
self,
):
"""With for_agent_generation=True, excluded block types (INPUT) are visible."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
block = make_mock_block(block_id, "Agent Input Block", BlockType.INPUT)
with patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query=block_id,
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert response.count == 1
assert response.blocks[0].id == block_id
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_mcp_tool_exposed_with_for_agent_generation(self):
"""MCP_TOOL blocks are returned by UUID lookup when for_agent_generation=True."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
with patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query=block_id,
for_agent_generation=True,
)
assert isinstance(response, BlockListResponse)
assert response.blocks[0].id == block_id
@pytest.mark.asyncio(loop_scope="session")
async def test_uuid_lookup_mcp_tool_excluded_without_for_agent_generation(self):
"""MCP_TOOL blocks are excluded by UUID lookup in normal CoPilot mode."""
session = make_session(user_id=_TEST_USER_ID)
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
with patch(
"backend.copilot.tools.find_block.get_block",
return_value=block,
):
tool = FindBlockTool()
response = await tool._execute(
user_id=_TEST_USER_ID,
session=session,
query=block_id,
for_agent_generation=False,
)
assert isinstance(response, NoResultsResponse)
assert "run_mcp_tool" in response.message

View File

@@ -1,6 +1,7 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { IMPERSONATION_HEADER_NAME } from "@/lib/constants";
import { getCopilotAuthHeaders } from "../helpers";
import { getCopilotAuthHeaders, getSendSuppressionReason } from "../helpers";
import type { UIMessage } from "ai";
vi.mock("@/lib/supabase/actions", () => ({
getWebSocketToken: vi.fn(),
@@ -72,3 +73,71 @@ describe("getCopilotAuthHeaders", () => {
);
});
});
// ─── getSendSuppressionReason ─────────────────────────────────────────────────
function makeUserMsg(text: string): UIMessage {
return {
id: "msg-1",
role: "user",
content: text,
parts: [{ type: "text", text }],
} as UIMessage;
}
describe("getSendSuppressionReason", () => {
it("returns null when no dedup context exists (fresh ref)", () => {
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: null,
messages: [],
});
expect(result).toBeNull();
});
it("returns 'reconnecting' when reconnect is scheduled regardless of text", () => {
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: true,
lastSubmittedText: null,
messages: [],
});
expect(result).toBe("reconnecting");
});
it("returns 'duplicate' when same text was submitted and is the last user message", () => {
// This is the core regression test: after a successful turn the ref
// is intentionally NOT cleared to null, so submitting the same text
// again is caught here.
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages: [makeUserMsg("hello")],
});
expect(result).toBe("duplicate");
});
it("returns null when same ref text but different last user message (different question)", () => {
// User asked "hello" before, got a reply, then asked a different question
// — the last user message in chat is now different, so no suppression.
const result = getSendSuppressionReason({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages: [makeUserMsg("hello"), makeUserMsg("something else")],
});
expect(result).toBeNull();
});
it("returns null when text differs from lastSubmittedText", () => {
const result = getSendSuppressionReason({
text: "new question",
isReconnectScheduled: false,
lastSubmittedText: "old question",
messages: [makeUserMsg("old question")],
});
expect(result).toBeNull();
});
});

View File

@@ -1,4 +1,4 @@
import { describe, expect, it, beforeEach, vi } from "vitest";
import { describe, expect, it, beforeEach, afterEach, vi } from "vitest";
import { useCopilotUIStore } from "../store";
vi.mock("@sentry/nextjs", () => ({
@@ -22,7 +22,8 @@ describe("useCopilotUIStore", () => {
isNotificationsEnabled: false,
isSoundEnabled: true,
showNotificationDialog: false,
copilotMode: "extended_thinking",
copilotChatMode: "extended_thinking",
copilotLlmModel: "standard",
});
});
@@ -154,35 +155,52 @@ describe("useCopilotUIStore", () => {
});
});
describe("copilotMode", () => {
describe("copilotChatMode", () => {
it("defaults to extended_thinking", () => {
expect(useCopilotUIStore.getState().copilotMode).toBe(
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
"extended_thinking",
);
});
it("sets mode to fast", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
expect(useCopilotUIStore.getState().copilotMode).toBe("fast");
useCopilotUIStore.getState().setCopilotChatMode("fast");
expect(useCopilotUIStore.getState().copilotChatMode).toBe("fast");
});
it("sets mode back to extended_thinking", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
useCopilotUIStore.getState().setCopilotMode("extended_thinking");
expect(useCopilotUIStore.getState().copilotMode).toBe(
useCopilotUIStore.getState().setCopilotChatMode("fast");
useCopilotUIStore.getState().setCopilotChatMode("extended_thinking");
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
"extended_thinking",
);
});
it("does not persist mode to localStorage", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
it("persists mode to localStorage", () => {
useCopilotUIStore.getState().setCopilotChatMode("fast");
expect(window.localStorage.getItem("copilot-mode")).toBe("fast");
});
});
describe("copilotLlmModel", () => {
it("defaults to standard", () => {
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("standard");
});
it("sets model to advanced", () => {
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("advanced");
});
it("persists model to localStorage", () => {
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
expect(window.localStorage.getItem("copilot-model")).toBe("advanced");
});
});
describe("clearCopilotLocalData", () => {
it("resets state and clears localStorage keys", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
useCopilotUIStore.getState().setCopilotChatMode("fast");
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
useCopilotUIStore.getState().setNotificationsEnabled(true);
useCopilotUIStore.getState().toggleSound();
useCopilotUIStore.getState().addCompletedSession("s1");
@@ -190,7 +208,8 @@ describe("useCopilotUIStore", () => {
useCopilotUIStore.getState().clearCopilotLocalData();
const state = useCopilotUIStore.getState();
expect(state.copilotMode).toBe("extended_thinking");
expect(state.copilotChatMode).toBe("extended_thinking");
expect(state.copilotLlmModel).toBe("standard");
expect(state.isNotificationsEnabled).toBe(false);
expect(state.isSoundEnabled).toBe(true);
expect(state.completedSessionIDs.size).toBe(0);
@@ -198,6 +217,8 @@ describe("useCopilotUIStore", () => {
window.localStorage.getItem("copilot-notifications-enabled"),
).toBeNull();
expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull();
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
expect(window.localStorage.getItem("copilot-model")).toBeNull();
expect(
window.localStorage.getItem("copilot-completed-sessions"),
).toBeNull();
@@ -222,3 +243,24 @@ describe("useCopilotUIStore", () => {
});
});
});
describe("useCopilotUIStore localStorage initialisation", () => {
afterEach(() => {
vi.resetModules();
window.localStorage.clear();
});
it("reads fast chat mode from localStorage on store creation", async () => {
window.localStorage.setItem("copilot-mode", "fast");
vi.resetModules();
const { useCopilotUIStore: fresh } = await import("../store");
expect(fresh.getState().copilotChatMode).toBe("fast");
});
it("reads advanced model from localStorage on store creation", async () => {
window.localStorage.setItem("copilot-model", "advanced");
vi.resetModules();
const { useCopilotUIStore: fresh } = await import("../store");
expect(fresh.getState().copilotLlmModel).toBe("advanced");
});
});

View File

@@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react";
import { AttachmentMenu } from "./components/AttachmentMenu";
import { DryRunToggleButton } from "./components/DryRunToggleButton";
import { FileChips } from "./components/FileChips";
import { ModelToggleButton } from "./components/ModelToggleButton";
import { ModeToggleButton } from "./components/ModeToggleButton";
import { RecordingButton } from "./components/RecordingButton";
import { RecordingIndicator } from "./components/RecordingIndicator";
@@ -50,16 +51,22 @@ export function ChatInput({
onDroppedFilesConsumed,
hasSession = false,
}: Props) {
const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } =
useCopilotUIStore();
const {
copilotChatMode,
setCopilotChatMode,
copilotLlmModel,
setCopilotLlmModel,
isDryRun,
setIsDryRun,
} = useCopilotUIStore();
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
const showDryRunToggle = showModeToggle;
const [files, setFiles] = useState<File[]>([]);
function handleToggleMode() {
const next =
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
setCopilotMode(next);
copilotChatMode === "extended_thinking" ? "fast" : "extended_thinking";
setCopilotChatMode(next);
toast({
title:
next === "fast"
@@ -72,6 +79,21 @@ export function ChatInput({
});
}
function handleToggleModel() {
const next = copilotLlmModel === "advanced" ? "standard" : "advanced";
setCopilotLlmModel(next);
toast({
title:
next === "advanced"
? "Switched to Advanced model"
: "Switched to Standard model",
description:
next === "advanced"
? "Using the highest-capability model."
: "Using the balanced standard model.",
});
}
function handleToggleDryRun() {
const next = !isDryRun;
setIsDryRun(next);
@@ -198,10 +220,16 @@ export function ChatInput({
/>
{showModeToggle && !isStreaming && (
<ModeToggleButton
mode={copilotMode}
mode={copilotChatMode}
onToggle={handleToggleMode}
/>
)}
{showModeToggle && !isStreaming && (
<ModelToggleButton
model={copilotLlmModel}
onToggle={handleToggleModel}
/>
)}
{showDryRunToggle && (!hasSession || isDryRun) && (
<DryRunToggleButton
isDryRun={isDryRun}

View File

@@ -8,14 +8,21 @@ import { afterEach, describe, expect, it, vi } from "vitest";
import { ChatInput } from "../ChatInput";
let mockCopilotMode = "extended_thinking";
const mockSetCopilotMode = vi.fn((mode: string) => {
const mockSetCopilotChatMode = vi.fn((mode: string) => {
mockCopilotMode = mode;
});
let mockCopilotLlmModel = "standard";
const mockSetCopilotLlmModel = vi.fn((model: string) => {
mockCopilotLlmModel = model;
});
vi.mock("@/app/(platform)/copilot/store", () => ({
useCopilotUIStore: () => ({
copilotMode: mockCopilotMode,
setCopilotMode: mockSetCopilotMode,
copilotChatMode: mockCopilotMode,
setCopilotChatMode: mockSetCopilotChatMode,
copilotLlmModel: mockCopilotLlmModel,
setCopilotLlmModel: mockSetCopilotLlmModel,
initialPrompt: null,
setInitialPrompt: vi.fn(),
}),
@@ -107,6 +114,7 @@ afterEach(() => {
cleanup();
vi.clearAllMocks();
mockCopilotMode = "extended_thinking";
mockCopilotLlmModel = "standard";
});
describe("ChatInput mode toggle", () => {
@@ -141,7 +149,7 @@ describe("ChatInput mode toggle", () => {
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
expect(mockSetCopilotMode).toHaveBeenCalledWith("fast");
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("fast");
});
it("toggles from fast to extended_thinking on click", () => {
@@ -149,7 +157,7 @@ describe("ChatInput mode toggle", () => {
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to extended thinking/i));
expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking");
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
});
it("hides toggle button when streaming", () => {
@@ -187,3 +195,69 @@ describe("ChatInput mode toggle", () => {
);
});
});
describe("ChatInput model toggle", () => {
it("renders model toggle button when flag is enabled", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} />);
expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined();
});
it("does not render model toggle when flag is disabled", () => {
mockFlagValue = false;
render(<ChatInput onSend={mockOnSend} />);
expect(
screen.queryByLabelText(/switch to (advanced|standard) model/i),
).toBeNull();
});
it("toggles from standard to advanced on click", () => {
mockFlagValue = true;
mockCopilotLlmModel = "standard";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced");
});
it("toggles from advanced to standard on click", () => {
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
});
it("hides model toggle when streaming", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
expect(
screen.queryByLabelText(/switch to (advanced|standard) model/i),
).toBeNull();
});
it("shows a toast when switching to advanced", async () => {
const { toast } = await import("@/components/molecules/Toast/use-toast");
mockFlagValue = true;
mockCopilotLlmModel = "standard";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to advanced model/i),
}),
);
});
it("shows a toast when switching to standard", async () => {
const { toast } = await import("@/components/molecules/Toast/use-toast");
mockFlagValue = true;
mockCopilotLlmModel = "advanced";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to standard model/i),
}),
);
});
});

View File

@@ -0,0 +1,38 @@
"use client";
import { cn } from "@/lib/utils";
import { Cpu } from "@phosphor-icons/react";
import type { CopilotLlmModel } from "../../../store";
interface Props {
model: CopilotLlmModel;
onToggle: () => void;
}
export function ModelToggleButton({ model, onToggle }: Props) {
const isAdvanced = model === "advanced";
return (
<button
type="button"
aria-pressed={isAdvanced}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isAdvanced
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
)}
aria-label={
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
}
title={
isAdvanced
? "Advanced model — highest capability (click to switch to Standard)"
: "Standard model — click to switch to Advanced"
}
>
<Cpu size={14} />
{isAdvanced && "Advanced"}
</button>
);
}

View File

@@ -0,0 +1,36 @@
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import { ModelToggleButton } from "../ModelToggleButton";
afterEach(cleanup);
describe("ModelToggleButton", () => {
it("shows no label when model is standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
expect(screen.queryByText("Advanced")).toBeNull();
});
it("shows Advanced label when model is advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
expect(screen.getByText("Advanced")).toBeTruthy();
});
it("calls onToggle when clicked", () => {
const onToggle = vi.fn();
render(<ModelToggleButton model="standard" onToggle={onToggle} />);
fireEvent.click(screen.getByRole("button"));
expect(onToggle).toHaveBeenCalledTimes(1);
});
it("sets aria-pressed=false for standard", () => {
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Advanced model");
expect(btn.getAttribute("aria-pressed")).toBe("false");
});
it("sets aria-pressed=true for advanced", () => {
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
const btn = screen.getByLabelText("Switch to Standard model");
expect(btn.getAttribute("aria-pressed")).toBe("true");
});
});

View File

@@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation";
import { getWebSocketToken } from "@/lib/supabase/actions";
import type { UIMessage } from "ai";
import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat";
export const ORIGINAL_TITLE = "AutoGPT";
/**
@@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend(
}
/**
* Deduplicate messages by ID and by content fingerprint.
* Fire-and-forget: tell the backend to release XREAD listeners for a session.
*
* Called on session switch so the backend doesn't wait for its 5-10 s timeout
* before cleaning up. Failures are silently ignored — the backend will
* eventually clean up on its own.
*/
export function disconnectSessionStream(sessionId: string): void {
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
}
/**
* Deduplicate messages by ID and by consecutive content fingerprint.
*
* ID dedup catches exact duplicates within the same source.
* Content dedup uses a composite key of `role + preceding-user-message-id +

View File

@@ -53,6 +53,9 @@ export const DEFAULT_PANEL_WIDTH = 600;
/** Autopilot response mode. */
export type CopilotMode = "extended_thinking" | "fast";
/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */
export type CopilotLlmModel = "standard" | "advanced";
const isClient = typeof window !== "undefined";
function getPersistedWidth(): number {
@@ -134,8 +137,12 @@ interface CopilotUIState {
goBackArtifact: () => void;
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
copilotMode: CopilotMode;
setCopilotMode: (mode: CopilotMode) => void;
copilotChatMode: CopilotMode;
setCopilotChatMode: (mode: CopilotMode) => void;
/** Model tier: 'standard' (default) or 'advanced' (highest-capability). */
copilotLlmModel: CopilotLlmModel;
setCopilotLlmModel: (model: CopilotLlmModel) => void;
/** Developer dry-run mode: sessions created with dry_run=true. */
isDryRun: boolean;
@@ -298,9 +305,22 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
};
}),
copilotMode: "extended_thinking",
setCopilotMode: (mode) => {
set({ copilotMode: mode });
copilotChatMode: (() => {
const saved = isClient ? storage.get(Key.COPILOT_MODE) : null;
return saved === "fast" ? "fast" : "extended_thinking";
})(),
setCopilotChatMode: (mode) => {
storage.set(Key.COPILOT_MODE, mode);
set({ copilotChatMode: mode });
},
copilotLlmModel: (() => {
const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null;
return saved === "advanced" ? "advanced" : "standard";
})(),
setCopilotLlmModel: (model) => {
storage.set(Key.COPILOT_MODEL, model);
set({ copilotLlmModel: model });
},
isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true",
@@ -322,6 +342,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH);
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
storage.clean(Key.COPILOT_DRY_RUN);
storage.clean(Key.COPILOT_MODE);
storage.clean(Key.COPILOT_MODEL);
set({
completedSessionIDs: new Set<string>(),
isNotificationsEnabled: false,
@@ -334,7 +356,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
activeArtifact: null,
history: [],
},
copilotMode: "extended_thinking",
copilotChatMode: "extended_thinking",
copilotLlmModel: "standard",
isDryRun: false,
});
if (isClient) {

View File

@@ -42,7 +42,8 @@ export function useCopilotPage() {
setSessionToDelete,
isDrawerOpen,
setDrawerOpen,
copilotMode,
copilotChatMode,
copilotLlmModel,
isDryRun,
} = useCopilotUIStore();
@@ -78,7 +79,8 @@ export function useCopilotPage() {
hydratedMessages,
hasActiveStream,
refetchSession,
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
copilotMode: isModeToggleEnabled ? copilotChatMode : undefined,
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
});
const { olderMessages, hasMore, isLoadingMore, loadMore } =

View File

@@ -17,8 +17,9 @@ import {
hasActiveBackendStream,
resolveInProgressTools,
getSendSuppressionReason,
disconnectSessionStream,
} from "./helpers";
import type { CopilotMode } from "./store";
import type { CopilotLlmModel, CopilotMode } from "./store";
const RECONNECT_BASE_DELAY_MS = 1_000;
const RECONNECT_MAX_ATTEMPTS = 3;
@@ -33,6 +34,8 @@ interface UseCopilotStreamArgs {
refetchSession: () => Promise<{ data?: unknown }>;
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
copilotMode: CopilotMode | undefined;
/** Model tier override. `undefined` = let backend decide. */
copilotModel: CopilotLlmModel | undefined;
}
export function useCopilotStream({
@@ -41,17 +44,20 @@ export function useCopilotStream({
hasActiveStream,
refetchSession,
copilotMode,
copilotModel,
}: UseCopilotStreamArgs) {
const queryClient = useQueryClient();
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
function dismissRateLimit() {
setRateLimitMessage(null);
}
// Use a ref for copilotMode so the transport closure always reads the
// latest value without recreating the DefaultChatTransport (which would
// Use refs for copilotMode and copilotModel so the transport closure always reads
// the latest value without recreating the DefaultChatTransport (which would
// reset useChat's internal Chat instance and break mid-session streaming).
const copilotModeRef = useRef(copilotMode);
copilotModeRef.current = copilotMode;
const copilotModelRef = useRef(copilotModel);
copilotModelRef.current = copilotModel;
// Connect directly to the Python backend for SSE, bypassing the Next.js
// serverless proxy. This eliminates the Vercel 800s function timeout that
@@ -83,6 +89,7 @@ export function useCopilotStream({
context: null,
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
mode: copilotModeRef.current ?? null,
model: copilotModelRef.current ?? null,
},
headers: await getCopilotAuthHeaders(),
};
@@ -147,16 +154,15 @@ export function useCopilotStream({
reconnectTimerRef.current = setTimeout(() => {
isReconnectScheduledRef.current = false;
setIsReconnectScheduled(false);
// Strip any stale in-progress assistant message before resuming.
// The backend replays from "0-0", so the partial message would
// otherwise sit alongside the fully-replayed version.
// Strip the stale in-progress assistant message before resuming
// the backend replays from "0-0", so keeping it would duplicate parts.
setMessages((prev) => {
if (prev.length > 0 && prev[prev.length - 1].role === "assistant") {
return prev.slice(0, -1);
}
return prev;
});
resumeStream();
resumeStreamRef.current();
}, delay);
}
@@ -254,6 +260,14 @@ export function useCopilotStream({
},
});
// Keep stable refs to sdkStop and resumeStream so that async callbacks
// (session-switch cleanup, wake re-sync, reconnect timer) always call the
// latest version without stale-closure bugs.
const sdkStopRef = useRef(sdkStop);
sdkStopRef.current = sdkStop;
const resumeStreamRef = useRef(resumeStream);
resumeStreamRef.current = resumeStream;
// Wrap sdkSendMessage to guard against re-sending the user message during a
// reconnect cycle. If the session already has the message (i.e. we are in a
// reconnect/resume flow), only GET-resume is safe — never re-POST.
@@ -380,7 +394,7 @@ export function useCopilotStream({
}
return prev;
});
await resumeStream();
await resumeStreamRef.current();
}
// If !backendActive, the refetch will update hydratedMessages via
// React Query, and the hydration effect below will merge them in.
@@ -403,7 +417,7 @@ export function useCopilotStream({
return () => {
document.removeEventListener("visibilitychange", onVisibilityChange);
};
}, [refetchSession, setMessages, resumeStream]);
}, [refetchSession, setMessages]);
// Hydrate messages from REST API when not actively streaming
useEffect(() => {
@@ -419,8 +433,34 @@ export function useCopilotStream({
// Track resume state per session
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
// Clean up reconnect state on session switch
// Clean up reconnect state on session switch.
// Abort the old stream's in-flight fetch and tell the backend to release
// its XREAD listeners immediately (fire-and-forget).
const prevStreamSessionRef = useRef(sessionId);
useEffect(() => {
const prevSid = prevStreamSessionRef.current;
prevStreamSessionRef.current = sessionId;
const isSwitching = Boolean(prevSid && prevSid !== sessionId);
if (isSwitching) {
// Mark BEFORE stopping so the old stream's async onError (which fires
// after the abort) sees the flag and short-circuits the reconnect path.
// Without this, the AbortError can queue a reconnect against the new
// session's `sessionId` (captured in the fresh onError closure).
isUserStoppingRef.current = true;
sdkStopRef.current();
disconnectSessionStream(prevSid!);
// Schedule the reset as a task (not a microtask) so it runs AFTER the
// aborted fetch's onError has fired — otherwise the new session would
// be stuck with the "user stopping" flag set, preventing auto-resume
// when hydration detects an active backend stream.
setTimeout(() => {
isUserStoppingRef.current = false;
}, 0);
} else {
isUserStoppingRef.current = false;
}
clearTimeout(reconnectTimerRef.current);
reconnectTimerRef.current = undefined;
reconnectAttemptsRef.current = 0;
@@ -428,7 +468,6 @@ export function useCopilotStream({
setIsReconnectScheduled(false);
setRateLimitMessage(null);
hasShownDisconnectToast.current = false;
isUserStoppingRef.current = false;
lastSubmittedMsgRef.current = null;
setReconnectExhausted(false);
setIsSyncing(false);
@@ -458,7 +497,12 @@ export function useCopilotStream({
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;
lastSubmittedMsgRef.current = null;
// Intentionally NOT clearing lastSubmittedMsgRef here: keeping the last
// submitted text prevents getSendSuppressionReason from allowing a
// duplicate POST of the same message immediately after a successful turn
// (the "duplicate" branch checks both the ref and the visible last user
// message, so legitimate re-sends after a different reply are still
// allowed).
setReconnectExhausted(false);
}
}
@@ -495,15 +539,8 @@ export function useCopilotStream({
return prev;
});
resumeStream();
}, [
sessionId,
hasActiveStream,
hydratedMessages,
status,
resumeStream,
setMessages,
]);
resumeStreamRef.current();
}, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]);
// Clear messages when session is null
useEffect(() => {

View File

@@ -1606,6 +1606,35 @@
}
},
"/api/chat/sessions/{session_id}/stream": {
"delete": {
"tags": ["v2", "chat", "chat"],
"summary": "Disconnect Session Stream",
"description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.",
"operationId": "deleteV2DisconnectSessionStream",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Session Id" }
}
],
"responses": {
"204": { "description": "Successful Response" },
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Resume Session Stream",
@@ -13931,6 +13960,14 @@
],
"title": "Mode",
"description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)."
},
"model": {
"anyOf": [
{ "type": "string", "enum": ["standard", "advanced"] },
{ "type": "null" }
],
"title": "Model",
"description": "Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. If None, the server applies per-user LD targeting then falls back to config."
}
},
"type": "object",

View File

@@ -17,6 +17,7 @@ export enum Key {
COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed",
COPILOT_ARTIFACT_PANEL_WIDTH = "copilot-artifact-panel-width",
COPILOT_MODE = "copilot-mode",
COPILOT_MODEL = "copilot-model",
COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions",
COPILOT_DRY_RUN = "copilot-dry-run",
}