From cafe49f29580dec1b1ae368beed6f81a4bcada80 Mon Sep 17 00:00:00 2001 From: majdyz Date: Fri, 10 Apr 2026 15:37:40 +0000 Subject: [PATCH] fix(copilot): address round 1 review on pending-messages feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical fix — the SDK mid-stream injection was structurally broken. ``ClaudeSDKClient.receive_response()`` explicitly returns after the first ``ResultMessage``, so re-issuing ``client.query()`` and setting ``acc.stream_completed = False`` could never restart the iteration — the next ``__anext__`` raised ``StopAsyncIteration`` and the injected turn's response was never consumed. Replaced the broken mid-stream path with a turn-start drain that works for both baseline and SDK. ### Changes **Atomic push via Lua EVAL** (``pending_messages.py``) - Replace the ``RPUSH`` + ``LTRIM`` + ``EXPIRE`` + ``LLEN`` pipeline (which was ``transaction=False`` and racy against concurrent ``LPOP``) with a single Lua script so the push is atomic. - Drop the unused ``enqueued_at`` field. - Add 16k ``max_length`` cap on ``PendingMessage.content``. **Baseline path** (``baseline/service.py``) - Drain at turn start (atomic ``LPOP``): any message queued while the session was idle or between turns is picked up before the first LLM call. - Mid-loop drain now skips the final ``tool_call_loop`` yield (``finished_naturally=True``) — draining there would append a user message the loop is about to exit past, silently losing it. - Inject via ``format_pending_as_user_message`` so file IDs + context are preserved in both ``openai_messages`` and the persisted session transcript (previously the DB copy lost file/context metadata). - Remove the ``finally`` ``clear_pending_messages`` — atomic drain at turn start means any late push belongs to the next turn; clearing here would racily clobber it. **SDK path** (``sdk/service.py``) - Remove the broken mid-stream injection block entirely. - Drain at turn start (same atomic ``LPOP``) and merge the drained messages into ``current_message`` before ``_build_query_message``, so the SDK CLI sees them as part of the initial user message. - Remove the ``finally`` ``clear_pending_messages``. - Delete the unused ``_combine_pending_messages`` helper. **Endpoint** (``api/features/chat/routes.py``) - Enforce ``check_rate_limit`` / ``get_global_rate_limits`` — was bypassing per-user daily/weekly token limits that ``/stream`` enforces. - ``QueuePendingMessageRequest`` gets ``extra="forbid"`` and ``message: max_length=16_000``. - Push-first, persist-second: if the Redis push fails we raise 5xx; previously the session DB got an orphan user message with no corresponding queued entry and a retry would duplicate it. - Log a warning when sanitised file IDs drop unknown entries. - Persisted message content now uses ``format_pending_as_user_message`` so the session copy matches what the model actually sees on drain. - Response returns ``buffer_length``, ``max_buffer_length``, and ``turn_in_flight`` so the frontend can show accurate feedback about whether the message will hit the current turn or the next one. **Tests** (``pending_messages_test.py``) - ``_FakeRedis.eval`` emulates the Lua push script so the existing push/drain/cap tests keep working under the new atomic path. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../backend/api/features/chat/routes.py | 101 ++++++++++++++---- .../backend/copilot/baseline/service.py | 61 +++++++---- .../backend/copilot/pending_messages.py | 48 ++++++--- .../backend/copilot/pending_messages_test.py | 19 ++++ .../backend/backend/copilot/sdk/service.py | 92 +++++----------- 5 files changed, 199 insertions(+), 122 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index a1eebdd6e3..b2269b0964 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -32,6 +32,7 @@ from backend.copilot.model import ( from backend.copilot.pending_messages import ( MAX_PENDING_MESSAGES, PendingMessage, + format_pending_as_user_message, push_pending_message, ) from backend.copilot.rate_limit import ( @@ -133,15 +134,28 @@ class QueuePendingMessageRequest(BaseModel): rounds. """ - message: str = Field(min_length=1) - context: dict[str, str] | None = None + model_config = ConfigDict(extra="forbid") + + message: str = Field(min_length=1, max_length=16_000) + context: dict[str, str] | None = Field( + default=None, + description="Optional page context: expected keys are 'url' and 'content'.", + ) file_ids: list[str] | None = Field(default=None, max_length=20) class QueuePendingMessageResponse(BaseModel): + """Response for the pending-message endpoint. + + Clients should rely on ``queued`` / ``buffer_length`` / ``turn_in_flight`` + — the ``detail`` field is human-readable and may change without notice. + """ + queued: bool buffer_length: int - message: str + max_buffer_length: int + turn_in_flight: bool + detail: str class CreateSessionRequest(BaseModel): @@ -1051,32 +1065,44 @@ async def queue_pending_message( When a user sends a follow-up message while a turn is still streaming, we don't want to block them or start a separate turn — - this endpoint appends the message to a per-session pending buffer - that the executor currently processing the turn will drain between - tool-call rounds, injecting it into the conversation before the - model's next LLM call. + this endpoint appends the message to a per-session pending buffer. + The executor currently running the turn (baseline path) drains the + buffer between tool-call rounds and appends the message to the + conversation before the next LLM call. On the SDK path the buffer + is drained at the *start* of the next turn (the long-lived + ``ClaudeSDKClient.receive_response`` iterator returns after a + ``ResultMessage`` so there is no safe point to inject mid-stream + into an existing connection). - Returns 202 with the new buffer length on success. If the buffer - is full (``MAX_PENDING_MESSAGES``), the oldest pending message is - evicted to make room for the new one — the newest message always - wins. - - Intended for the frontend "send while streaming" flow. If no turn - is currently in flight the message is still queued — the next turn - the user starts will pick it up before its first LLM call. + Returns 202. Enforces the same per-user daily/weekly token rate + limit as the regular ``/stream`` endpoint so a client can't bypass + it by batching messages through here. """ await _validate_and_get_session(session_id, user_id) - # Persist the message to the session immediately so it shows up in - # the transcript even before the executor drains the buffer. - chat_msg = ChatMessage(role="user", content=request.message) + # Pre-turn rate-limit check — mirrors stream_chat_post. Without + # this, a client could bypass per-turn token limits by batching + # their extra context through this endpoint while a cheap stream + # is in flight. + if user_id: + try: + daily_limit, weekly_limit, _tier = await get_global_rate_limits( + user_id, config.daily_token_limit, config.weekly_token_limit + ) + await check_rate_limit( + user_id=user_id, + daily_token_limit=daily_limit, + weekly_token_limit=weekly_limit, + ) + except RateLimitExceeded as e: + raise HTTPException(status_code=429, detail=str(e)) from e + if user_id: track_user_message( user_id=user_id, session_id=session_id, message_length=len(request.message), ) - await append_and_save_message(session_id, chat_msg) # Sanitise file IDs to the user's own workspace (same logic as # stream_chat_post) so injection doesn't surface other users' files. @@ -1093,7 +1119,18 @@ async def queue_pending_message( } ) sanitized_file_ids = [wf.id for wf in files] + if len(sanitized_file_ids) != len(valid_ids): + logger.warning( + "queue_pending_message: dropped %d file id(s) not in " + "caller's workspace (session=%s)", + len(valid_ids) - len(sanitized_file_ids), + session_id, + ) + # Push to Redis BEFORE writing to the session DB. If the push + # fails we raise 5xx and the client retries; ``append_and_save_message`` + # would otherwise leave an orphan user message persisted with no + # corresponding queued pending entry, and a retry would duplicate it. pending = PendingMessage( content=request.message, file_ids=sanitized_file_ids, @@ -1101,12 +1138,32 @@ async def queue_pending_message( ) buffer_length = await push_pending_message(session_id, pending) + # Persist the message into the session transcript only after the + # push succeeds. The message content embeds file/context metadata + # via format_pending_as_user_message so the DB copy matches what + # the model will actually see on drain. + chat_msg = ChatMessage( + role="user", + content=format_pending_as_user_message(pending)["content"], + ) + await append_and_save_message(session_id, chat_msg) + + # Check whether a turn is currently running for UX feedback. + active_session = await stream_registry.get_session(session_id) + turn_in_flight = bool(active_session and active_session.status == "running") + return QueuePendingMessageResponse( queued=True, buffer_length=buffer_length, - message=( - f"Queued — will be injected into the current turn " - f"(buffer: {buffer_length}/{MAX_PENDING_MESSAGES})" + max_buffer_length=MAX_PENDING_MESSAGES, + turn_in_flight=turn_in_flight, + detail=( + ( + "Queued — will be injected into the current turn." + if turn_in_flight + else "Queued — will be injected at the start of the next turn." + ) + + f" buffer={buffer_length}/{MAX_PENDING_MESSAGES}" ), ) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 1658d93eb1..bb800c10c7 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -36,7 +36,6 @@ from backend.copilot.model import ( upsert_chat_session, ) from backend.copilot.pending_messages import ( - clear_pending_messages, drain_pending_messages, format_pending_as_user_message, ) @@ -933,6 +932,23 @@ async def stream_chat_completion_baseline( message_length=len(message or ""), ) + # Drain any messages the user queued via POST /messages/pending + # while this session was idle (or during a previous turn whose + # mid-loop drains missed them). Atomic LPOP guarantees that a + # concurrent push lands *after* the drain and stays queued for the + # next turn instead of being lost. Prepended to the session so + # the initial LLM call sees them. + drained_at_start = await drain_pending_messages(session_id) + if drained_at_start: + logger.info( + "[Baseline] Draining %d pending message(s) at turn start " "for session %s", + len(drained_at_start), + session_id, + ) + for _pm in drained_at_start: + _content = format_pending_as_user_message(_pm)["content"] + maybe_append_user_message(session, _content, is_user_message=True) + session = await upsert_chat_session(session) # Select model based on the per-request mode. 'fast' downgrades to @@ -1168,16 +1184,32 @@ async def stream_chat_completion_baseline( # Inject any messages the user queued while the turn was # running. ``tool_call_loop`` mutates ``openai_messages`` # in-place, so appending here means the model sees the new - # messages before its next LLM call. Also persist them to - # the ChatSession so they're part of the durable transcript. + # messages on its next LLM call. + # + # IMPORTANT: skip when the loop has already finished (no + # more LLM calls are coming). Draining here would silently + # lose the message because ``tool_call_loop`` is about to + # return on the next ``async for`` step — the user would + # see a 202 from the pending endpoint but the model would + # never actually read the text. Those messages stay in + # the buffer and will be picked up at the start of the + # next turn. + if loop_result is None or loop_result.finished_naturally: + continue pending = await drain_pending_messages(session_id) if pending: for pm in pending: + # ``format_pending_as_user_message`` embeds file + # attachments and context URL/page content into the + # content string so the in-session transcript is + # a faithful copy of what the model actually saw. + formatted = format_pending_as_user_message(pm) + content_for_db = formatted["content"] maybe_append_user_message( - session, pm.content, is_user_message=True + session, content_for_db, is_user_message=True ) - openai_messages.append(format_pending_as_user_message(pm)) - transcript_builder.append_user(content=pm.content) + openai_messages.append(formatted) + transcript_builder.append_user(content=content_for_db) try: await upsert_chat_session(session) except Exception as persist_err: @@ -1234,19 +1266,10 @@ async def stream_chat_completion_baseline( yield StreamError(errorText=error_msg, code="baseline_error") # Still persist whatever we got finally: - # Safety net — if the stream exited early (error, timeout, etc.) - # we may still have queued pending messages in the buffer. Drop - # them so they don't leak into the next turn. During normal - # completion the tool-call loop drain will already have cleared - # the buffer, so this is a no-op in the happy path. - try: - await clear_pending_messages(session_id) - except Exception as clear_err: - logger.warning( - "[Baseline] Failed to clear pending messages for %s: %s", - session_id, - clear_err, - ) + # Pending messages are drained atomically at turn start and + # between tool rounds, so there's nothing to clear in finally. + # Any message pushed after the final drain window stays in the + # buffer and gets picked up at the start of the next turn. # Set cost attributes on OTEL span before closing if _trace_ctx is not None: diff --git a/autogpt_platform/backend/backend/copilot/pending_messages.py b/autogpt_platform/backend/backend/copilot/pending_messages.py index 0930a87e2d..ea0ae6bc4c 100644 --- a/autogpt_platform/backend/backend/copilot/pending_messages.py +++ b/autogpt_platform/backend/backend/copilot/pending_messages.py @@ -23,7 +23,6 @@ buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push. import json import logging -import time from typing import Any, cast from pydantic import BaseModel, Field @@ -49,11 +48,9 @@ _PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default class PendingMessage(BaseModel): """A user message queued for injection into an in-flight turn.""" - content: str = Field(min_length=1) + content: str = Field(min_length=1, max_length=16_000) file_ids: list[str] = Field(default_factory=list) context: dict[str, str] | None = None - # Unix epoch seconds at enqueue time, for ordering and debugging. - enqueued_at: float = Field(default_factory=time.time) def _buffer_key(session_id: str) -> str: @@ -64,31 +61,50 @@ def _notify_channel(session_id: str) -> str: return f"{_PENDING_CHANNEL_PREFIX}{session_id}" +# Lua script: push-then-trim-then-expire-then-length, atomically. +# Running these four commands via a single EVAL guarantees a concurrent +# LPOP drain lands either entirely before the push (returns 0 from +# our earlier LLEN) or entirely after it (sees the new message) — +# never in the middle of a partial state. +_PUSH_LUA = """ +redis.call('RPUSH', KEYS[1], ARGV[1]) +redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1) +redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3])) +return redis.call('LLEN', KEYS[1]) +""" + + async def push_pending_message( session_id: str, message: PendingMessage, ) -> int: - """Append a pending message to the session's buffer. + """Append a pending message to the session's buffer atomically. Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by trimming from the left (oldest) — the newest message always wins if the user has been typing faster than the copilot can drain. + + The push + trim + expire + llen are wrapped in a single Lua EVAL so + concurrent LPOP drains from the executor never observe a partial + state. """ redis = await get_redis_async() key = _buffer_key(session_id) payload = message.model_dump_json() - # Push + trim + expire in a pipeline so the three writes land atomically - # enough for this use case (pipelining doesn't guarantee atomicity - # across commands but ordering is preserved). - async with redis.pipeline(transaction=False) as pipe: - pipe.rpush(key, payload) - pipe.ltrim(key, -MAX_PENDING_MESSAGES, -1) - pipe.expire(key, _PENDING_TTL_SECONDS) - pipe.llen(key) - results = await pipe.execute() - - new_length = int(results[-1]) + new_length = int( + await cast( + "Any", + redis.eval( + _PUSH_LUA, + 1, + key, + payload, + str(MAX_PENDING_MESSAGES), + str(_PENDING_TTL_SECONDS), + ), + ) + ) # Fire-and-forget notify. Subscribers use this as a wake-up hint; # the buffer itself is authoritative so a lost notify is harmless. diff --git a/autogpt_platform/backend/backend/copilot/pending_messages_test.py b/autogpt_platform/backend/backend/copilot/pending_messages_test.py index b03906f52a..7fec16c708 100644 --- a/autogpt_platform/backend/backend/copilot/pending_messages_test.py +++ b/autogpt_platform/backend/backend/copilot/pending_messages_test.py @@ -79,6 +79,25 @@ class _FakeRedis: def pipeline(self, transaction: bool = False) -> _FakePipeline: return _FakePipeline(self) + async def eval(self, script: str, num_keys: int, *args: Any) -> Any: + """Emulate the push Lua script. + + The real Lua script runs atomically in Redis; the fake + implementation just runs the equivalent list operations in + order and returns the final LLEN. That's enough to exercise + the cap + ordering invariants the tests care about. + """ + key = args[0] + payload = args[1] + max_len = int(args[2]) + # ARGV[3] is TTL — fake doesn't enforce expiry + lst = self.lists.setdefault(key, []) + lst.append(payload) + if len(lst) > max_len: + # RPUSH + LTRIM(-N, -1) = keep only last N + self.lists[key] = lst[-max_len:] + return len(self.lists[key]) + async def publish(self, channel: str, payload: str) -> int: self.published.append((channel, payload)) return 1 diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index feaaabe0ce..7d13b24925 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -35,9 +35,8 @@ from pydantic import BaseModel from backend.copilot.context import get_workspace_manager from backend.copilot.pending_messages import ( - PendingMessage, - clear_pending_messages, drain_pending_messages, + format_pending_as_user_message, ) from backend.copilot.permissions import apply_tool_permissions from backend.copilot.rate_limit import get_user_tier @@ -218,25 +217,6 @@ def _is_prompt_too_long(err: BaseException) -> bool: return False -def _combine_pending_messages(pending: list[PendingMessage]) -> str: - """Merge drained pending messages into a single user-message body. - - The Claude Agent SDK's ``client.query()`` takes a plain string (or - an async iterable); the simplest way to preserve ordering across - multiple drained messages is to concatenate them with a separator - and send a single ``query()`` call. If there's only one message, - its ``content`` is returned verbatim so the transcript stays clean. - """ - if len(pending) == 1: - return pending[0].content - parts: list[str] = [] - for idx, msg in enumerate(pending, start=1): - header = f"[Additional message {idx}]" if idx > 1 else "" - body = msg.content - parts.append(f"{header}\n{body}".lstrip("\n") if header else body) - return "\n\n".join(parts) - - def _is_sdk_disconnect_error(exc: BaseException) -> bool: """Return True if *exc* is an expected SDK cleanup error from client disconnect. @@ -1808,39 +1788,6 @@ async def _run_stream_attempt( _msgs_since_flush = 0 if acc.stream_completed: - # Before exiting the loop, check if the user queued any - # follow-up messages while this turn was running. If so, - # send them to the same live SDK client as a new query - # and reset the stream completion state so we keep - # consuming CLI messages. This avoids releasing the - # cluster lock and requeueing — the pending messages - # flow directly into the existing conversation. - pending = await drain_pending_messages(ctx.session_id) - if pending: - logger.info( - "%s Injecting %d pending message(s) mid-turn", - ctx.log_prefix, - len(pending), - ) - injected_text = _combine_pending_messages(pending) - injected_chat_msg = ChatMessage(role="user", content=injected_text) - ctx.session.messages.append(injected_chat_msg) - state.transcript_builder.append_user(content=injected_text) - try: - await asyncio.shield(upsert_chat_session(ctx.session)) - except Exception as persist_err: - logger.warning( - "%s Failed to persist injected pending message: %s", - ctx.log_prefix, - persist_err, - ) - await client.query(injected_text, session_id=ctx.session_id) - # Reset turn-level state so the next ResultMessage - # ends the injected turn cleanly instead of - # re-completing the previous one. - acc.stream_completed = False - _last_real_msg_time = time.monotonic() - continue break finally: await _safe_close_sdk_client(sdk_client, ctx.log_prefix) @@ -2328,6 +2275,28 @@ async def stream_chat_completion_sdk( if last_user: current_message = last_user[-1].content or "" + # Drain any messages the user queued via POST /messages/pending + # while the previous turn was running (or since the session was + # idle). Messages are drained ATOMICALLY — one LPOP with count + # removes them all at once, so a concurrent push lands *after* + # the drain and stays queued for the next turn instead of being + # lost between LPOP and clear. File IDs and context are + # preserved via format_pending_as_user_message. + pending_at_start = await drain_pending_messages(session_id) + if pending_at_start: + logger.info( + "%s Draining %d pending message(s) at turn start", + log_prefix, + len(pending_at_start), + ) + pending_texts: list[str] = [ + format_pending_as_user_message(pm)["content"] for pm in pending_at_start + ] + if current_message.strip(): + current_message = current_message + "\n\n" + "\n\n".join(pending_texts) + else: + current_message = "\n\n".join(pending_texts) + if not current_message.strip(): yield StreamError( errorText="Message cannot be empty.", @@ -2783,17 +2752,10 @@ async def stream_chat_completion_sdk( raise finally: - # Safety net — drop any pending messages still in the buffer. - # During normal completion the mid-turn drain already cleared - # them; this handles early exits (errors, cancellation, retry). - try: - await clear_pending_messages(session_id) - except Exception as _clear_err: - logger.warning( - "Failed to clear pending messages for %s: %s", - session_id, - _clear_err, - ) + # Pending messages are drained atomically at the start of each + # turn (see drain_pending_messages call above), so there's + # nothing to clean up here — any message pushed after that + # point belongs to the next turn. # --- Close OTEL context (with cost attributes) --- if _otel_ctx is not None: