diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 57a7b9a204..023e14f3dc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -4,7 +4,7 @@ import asyncio import logging import re from collections.abc import AsyncGenerator -from typing import Annotated +from typing import Annotated, Any, cast from uuid import uuid4 from autogpt_libs import auth @@ -29,6 +29,12 @@ from backend.copilot.model import ( get_user_sessions, update_session_title, ) +from backend.copilot.pending_messages import ( + MAX_PENDING_MESSAGES, + PendingMessage, + PendingMessageContext, + push_pending_message, +) from backend.copilot.rate_limit import ( CoPilotUsageStatus, RateLimitExceeded, @@ -84,6 +90,32 @@ _UUID_RE = re.compile( r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I ) +# Call-frequency cap for the pending-message endpoint. The token-budget +# check in queue_pending_message guards against overspend, but does not +# prevent rapid-fire pushes from a client with a large budget. This cap +# (per user, per 60-second window) limits the rate a caller can hammer the +# endpoint independently of token consumption. +_PENDING_CALL_LIMIT = 30 # pushes per minute per user +_PENDING_CALL_WINDOW_SECONDS = 60 +_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:" + +# Maximum lengths for pending-message context fields (url: 2 KB, content: 32 KB). +# Enforced by QueuePendingMessageRequest._validate_context_length. +_CONTEXT_URL_MAX_LENGTH = 2_000 +_CONTEXT_CONTENT_MAX_LENGTH = 32_000 + +# Lua script for atomic INCR + conditional EXPIRE. +# Using a single EVAL ensures the counter never persists without a TTL — +# a bare INCR followed by a separate EXPIRE can leave the key without +# an expiry if the process crashes between the two commands. +_CALL_INCR_LUA = """ +local count = redis.call('INCR', KEYS[1]) +if count == 1 then + redis.call('EXPIRE', KEYS[1], tonumber(ARGV[1])) +end +return count +""" + async def _validate_and_get_session( session_id: str, @@ -96,6 +128,29 @@ async def _validate_and_get_session( return session +async def _resolve_workspace_files( + user_id: str, + file_ids: list[str], +) -> list[UserWorkspaceFile]: + """Filter *file_ids* to UUID-valid entries that exist in the caller's workspace. + + Returns the matching ``UserWorkspaceFile`` records (empty list if none pass). + Used by both the stream and pending-message endpoints to prevent callers from + referencing other users' files. + """ + valid_ids = [fid for fid in file_ids if _UUID_RE.fullmatch(fid)] + if not valid_ids: + return [] + workspace = await get_or_create_workspace(user_id) + return await UserWorkspaceFile.prisma().find_many( + where={ + "id": {"in": valid_ids}, + "workspaceId": workspace.id, + "isDeleted": False, + } + ) + + router = APIRouter( tags=["chat"], ) @@ -119,6 +174,61 @@ class StreamChatRequest(BaseModel): ) +class QueuePendingMessageRequest(BaseModel): + """Request model for queueing a message into an in-flight turn. + + Unlike ``StreamChatRequest`` this endpoint does **not** start a new + turn — the message is appended to a per-session pending buffer that + the executor currently processing the turn will drain between tool + rounds. + """ + + model_config = ConfigDict(extra="forbid") + + message: str = Field(min_length=1, max_length=16_000) + context: PendingMessageContext | None = Field( + default=None, + description="Optional page context with 'url' and 'content' fields.", + ) + file_ids: list[str] | None = Field(default=None, max_length=20) + + @field_validator("context") + @classmethod + def _validate_context_length( + cls, v: PendingMessageContext | None + ) -> PendingMessageContext | None: + if v is None: + return v + # Cap context values to prevent LLM context-window stuffing via + # large page payloads. Limits are module-level constants so + # they are visible to callers and documentation. + if v.url and len(v.url) > _CONTEXT_URL_MAX_LENGTH: + raise ValueError( + f"context.url exceeds maximum length of {_CONTEXT_URL_MAX_LENGTH} characters" + ) + if v.content and len(v.content) > _CONTEXT_CONTENT_MAX_LENGTH: + raise ValueError( + f"context.content exceeds maximum length of {_CONTEXT_CONTENT_MAX_LENGTH} characters" + ) + return v + + +class QueuePendingMessageResponse(BaseModel): + """Response for the pending-message endpoint. + + - ``buffer_length``: how many messages are now in the session's + pending buffer (after this push) + - ``max_buffer_length``: the per-session cap (server-side constant) + - ``turn_in_flight``: ``True`` if a copilot turn was running when + we checked — purely informational for UX feedback. Even when + ``False`` the message is still queued: the next turn drains it. + """ + + buffer_length: int + max_buffer_length: int + turn_in_flight: bool + + class CreateSessionRequest(BaseModel): """Request model for creating a new chat session. @@ -786,33 +896,21 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None - if request.file_ids and user_id: - # Filter to valid UUIDs only to prevent DB abuse - valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] - - if valid_ids: - workspace = await get_or_create_workspace(user_id) - # Batch query instead of N+1 - files = await UserWorkspaceFile.prisma().find_many( - where={ - "id": {"in": valid_ids}, - "workspaceId": workspace.id, - "isDeleted": False, - } + if request.file_ids: + files = await _resolve_workspace_files(user_id, request.file_ids) + # Only keep IDs that actually exist in the user's workspace + sanitized_file_ids = [wf.id for wf in files] or None + file_lines: list[str] = [ + f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" + for wf in files + ] + if file_lines: + files_block = ( + "\n\n[Attached files]\n" + + "\n".join(file_lines) + + "\nUse read_workspace_file with the file_id to access file contents." ) - # Only keep IDs that actually exist in the user's workspace - sanitized_file_ids = [wf.id for wf in files] or None - file_lines: list[str] = [ - f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}" - for wf in files - ] - if file_lines: - files_block = ( - "\n\n[Attached files]\n" - + "\n".join(file_lines) - + "\nUse read_workspace_file with the file_id to access file contents." - ) - request.message += files_block + request.message += files_block # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't @@ -1012,6 +1110,135 @@ async def stream_chat_post( ) +@router.post( + "/sessions/{session_id}/messages/pending", + response_model=QueuePendingMessageResponse, + status_code=202, + responses={ + 404: {"description": "Session not found or access denied"}, + 429: {"description": "Token rate-limit or call-frequency cap exceeded"}, + }, +) +async def queue_pending_message( + session_id: str, + request: QueuePendingMessageRequest, + user_id: str = Security(auth.get_user_id), +): + """Queue a new user message into an in-flight copilot turn. + + When a user sends a follow-up message while a turn is still + streaming, we don't want to block them or start a separate turn — + this endpoint appends the message to a per-session pending buffer. + The executor currently running the turn (baseline path) drains the + buffer between tool-call rounds and appends the message to the + conversation before the next LLM call. On the SDK path the buffer + is drained at the *start* of the next turn (the long-lived + ``ClaudeSDKClient.receive_response`` iterator returns after a + ``ResultMessage`` so there is no safe point to inject mid-stream + into an existing connection). + + Returns 202. Enforces the same per-user daily/weekly token rate + limit as the regular ``/stream`` endpoint so a client can't bypass + it by batching messages through here. + """ + await _validate_and_get_session(session_id, user_id) + + # Pre-turn rate-limit check — mirrors stream_chat_post. Without + # this, a client could bypass per-turn token limits by batching + # their extra context through this endpoint while a cheap stream + # is in flight. + # user_id is guaranteed non-empty by Security(auth.get_user_id) — no guard needed. + try: + daily_limit, weekly_limit, _tier = await get_global_rate_limits( + user_id, config.daily_token_limit, config.weekly_token_limit + ) + await check_rate_limit( + user_id=user_id, + daily_token_limit=daily_limit, + weekly_token_limit=weekly_limit, + ) + except RateLimitExceeded as e: + raise HTTPException(status_code=429, detail=str(e)) from e + + # Call-frequency cap: prevent rapid-fire pushes that would bypass the + # token-budget check (which only fires per-turn, not per-push). + # Uses an atomic Lua EVAL (INCR + EXPIRE) so the key can never be + # orphaned without a TTL; fails open if Redis is down. + try: + _redis = await get_redis_async() + _call_key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}" + _call_count = int( + await cast( + "Any", + _redis.eval( + _CALL_INCR_LUA, + 1, + _call_key, + str(_PENDING_CALL_WINDOW_SECONDS), + ), + ) + ) + if _call_count > _PENDING_CALL_LIMIT: + raise HTTPException( + status_code=429, + detail=f"Too many pending messages: limit is {_PENDING_CALL_LIMIT} per {_PENDING_CALL_WINDOW_SECONDS}s", + ) + except HTTPException: + raise + except Exception: + logger.warning( + "queue_pending_message: rate-limit check failed, failing open" + ) # non-fatal + + # Sanitise file IDs to the user's own workspace so injection doesn't + # surface other users' files. _resolve_workspace_files handles UUID + # filtering and the workspace-scoped DB lookup. + sanitized_file_ids: list[str] = [] + if request.file_ids: + valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.fullmatch(fid)) + files = await _resolve_workspace_files(user_id, request.file_ids) + sanitized_file_ids = [wf.id for wf in files] + if len(sanitized_file_ids) != valid_id_count: + logger.warning( + "queue_pending_message: dropped %d file id(s) not in " + "caller's workspace (session=%s)", + valid_id_count - len(sanitized_file_ids), + session_id, + ) + + # Redis is the single source of truth for pending messages. We do + # NOT persist to ``session.messages`` here — the drain-at-start + # path in the baseline/SDK executor is the sole writer for pending + # content. Persisting both here AND in the drain would cause + # double injection (executor sees the message in ``session.messages`` + # *and* drains it from Redis) unless we also dedupe. The dedup in + # ``maybe_append_user_message`` only checks trailing same-role + # repeats, so relying on it is fragile. Keeping the endpoint + # Redis-only avoids the whole consistency-bug class. + pending = PendingMessage( + content=request.message, + file_ids=sanitized_file_ids, + context=request.context, + ) + buffer_length = await push_pending_message(session_id, pending) + + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + + # Check whether a turn is currently running for UX feedback. + active_session = await stream_registry.get_session(session_id) + turn_in_flight = bool(active_session and active_session.status == "running") + + return QueuePendingMessageResponse( + buffer_length=buffer_length, + max_buffer_length=MAX_PENDING_MESSAGES, + turn_in_flight=turn_in_flight, + ) + + @router.get( "/sessions/{session_id}/stream", ) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index cd87fe611f..401d73bea3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -579,3 +579,300 @@ class TestStreamChatRequestModeValidation: req = StreamChatRequest(message="hi") assert req.mode is None + + +# ─── QueuePendingMessageRequest validation ──────────────────────────── + + +class TestQueuePendingMessageRequest: + """Unit tests for QueuePendingMessageRequest field validation.""" + + def test_accepts_valid_message(self) -> None: + from backend.api.features.chat.routes import QueuePendingMessageRequest + + req = QueuePendingMessageRequest(message="hello") + assert req.message == "hello" + + def test_rejects_empty_message(self) -> None: + import pydantic + + from backend.api.features.chat.routes import QueuePendingMessageRequest + + with pytest.raises(pydantic.ValidationError): + QueuePendingMessageRequest(message="") + + def test_rejects_message_over_limit(self) -> None: + import pydantic + + from backend.api.features.chat.routes import QueuePendingMessageRequest + + with pytest.raises(pydantic.ValidationError): + QueuePendingMessageRequest(message="x" * 16_001) + + def test_accepts_valid_context(self) -> None: + from backend.api.features.chat.routes import QueuePendingMessageRequest + + req = QueuePendingMessageRequest( + message="hi", + context={"url": "https://example.com", "content": "page text"}, + ) + assert req.context is not None + assert req.context.url == "https://example.com" + + def test_rejects_context_url_over_limit(self) -> None: + import pydantic + + from backend.api.features.chat.routes import QueuePendingMessageRequest + + with pytest.raises(pydantic.ValidationError, match="url"): + QueuePendingMessageRequest( + message="hi", + context={"url": "https://example.com/" + "x" * 2_000}, + ) + + def test_rejects_context_content_over_limit(self) -> None: + import pydantic + + from backend.api.features.chat.routes import QueuePendingMessageRequest + + with pytest.raises(pydantic.ValidationError, match="content"): + QueuePendingMessageRequest( + message="hi", + context={"content": "x" * 32_001}, + ) + + def test_rejects_extra_fields(self) -> None: + """extra='forbid' should reject unknown fields.""" + import pydantic + + from backend.api.features.chat.routes import QueuePendingMessageRequest + + with pytest.raises(pydantic.ValidationError): + QueuePendingMessageRequest(message="hi", unknown_field="bad") # type: ignore[call-arg] + + def test_accepts_up_to_20_file_ids(self) -> None: + from backend.api.features.chat.routes import QueuePendingMessageRequest + + req = QueuePendingMessageRequest( + message="hi", + file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(20)], + ) + assert req.file_ids is not None + assert len(req.file_ids) == 20 + + def test_rejects_more_than_20_file_ids(self) -> None: + import pydantic + + from backend.api.features.chat.routes import QueuePendingMessageRequest + + with pytest.raises(pydantic.ValidationError): + QueuePendingMessageRequest( + message="hi", + file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(21)], + ) + + +# ─── queue_pending_message endpoint ────────────────────────────────── + + +def _mock_pending_internals( + mocker: pytest_mock.MockerFixture, + *, + session_exists: bool = True, + call_count: int = 1, +): + """Mock all async dependencies for the pending-message endpoint.""" + if session_exists: + mock_session = mocker.MagicMock() + mock_session.id = "sess-1" + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=mock_session, + ) + else: + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + side_effect=fastapi.HTTPException( + status_code=404, detail="Session not found." + ), + ) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 0, None), + ) + mocker.patch( + "backend.api.features.chat.routes.check_rate_limit", + new_callable=AsyncMock, + return_value=None, + ) + # Mock Redis for per-user call-frequency rate limit (atomic Lua EVAL) + mock_redis = mocker.MagicMock() + mock_redis.eval = mocker.AsyncMock(return_value=call_count) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + mocker.patch( + "backend.api.features.chat.routes.track_user_message", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.push_pending_message", + new_callable=AsyncMock, + return_value=1, + ) + mock_registry = mocker.MagicMock() + mock_registry.get_session = mocker.AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.stream_registry", + mock_registry, + ) + + +def test_queue_pending_message_returns_202(mocker: pytest_mock.MockerFixture) -> None: + """Happy path: valid message returns 202 with buffer_length.""" + _mock_pending_internals(mocker) + + response = client.post( + "/sessions/sess-1/messages/pending", + json={"message": "follow-up"}, + ) + + assert response.status_code == 202 + data = response.json() + assert data["buffer_length"] == 1 + assert data["turn_in_flight"] is False + + +def test_queue_pending_message_empty_body_returns_422() -> None: + """Empty message must be rejected by Pydantic before hitting any route logic.""" + response = client.post( + "/sessions/sess-1/messages/pending", + json={"message": ""}, + ) + assert response.status_code == 422 + + +def test_queue_pending_message_missing_message_returns_422() -> None: + """Missing 'message' field returns 422.""" + response = client.post( + "/sessions/sess-1/messages/pending", + json={}, + ) + assert response.status_code == 422 + + +def test_queue_pending_message_session_not_found_returns_404( + mocker: pytest_mock.MockerFixture, +) -> None: + """If the session doesn't exist or belong to the user, returns 404.""" + _mock_pending_internals(mocker, session_exists=False) + + response = client.post( + "/sessions/bad-sess/messages/pending", + json={"message": "hi"}, + ) + assert response.status_code == 404 + + +def test_queue_pending_message_rate_limited_returns_429( + mocker: pytest_mock.MockerFixture, +) -> None: + """When rate limit is exceeded, endpoint returns 429.""" + from backend.copilot.rate_limit import RateLimitExceeded + + _mock_pending_internals(mocker) + mocker.patch( + "backend.api.features.chat.routes.check_rate_limit", + side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)), + ) + + response = client.post( + "/sessions/sess-1/messages/pending", + json={"message": "hi"}, + ) + assert response.status_code == 429 + + +def test_queue_pending_message_call_frequency_limit_returns_429( + mocker: pytest_mock.MockerFixture, +) -> None: + """When per-user call frequency limit is exceeded, endpoint returns 429.""" + from backend.api.features.chat.routes import _PENDING_CALL_LIMIT + + _mock_pending_internals(mocker, call_count=_PENDING_CALL_LIMIT + 1) + + response = client.post( + "/sessions/sess-1/messages/pending", + json={"message": "hi"}, + ) + assert response.status_code == 429 + assert "Too many pending messages" in response.json()["detail"] + + +def test_queue_pending_message_context_url_too_long_returns_422() -> None: + """context.url over 2 KB is rejected.""" + response = client.post( + "/sessions/sess-1/messages/pending", + json={ + "message": "hi", + "context": {"url": "https://example.com/" + "x" * 2_000}, + }, + ) + assert response.status_code == 422 + + +def test_queue_pending_message_context_content_too_long_returns_422() -> None: + """context.content over 32 KB is rejected.""" + response = client.post( + "/sessions/sess-1/messages/pending", + json={ + "message": "hi", + "context": {"content": "x" * 32_001}, + }, + ) + assert response.status_code == 422 + + +def test_queue_pending_message_too_many_file_ids_returns_422() -> None: + """More than 20 file_ids should be rejected.""" + response = client.post( + "/sessions/sess-1/messages/pending", + json={ + "message": "hi", + "file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)], + }, + ) + assert response.status_code == 422 + + +def test_queue_pending_message_file_ids_scoped_to_workspace( + mocker: pytest_mock.MockerFixture, +) -> None: + """File IDs must be sanitized to the user's workspace before push.""" + _mock_pending_internals(mocker) + mocker.patch( + "backend.api.features.chat.routes.get_or_create_workspace", + new_callable=AsyncMock, + return_value=type("W", (), {"id": "ws-1"})(), + ) + mock_prisma = mocker.MagicMock() + mock_prisma.find_many = mocker.AsyncMock(return_value=[]) + mocker.patch( + "prisma.models.UserWorkspaceFile.prisma", + return_value=mock_prisma, + ) + fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + client.post( + "/sessions/sess-1/messages/pending", + json={"message": "hi", "file_ids": [fid, "not-a-uuid"]}, + ) + + call_kwargs = mock_prisma.find_many.call_args[1] + assert call_kwargs["where"]["id"]["in"] == [fid] + assert call_kwargs["where"]["workspaceId"] == "ws-1" + assert call_kwargs["where"]["isDeleted"] is False diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 1f1fe42f59..ad54b20f97 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -36,6 +36,10 @@ from backend.copilot.model import ( maybe_append_user_message, upsert_chat_session, ) +from backend.copilot.pending_messages import ( + drain_pending_messages, + format_pending_as_user_message, +) from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement from backend.copilot.response_model import ( StreamBaseResponse, @@ -341,6 +345,11 @@ class _BaselineStreamState: cost_usd: float | None = None thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper) session_messages: list[ChatMessage] = field(default_factory=list) + # Tracks how much of ``assistant_text`` has already been flushed to + # ``session.messages`` via mid-loop pending drains, so the ``finally`` + # block only appends the *new* assistant text (avoiding duplication of + # round-1 text when round-1 entries were cleared from session_messages). + _flushed_assistant_text_len: int = 0 async def _baseline_llm_caller( @@ -930,7 +939,54 @@ async def stream_chat_completion_baseline( message_length=len(message or ""), ) - session = await upsert_chat_session(session) + # Capture count *before* the pending drain so is_first_turn and the + # transcript staleness check are not skewed by queued messages. + _pre_drain_msg_count = len(session.messages) + + # Drain any messages the user queued via POST /messages/pending + # while this session was idle (or during a previous turn whose + # mid-loop drains missed them). Atomic LPOP guarantees that a + # concurrent push lands *after* the drain and stays queued for the + # next turn instead of being lost. + try: + drained_at_start = await drain_pending_messages(session_id) + except Exception: + logger.warning( + "[Baseline] drain_pending_messages failed at turn start, skipping", + exc_info=True, + ) + drained_at_start = [] + # Pre-compute formatted content once per message so we don't call + # format_pending_as_user_message twice (once for session.messages and + # once for transcript_builder below). + drained_at_start_content: list[str] = [] + if drained_at_start: + logger.info( + "[Baseline] Draining %d pending message(s) at turn start for session %s", + len(drained_at_start), + session_id, + ) + for pm in drained_at_start: + content = format_pending_as_user_message(pm)["content"] + drained_at_start_content.append(content) + # Append directly — pending messages are atomically-popped from + # Redis and are never stale-cache duplicates, so the + # maybe_append_user_message dedup is wrong here. + session.messages.append(ChatMessage(role="user", content=content)) + + # Persist the drained pending messages (if any) plus the current user + # message. Wrap in try/except so a transient DB failure here does not + # silently discard messages that were already popped from Redis — the + # turn can still proceed using the in-memory session.messages, and a + # later resume/replay will backfill from the DB on the next turn. + try: + session = await upsert_chat_session(session) + except Exception as _persist_err: + logger.warning( + "[Baseline] Failed to persist session at turn start " + "(pending drain may not be durable): %s", + _persist_err, + ) # Select model based on the per-request mode. 'fast' downgrades to # the cheaper/faster model; everything else keeps the default. @@ -959,7 +1015,9 @@ async def stream_chat_completion_baseline( # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. - is_first_turn = len(session.messages) <= 1 + # Use the pre-drain count so queued pending messages don't incorrectly + # flip is_first_turn to False on an actual first turn. + is_first_turn = _pre_drain_msg_count <= 1 # Gate context fetch on both first turn AND user message so that assistant- # role calls (e.g. tool-result submissions) on the first turn don't trigger # a needless DB lookup for user understanding. @@ -970,14 +1028,18 @@ async def stream_chat_completion_baseline( prompt_task = _build_cacheable_system_prompt(None) # Run download + prompt build concurrently — both are independent I/O - # on the request critical path. - if user_id and len(session.messages) > 1: + # on the request critical path. Use the pre-drain count so pending + # messages drained at turn start don't spuriously trigger a transcript + # load on an actual first turn. + if user_id and _pre_drain_msg_count > 1: transcript_covers_prefix, (base_system_prompt, understanding) = ( await asyncio.gather( _load_prior_transcript( user_id=user_id, session_id=session_id, - session_msg_count=len(session.messages), + # Use pre-drain count so pending messages don't falsely + # mark the stored transcript as stale and prevent upload. + session_msg_count=_pre_drain_msg_count, transcript_builder=transcript_builder, ), prompt_task, @@ -989,6 +1051,15 @@ async def stream_chat_completion_baseline( # Append user message to transcript after context injection below so the # transcript receives the prefixed message when user context is available. + # Mirror any messages drained at turn start (see above) into the + # transcript — otherwise the loaded prior transcript would be + # missing them and a mid-turn upload could leave a malformed + # assistant-after-assistant structure on the next turn. + # Reuse the pre-computed content strings to avoid calling + # format_pending_as_user_message a second time. + for _drained_content in drained_at_start_content: + transcript_builder.append_user(content=_drained_content) + # Generate title for new sessions if is_user_message and not session.title: user_messages = [m for m in session.messages if m.role == "user"] @@ -1009,8 +1080,10 @@ async def stream_chat_completion_baseline( graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement - # Warm context: pre-load relevant facts from Graphiti on first turn - if graphiti_enabled and user_id and len(session.messages) <= 1: + # Warm context: pre-load relevant facts from Graphiti on first turn. + # Use the pre-drain count so pending messages drained at turn start + # don't prevent warm context injection on an actual first turn. + if graphiti_enabled and user_id and _pre_drain_msg_count <= 1: from backend.copilot.graphiti.context import fetch_warm_context warm_ctx = await fetch_warm_context(user_id, message or "") @@ -1203,6 +1276,91 @@ async def stream_chat_completion_baseline( yield evt state.pending_events.clear() + # Inject any messages the user queued while the turn was + # running. ``tool_call_loop`` mutates ``openai_messages`` + # in-place, so appending here means the model sees the new + # messages on its next LLM call. + # + # IMPORTANT: skip when the loop has already finished (no + # more LLM calls are coming). ``tool_call_loop`` yields + # a final ``ToolCallLoopResult`` on both paths: + # - natural finish: ``finished_naturally=True`` + # - hit max_iterations: ``finished_naturally=False`` + # and ``iterations >= max_iterations`` + # In either case the loop is about to return on the next + # ``async for`` step, so draining here would silently + # lose the message (the user sees 202 but the model never + # reads the text). Those messages stay in the buffer and + # get picked up at the start of the next turn. + if loop_result is None: + continue + is_final_yield = ( + loop_result.finished_naturally + or loop_result.iterations >= _MAX_TOOL_ROUNDS + ) + if is_final_yield: + continue + try: + pending = await drain_pending_messages(session_id) + except Exception: + logger.warning( + "Mid-loop drain_pending_messages failed for session %s", + session_id, + exc_info=True, + ) + pending = [] + if pending: + # Flush any buffered assistant/tool messages from completed + # rounds into session.messages BEFORE appending the pending + # user message. ``_baseline_conversation_updater`` only + # records assistant+tool rounds into ``state.session_messages`` + # — they are normally batch-flushed in the finally block. + # Without this in-order flush, the mid-loop pending user + # message lands before the preceding round's assistant/tool + # entries, producing chronologically-wrong session.messages + # on persist (user interposed between an assistant tool_call + # and its tool-result), which breaks OpenAI tool-call ordering + # invariants on the next turn's replay. + for _buffered in state.session_messages: + session.messages.append(_buffered) + state.session_messages.clear() + # Record how much assistant_text has been covered by the + # structured entries just flushed, so the finally block's + # final-text dedup doesn't re-append rounds already persisted. + state._flushed_assistant_text_len = len(state.assistant_text) + + for pm in pending: + # ``format_pending_as_user_message`` embeds file + # attachments and context URL/page content into the + # content string so the in-session transcript is + # a faithful copy of what the model actually saw. + formatted = format_pending_as_user_message(pm) + content_for_db = formatted["content"] + # Append directly — pending messages are atomically-popped + # from Redis and are never stale-cache duplicates, so the + # maybe_append_user_message dedup is wrong here and would + # cause openai_messages/transcript to diverge from session. + session.messages.append( + ChatMessage(role="user", content=content_for_db) + ) + openai_messages.append(formatted) + transcript_builder.append_user(content=content_for_db) + try: + await upsert_chat_session(session) + except Exception as persist_err: + logger.warning( + "[Baseline] Failed to persist pending messages for " + "session %s: %s", + session_id, + persist_err, + ) + logger.info( + "[Baseline] Injected %d pending message(s) into " + "session %s mid-turn", + len(pending), + session_id, + ) + if loop_result and not loop_result.finished_naturally: limit_msg = ( f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds " @@ -1243,6 +1401,11 @@ async def stream_chat_completion_baseline( yield StreamError(errorText=error_msg, code="baseline_error") # Still persist whatever we got finally: + # Pending messages are drained atomically at turn start and + # between tool rounds, so there's nothing to clear in finally. + # Any message pushed after the final drain window stays in the + # buffer and gets picked up at the start of the next turn. + # Set cost attributes on OTEL span before closing if _trace_ctx is not None: try: @@ -1312,7 +1475,11 @@ async def stream_chat_completion_baseline( # no tool calls, i.e. the natural finish). Only add it if the # conversation updater didn't already record it as part of a # tool-call round (which would have empty response_text). - final_text = state.assistant_text + # Only consider assistant text produced AFTER the last mid-loop + # flush. ``_flushed_assistant_text_len`` tracks the prefix already + # persisted via structured session_messages during mid-loop pending + # drains; including it here would duplicate those rounds. + final_text = state.assistant_text[state._flushed_assistant_text_len :] if state.session_messages: # Strip text already captured in tool-call round messages recorded = "".join( diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index ba1374b720..b67793076f 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -828,3 +828,204 @@ class TestBaselineCostExtraction: # response was never assigned so cost extraction must not raise assert state.cost_usd is None + + +class TestMidLoopPendingFlushOrdering: + """Regression test for the mid-loop pending drain ordering invariant. + + ``_baseline_conversation_updater`` records assistant+tool entries from + each tool-call round into ``state.session_messages``; the finally block + of ``stream_chat_completion_baseline`` batch-flushes them into + ``session.messages`` at the end of the turn. + + The mid-loop pending drain appends pending user messages directly to + ``session.messages``. Without flushing ``state.session_messages`` first, + the pending user message lands BEFORE the preceding round's assistant+ + tool entries in the final persisted ``session.messages`` — which + produces a malformed tool-call/tool-result ordering on the next turn's + replay. + + This test documents the invariant by replaying the production flush + sequence against an in-memory state. + """ + + def test_flush_then_append_preserves_chronological_order(self): + """Mid-loop drain must flush state.session_messages before appending + the pending user message, so the final order matches the + chronological execution order. + """ + # Initial state: user turn already appended by maybe_append_user_message + session_messages: list[ChatMessage] = [ + ChatMessage(role="user", content="original user turn"), + ] + state = _BaselineStreamState() + + # Round 1 completes: conversation_updater buffers assistant+tool + # entries into state.session_messages (but does NOT write to + # session.messages yet). + builder = TranscriptBuilder() + builder.append_user("original user turn") + response = LLMLoopResponse( + response_text="calling search", + tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results = [ + ToolCallResult( + tool_call_id="tc_1", tool_name="search", content="search output" + ), + ] + openai_messages: list = [] + _baseline_conversation_updater( + openai_messages, + response, + tool_results=tool_results, + transcript_builder=builder, + state=state, + model="test-model", + ) + # state.session_messages should now hold the round-1 assistant + tool + assert len(state.session_messages) == 2 + assert state.session_messages[0].role == "assistant" + assert state.session_messages[1].role == "tool" + + # --- Mid-loop pending drain (production code pattern) --- + # Flush first, THEN append pending. This is the ordering fix. + for _buffered in state.session_messages: + session_messages.append(_buffered) + state.session_messages.clear() + session_messages.append( + ChatMessage(role="user", content="pending mid-loop message") + ) + + # Round 2 completes: new assistant+tool entries buffer again. + response2 = LLMLoopResponse( + response_text="another call", + tool_calls=[LLMToolCall(id="tc_2", name="calc", arguments="{}")], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + tool_results2 = [ + ToolCallResult( + tool_call_id="tc_2", tool_name="calc", content="calc output" + ), + ] + _baseline_conversation_updater( + openai_messages, + response2, + tool_results=tool_results2, + transcript_builder=builder, + state=state, + model="test-model", + ) + + # --- Finally-block flush (end of turn) --- + for msg in state.session_messages: + session_messages.append(msg) + + # Assert chronological order: original user, round-1 assistant, + # round-1 tool, pending user, round-2 assistant, round-2 tool. + assert [m.role for m in session_messages] == [ + "user", + "assistant", + "tool", + "user", + "assistant", + "tool", + ] + assert session_messages[0].content == "original user turn" + assert session_messages[3].content == "pending mid-loop message" + # The assistant message carrying tool_call tc_1 must be immediately + # followed by its tool result — no user message interposed. + assert session_messages[1].role == "assistant" + assert session_messages[1].tool_calls is not None + assert session_messages[1].tool_calls[0]["id"] == "tc_1" + assert session_messages[2].role == "tool" + assert session_messages[2].tool_call_id == "tc_1" + # Same invariant for the round after the pending user. + assert session_messages[4].tool_calls is not None + assert session_messages[4].tool_calls[0]["id"] == "tc_2" + assert session_messages[5].tool_call_id == "tc_2" + + def test_flushed_assistant_text_len_prevents_duplicate_final_text(self): + """After mid-loop drain clears state.session_messages, the finally + block must not re-append assistant text from rounds already flushed. + + ``state.assistant_text`` accumulates ALL rounds' text, but + ``state.session_messages`` only holds entries from rounds AFTER the + last mid-loop flush. Without ``_flushed_assistant_text_len``, the + ``finally`` block's ``startswith(recorded)`` check fails because + ``recorded`` only covers post-flush rounds, and the full + ``assistant_text`` is appended — duplicating pre-flush rounds. + """ + state = _BaselineStreamState() + session_messages: list[ChatMessage] = [ + ChatMessage(role="user", content="user turn"), + ] + + # Simulate round 1 text accumulation (as _bound_llm_caller does) + state.assistant_text += "calling search" + + # Round 1 conversation_updater buffers structured entries + builder = TranscriptBuilder() + builder.append_user("user turn") + response1 = LLMLoopResponse( + response_text="calling search", + tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")], + raw_response=None, + prompt_tokens=0, + completion_tokens=0, + ) + _baseline_conversation_updater( + [], + response1, + tool_results=[ + ToolCallResult( + tool_call_id="tc_1", tool_name="search", content="result" + ) + ], + transcript_builder=builder, + state=state, + model="test-model", + ) + + # Mid-loop drain: flush + clear + record flushed text length + for _buffered in state.session_messages: + session_messages.append(_buffered) + state.session_messages.clear() + state._flushed_assistant_text_len = len(state.assistant_text) + session_messages.append(ChatMessage(role="user", content="pending message")) + + # Simulate round 2 text accumulation + state.assistant_text += "final answer" + + # Round 2: natural finish (no tool calls → no session_messages entry) + + # --- Finally block logic (production code) --- + for msg in state.session_messages: + session_messages.append(msg) + + final_text = state.assistant_text[state._flushed_assistant_text_len :] + if state.session_messages: + recorded = "".join( + m.content or "" for m in state.session_messages if m.role == "assistant" + ) + if final_text.startswith(recorded): + final_text = final_text[len(recorded) :] + if final_text.strip(): + session_messages.append(ChatMessage(role="assistant", content=final_text)) + + # The final assistant message should only contain round-2 text, + # not the round-1 text that was already flushed mid-loop. + assistant_msgs = [m for m in session_messages if m.role == "assistant"] + # Round-1 structured assistant (from mid-loop flush) + assert assistant_msgs[0].content == "calling search" + assert assistant_msgs[0].tool_calls is not None + # Round-2 final text (from finally block) + assert assistant_msgs[1].content == "final answer" + assert assistant_msgs[1].tool_calls is None + # Crucially: only 2 assistant messages, not 3 (no duplicate) + assert len(assistant_msgs) == 2 diff --git a/autogpt_platform/backend/backend/copilot/pending_messages.py b/autogpt_platform/backend/backend/copilot/pending_messages.py new file mode 100644 index 0000000000..20f673215d --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/pending_messages.py @@ -0,0 +1,222 @@ +"""Pending-message buffer for in-flight copilot turns. + +When a user sends a new message while a copilot turn is already executing, +instead of blocking the frontend (or queueing a brand-new turn after the +current one finishes), we want the new message to be *injected into the +running turn* — appended between tool-call rounds so the model sees it +before its next LLM call. + +This module provides the cross-process buffer that makes that possible: + +- **Producer** (chat API route): pushes a pending message to Redis and + publishes a notification on a pub/sub channel. +- **Consumer** (executor running the turn): on each tool-call round, + drains the buffer and appends the pending messages to the conversation. + +The Redis list is the durable store; the pub/sub channel is a fast +wake-up hint for long-idle consumers (not used by default, but available +for future blocking-wait semantics). + +A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The +buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push. +""" + +import json +import logging +from typing import Any, cast + +from pydantic import BaseModel, Field, ValidationError + +from backend.data.redis_client import get_redis_async + +logger = logging.getLogger(__name__) + +# Per-session cap. Higher values risk a runaway consumer; lower values +# risk dropping user input under heavy typing. 10 was chosen as a +# reasonable ceiling — a user typing faster than the copilot can drain +# between tool rounds is already an unusual usage pattern. +MAX_PENDING_MESSAGES = 10 + +# Redis key + TTL. The buffer is ephemeral: if a turn completes or the +# executor dies, the pending messages should either have been drained +# already or are safe to drop (the user can resend). +_PENDING_KEY_PREFIX = "copilot:pending:" +_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:" +_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default + +# Payload sent on the pub/sub notify channel. Subscribers treat any +# message as a wake-up hint; the value itself is not meaningful. +_NOTIFY_PAYLOAD = "1" + + +class PendingMessageContext(BaseModel, extra="forbid"): + """Structured page context attached to a pending message.""" + + url: str | None = None + content: str | None = None + + +class PendingMessage(BaseModel): + """A user message queued for injection into an in-flight turn.""" + + content: str = Field(min_length=1, max_length=16_000) + file_ids: list[str] = Field(default_factory=list) + context: PendingMessageContext | None = None + + +def _buffer_key(session_id: str) -> str: + return f"{_PENDING_KEY_PREFIX}{session_id}" + + +def _notify_channel(session_id: str) -> str: + return f"{_PENDING_CHANNEL_PREFIX}{session_id}" + + +# Lua script: push-then-trim-then-expire-then-length, atomically. +# Redis serializes EVAL commands, so a concurrent ``LPOP`` drain +# observes either the pre-push or post-push state of the list — never +# a partial state where the RPUSH has landed but LTRIM hasn't run. +_PUSH_LUA = """ +redis.call('RPUSH', KEYS[1], ARGV[1]) +redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1) +redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3])) +return redis.call('LLEN', KEYS[1]) +""" + + +async def push_pending_message( + session_id: str, + message: PendingMessage, +) -> int: + """Append a pending message to the session's buffer atomically. + + Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by + trimming from the left (oldest) — the newest message always wins if + the user has been typing faster than the copilot can drain. + + The push + trim + expire + llen are wrapped in a single Lua EVAL so + concurrent LPOP drains from the executor never observe a partial + state. + """ + redis = await get_redis_async() + key = _buffer_key(session_id) + payload = message.model_dump_json() + + new_length = int( + await cast( + "Any", + redis.eval( + _PUSH_LUA, + 1, + key, + payload, + str(MAX_PENDING_MESSAGES), + str(_PENDING_TTL_SECONDS), + ), + ) + ) + + # Fire-and-forget notify. Subscribers use this as a wake-up hint; + # the buffer itself is authoritative so a lost notify is harmless. + try: + await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD) + except Exception as e: # pragma: no cover + logger.warning("pending_messages: publish failed for %s: %s", session_id, e) + + logger.info( + "pending_messages: pushed message to session=%s (buffer_len=%d)", + session_id, + new_length, + ) + return new_length + + +async def drain_pending_messages(session_id: str) -> list[PendingMessage]: + """Atomically pop all pending messages for *session_id*. + + Returns them in enqueue order (oldest first). Uses ``LPOP`` with a + count so the read+delete is a single Redis round trip. If the list + is empty or missing, returns ``[]``. + """ + redis = await get_redis_async() + key = _buffer_key(session_id) + + # Redis LPOP with count (Redis 6.2+) returns None for missing key, + # empty list if we somehow race an empty key, or the popped items. + # redis-py's async lpop overload with a count collapses the return + # type in pyright; cast the awaitable so strict type-check stays + # clean without changing runtime behaviour. + lpop_result = await cast( + "Any", + redis.lpop(key, MAX_PENDING_MESSAGES), + ) + if not lpop_result: + return [] + raw_popped: list[Any] = list(lpop_result) + + # redis-py may return bytes or str depending on decode_responses. + decoded: list[str] = [ + item.decode("utf-8") if isinstance(item, bytes) else str(item) + for item in raw_popped + ] + + messages: list[PendingMessage] = [] + for payload in decoded: + try: + messages.append(PendingMessage(**json.loads(payload))) + except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e: + logger.warning( + "pending_messages: dropping malformed entry for %s: %s", + session_id, + e, + ) + + if messages: + logger.info( + "pending_messages: drained %d messages for session=%s", + len(messages), + session_id, + ) + return messages + + +async def peek_pending_count(session_id: str) -> int: + """Return the current buffer length without consuming it.""" + redis = await get_redis_async() + length = await cast("Any", redis.llen(_buffer_key(session_id))) + return int(length) + + +async def clear_pending_messages(session_id: str) -> None: + """Drop the session's pending buffer. + + Not called by the normal turn flow — the atomic ``LPOP`` drain at + turn start is the primary consumer, and any push that arrives + after the drain window belongs to the next turn by definition. + Retained as an operator/debug escape hatch for manually clearing a + stuck session and as a fixture in the unit tests. + """ + redis = await get_redis_async() + await redis.delete(_buffer_key(session_id)) + + +def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]: + """Shape a ``PendingMessage`` into the OpenAI-format user message dict. + + Used by the baseline tool-call loop when injecting the buffered + message into the conversation. Context/file metadata (if any) is + embedded into the content so the model sees everything in one block. + """ + parts: list[str] = [message.content] + if message.context: + if message.context.url: + parts.append(f"\n\n[Page URL: {message.context.url}]") + if message.context.content: + parts.append(f"\n\n[Page content]\n{message.context.content}") + if message.file_ids: + parts.append( + "\n\n[Attached files]\n" + + "\n".join(f"- file_id={fid}" for fid in message.file_ids) + + "\nUse read_workspace_file with the file_id to access file contents." + ) + return {"role": "user", "content": "".join(parts)} diff --git a/autogpt_platform/backend/backend/copilot/pending_messages_test.py b/autogpt_platform/backend/backend/copilot/pending_messages_test.py new file mode 100644 index 0000000000..cd3f6b7c43 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/pending_messages_test.py @@ -0,0 +1,246 @@ +"""Tests for the copilot pending-messages buffer. + +Uses a fake async Redis client so the tests don't require a real Redis +instance (the backend test suite's DB/Redis fixtures are heavyweight +and pull in the full app startup). +""" + +import json +from typing import Any + +import pytest + +from backend.copilot import pending_messages as pm_module +from backend.copilot.pending_messages import ( + MAX_PENDING_MESSAGES, + PendingMessage, + PendingMessageContext, + clear_pending_messages, + drain_pending_messages, + format_pending_as_user_message, + peek_pending_count, + push_pending_message, +) + +# ── Fake Redis ────────────────────────────────────────────────────── + + +class _FakeRedis: + def __init__(self) -> None: + # Values are ``str | bytes`` because real redis-py returns + # bytes when ``decode_responses=False``; the drain path must + # handle both and our tests exercise both. + self.lists: dict[str, list[str | bytes]] = {} + self.published: list[tuple[str, str]] = [] + + async def eval(self, script: str, num_keys: int, *args: Any) -> Any: + """Emulate the push Lua script. + + The real Lua script runs atomically in Redis; the fake + implementation just runs the equivalent list operations in + order and returns the final LLEN. That's enough to exercise + the cap + ordering invariants the tests care about. + """ + key = args[0] + payload = args[1] + max_len = int(args[2]) + # ARGV[3] is TTL — fake doesn't enforce expiry + lst = self.lists.setdefault(key, []) + lst.append(payload) + if len(lst) > max_len: + # RPUSH + LTRIM(-N, -1) = keep only last N + self.lists[key] = lst[-max_len:] + return len(self.lists[key]) + + async def publish(self, channel: str, payload: str) -> int: + self.published.append((channel, payload)) + return 1 + + async def lpop(self, key: str, count: int) -> list[str | bytes] | None: + lst = self.lists.get(key) + if not lst: + return None + popped = lst[:count] + self.lists[key] = lst[count:] + return popped + + async def llen(self, key: str) -> int: + return len(self.lists.get(key, [])) + + async def delete(self, key: str) -> int: + if key in self.lists: + del self.lists[key] + return 1 + return 0 + + +@pytest.fixture() +def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis: + redis = _FakeRedis() + + async def _get_redis_async() -> _FakeRedis: + return redis + + monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async) + return redis + + +# ── Basic push / drain ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None: + length = await push_pending_message("sess1", PendingMessage(content="hello")) + assert length == 1 + assert await peek_pending_count("sess1") == 1 + + drained = await drain_pending_messages("sess1") + assert len(drained) == 1 + assert drained[0].content == "hello" + assert await peek_pending_count("sess1") == 0 + + +@pytest.mark.asyncio +async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None: + for i in range(3): + await push_pending_message("sess2", PendingMessage(content=f"msg {i}")) + + drained = await drain_pending_messages("sess2") + assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"] + + +@pytest.mark.asyncio +async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None: + assert await drain_pending_messages("nope") == [] + + +# ── Buffer cap ────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None: + # Push MAX_PENDING_MESSAGES + 3 messages + for i in range(MAX_PENDING_MESSAGES + 3): + await push_pending_message("sess3", PendingMessage(content=f"m{i}")) + + # Buffer should be clamped to MAX + assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES + + drained = await drain_pending_messages("sess3") + assert len(drained) == MAX_PENDING_MESSAGES + # Oldest 3 dropped — we should only see m3..m(MAX+2) + assert drained[0].content == "m3" + assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}" + + +# ── Clear ─────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None: + await push_pending_message("sess4", PendingMessage(content="x")) + await push_pending_message("sess4", PendingMessage(content="y")) + await clear_pending_messages("sess4") + assert await peek_pending_count("sess4") == 0 + + +@pytest.mark.asyncio +async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None: + # Clearing an already-empty buffer should not raise + await clear_pending_messages("sess_empty") + await clear_pending_messages("sess_empty") + + +# ── Publish hook ──────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None: + await push_pending_message("sess5", PendingMessage(content="hi")) + assert ("copilot:pending:notify:sess5", "1") in fake_redis.published + + +# ── Format helper ─────────────────────────────────────────────────── + + +def test_format_pending_plain_text() -> None: + msg = PendingMessage(content="just text") + out = format_pending_as_user_message(msg) + assert out == {"role": "user", "content": "just text"} + + +def test_format_pending_with_context_url() -> None: + msg = PendingMessage( + content="see this page", + context=PendingMessageContext(url="https://example.com"), + ) + out = format_pending_as_user_message(msg) + content = out["content"] + assert out["role"] == "user" + assert "see this page" in content + # The URL should appear verbatim in the [Page URL: ...] block. + assert "[Page URL: https://example.com]" in content + + +def test_format_pending_with_file_ids() -> None: + msg = PendingMessage(content="look here", file_ids=["a", "b"]) + out = format_pending_as_user_message(msg) + assert "file_id=a" in out["content"] + assert "file_id=b" in out["content"] + + +def test_format_pending_with_all_fields() -> None: + """All fields (content + context url/content + file_ids) should all appear.""" + msg = PendingMessage( + content="summarise this", + context=PendingMessageContext( + url="https://example.com/page", + content="headline text", + ), + file_ids=["f1", "f2"], + ) + out = format_pending_as_user_message(msg) + body = out["content"] + assert out["role"] == "user" + assert "summarise this" in body + assert "[Page URL: https://example.com/page]" in body + assert "[Page content]\nheadline text" in body + assert "file_id=f1" in body + assert "file_id=f2" in body + + +# ── Malformed payload handling ────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_drain_skips_malformed_entries( + fake_redis: _FakeRedis, +) -> None: + # Seed the fake with a mix of valid and malformed payloads + fake_redis.lists["copilot:pending:bad"] = [ + json.dumps({"content": "valid"}), + "{not valid json", + json.dumps({"content": "also valid", "file_ids": ["a"]}), + ] + drained = await drain_pending_messages("bad") + assert len(drained) == 2 + assert drained[0].content == "valid" + assert drained[1].content == "also valid" + + +@pytest.mark.asyncio +async def test_drain_decodes_bytes_payloads( + fake_redis: _FakeRedis, +) -> None: + """Real redis-py returns ``bytes`` when ``decode_responses=False``. + + Seed the fake with bytes values to exercise the ``decode("utf-8")`` + branch in ``drain_pending_messages`` so a regression there doesn't + slip past CI. + """ + fake_redis.lists["copilot:pending:bytes_sess"] = [ + json.dumps({"content": "from bytes"}).encode("utf-8"), + ] + drained = await drain_pending_messages("bytes_sess") + assert len(drained) == 1 + assert drained[0].content == "from bytes" diff --git a/autogpt_platform/backend/backend/copilot/sdk/query_builder_test.py b/autogpt_platform/backend/backend/copilot/sdk/query_builder_test.py index 57f037baba..4a7bf01823 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/query_builder_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/query_builder_test.py @@ -226,6 +226,111 @@ async def test_build_query_no_resume_multi_message(monkeypatch): assert was_compacted is False # mock returns False +@pytest.mark.asyncio +async def test_build_query_session_msg_ceiling_prevents_pending_duplication(): + """session_msg_ceiling stops pending messages from leaking into the gap. + + Scenario: transcript covers 2 messages, session has 2 historical + 1 current + + 2 pending drained at turn start. Without the ceiling the gap would include + the pending messages AND current_message already has them → duplication. + With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and + only current_message carries the pending content. + """ + # session.messages after drain: [hist1, hist2, current_msg, pending1, pending2] + session = _make_session( + [ + ChatMessage(role="user", content="hist1"), + ChatMessage(role="assistant", content="hist2"), + ChatMessage(role="user", content="current msg with pending1 pending2"), + ChatMessage(role="user", content="pending1"), + ChatMessage(role="user", content="pending2"), + ] + ) + # transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg) + result, was_compacted = await _build_query_message( + "current msg with pending1 pending2", + session, + use_resume=True, + transcript_msg_count=2, + session_id="test-session", + session_msg_ceiling=3, # len(session.messages) before drain + ) + # Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended + assert result == "current msg with pending1 pending2" + assert was_compacted is False + # Pending messages must NOT appear in gap context + assert "pending1" not in result.split("current msg")[0] + + +@pytest.mark.asyncio +async def test_build_query_session_msg_ceiling_preserves_real_gap(): + """session_msg_ceiling still surfaces a genuine stale-transcript gap. + + Scenario: transcript covers 2 messages, session has 4 historical + 1 current + + 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4). + """ + session = _make_session( + [ + ChatMessage(role="user", content="hist1"), + ChatMessage(role="assistant", content="hist2"), + ChatMessage(role="user", content="hist3"), + ChatMessage(role="assistant", content="hist4"), + ChatMessage(role="user", content="current"), + ChatMessage(role="user", content="pending1"), + ChatMessage(role="user", content="pending2"), + ] + ) + result, was_compacted = await _build_query_message( + "current", + session, + use_resume=True, + transcript_msg_count=2, + session_id="test-session", + session_msg_ceiling=5, # pre-drain: [hist1..hist4, current] + ) + # Gap = session.messages[2:4] = [hist3, hist4] + assert "" in result + assert "hist3" in result + assert "hist4" in result + assert "Now, the user says:\ncurrent" in result + # Pending messages must NOT appear in gap + assert "pending1" not in result + assert "pending2" not in result + + +@pytest.mark.asyncio +async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback(): + """session_msg_ceiling prevents the no-resume compression fallback from + firing on the first turn of a session when pending messages inflate msg_count. + + Scenario: fresh session (1 message) + 1 pending message drained at turn start. + Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message + leaked into history → wrong context sent to model. + With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False + → fallback does not trigger → current_message returned as-is. + """ + # session.messages after drain: [current_msg, pending_msg] + session = _make_session( + [ + ChatMessage(role="user", content="What is 2 plus 2?"), + ChatMessage(role="user", content="What is 7 plus 7?"), # pending + ] + ) + result, was_compacted = await _build_query_message( + "What is 2 plus 2?\n\nWhat is 7 plus 7?", + session, + use_resume=False, + transcript_msg_count=0, + session_id="test-session", + session_msg_ceiling=1, # pre-drain: only 1 message existed + ) + # Should return current_message directly without wrapping in history context + assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?" + assert was_compacted is False + # Pending question must NOT appear in a spurious history section + assert "" not in result + + @pytest.mark.asyncio async def test_build_query_no_resume_multi_message_compacted(monkeypatch): """When compression actually compacts, was_compacted should be True.""" diff --git a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py index fd831214a6..710daf626a 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py @@ -1031,6 +1031,12 @@ def _make_sdk_patches( ), (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)), + # Stub pending-message drain so retry tests don't hit Redis. + # Returns an empty list → no mid-turn injection happens. + ( + f"{_SVC}.drain_pending_messages", + dict(new_callable=AsyncMock, return_value=[]), + ), ] diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 5ee6bba8ca..08a0f6c08b 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -34,6 +34,10 @@ from opentelemetry import trace as otel_trace from pydantic import BaseModel from backend.copilot.context import get_workspace_manager +from backend.copilot.pending_messages import ( + drain_pending_messages, + format_pending_as_user_message, +) from backend.copilot.permissions import apply_tool_permissions from backend.copilot.rate_limit import get_user_tier from backend.copilot.transcript import ( @@ -955,17 +959,33 @@ async def _build_query_message( use_resume: bool, transcript_msg_count: int, session_id: str, + *, + session_msg_ceiling: int | None = None, ) -> tuple[str, bool]: """Build the query message with appropriate context. + Args: + session_msg_ceiling: If provided, treat ``session.messages`` as if it + only has this many entries when computing the gap slice. Pass + ``len(session.messages)`` captured *before* appending any pending + messages so that mid-turn drains do not skew the gap calculation + and cause pending messages to be duplicated in both the gap context + and ``current_message``. + Returns: Tuple of (query_message, was_compacted). """ msg_count = len(session.messages) + # Use the ceiling if supplied (prevents pending-message duplication when + # messages were appended to session.messages after the drain but before + # this function is called). + effective_count = ( + session_msg_ceiling if session_msg_ceiling is not None else msg_count + ) if use_resume and transcript_msg_count > 0: - if transcript_msg_count < msg_count - 1: - gap = session.messages[transcript_msg_count:-1] + if transcript_msg_count < effective_count - 1: + gap = session.messages[transcript_msg_count : effective_count - 1] compressed, was_compressed = await _compress_messages(gap) gap_context = _format_conversation_context(compressed) if gap_context: @@ -981,12 +1001,14 @@ async def _build_query_message( f"{gap_context}\n\nNow, the user says:\n{current_message}", was_compressed, ) - elif not use_resume and msg_count > 1: + elif not use_resume and effective_count > 1: logger.warning( f"[SDK] Using compression fallback for session " - f"{session_id} ({msg_count} messages) — no transcript for --resume" + f"{session_id} ({effective_count} messages) — no transcript for --resume" + ) + compressed, was_compressed = await _compress_messages( + session.messages[: effective_count - 1] ) - compressed, was_compressed = await _compress_messages(session.messages[:-1]) history_context = _format_conversation_context(compressed) if history_context: return ( @@ -2042,6 +2064,7 @@ async def stream_chat_completion_sdk( async def _fetch_transcript(): """Download transcript for --resume if applicable.""" + assert session is not None # narrowed at line 1898 if not ( config.claude_agent_use_resume and user_id and len(session.messages) > 1 ): @@ -2288,6 +2311,69 @@ async def stream_chat_completion_sdk( if last_user: current_message = last_user[-1].content or "" + # Capture the message count *before* draining so _build_query_message + # can compute the gap slice without including the newly-drained pending + # messages. Pending messages are both appended to session.messages AND + # concatenated into current_message; without the ceiling the gap slice + # would extend into the pending messages and duplicate them in the + # model's input context (gap_context + current_message both containing + # them). + _pre_drain_msg_count = len(session.messages) + + # Drain any messages the user queued via POST /messages/pending + # while the previous turn was running (or since the session was + # idle). Messages are drained ATOMICALLY — one LPOP with count + # removes them all at once, so a concurrent push lands *after* + # the drain and stays queued for the next turn instead of being + # lost between LPOP and clear. File IDs and context are + # preserved via format_pending_as_user_message. + # + # The drained content is concatenated into ``current_message`` + # so the SDK CLI sees it in the new user message, AND appended + # directly to ``session.messages`` (no dedup — pending messages are + # atomically-popped from Redis and are never stale-cache duplicates) + # so the durable transcript records it too. Session is persisted + # immediately after the drain so a crash doesn't lose the messages. + # The endpoint deliberately does NOT persist to session.messages — + # Redis is the single source of truth until this drain runs. + try: + pending_at_start = await drain_pending_messages(session_id) + except Exception: + logger.warning( + "%s drain_pending_messages failed at turn start, skipping", + log_prefix, + exc_info=True, + ) + pending_at_start = [] + if pending_at_start: + logger.info( + "%s Draining %d pending message(s) at turn start", + log_prefix, + len(pending_at_start), + ) + pending_texts: list[str] = [ + format_pending_as_user_message(pm)["content"] for pm in pending_at_start + ] + for pt in pending_texts: + # Append directly — pending messages are atomically-popped from + # Redis and are never stale-cache duplicates, so the + # maybe_append_user_message dedup is wrong here. + session.messages.append(ChatMessage(role="user", content=pt)) + if current_message.strip(): + current_message = current_message + "\n\n" + "\n\n".join(pending_texts) + else: + current_message = "\n\n".join(pending_texts) + # Persist immediately so a crash between here and the finally block + # doesn't lose messages that were already drained from Redis. + try: + session = await upsert_chat_session(session) + except Exception as _persist_err: + logger.warning( + "%s Failed to persist drained pending messages: %s", + log_prefix, + _persist_err, + ) + if not current_message.strip(): yield StreamError( errorText="Message cannot be empty.", @@ -2301,6 +2387,7 @@ async def stream_chat_completion_sdk( use_resume, transcript_msg_count, session_id, + session_msg_ceiling=_pre_drain_msg_count, ) # On the first turn inject user context into the message instead of the # system prompt — the system prompt is now static (same for all users) @@ -2438,6 +2525,7 @@ async def stream_chat_completion_sdk( state.use_resume, state.transcript_msg_count, session_id, + session_msg_ceiling=_pre_drain_msg_count, ) if attachments.hint: state.query_message = f"{state.query_message}\n\n{attachments.hint}" @@ -2767,6 +2855,11 @@ async def stream_chat_completion_sdk( raise finally: + # Pending messages are drained atomically at the start of each + # turn (see drain_pending_messages call above), so there's + # nothing to clean up here — any message pushed after that + # point belongs to the next turn. + # --- Close OTEL context (with cost attributes) --- if _otel_ctx is not None: try: diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 446b2eb079..1b3b1b75f2 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1605,6 +1605,60 @@ } } }, + "/api/chat/sessions/{session_id}/messages/pending": { + "post": { + "tags": ["v2", "chat", "chat"], + "summary": "Queue Pending Message", + "description": "Queue a new user message into an in-flight copilot turn.\n\nWhen a user sends a follow-up message while a turn is still\nstreaming, we don't want to block them or start a separate turn —\nthis endpoint appends the message to a per-session pending buffer.\nThe executor currently running the turn (baseline path) drains the\nbuffer between tool-call rounds and appends the message to the\nconversation before the next LLM call. On the SDK path the buffer\nis drained at the *start* of the next turn (the long-lived\n``ClaudeSDKClient.receive_response`` iterator returns after a\n``ResultMessage`` so there is no safe point to inject mid-stream\ninto an existing connection).\n\nReturns 202. Enforces the same per-user daily/weekly token rate\nlimit as the regular ``/stream`` endpoint so a client can't bypass\nit by batching messages through here.", + "operationId": "postV2QueuePendingMessage", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "session_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Session Id" } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueuePendingMessageRequest" + } + } + } + }, + "responses": { + "202": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/QueuePendingMessageResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "404": { "description": "Session not found or access denied" }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + }, + "429": { + "description": "Token rate-limit or call-frequency cap exceeded" + } + } + } + }, "/api/chat/sessions/{session_id}/stream": { "get": { "tags": ["v2", "chat", "chat"], @@ -12124,6 +12178,22 @@ "title": "PendingHumanReviewModel", "description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)" }, + "PendingMessageContext": { + "properties": { + "url": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Url" + }, + "content": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Content" + } + }, + "additionalProperties": false, + "type": "object", + "title": "PendingMessageContext", + "description": "Structured page context attached to a pending message." + }, "PlatformCostDashboard": { "properties": { "by_provider": { @@ -12668,6 +12738,53 @@ "required": ["providers", "pagination"], "title": "ProviderResponse" }, + "QueuePendingMessageRequest": { + "properties": { + "message": { + "type": "string", + "maxLength": 16000, + "minLength": 1, + "title": "Message" + }, + "context": { + "anyOf": [ + { "$ref": "#/components/schemas/PendingMessageContext" }, + { "type": "null" } + ], + "description": "Optional page context with 'url' and 'content' fields." + }, + "file_ids": { + "anyOf": [ + { + "items": { "type": "string" }, + "type": "array", + "maxItems": 20 + }, + { "type": "null" } + ], + "title": "File Ids" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["message"], + "title": "QueuePendingMessageRequest", + "description": "Request model for queueing a message into an in-flight turn.\n\nUnlike ``StreamChatRequest`` this endpoint does **not** start a new\nturn — the message is appended to a per-session pending buffer that\nthe executor currently processing the turn will drain between tool\nrounds." + }, + "QueuePendingMessageResponse": { + "properties": { + "buffer_length": { "type": "integer", "title": "Buffer Length" }, + "max_buffer_length": { + "type": "integer", + "title": "Max Buffer Length" + }, + "turn_in_flight": { "type": "boolean", "title": "Turn In Flight" } + }, + "type": "object", + "required": ["buffer_length", "max_buffer_length", "turn_in_flight"], + "title": "QueuePendingMessageResponse", + "description": "Response for the pending-message endpoint.\n\n- ``buffer_length``: how many messages are now in the session's\n pending buffer (after this push)\n- ``max_buffer_length``: the per-session cap (server-side constant)\n- ``turn_in_flight``: ``True`` if a copilot turn was running when\n we checked — purely informational for UX feedback. Even when\n ``False`` the message is still queued: the next turn drains it." + }, "RateLimitResetResponse": { "properties": { "success": { "type": "boolean", "title": "Success" },