mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into fix/copilot-static-system-prompt
This commit is contained in:
@@ -15,9 +15,10 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from backend.copilot.db import get_chat_messages_paginated
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.message_dedup import acquire_dedup_lock
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -140,6 +141,11 @@ class StreamChatRequest(BaseModel):
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
)
|
||||
model: CopilotLlmModel | None = Field(
|
||||
default=None,
|
||||
description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. "
|
||||
"If None, the server applies per-user LD targeting then falls back to config.",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -377,6 +383,31 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sessions/{session_id}/stream",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=204,
|
||||
)
|
||||
async def disconnect_session_stream(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> Response:
|
||||
"""Disconnect all active SSE listeners for a session.
|
||||
|
||||
Called by the frontend when the user switches away from a chat so the
|
||||
backend releases XREAD listeners immediately rather than waiting for
|
||||
the 5-10 s timeout.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
await stream_registry.disconnect_all_listeners(session_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
@@ -811,6 +842,9 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
# Capture the original message text BEFORE any mutation (attachment enrichment)
|
||||
# so the idempotency hash is stable across retries.
|
||||
original_message = request.message
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
@@ -839,60 +873,91 @@ async def stream_chat_post(
|
||||
)
|
||||
request.message += files_block
|
||||
|
||||
# ── Idempotency guard ────────────────────────────────────────────────────
|
||||
# Blocks duplicate executor tasks from concurrent/retried POSTs.
|
||||
# See backend/copilot/message_dedup.py for the full lifecycle description.
|
||||
dedup_lock = None
|
||||
if request.is_user_message:
|
||||
dedup_lock = await acquire_dedup_lock(
|
||||
session_id, original_message, sanitized_file_ids
|
||||
)
|
||||
if dedup_lock is None and (original_message or sanitized_file_ids):
|
||||
|
||||
async def _empty_sse() -> AsyncGenerator[str, None]:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_empty_sse(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||
# message loss from concurrent requests.
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
#
|
||||
# If any of these operations raises, release the dedup lock before propagating
|
||||
# so subsequent retries are not blocked for 30 s.
|
||||
try:
|
||||
if request.message:
|
||||
message = ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
# Create a task in the stream registry for reconnection support
|
||||
turn_id = str(uuid4())
|
||||
log_meta["turn_id"] = turn_id
|
||||
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
session_create_start = time.perf_counter()
|
||||
await stream_registry.create_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream",
|
||||
tool_name="chat",
|
||||
turn_id=turn_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
)
|
||||
await enqueue_copilot_turn(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=request.message,
|
||||
turn_id=turn_id,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
file_ids=sanitized_file_ids,
|
||||
mode=request.mode,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception:
|
||||
if dedup_lock:
|
||||
await dedup_lock.release()
|
||||
raise
|
||||
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
@@ -900,6 +965,9 @@ async def stream_chat_post(
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
|
||||
subscribe_from_id = "0-0"
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import time as time_module
|
||||
@@ -913,6 +981,12 @@ async def stream_chat_post(
|
||||
subscriber_queue = None
|
||||
first_chunk_yielded = False
|
||||
chunks_yielded = 0
|
||||
# True for every exit path except GeneratorExit (client disconnect).
|
||||
# On disconnect the backend turn is still running — releasing the lock
|
||||
# there would reopen the infra-retry duplicate window. The 30 s TTL
|
||||
# is the fallback. All other exits (normal finish, early return, error)
|
||||
# should release so the user can re-send the same message.
|
||||
release_dedup_lock_on_exit = True
|
||||
try:
|
||||
# Subscribe from the position we captured before enqueuing
|
||||
# This avoids replaying old messages while catching all new ones
|
||||
@@ -924,8 +998,7 @@ async def stream_chat_post(
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
return # finally releases dedup_lock
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
logger.info(
|
||||
@@ -954,7 +1027,6 @@ async def stream_chat_post(
|
||||
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
total_time = time_module.perf_counter() - event_gen_start
|
||||
logger.info(
|
||||
@@ -968,7 +1040,8 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
break
|
||||
break # finally releases dedup_lock
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
@@ -983,7 +1056,7 @@ async def stream_chat_post(
|
||||
}
|
||||
},
|
||||
)
|
||||
pass # Client disconnected - background task continues
|
||||
release_dedup_lock_on_exit = False
|
||||
except Exception as e:
|
||||
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||
logger.error(
|
||||
@@ -998,7 +1071,10 @@ async def stream_chat_post(
|
||||
code="stream_error",
|
||||
).to_sse()
|
||||
yield StreamFinish().to_sse()
|
||||
# finally releases dedup_lock
|
||||
finally:
|
||||
if dedup_lock and release_dedup_lock_on_exit:
|
||||
await dedup_lock.release()
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
|
||||
@@ -133,14 +133,30 @@ def test_stream_chat_rejects_too_many_file_ids():
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
def _mock_stream_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
redis_set_returns: object = True,
|
||||
):
|
||||
"""Mock the async internals of stream_chat_post so tests can exercise
|
||||
validation and enrichment logic without needing Redis/RabbitMQ."""
|
||||
validation and enrichment logic without needing Redis/RabbitMQ.
|
||||
|
||||
Args:
|
||||
redis_set_returns: Value returned by the mocked Redis ``set`` call.
|
||||
``True`` (default) simulates a fresh key (new message);
|
||||
``None`` simulates a collision (duplicate blocked).
|
||||
|
||||
Returns:
|
||||
A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so
|
||||
callers can make additional assertions about side-effects.
|
||||
"""
|
||||
import types
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
mock_save = mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
@@ -150,7 +166,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mocker.patch(
|
||||
mock_enqueue = mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
@@ -158,9 +174,18 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture):
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=redis_set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue)
|
||||
return ns
|
||||
|
||||
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture):
|
||||
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
|
||||
_mock_stream_internals(mocker)
|
||||
# Patch workspace lookup as imported by the routes module
|
||||
@@ -189,7 +214,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
|
||||
# ─── UUID format filtering ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture):
|
||||
"""Non-UUID strings in file_ids should be silently filtered out
|
||||
and NOT passed to the database query."""
|
||||
_mock_stream_internals(mocker)
|
||||
@@ -228,7 +253,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
|
||||
# ─── Cross-workspace file_ids ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture):
|
||||
"""The batch query should scope to the user's workspace."""
|
||||
_mock_stream_internals(mocker)
|
||||
mocker.patch(
|
||||
@@ -257,7 +282,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
|
||||
# ─── Rate limit → 429 ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture):
|
||||
"""When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -278,7 +303,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix
|
||||
assert "daily" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_returns_429_on_weekly_rate_limit(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
"""When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -301,7 +328,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi
|
||||
assert "resets in" in detail
|
||||
|
||||
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture):
|
||||
def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture):
|
||||
"""The 429 response detail should include the human-readable reset time."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
@@ -677,3 +704,279 @@ class TestStripInjectedContext:
|
||||
result = _strip_injected_context(msg)
|
||||
# Without a role, the helper short-circuits without touching content.
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── Idempotency / duplicate-POST guard ──────────────────────────────
|
||||
|
||||
|
||||
def test_stream_chat_blocks_duplicate_post_returns_empty_sse(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""A second POST with the same message within the 30-s window must return
|
||||
an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the
|
||||
turn complete without creating a ghost response."""
|
||||
# redis_set_returns=None simulates a collision: the NX key already exists.
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-dup/stream",
|
||||
json={"message": "duplicate message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
# The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator.
|
||||
assert '"finish"' in body
|
||||
assert "[DONE]" in body
|
||||
# The empty SSE response must include the AI SDK protocol header so the
|
||||
# frontend treats it as a valid stream and marks the turn complete.
|
||||
assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1"
|
||||
# The duplicate guard must prevent save/enqueue side effects.
|
||||
ns.save.assert_not_called()
|
||||
ns.enqueue.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_first_post_proceeds_normally(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The first POST (Redis NX key set successfully) must proceed through the
|
||||
normal streaming path — no early return."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-new/stream",
|
||||
json={"message": "first message", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Redis set must have been called once with the NX flag.
|
||||
ns.redis.set.assert_called_once()
|
||||
call_kwargs = ns.redis.set.call_args
|
||||
assert call_kwargs.kwargs.get("nx") is True
|
||||
|
||||
|
||||
def test_stream_chat_dedup_skipped_for_non_user_messages(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""System/assistant messages (is_user_message=False) bypass the dedup
|
||||
guard — they are injected programmatically and must always be processed."""
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=None)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-sys/stream",
|
||||
json={"message": "system context", "is_user_message": False},
|
||||
)
|
||||
|
||||
# Even though redis_set_returns=None (would block a user message),
|
||||
# the endpoint must proceed because is_user_message=False.
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_not_called()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_hash_uses_original_message_not_mutated(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup hash must be computed from the original request message,
|
||||
not the mutated version that has the [Attached files] block appended.
|
||||
A file_id is sent so the route actually appends the [Attached files] block,
|
||||
exercising the mutation path — the hash must still match the original text."""
|
||||
import hashlib
|
||||
|
||||
ns = _mock_stream_internals(mocker, redis_set_returns=True)
|
||||
|
||||
file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
# Mock workspace + prisma so the attachment block is actually appended.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
fake_file = type(
|
||||
"F",
|
||||
(),
|
||||
{
|
||||
"id": file_id,
|
||||
"name": "doc.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"sizeBytes": 1024,
|
||||
},
|
||||
)()
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-hash/stream",
|
||||
json={
|
||||
"message": "plain message",
|
||||
"is_user_message": True,
|
||||
"file_ids": [file_id],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
ns.redis.set.assert_called_once()
|
||||
call_args = ns.redis.set.call_args
|
||||
dedup_key = call_args.args[0]
|
||||
|
||||
# Hash must use the original message + sorted file IDs, not the mutated text.
|
||||
expected_hash = hashlib.sha256(
|
||||
f"sess-hash:plain message:{file_id}".encode()
|
||||
).hexdigest()[:16]
|
||||
expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}"
|
||||
assert dedup_key == expected_key, (
|
||||
f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — "
|
||||
"hash may be using mutated message or wrong inputs"
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_after_stream_finish(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The dedup Redis key must be deleted after the turn completes (when
|
||||
subscriber_queue is None the route yields StreamFinish immediately and
|
||||
should release the key so the user can re-send the same message)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
# Set up all internals manually so we can control subscribe_to_session.
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
# None → early-finish path: StreamFinish yielded immediately, dedup key released.
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-finish/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.text
|
||||
assert '"finish"' in body
|
||||
# The dedup key must be released so intentional re-sends are allowed.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_chat_dedup_key_released_even_when_redis_delete_raises(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""The route must not crash when the dedup Redis delete fails on the
|
||||
subscriber_queue-is-None early-finish path (except Exception: pass)."""
|
||||
from unittest.mock import AsyncMock as _AsyncMock
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.append_and_save_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.enqueue_copilot_turn",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.create_session = _AsyncMock(return_value=None)
|
||||
mock_registry.subscribe_to_session = _AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
mock_redis = mocker.AsyncMock()
|
||||
mock_redis.set = _AsyncMock(return_value=True)
|
||||
# Make the delete raise so the except-pass branch is exercised.
|
||||
mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone"))
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=_AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
|
||||
# Should not raise even though delete fails.
|
||||
response = client.post(
|
||||
"/sessions/sess-finish-err/stream",
|
||||
json={"message": "hello", "is_user_message": True},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '"finish"' in response.text
|
||||
# delete must have been attempted — the except-pass branch silenced the error.
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_204_and_awaits_registry(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/sess-1/stream")
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_disconnect.assert_awaited_once_with("sess-1")
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_404_when_session_missing(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/unknown-session/stream")
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
@@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
# Per-request model tier set by the frontend model toggle.
|
||||
# 'standard' uses the global config default (currently Sonnet).
|
||||
# 'advanced' forces the highest-capability model (currently Opus).
|
||||
# None means no preference — falls through to LD per-user targeting, then config.
|
||||
# Using tier names instead of model names keeps the contract model-agnostic.
|
||||
CopilotLlmModel = Literal["standard", "advanced"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
@@ -163,12 +170,12 @@ class ChatConfig(BaseSettings):
|
||||
"CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.",
|
||||
)
|
||||
claude_agent_max_budget_usd: float = Field(
|
||||
default=15.0,
|
||||
default=10.0,
|
||||
ge=0.01,
|
||||
le=1000.0,
|
||||
description="Maximum spend in USD per SDK query. The CLI attempts "
|
||||
"to wrap up gracefully when this budget is reached. "
|
||||
"Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
|
||||
"Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). "
|
||||
"Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.",
|
||||
)
|
||||
claude_agent_max_thinking_tokens: int = Field(
|
||||
|
||||
@@ -351,6 +351,7 @@ class CoPilotProcessor:
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
model=entry.model,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.config import CopilotLlmModel, CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
model: CopilotLlmModel | None = None
|
||||
"""Per-request model tier: 'standard' or 'advanced'. None = server default."""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
@@ -180,6 +183,7 @@ async def enqueue_copilot_turn(
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
@@ -192,6 +196,7 @@ async def enqueue_copilot_turn(
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
file_ids: Optional workspace file IDs attached to the user's message
|
||||
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
|
||||
model: Per-request model tier ('standard' or 'advanced'). None = server default.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
@@ -204,6 +209,7 @@ async def enqueue_copilot_turn(
|
||||
context=context,
|
||||
file_ids=file_ids,
|
||||
mode=mode,
|
||||
model=model,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
|
||||
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
71
autogpt_platform/backend/backend/copilot/message_dedup.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Per-request idempotency lock for the /stream endpoint.
|
||||
|
||||
Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s
|
||||
rolling-deploy retries, nginx upstream retries, rapid double-clicks).
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids)
|
||||
and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or
|
||||
``None`` when the key already exists (duplicate request).
|
||||
2. ``release()`` — deletes the key. Must be called on turn completion or turn
|
||||
error so the next legitimate send is never blocked.
|
||||
3. On client disconnect (``GeneratorExit``) the lock must NOT be released —
|
||||
the backend turn is still running, and releasing would reopen the duplicate
|
||||
window for infra-level retries. The 30 s TTL is the safety net.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_KEY_PREFIX = "chat:msg_dedup"
|
||||
_TTL_SECONDS = 30
|
||||
|
||||
|
||||
class _DedupLock:
|
||||
def __init__(self, key: str, redis) -> None:
|
||||
self._key = key
|
||||
self._redis = redis
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Best-effort key deletion. The TTL handles failures silently."""
|
||||
try:
|
||||
await self._redis.delete(self._key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def acquire_dedup_lock(
|
||||
session_id: str,
|
||||
message: str | None,
|
||||
file_ids: list[str] | None,
|
||||
) -> _DedupLock | None:
|
||||
"""Acquire the idempotency lock for this (session, message, files) tuple.
|
||||
|
||||
Returns a ``_DedupLock`` when the lock is freshly acquired (first request).
|
||||
Returns ``None`` when a duplicate is detected (lock already held).
|
||||
Returns ``None`` when there is nothing to deduplicate (no message, no files).
|
||||
"""
|
||||
if not message and not file_ids:
|
||||
return None
|
||||
|
||||
sorted_ids = ":".join(sorted(file_ids or []))
|
||||
content_hash = hashlib.sha256(
|
||||
f"{session_id}:{message or ''}:{sorted_ids}".encode()
|
||||
).hexdigest()[:16]
|
||||
key = f"{_KEY_PREFIX}:{session_id}:{content_hash}"
|
||||
|
||||
redis = await get_redis_async()
|
||||
acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"[STREAM] Duplicate user message blocked for session {session_id}, "
|
||||
f"hash={content_hash} — returning empty SSE",
|
||||
)
|
||||
return None
|
||||
|
||||
return _DedupLock(key, redis)
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Unit tests for backend.copilot.message_dedup."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock
|
||||
|
||||
|
||||
def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns):
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=set_returns)
|
||||
mocker.patch(
|
||||
"backend.copilot.message_dedup.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
return mock_redis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_when_no_message_no_files(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Nothing to deduplicate — no Redis call made, None returned."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
result = await acquire_dedup_lock("sess-1", None, None)
|
||||
assert result is None
|
||||
mock_redis.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_lock_on_first_request(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""First request acquires the lock and returns a _DedupLock."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
mock_redis.set.assert_called_once()
|
||||
key_arg = mock_redis.set.call_args.args[0]
|
||||
assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_returns_none_on_duplicate(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""Duplicate request (NX fails) returns None to signal the caller."""
|
||||
_patch_redis(mocker, set_returns=None)
|
||||
result = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_key_stable_across_file_order(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs are sorted before hashing so order doesn't affect the key."""
|
||||
mock_redis_1 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["b", "a"])
|
||||
key_ab = mock_redis_1.set.call_args.args[0]
|
||||
|
||||
mock_redis_2 = _patch_redis(mocker, set_returns=True)
|
||||
await acquire_dedup_lock("sess-1", "msg", ["a", "b"])
|
||||
key_ba = mock_redis_2.set.call_args.args[0]
|
||||
|
||||
assert key_ab == key_ba
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_deletes_key(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() calls Redis delete exactly once."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release()
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_swallows_redis_error(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""release() must not raise even when Redis delete fails."""
|
||||
mock_redis = _patch_redis(mocker, set_returns=True)
|
||||
mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
lock = await acquire_dedup_lock("sess-1", "hello", None)
|
||||
assert lock is not None
|
||||
await lock.release() # must not raise
|
||||
mock_redis.delete.assert_called_once()
|
||||
@@ -302,6 +302,7 @@ async def record_token_usage(
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> None:
|
||||
"""Record token usage for a user across all windows.
|
||||
|
||||
@@ -315,12 +316,17 @@ async def record_token_usage(
|
||||
``prompt_tokens`` should be the *uncached* input count (``input_tokens``
|
||||
from the API response). Cache counts are passed separately.
|
||||
|
||||
``model_cost_multiplier`` scales the final weighted total to reflect
|
||||
relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet)
|
||||
so that Opus turns deplete the rate limit faster, proportional to cost.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
prompt_tokens: Uncached input tokens.
|
||||
completion_tokens: Output tokens.
|
||||
cache_read_tokens: Tokens served from prompt cache (10% cost).
|
||||
cache_creation_tokens: Tokens written to prompt cache (25% cost).
|
||||
model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus).
|
||||
"""
|
||||
prompt_tokens = max(0, prompt_tokens)
|
||||
completion_tokens = max(0, completion_tokens)
|
||||
@@ -332,7 +338,9 @@ async def record_token_usage(
|
||||
+ round(cache_creation_tokens * 0.25)
|
||||
+ round(cache_read_tokens * 0.1)
|
||||
)
|
||||
total = weighted_input + completion_tokens
|
||||
total = round(
|
||||
(weighted_input + completion_tokens) * max(1.0, model_cost_multiplier)
|
||||
)
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
@@ -340,11 +348,12 @@ async def record_token_usage(
|
||||
prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens
|
||||
)
|
||||
logger.info(
|
||||
"Recording token usage for %s: raw=%d, weighted=%d "
|
||||
"Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx "
|
||||
"(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)",
|
||||
user_id[:8],
|
||||
raw_total,
|
||||
total,
|
||||
model_cost_multiplier,
|
||||
prompt_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_tokens,
|
||||
|
||||
@@ -34,9 +34,13 @@ Steps:
|
||||
always inspect the current graph first so you know exactly what to change.
|
||||
Avoid using `include_graph=true` with broad keyword searches, as fetching
|
||||
multiple graphs at once is expensive and consumes LLM context budget.
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to
|
||||
2. **Discover blocks**: Call `find_block(query, include_schemas=true, for_agent_generation=true)` to
|
||||
search for relevant blocks. This returns block IDs, names, descriptions,
|
||||
and full input/output schemas.
|
||||
and full input/output schemas. The `for_agent_generation=true` flag is
|
||||
required to surface graph-only blocks such as AgentInputBlock,
|
||||
AgentDropdownInputBlock, AgentOutputBlock, OrchestratorBlock,
|
||||
and WebhookBlock and MCPToolBlock. (When running MCP tools interactively
|
||||
in CoPilot outside agent generation, use `run_mcp_tool` instead.)
|
||||
3. **Find library agents**: Call `find_library_agent` to discover reusable
|
||||
agents that can be composed as sub-agents via `AgentExecutorBlock`.
|
||||
4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas:
|
||||
@@ -177,6 +181,12 @@ To compose agents using other agents as sub-agents:
|
||||
|
||||
### Using MCP Tools (MCPToolBlock)
|
||||
|
||||
> **Agent graph vs CoPilot direct execution**: This section covers embedding MCP
|
||||
> tools as persistent nodes in an agent graph. When running MCP tools directly in
|
||||
> CoPilot (outside agent generation), use `run_mcp_tool` instead — it handles
|
||||
> server discovery and authentication interactively. Use `MCPToolBlock` here only
|
||||
> when the user wants the MCP call baked into a reusable agent graph.
|
||||
|
||||
To use an MCP (Model Context Protocol) tool as a node in the agent:
|
||||
1. The user must specify which MCP server URL and tool name they want
|
||||
2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`)
|
||||
|
||||
@@ -207,7 +207,7 @@ class TestConfigDefaults:
|
||||
|
||||
def test_max_budget_usd_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_budget_usd == 15.0
|
||||
assert cfg.claude_agent_max_budget_usd == 10.0
|
||||
|
||||
def test_max_thinking_tokens_default(self):
|
||||
cfg = _make_config()
|
||||
|
||||
@@ -56,7 +56,7 @@ from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig, CopilotMode
|
||||
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -132,6 +132,11 @@ _MAX_STREAM_ATTEMPTS = 3
|
||||
# self-correct. The limit is generous to allow recovery attempts.
|
||||
_EMPTY_TOOL_CALL_LIMIT = 5
|
||||
|
||||
# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet
|
||||
# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus
|
||||
# turns deplete quota proportionally faster.
|
||||
_OPUS_COST_MULTIPLIER = 5.0
|
||||
|
||||
# User-facing error shown when the empty-tool-call circuit breaker trips.
|
||||
_CIRCUIT_BREAKER_ERROR_MSG = (
|
||||
"AutoPilot was unable to complete the tool call "
|
||||
@@ -674,6 +679,48 @@ def _resolve_fallback_model() -> str | None:
|
||||
return _normalize_model_name(raw)
|
||||
|
||||
|
||||
async def _resolve_model_and_multiplier(
|
||||
model: "CopilotLlmModel | None",
|
||||
session_id: str,
|
||||
) -> tuple[str | None, float]:
|
||||
"""Resolve the SDK model string and rate-limit cost multiplier for a turn.
|
||||
|
||||
Priority (highest first):
|
||||
1. Explicit per-request ``model`` tier from the frontend toggle.
|
||||
2. Global config default (``_resolve_sdk_model()``).
|
||||
|
||||
Returns a ``(sdk_model, cost_multiplier)`` pair.
|
||||
``sdk_model`` is ``None`` when the Claude Code subscription default applies.
|
||||
``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise.
|
||||
"""
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
if model == "advanced":
|
||||
sdk_model = _normalize_model_name("anthropic/claude-opus-4-6")
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: advanced (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model,
|
||||
)
|
||||
return sdk_model, _OPUS_COST_MULTIPLIER
|
||||
|
||||
if model == "standard":
|
||||
# Reset to config default — respects subscription mode (None = CLI default).
|
||||
sdk_model = _resolve_sdk_model()
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: standard (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model or "subscription-default",
|
||||
)
|
||||
return sdk_model, 1.0
|
||||
|
||||
# No per-request override; derive multiplier from final resolved model.
|
||||
cost_multiplier = (
|
||||
_OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0
|
||||
)
|
||||
return sdk_model, cost_multiplier
|
||||
|
||||
|
||||
_MAX_TRANSIENT_BACKOFF_SECONDS = 30
|
||||
|
||||
|
||||
@@ -1865,15 +1912,20 @@ async def _run_stream_attempt(
|
||||
# cache_read_input_tokens = served from cache
|
||||
# cache_creation_input_tokens = written to cache
|
||||
if sdk_msg.usage:
|
||||
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens", 0)
|
||||
state.usage.cache_read_tokens += sdk_msg.usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
# Use `or 0` instead of a default in .get() because
|
||||
# OpenRouter may include the key with a null value (e.g.
|
||||
# {"cache_read_input_tokens": null}) for models that don't
|
||||
# yet report cache tokens, making .get("key", 0) return
|
||||
# None rather than the fallback 0.
|
||||
state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0
|
||||
state.usage.cache_read_tokens += (
|
||||
sdk_msg.usage.get("cache_read_input_tokens") or 0
|
||||
)
|
||||
state.usage.cache_creation_tokens += sdk_msg.usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
state.usage.cache_creation_tokens += (
|
||||
sdk_msg.usage.get("cache_creation_input_tokens") or 0
|
||||
)
|
||||
state.usage.completion_tokens += sdk_msg.usage.get(
|
||||
"output_tokens", 0
|
||||
state.usage.completion_tokens += (
|
||||
sdk_msg.usage.get("output_tokens") or 0
|
||||
)
|
||||
logger.info(
|
||||
"%s Token usage: uncached=%d, cache_read=%d, "
|
||||
@@ -2150,6 +2202,7 @@ async def stream_chat_completion_sdk(
|
||||
file_ids: list[str] | None = None,
|
||||
permissions: "CopilotPermissions | None" = None,
|
||||
mode: CopilotMode | None = None,
|
||||
model: CopilotLlmModel | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncIterator[StreamBaseResponse]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
@@ -2160,6 +2213,9 @@ async def stream_chat_completion_sdk(
|
||||
saved to the SDK working directory for the Read tool.
|
||||
mode: Accepted for signature compatibility with the baseline path.
|
||||
The SDK path does not currently branch on this value.
|
||||
model: Per-request model preference from the frontend toggle.
|
||||
'advanced' → Claude Opus; 'standard' → global config default.
|
||||
Takes priority over per-user LaunchDarkly targeting.
|
||||
"""
|
||||
_ = mode # SDK path ignores the requested mode.
|
||||
|
||||
@@ -2274,6 +2330,10 @@ async def stream_chat_completion_sdk(
|
||||
turn_cache_creation_tokens = 0
|
||||
turn_cost_usd: float | None = None
|
||||
graphiti_enabled = False
|
||||
# Defaults ensure the finally block can always reference these safely even when
|
||||
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
|
||||
sdk_model: str | None = None
|
||||
model_cost_multiplier: float = 1.0
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
try:
|
||||
@@ -2487,7 +2547,10 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
# Resolve model and cost multiplier (request tier → config default).
|
||||
sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier(
|
||||
model, session_id
|
||||
)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
@@ -3188,8 +3251,9 @@ async def stream_chat_completion_sdk(
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=config.model,
|
||||
model=sdk_model or config.model,
|
||||
provider="anthropic",
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
|
||||
@@ -20,7 +20,9 @@ from .service import (
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_TokenUsage,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -350,3 +352,128 @@ class TestIsParallelContinuation:
|
||||
msg = MagicMock(spec=AssistantMessage)
|
||||
msg.content = [self._make_tool_block()]
|
||||
assert _is_tool_only_message(msg) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_model_name — used by per-request model override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Unit tests for the model-name normalisation helper.
|
||||
|
||||
The per-request model toggle calls _normalize_model_name with either
|
||||
``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for
|
||||
'standard'). These tests verify the OpenRouter/provider-prefix stripping
|
||||
that keeps the value compatible with the Claude CLI.
|
||||
"""
|
||||
|
||||
def test_strips_anthropic_prefix(self):
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_strips_openai_prefix(self):
|
||||
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
|
||||
|
||||
def test_strips_google_prefix(self):
|
||||
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
|
||||
|
||||
def test_already_normalized_unchanged(self):
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
def test_empty_string_unchanged(self):
|
||||
assert _normalize_model_name("") == ""
|
||||
|
||||
def test_opus_model_roundtrip(self):
|
||||
"""The exact string used for the 'opus' toggle strips correctly."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_sonnet_openrouter_model(self):
|
||||
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
|
||||
assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenUsageNullSafety:
|
||||
"""Verify that ResultMessage.usage dicts with null-valued cache fields
|
||||
(as emitted by OpenRouter for the initial streaming event before real
|
||||
token counts are available) do not crash the accumulator.
|
||||
|
||||
Before the fix, dict.get("cache_read_input_tokens", 0) returned None
|
||||
when the key existed with a null value, causing 'int += None' TypeError.
|
||||
"""
|
||||
|
||||
def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None:
|
||||
"""Mirror the production accumulation in sdk/service.py."""
|
||||
acc.prompt_tokens += usage.get("input_tokens") or 0
|
||||
acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0
|
||||
acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0
|
||||
acc.completion_tokens += usage.get("output_tokens") or 0
|
||||
|
||||
def test_null_cache_tokens_do_not_crash(self):
|
||||
"""OpenRouter initial event: cache keys present with null value."""
|
||||
usage = {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_input_tokens": None,
|
||||
"cache_creation_input_tokens": None,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc) # must not raise TypeError
|
||||
assert acc.prompt_tokens == 0
|
||||
assert acc.cache_read_tokens == 0
|
||||
assert acc.cache_creation_tokens == 0
|
||||
assert acc.completion_tokens == 0
|
||||
|
||||
def test_real_cache_tokens_are_accumulated(self):
|
||||
"""OpenRouter final event: real cache token counts are captured."""
|
||||
usage = {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 349,
|
||||
"cache_read_input_tokens": 16600,
|
||||
"cache_creation_input_tokens": 512,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc)
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.cache_read_tokens == 16600
|
||||
assert acc.cache_creation_tokens == 512
|
||||
assert acc.completion_tokens == 349
|
||||
|
||||
def test_absent_cache_keys_default_to_zero(self):
|
||||
"""Minimal usage dict without cache keys defaults correctly."""
|
||||
usage = {"input_tokens": 5, "output_tokens": 20}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(usage, acc)
|
||||
assert acc.prompt_tokens == 5
|
||||
assert acc.cache_read_tokens == 0
|
||||
assert acc.cache_creation_tokens == 0
|
||||
assert acc.completion_tokens == 20
|
||||
|
||||
def test_multi_turn_accumulation(self):
|
||||
"""Null event followed by real event: only real tokens counted."""
|
||||
null_event = {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_input_tokens": None,
|
||||
"cache_creation_input_tokens": None,
|
||||
}
|
||||
real_event = {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 349,
|
||||
"cache_read_input_tokens": 16600,
|
||||
"cache_creation_input_tokens": 512,
|
||||
}
|
||||
acc = _TokenUsage()
|
||||
self._apply_usage(null_event, acc)
|
||||
self._apply_usage(real_event, acc)
|
||||
assert acc.prompt_tokens == 10
|
||||
assert acc.cache_read_tokens == 16600
|
||||
assert acc.cache_creation_tokens == 512
|
||||
assert acc.completion_tokens == 349
|
||||
|
||||
@@ -1149,3 +1149,50 @@ async def unsubscribe_from_session(
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from session {session_id}")
|
||||
|
||||
|
||||
async def disconnect_all_listeners(session_id: str) -> int:
|
||||
"""Cancel every active listener task for *session_id*.
|
||||
|
||||
Called when the frontend switches away from a session and wants the
|
||||
backend to release resources immediately rather than waiting for the
|
||||
XREAD timeout.
|
||||
|
||||
Scope / limitations (best-effort optimisation, not a correctness primitive):
|
||||
- Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request
|
||||
lands on a different worker than the one serving the SSE, no listener
|
||||
is cancelled here — the SSE worker still releases on its XREAD timeout.
|
||||
- Session-scoped (not subscriber-scoped): cancels every active listener
|
||||
for the session on this pod. In the rare case a single user opens two
|
||||
SSE connections to the same session on the same pod (e.g. two tabs),
|
||||
both would be torn down. Cross-pod, subscriber-scoped cancellation
|
||||
would require a Redis pub/sub fan-out with per-listener tokens; that
|
||||
is not implemented here because the XREAD timeout already bounds the
|
||||
worst case.
|
||||
|
||||
Returns the number of listener tasks that were cancelled.
|
||||
"""
|
||||
to_cancel: list[tuple[int, asyncio.Task]] = [
|
||||
(qid, task)
|
||||
for qid, (sid, task) in list(_listener_sessions.items())
|
||||
if sid == session_id and not task.done()
|
||||
]
|
||||
|
||||
for qid, task in to_cancel:
|
||||
_listener_sessions.pop(qid, None)
|
||||
task.cancel()
|
||||
|
||||
cancelled = 0
|
||||
for _qid, task in to_cancel:
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=5.0)
|
||||
except asyncio.CancelledError:
|
||||
cancelled += 1
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling listener for session {session_id}: {e}")
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}")
|
||||
return cancelled
|
||||
|
||||
110
autogpt_platform/backend/backend/copilot/stream_registry_test.py
Normal file
110
autogpt_platform/backend/backend/copilot/stream_registry_test.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for disconnect_all_listeners in stream_registry."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_listener_sessions():
|
||||
stream_registry._listener_sessions.clear()
|
||||
yield
|
||||
stream_registry._listener_sessions.clear()
|
||||
|
||||
|
||||
async def _sleep_forever():
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_cancels_matching_session():
|
||||
task_a = asyncio.create_task(_sleep_forever())
|
||||
task_b = asyncio.create_task(_sleep_forever())
|
||||
task_other = asyncio.create_task(_sleep_forever())
|
||||
|
||||
stream_registry._listener_sessions[1] = ("sess-1", task_a)
|
||||
stream_registry._listener_sessions[2] = ("sess-1", task_b)
|
||||
stream_registry._listener_sessions[3] = ("sess-other", task_other)
|
||||
|
||||
try:
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
assert cancelled == 2
|
||||
assert task_a.cancelled()
|
||||
assert task_b.cancelled()
|
||||
assert not task_other.done()
|
||||
# Matching entries are removed, non-matching entries remain.
|
||||
assert 1 not in stream_registry._listener_sessions
|
||||
assert 2 not in stream_registry._listener_sessions
|
||||
assert 3 in stream_registry._listener_sessions
|
||||
finally:
|
||||
task_other.cancel()
|
||||
try:
|
||||
await task_other
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_no_match_returns_zero():
|
||||
task = asyncio.create_task(_sleep_forever())
|
||||
stream_registry._listener_sessions[1] = ("sess-other", task)
|
||||
|
||||
try:
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-missing")
|
||||
|
||||
assert cancelled == 0
|
||||
assert not task.done()
|
||||
assert 1 in stream_registry._listener_sessions
|
||||
finally:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_skips_already_done_tasks():
|
||||
async def _noop():
|
||||
return None
|
||||
|
||||
done_task = asyncio.create_task(_noop())
|
||||
await done_task
|
||||
stream_registry._listener_sessions[1] = ("sess-1", done_task)
|
||||
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
# Done tasks are filtered out before cancellation.
|
||||
assert cancelled == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_empty_registry():
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
assert cancelled == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_timeout_not_counted():
|
||||
"""Tasks that don't respond to cancellation (timeout) are not counted."""
|
||||
task = asyncio.create_task(_sleep_forever())
|
||||
stream_registry._listener_sessions[1] = ("sess-1", task)
|
||||
|
||||
with patch.object(
|
||||
asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
):
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
assert cancelled == 0
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@@ -96,6 +96,7 @@ async def persist_and_record_usage(
|
||||
cost_usd: float | str | None = None,
|
||||
model: str | None = None,
|
||||
provider: str = "open_router",
|
||||
model_cost_multiplier: float = 1.0,
|
||||
) -> int:
|
||||
"""Persist token usage to session and record for rate limiting.
|
||||
|
||||
@@ -109,6 +110,9 @@ async def persist_and_record_usage(
|
||||
log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]").
|
||||
cost_usd: Optional cost for logging (float from SDK, str otherwise).
|
||||
provider: Cost provider name (e.g. "anthropic", "open_router").
|
||||
model_cost_multiplier: Relative model cost factor for rate limiting
|
||||
(1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so
|
||||
more expensive models deplete the rate limit proportionally faster.
|
||||
|
||||
Returns:
|
||||
The computed total_tokens (prompt + completion; cache excluded).
|
||||
@@ -163,6 +167,7 @@ async def persist_and_record_usage(
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
model_cost_multiplier=model_cost_multiplier,
|
||||
)
|
||||
except Exception as usage_err:
|
||||
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
|
||||
|
||||
@@ -230,6 +230,7 @@ class TestRateLimitRecording:
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=1000,
|
||||
cache_creation_tokens=200,
|
||||
model_cost_multiplier=1.0,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -74,6 +74,15 @@ class FindBlockTool(BaseTool):
|
||||
"description": "Include full input/output schemas (for agent JSON generation).",
|
||||
"default": False,
|
||||
},
|
||||
"for_agent_generation": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Set to true when searching for blocks to use inside an agent graph "
|
||||
"(e.g. AgentInputBlock, AgentOutputBlock, OrchestratorBlock). "
|
||||
"Bypasses the CoPilot-only filter so graph-only blocks are visible."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
@@ -88,6 +97,7 @@ class FindBlockTool(BaseTool):
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
include_schemas: bool = False,
|
||||
for_agent_generation: bool = False,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
@@ -97,6 +107,8 @@ class FindBlockTool(BaseTool):
|
||||
session: Chat session
|
||||
query: Search query
|
||||
include_schemas: Whether to include block schemas in results
|
||||
for_agent_generation: When True, bypasses the CoPilot exclusion filter
|
||||
so graph-only blocks (INPUT, OUTPUT, ORCHESTRATOR, etc.) are visible.
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
@@ -123,34 +135,36 @@ class FindBlockTool(BaseTool):
|
||||
suggestions=["Search for an alternative block by name"],
|
||||
session_id=session_id,
|
||||
)
|
||||
if (
|
||||
is_excluded = (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
)
|
||||
if is_excluded:
|
||||
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
|
||||
# exposed when building an agent graph so the LLM can inspect
|
||||
# their schemas and wire them as nodes. In CoPilot direct use
|
||||
# they are not executable — guide the LLM to the right tool.
|
||||
if not for_agent_generation:
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
message = (
|
||||
f"Block '{block.name}' (ID: {block.id}) cannot be "
|
||||
"run directly in CoPilot. Use run_mcp_tool for "
|
||||
"interactive MCP execution, or call find_block with "
|
||||
"for_agent_generation=true to embed it in an agent graph."
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Block '{block.name}' (ID: {block.id}) is not available "
|
||||
"in CoPilot. It can only be used within agent graphs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not "
|
||||
"runnable through find_block/run_block. Use "
|
||||
"run_mcp_tool instead."
|
||||
),
|
||||
message=message,
|
||||
suggestions=[
|
||||
"Use run_mcp_tool to discover and run this MCP tool",
|
||||
"Search for an alternative block by name",
|
||||
"Use this block in an agent graph instead",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not available "
|
||||
"in CoPilot. It can only be used within agent graphs."
|
||||
),
|
||||
suggestions=[
|
||||
"Search for an alternative block by name",
|
||||
"Use this block in an agent graph instead",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check block-level permissions — hide denied blocks entirely
|
||||
perms = get_current_permissions()
|
||||
@@ -221,8 +235,9 @@ class FindBlockTool(BaseTool):
|
||||
if not block or block.disabled:
|
||||
continue
|
||||
|
||||
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||
if (
|
||||
# Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are
|
||||
# skipped in CoPilot direct use but surfaced for agent graph building.
|
||||
if not for_agent_generation and (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
|
||||
@@ -12,7 +12,7 @@ from .find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from .models import BlockListResponse
|
||||
from .models import BlockListResponse, NoResultsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
@@ -166,6 +166,194 @@ class TestFindBlockFiltering:
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_for_agent_generation_exposes_excluded_blocks_in_search(self):
|
||||
"""With for_agent_generation=True, excluded block types appear in search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "output-block-id", "score": 0.8},
|
||||
]
|
||||
input_block = make_mock_block("input-block-id", "Agent Input", BlockType.INPUT)
|
||||
output_block = make_mock_block(
|
||||
"output-block-id", "Agent Output", BlockType.OUTPUT
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"output-block-id": output_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="agent input",
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 2
|
||||
block_ids = {b.id for b in response.blocks}
|
||||
assert "input-block-id" in block_ids
|
||||
assert "output-block-id" in block_ids
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_tool_exposed_with_for_agent_generation_in_search(self):
|
||||
"""MCP_TOOL blocks appear in search results when for_agent_generation=True."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
search_results = [
|
||||
{"content_id": "mcp-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"mcp-block-id": mcp_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="mcp tool",
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 2
|
||||
assert any(b.id == "mcp-block-id" for b in response.blocks)
|
||||
assert any(b.id == "standard-block-id" for b in response.blocks)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_mcp_tool_excluded_without_for_agent_generation_in_search(self):
|
||||
"""MCP_TOOL blocks are excluded from search in normal CoPilot mode."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
search_results = [
|
||||
{"content_id": "mcp-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"mcp-block-id": mcp_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="mcp tool",
|
||||
for_agent_generation=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_for_agent_generation_exposes_excluded_ids_in_search(self):
|
||||
"""With for_agent_generation=True, excluded block IDs appear in search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
orchestrator_id = next(iter(COPILOT_EXCLUDED_BLOCK_IDS))
|
||||
|
||||
search_results = [
|
||||
{"content_id": orchestrator_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
orchestrator_block = make_mock_block(
|
||||
orchestrator_id, "Orchestrator", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
orchestrator_id: orchestrator_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query="orchestrator",
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 2
|
||||
block_ids = {b.id for b in response.blocks}
|
||||
assert orchestrator_id in block_ids
|
||||
assert "normal-block-id" in block_ids
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_response_size_average_chars_per_block(self):
|
||||
"""Measure average chars per block in the serialized response."""
|
||||
@@ -549,8 +737,6 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@@ -571,8 +757,6 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "disabled" in response.message.lower()
|
||||
|
||||
@@ -592,8 +776,6 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@@ -613,7 +795,74 @@ class TestFindBlockDirectLookup:
|
||||
user_id=_TEST_USER_ID, session=session, query=orchestrator_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_excluded_block_type_allowed_with_for_agent_generation(
|
||||
self,
|
||||
):
|
||||
"""With for_agent_generation=True, excluded block types (INPUT) are visible."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "Agent Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query=block_id,
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.count == 1
|
||||
assert response.blocks[0].id == block_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_mcp_tool_exposed_with_for_agent_generation(self):
|
||||
"""MCP_TOOL blocks are returned by UUID lookup when for_agent_generation=True."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query=block_id,
|
||||
for_agent_generation=True,
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.blocks[0].id == block_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_mcp_tool_excluded_without_for_agent_generation(self):
|
||||
"""MCP_TOOL blocks are excluded by UUID lookup in normal CoPilot mode."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
query=block_id,
|
||||
for_agent_generation=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "run_mcp_tool" in response.message
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { IMPERSONATION_HEADER_NAME } from "@/lib/constants";
|
||||
import { getCopilotAuthHeaders } from "../helpers";
|
||||
import { getCopilotAuthHeaders, getSendSuppressionReason } from "../helpers";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
vi.mock("@/lib/supabase/actions", () => ({
|
||||
getWebSocketToken: vi.fn(),
|
||||
@@ -72,3 +73,71 @@ describe("getCopilotAuthHeaders", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getSendSuppressionReason ─────────────────────────────────────────────────
|
||||
|
||||
function makeUserMsg(text: string): UIMessage {
|
||||
return {
|
||||
id: "msg-1",
|
||||
role: "user",
|
||||
content: text,
|
||||
parts: [{ type: "text", text }],
|
||||
} as UIMessage;
|
||||
}
|
||||
|
||||
describe("getSendSuppressionReason", () => {
|
||||
it("returns null when no dedup context exists (fresh ref)", () => {
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: null,
|
||||
messages: [],
|
||||
});
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("returns 'reconnecting' when reconnect is scheduled regardless of text", () => {
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: true,
|
||||
lastSubmittedText: null,
|
||||
messages: [],
|
||||
});
|
||||
expect(result).toBe("reconnecting");
|
||||
});
|
||||
|
||||
it("returns 'duplicate' when same text was submitted and is the last user message", () => {
|
||||
// This is the core regression test: after a successful turn the ref
|
||||
// is intentionally NOT cleared to null, so submitting the same text
|
||||
// again is caught here.
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages: [makeUserMsg("hello")],
|
||||
});
|
||||
expect(result).toBe("duplicate");
|
||||
});
|
||||
|
||||
it("returns null when same ref text but different last user message (different question)", () => {
|
||||
// User asked "hello" before, got a reply, then asked a different question
|
||||
// — the last user message in chat is now different, so no suppression.
|
||||
const result = getSendSuppressionReason({
|
||||
text: "hello",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "hello",
|
||||
messages: [makeUserMsg("hello"), makeUserMsg("something else")],
|
||||
});
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null when text differs from lastSubmittedText", () => {
|
||||
const result = getSendSuppressionReason({
|
||||
text: "new question",
|
||||
isReconnectScheduled: false,
|
||||
lastSubmittedText: "old question",
|
||||
messages: [makeUserMsg("old question")],
|
||||
});
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { describe, expect, it, beforeEach, vi } from "vitest";
|
||||
import { describe, expect, it, beforeEach, afterEach, vi } from "vitest";
|
||||
import { useCopilotUIStore } from "../store";
|
||||
|
||||
vi.mock("@sentry/nextjs", () => ({
|
||||
@@ -22,7 +22,8 @@ describe("useCopilotUIStore", () => {
|
||||
isNotificationsEnabled: false,
|
||||
isSoundEnabled: true,
|
||||
showNotificationDialog: false,
|
||||
copilotMode: "extended_thinking",
|
||||
copilotChatMode: "extended_thinking",
|
||||
copilotLlmModel: "standard",
|
||||
});
|
||||
});
|
||||
|
||||
@@ -154,35 +155,52 @@ describe("useCopilotUIStore", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("copilotMode", () => {
|
||||
describe("copilotChatMode", () => {
|
||||
it("defaults to extended_thinking", () => {
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe(
|
||||
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
|
||||
"extended_thinking",
|
||||
);
|
||||
});
|
||||
|
||||
it("sets mode to fast", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe("fast");
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
expect(useCopilotUIStore.getState().copilotChatMode).toBe("fast");
|
||||
});
|
||||
|
||||
it("sets mode back to extended_thinking", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotMode("extended_thinking");
|
||||
expect(useCopilotUIStore.getState().copilotMode).toBe(
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotChatMode("extended_thinking");
|
||||
expect(useCopilotUIStore.getState().copilotChatMode).toBe(
|
||||
"extended_thinking",
|
||||
);
|
||||
});
|
||||
|
||||
it("does not persist mode to localStorage", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
|
||||
it("persists mode to localStorage", () => {
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBe("fast");
|
||||
});
|
||||
});
|
||||
|
||||
describe("copilotLlmModel", () => {
|
||||
it("defaults to standard", () => {
|
||||
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("standard");
|
||||
});
|
||||
|
||||
it("sets model to advanced", () => {
|
||||
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
|
||||
expect(useCopilotUIStore.getState().copilotLlmModel).toBe("advanced");
|
||||
});
|
||||
|
||||
it("persists model to localStorage", () => {
|
||||
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
|
||||
expect(window.localStorage.getItem("copilot-model")).toBe("advanced");
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearCopilotLocalData", () => {
|
||||
it("resets state and clears localStorage keys", () => {
|
||||
useCopilotUIStore.getState().setCopilotMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotChatMode("fast");
|
||||
useCopilotUIStore.getState().setCopilotLlmModel("advanced");
|
||||
useCopilotUIStore.getState().setNotificationsEnabled(true);
|
||||
useCopilotUIStore.getState().toggleSound();
|
||||
useCopilotUIStore.getState().addCompletedSession("s1");
|
||||
@@ -190,7 +208,8 @@ describe("useCopilotUIStore", () => {
|
||||
useCopilotUIStore.getState().clearCopilotLocalData();
|
||||
|
||||
const state = useCopilotUIStore.getState();
|
||||
expect(state.copilotMode).toBe("extended_thinking");
|
||||
expect(state.copilotChatMode).toBe("extended_thinking");
|
||||
expect(state.copilotLlmModel).toBe("standard");
|
||||
expect(state.isNotificationsEnabled).toBe(false);
|
||||
expect(state.isSoundEnabled).toBe(true);
|
||||
expect(state.completedSessionIDs.size).toBe(0);
|
||||
@@ -198,6 +217,8 @@ describe("useCopilotUIStore", () => {
|
||||
window.localStorage.getItem("copilot-notifications-enabled"),
|
||||
).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
|
||||
expect(window.localStorage.getItem("copilot-model")).toBeNull();
|
||||
expect(
|
||||
window.localStorage.getItem("copilot-completed-sessions"),
|
||||
).toBeNull();
|
||||
@@ -222,3 +243,24 @@ describe("useCopilotUIStore", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("useCopilotUIStore localStorage initialisation", () => {
|
||||
afterEach(() => {
|
||||
vi.resetModules();
|
||||
window.localStorage.clear();
|
||||
});
|
||||
|
||||
it("reads fast chat mode from localStorage on store creation", async () => {
|
||||
window.localStorage.setItem("copilot-mode", "fast");
|
||||
vi.resetModules();
|
||||
const { useCopilotUIStore: fresh } = await import("../store");
|
||||
expect(fresh.getState().copilotChatMode).toBe("fast");
|
||||
});
|
||||
|
||||
it("reads advanced model from localStorage on store creation", async () => {
|
||||
window.localStorage.setItem("copilot-model", "advanced");
|
||||
vi.resetModules();
|
||||
const { useCopilotUIStore: fresh } = await import("../store");
|
||||
expect(fresh.getState().copilotLlmModel).toBe("advanced");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { DryRunToggleButton } from "./components/DryRunToggleButton";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { ModelToggleButton } from "./components/ModelToggleButton";
|
||||
import { ModeToggleButton } from "./components/ModeToggleButton";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
@@ -50,16 +51,22 @@ export function ChatInput({
|
||||
onDroppedFilesConsumed,
|
||||
hasSession = false,
|
||||
}: Props) {
|
||||
const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } =
|
||||
useCopilotUIStore();
|
||||
const {
|
||||
copilotChatMode,
|
||||
setCopilotChatMode,
|
||||
copilotLlmModel,
|
||||
setCopilotLlmModel,
|
||||
isDryRun,
|
||||
setIsDryRun,
|
||||
} = useCopilotUIStore();
|
||||
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const showDryRunToggle = showModeToggle;
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
function handleToggleMode() {
|
||||
const next =
|
||||
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
|
||||
setCopilotMode(next);
|
||||
copilotChatMode === "extended_thinking" ? "fast" : "extended_thinking";
|
||||
setCopilotChatMode(next);
|
||||
toast({
|
||||
title:
|
||||
next === "fast"
|
||||
@@ -72,6 +79,21 @@ export function ChatInput({
|
||||
});
|
||||
}
|
||||
|
||||
function handleToggleModel() {
|
||||
const next = copilotLlmModel === "advanced" ? "standard" : "advanced";
|
||||
setCopilotLlmModel(next);
|
||||
toast({
|
||||
title:
|
||||
next === "advanced"
|
||||
? "Switched to Advanced model"
|
||||
: "Switched to Standard model",
|
||||
description:
|
||||
next === "advanced"
|
||||
? "Using the highest-capability model."
|
||||
: "Using the balanced standard model.",
|
||||
});
|
||||
}
|
||||
|
||||
function handleToggleDryRun() {
|
||||
const next = !isDryRun;
|
||||
setIsDryRun(next);
|
||||
@@ -198,10 +220,16 @@ export function ChatInput({
|
||||
/>
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModeToggleButton
|
||||
mode={copilotMode}
|
||||
mode={copilotChatMode}
|
||||
onToggle={handleToggleMode}
|
||||
/>
|
||||
)}
|
||||
{showModeToggle && !isStreaming && (
|
||||
<ModelToggleButton
|
||||
model={copilotLlmModel}
|
||||
onToggle={handleToggleModel}
|
||||
/>
|
||||
)}
|
||||
{showDryRunToggle && (!hasSession || isDryRun) && (
|
||||
<DryRunToggleButton
|
||||
isDryRun={isDryRun}
|
||||
|
||||
@@ -8,14 +8,21 @@ import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatInput } from "../ChatInput";
|
||||
|
||||
let mockCopilotMode = "extended_thinking";
|
||||
const mockSetCopilotMode = vi.fn((mode: string) => {
|
||||
const mockSetCopilotChatMode = vi.fn((mode: string) => {
|
||||
mockCopilotMode = mode;
|
||||
});
|
||||
|
||||
let mockCopilotLlmModel = "standard";
|
||||
const mockSetCopilotLlmModel = vi.fn((model: string) => {
|
||||
mockCopilotLlmModel = model;
|
||||
});
|
||||
|
||||
vi.mock("@/app/(platform)/copilot/store", () => ({
|
||||
useCopilotUIStore: () => ({
|
||||
copilotMode: mockCopilotMode,
|
||||
setCopilotMode: mockSetCopilotMode,
|
||||
copilotChatMode: mockCopilotMode,
|
||||
setCopilotChatMode: mockSetCopilotChatMode,
|
||||
copilotLlmModel: mockCopilotLlmModel,
|
||||
setCopilotLlmModel: mockSetCopilotLlmModel,
|
||||
initialPrompt: null,
|
||||
setInitialPrompt: vi.fn(),
|
||||
}),
|
||||
@@ -107,6 +114,7 @@ afterEach(() => {
|
||||
cleanup();
|
||||
vi.clearAllMocks();
|
||||
mockCopilotMode = "extended_thinking";
|
||||
mockCopilotLlmModel = "standard";
|
||||
});
|
||||
|
||||
describe("ChatInput mode toggle", () => {
|
||||
@@ -141,7 +149,7 @@ describe("ChatInput mode toggle", () => {
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
|
||||
expect(mockSetCopilotMode).toHaveBeenCalledWith("fast");
|
||||
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("fast");
|
||||
});
|
||||
|
||||
it("toggles from fast to extended_thinking on click", () => {
|
||||
@@ -149,7 +157,7 @@ describe("ChatInput mode toggle", () => {
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to extended thinking/i));
|
||||
expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking");
|
||||
expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking");
|
||||
});
|
||||
|
||||
it("hides toggle button when streaming", () => {
|
||||
@@ -187,3 +195,69 @@ describe("ChatInput mode toggle", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("ChatInput model toggle", () => {
|
||||
it("renders model toggle button when flag is enabled", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined();
|
||||
});
|
||||
|
||||
it("does not render model toggle when flag is disabled", () => {
|
||||
mockFlagValue = false;
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|standard) model/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("toggles from standard to advanced on click", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "standard";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
|
||||
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced");
|
||||
});
|
||||
|
||||
it("toggles from advanced to standard on click", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard");
|
||||
});
|
||||
|
||||
it("hides model toggle when streaming", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
expect(
|
||||
screen.queryByLabelText(/switch to (advanced|standard) model/i),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("shows a toast when switching to advanced", async () => {
|
||||
const { toast } = await import("@/components/molecules/Toast/use-toast");
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "standard";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to advanced model/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to advanced model/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("shows a toast when switching to standard", async () => {
|
||||
const { toast } = await import("@/components/molecules/Toast/use-toast");
|
||||
mockFlagValue = true;
|
||||
mockCopilotLlmModel = "advanced";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByLabelText(/switch to standard model/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to standard model/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Cpu } from "@phosphor-icons/react";
|
||||
import type { CopilotLlmModel } from "../../../store";
|
||||
|
||||
interface Props {
|
||||
model: CopilotLlmModel;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
export function ModelToggleButton({ model, onToggle }: Props) {
|
||||
const isAdvanced = model === "advanced";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isAdvanced}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isAdvanced
|
||||
? "bg-sky-100 text-sky-900 hover:bg-sky-200"
|
||||
: "text-neutral-500 hover:bg-neutral-100 hover:text-neutral-700",
|
||||
)}
|
||||
aria-label={
|
||||
isAdvanced ? "Switch to Standard model" : "Switch to Advanced model"
|
||||
}
|
||||
title={
|
||||
isAdvanced
|
||||
? "Advanced model — highest capability (click to switch to Standard)"
|
||||
: "Standard model — click to switch to Advanced"
|
||||
}
|
||||
>
|
||||
<Cpu size={14} />
|
||||
{isAdvanced && "Advanced"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { ModelToggleButton } from "../ModelToggleButton";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
describe("ModelToggleButton", () => {
|
||||
it("shows no label when model is standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
expect(screen.queryByText("Advanced")).toBeNull();
|
||||
});
|
||||
|
||||
it("shows Advanced label when model is advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
expect(screen.getByText("Advanced")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("calls onToggle when clicked", () => {
|
||||
const onToggle = vi.fn();
|
||||
render(<ModelToggleButton model="standard" onToggle={onToggle} />);
|
||||
fireEvent.click(screen.getByRole("button"));
|
||||
expect(onToggle).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("sets aria-pressed=false for standard", () => {
|
||||
render(<ModelToggleButton model="standard" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Advanced model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("false");
|
||||
});
|
||||
|
||||
it("sets aria-pressed=true for advanced", () => {
|
||||
render(<ModelToggleButton model="advanced" onToggle={vi.fn()} />);
|
||||
const btn = screen.getByLabelText("Switch to Standard model");
|
||||
expect(btn.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
});
|
||||
@@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
export const ORIGINAL_TITLE = "AutoGPT";
|
||||
|
||||
/**
|
||||
@@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend(
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by content fingerprint.
|
||||
* Fire-and-forget: tell the backend to release XREAD listeners for a session.
|
||||
*
|
||||
* Called on session switch so the backend doesn't wait for its 5-10 s timeout
|
||||
* before cleaning up. Failures are silently ignored — the backend will
|
||||
* eventually clean up on its own.
|
||||
*/
|
||||
export function disconnectSessionStream(sessionId: string): void {
|
||||
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by consecutive content fingerprint.
|
||||
*
|
||||
* ID dedup catches exact duplicates within the same source.
|
||||
* Content dedup uses a composite key of `role + preceding-user-message-id +
|
||||
|
||||
@@ -53,6 +53,9 @@ export const DEFAULT_PANEL_WIDTH = 600;
|
||||
/** Autopilot response mode. */
|
||||
export type CopilotMode = "extended_thinking" | "fast";
|
||||
|
||||
/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */
|
||||
export type CopilotLlmModel = "standard" | "advanced";
|
||||
|
||||
const isClient = typeof window !== "undefined";
|
||||
|
||||
function getPersistedWidth(): number {
|
||||
@@ -134,8 +137,12 @@ interface CopilotUIState {
|
||||
goBackArtifact: () => void;
|
||||
|
||||
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
|
||||
copilotMode: CopilotMode;
|
||||
setCopilotMode: (mode: CopilotMode) => void;
|
||||
copilotChatMode: CopilotMode;
|
||||
setCopilotChatMode: (mode: CopilotMode) => void;
|
||||
|
||||
/** Model tier: 'standard' (default) or 'advanced' (highest-capability). */
|
||||
copilotLlmModel: CopilotLlmModel;
|
||||
setCopilotLlmModel: (model: CopilotLlmModel) => void;
|
||||
|
||||
/** Developer dry-run mode: sessions created with dry_run=true. */
|
||||
isDryRun: boolean;
|
||||
@@ -298,9 +305,22 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
};
|
||||
}),
|
||||
|
||||
copilotMode: "extended_thinking",
|
||||
setCopilotMode: (mode) => {
|
||||
set({ copilotMode: mode });
|
||||
copilotChatMode: (() => {
|
||||
const saved = isClient ? storage.get(Key.COPILOT_MODE) : null;
|
||||
return saved === "fast" ? "fast" : "extended_thinking";
|
||||
})(),
|
||||
setCopilotChatMode: (mode) => {
|
||||
storage.set(Key.COPILOT_MODE, mode);
|
||||
set({ copilotChatMode: mode });
|
||||
},
|
||||
|
||||
copilotLlmModel: (() => {
|
||||
const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null;
|
||||
return saved === "advanced" ? "advanced" : "standard";
|
||||
})(),
|
||||
setCopilotLlmModel: (model) => {
|
||||
storage.set(Key.COPILOT_MODEL, model);
|
||||
set({ copilotLlmModel: model });
|
||||
},
|
||||
|
||||
isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true",
|
||||
@@ -322,6 +342,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH);
|
||||
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
|
||||
storage.clean(Key.COPILOT_DRY_RUN);
|
||||
storage.clean(Key.COPILOT_MODE);
|
||||
storage.clean(Key.COPILOT_MODEL);
|
||||
set({
|
||||
completedSessionIDs: new Set<string>(),
|
||||
isNotificationsEnabled: false,
|
||||
@@ -334,7 +356,8 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
copilotMode: "extended_thinking",
|
||||
copilotChatMode: "extended_thinking",
|
||||
copilotLlmModel: "standard",
|
||||
isDryRun: false,
|
||||
});
|
||||
if (isClient) {
|
||||
|
||||
@@ -42,7 +42,8 @@ export function useCopilotPage() {
|
||||
setSessionToDelete,
|
||||
isDrawerOpen,
|
||||
setDrawerOpen,
|
||||
copilotMode,
|
||||
copilotChatMode,
|
||||
copilotLlmModel,
|
||||
isDryRun,
|
||||
} = useCopilotUIStore();
|
||||
|
||||
@@ -78,7 +79,8 @@ export function useCopilotPage() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
|
||||
copilotMode: isModeToggleEnabled ? copilotChatMode : undefined,
|
||||
copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined,
|
||||
});
|
||||
|
||||
const { olderMessages, hasMore, isLoadingMore, loadMore } =
|
||||
|
||||
@@ -17,8 +17,9 @@ import {
|
||||
hasActiveBackendStream,
|
||||
resolveInProgressTools,
|
||||
getSendSuppressionReason,
|
||||
disconnectSessionStream,
|
||||
} from "./helpers";
|
||||
import type { CopilotMode } from "./store";
|
||||
import type { CopilotLlmModel, CopilotMode } from "./store";
|
||||
|
||||
const RECONNECT_BASE_DELAY_MS = 1_000;
|
||||
const RECONNECT_MAX_ATTEMPTS = 3;
|
||||
@@ -33,6 +34,8 @@ interface UseCopilotStreamArgs {
|
||||
refetchSession: () => Promise<{ data?: unknown }>;
|
||||
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
|
||||
copilotMode: CopilotMode | undefined;
|
||||
/** Model tier override. `undefined` = let backend decide. */
|
||||
copilotModel: CopilotLlmModel | undefined;
|
||||
}
|
||||
|
||||
export function useCopilotStream({
|
||||
@@ -41,17 +44,20 @@ export function useCopilotStream({
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode,
|
||||
copilotModel,
|
||||
}: UseCopilotStreamArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
|
||||
function dismissRateLimit() {
|
||||
setRateLimitMessage(null);
|
||||
}
|
||||
// Use a ref for copilotMode so the transport closure always reads the
|
||||
// latest value without recreating the DefaultChatTransport (which would
|
||||
// Use refs for copilotMode and copilotModel so the transport closure always reads
|
||||
// the latest value without recreating the DefaultChatTransport (which would
|
||||
// reset useChat's internal Chat instance and break mid-session streaming).
|
||||
const copilotModeRef = useRef(copilotMode);
|
||||
copilotModeRef.current = copilotMode;
|
||||
const copilotModelRef = useRef(copilotModel);
|
||||
copilotModelRef.current = copilotModel;
|
||||
|
||||
// Connect directly to the Python backend for SSE, bypassing the Next.js
|
||||
// serverless proxy. This eliminates the Vercel 800s function timeout that
|
||||
@@ -83,6 +89,7 @@ export function useCopilotStream({
|
||||
context: null,
|
||||
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
|
||||
mode: copilotModeRef.current ?? null,
|
||||
model: copilotModelRef.current ?? null,
|
||||
},
|
||||
headers: await getCopilotAuthHeaders(),
|
||||
};
|
||||
@@ -147,16 +154,15 @@ export function useCopilotStream({
|
||||
reconnectTimerRef.current = setTimeout(() => {
|
||||
isReconnectScheduledRef.current = false;
|
||||
setIsReconnectScheduled(false);
|
||||
// Strip any stale in-progress assistant message before resuming.
|
||||
// The backend replays from "0-0", so the partial message would
|
||||
// otherwise sit alongside the fully-replayed version.
|
||||
// Strip the stale in-progress assistant message before resuming —
|
||||
// the backend replays from "0-0", so keeping it would duplicate parts.
|
||||
setMessages((prev) => {
|
||||
if (prev.length > 0 && prev[prev.length - 1].role === "assistant") {
|
||||
return prev.slice(0, -1);
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
resumeStream();
|
||||
resumeStreamRef.current();
|
||||
}, delay);
|
||||
}
|
||||
|
||||
@@ -254,6 +260,14 @@ export function useCopilotStream({
|
||||
},
|
||||
});
|
||||
|
||||
// Keep stable refs to sdkStop and resumeStream so that async callbacks
|
||||
// (session-switch cleanup, wake re-sync, reconnect timer) always call the
|
||||
// latest version without stale-closure bugs.
|
||||
const sdkStopRef = useRef(sdkStop);
|
||||
sdkStopRef.current = sdkStop;
|
||||
const resumeStreamRef = useRef(resumeStream);
|
||||
resumeStreamRef.current = resumeStream;
|
||||
|
||||
// Wrap sdkSendMessage to guard against re-sending the user message during a
|
||||
// reconnect cycle. If the session already has the message (i.e. we are in a
|
||||
// reconnect/resume flow), only GET-resume is safe — never re-POST.
|
||||
@@ -380,7 +394,7 @@ export function useCopilotStream({
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
await resumeStream();
|
||||
await resumeStreamRef.current();
|
||||
}
|
||||
// If !backendActive, the refetch will update hydratedMessages via
|
||||
// React Query, and the hydration effect below will merge them in.
|
||||
@@ -403,7 +417,7 @@ export function useCopilotStream({
|
||||
return () => {
|
||||
document.removeEventListener("visibilitychange", onVisibilityChange);
|
||||
};
|
||||
}, [refetchSession, setMessages, resumeStream]);
|
||||
}, [refetchSession, setMessages]);
|
||||
|
||||
// Hydrate messages from REST API when not actively streaming
|
||||
useEffect(() => {
|
||||
@@ -419,8 +433,34 @@ export function useCopilotStream({
|
||||
// Track resume state per session
|
||||
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
|
||||
|
||||
// Clean up reconnect state on session switch
|
||||
// Clean up reconnect state on session switch.
|
||||
// Abort the old stream's in-flight fetch and tell the backend to release
|
||||
// its XREAD listeners immediately (fire-and-forget).
|
||||
const prevStreamSessionRef = useRef(sessionId);
|
||||
useEffect(() => {
|
||||
const prevSid = prevStreamSessionRef.current;
|
||||
prevStreamSessionRef.current = sessionId;
|
||||
|
||||
const isSwitching = Boolean(prevSid && prevSid !== sessionId);
|
||||
if (isSwitching) {
|
||||
// Mark BEFORE stopping so the old stream's async onError (which fires
|
||||
// after the abort) sees the flag and short-circuits the reconnect path.
|
||||
// Without this, the AbortError can queue a reconnect against the new
|
||||
// session's `sessionId` (captured in the fresh onError closure).
|
||||
isUserStoppingRef.current = true;
|
||||
sdkStopRef.current();
|
||||
disconnectSessionStream(prevSid!);
|
||||
// Schedule the reset as a task (not a microtask) so it runs AFTER the
|
||||
// aborted fetch's onError has fired — otherwise the new session would
|
||||
// be stuck with the "user stopping" flag set, preventing auto-resume
|
||||
// when hydration detects an active backend stream.
|
||||
setTimeout(() => {
|
||||
isUserStoppingRef.current = false;
|
||||
}, 0);
|
||||
} else {
|
||||
isUserStoppingRef.current = false;
|
||||
}
|
||||
|
||||
clearTimeout(reconnectTimerRef.current);
|
||||
reconnectTimerRef.current = undefined;
|
||||
reconnectAttemptsRef.current = 0;
|
||||
@@ -428,7 +468,6 @@ export function useCopilotStream({
|
||||
setIsReconnectScheduled(false);
|
||||
setRateLimitMessage(null);
|
||||
hasShownDisconnectToast.current = false;
|
||||
isUserStoppingRef.current = false;
|
||||
lastSubmittedMsgRef.current = null;
|
||||
setReconnectExhausted(false);
|
||||
setIsSyncing(false);
|
||||
@@ -458,7 +497,12 @@ export function useCopilotStream({
|
||||
if (status === "ready") {
|
||||
reconnectAttemptsRef.current = 0;
|
||||
hasShownDisconnectToast.current = false;
|
||||
lastSubmittedMsgRef.current = null;
|
||||
// Intentionally NOT clearing lastSubmittedMsgRef here: keeping the last
|
||||
// submitted text prevents getSendSuppressionReason from allowing a
|
||||
// duplicate POST of the same message immediately after a successful turn
|
||||
// (the "duplicate" branch checks both the ref and the visible last user
|
||||
// message, so legitimate re-sends after a different reply are still
|
||||
// allowed).
|
||||
setReconnectExhausted(false);
|
||||
}
|
||||
}
|
||||
@@ -495,15 +539,8 @@ export function useCopilotStream({
|
||||
return prev;
|
||||
});
|
||||
|
||||
resumeStream();
|
||||
}, [
|
||||
sessionId,
|
||||
hasActiveStream,
|
||||
hydratedMessages,
|
||||
status,
|
||||
resumeStream,
|
||||
setMessages,
|
||||
]);
|
||||
resumeStreamRef.current();
|
||||
}, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]);
|
||||
|
||||
// Clear messages when session is null
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1606,6 +1606,35 @@
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/stream": {
|
||||
"delete": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Disconnect Session Stream",
|
||||
"description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.",
|
||||
"operationId": "deleteV2DisconnectSessionStream",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": { "description": "Successful Response" },
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Resume Session Stream",
|
||||
@@ -13931,6 +13960,14 @@
|
||||
],
|
||||
"title": "Mode",
|
||||
"description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)."
|
||||
},
|
||||
"model": {
|
||||
"anyOf": [
|
||||
{ "type": "string", "enum": ["standard", "advanced"] },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. If None, the server applies per-user LD targeting then falls back to config."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -17,6 +17,7 @@ export enum Key {
|
||||
COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed",
|
||||
COPILOT_ARTIFACT_PANEL_WIDTH = "copilot-artifact-panel-width",
|
||||
COPILOT_MODE = "copilot-mode",
|
||||
COPILOT_MODEL = "copilot-model",
|
||||
COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions",
|
||||
COPILOT_DRY_RUN = "copilot-dry-run",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user