fix(copilot): address round 1 review on pending-messages feature

Critical fix — the SDK mid-stream injection was structurally broken.
``ClaudeSDKClient.receive_response()`` explicitly returns after the
first ``ResultMessage``, so re-issuing ``client.query()`` and setting
``acc.stream_completed = False`` could never restart the iteration —
the next ``__anext__`` raised ``StopAsyncIteration`` and the injected
turn's response was never consumed.  Replaced the broken mid-stream
path with a turn-start drain that works for both baseline and SDK.

### Changes

**Atomic push via Lua EVAL** (``pending_messages.py``)
- Replace the ``RPUSH`` + ``LTRIM`` + ``EXPIRE`` + ``LLEN`` pipeline
  (which was ``transaction=False`` and racy against concurrent
  ``LPOP``) with a single Lua script so the push is atomic.
- Drop the unused ``enqueued_at`` field.
- Add 16k ``max_length`` cap on ``PendingMessage.content``.

**Baseline path** (``baseline/service.py``)
- Drain at turn start (atomic ``LPOP``): any message queued while the
  session was idle or between turns is picked up before the first
  LLM call.
- Mid-loop drain now skips the final ``tool_call_loop`` yield
  (``finished_naturally=True``) — draining there would append a user
  message the loop is about to exit past, silently losing it.
- Inject via ``format_pending_as_user_message`` so file IDs + context
  are preserved in both ``openai_messages`` and the persisted session
  transcript (previously the DB copy lost file/context metadata).
- Remove the ``finally`` ``clear_pending_messages`` — atomic drain at
  turn start means any late push belongs to the next turn; clearing
  here would racily clobber it.

**SDK path** (``sdk/service.py``)
- Remove the broken mid-stream injection block entirely.
- Drain at turn start (same atomic ``LPOP``) and merge the drained
  messages into ``current_message`` before ``_build_query_message``,
  so the SDK CLI sees them as part of the initial user message.
- Remove the ``finally`` ``clear_pending_messages``.
- Delete the unused ``_combine_pending_messages`` helper.

**Endpoint** (``api/features/chat/routes.py``)
- Enforce ``check_rate_limit`` / ``get_global_rate_limits`` — was
  bypassing per-user daily/weekly token limits that ``/stream``
  enforces.
- ``QueuePendingMessageRequest`` gets ``extra="forbid"`` and
  ``message: max_length=16_000``.
- Push-first, persist-second: if the Redis push fails we raise 5xx;
  previously the session DB got an orphan user message with no
  corresponding queued entry and a retry would duplicate it.
- Log a warning when sanitised file IDs drop unknown entries.
- Persisted message content now uses ``format_pending_as_user_message``
  so the session copy matches what the model actually sees on drain.
- Response returns ``buffer_length``, ``max_buffer_length``, and
  ``turn_in_flight`` so the frontend can show accurate feedback about
  whether the message will hit the current turn or the next one.

**Tests** (``pending_messages_test.py``)
- ``_FakeRedis.eval`` emulates the Lua push script so the existing
  push/drain/cap tests keep working under the new atomic path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
majdyz
2026-04-10 15:37:40 +00:00
parent c6a31cb501
commit cafe49f295
5 changed files with 199 additions and 122 deletions

View File

@@ -32,6 +32,7 @@ from backend.copilot.model import (
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
format_pending_as_user_message,
push_pending_message,
)
from backend.copilot.rate_limit import (
@@ -133,15 +134,28 @@ class QueuePendingMessageRequest(BaseModel):
rounds.
"""
message: str = Field(min_length=1)
context: dict[str, str] | None = None
model_config = ConfigDict(extra="forbid")
message: str = Field(min_length=1, max_length=16_000)
context: dict[str, str] | None = Field(
default=None,
description="Optional page context: expected keys are 'url' and 'content'.",
)
file_ids: list[str] | None = Field(default=None, max_length=20)
class QueuePendingMessageResponse(BaseModel):
"""Response for the pending-message endpoint.
Clients should rely on ``queued`` / ``buffer_length`` / ``turn_in_flight``
— the ``detail`` field is human-readable and may change without notice.
"""
queued: bool
buffer_length: int
message: str
max_buffer_length: int
turn_in_flight: bool
detail: str
class CreateSessionRequest(BaseModel):
@@ -1051,32 +1065,44 @@ async def queue_pending_message(
When a user sends a follow-up message while a turn is still
streaming, we don't want to block them or start a separate turn —
this endpoint appends the message to a per-session pending buffer
that the executor currently processing the turn will drain between
tool-call rounds, injecting it into the conversation before the
model's next LLM call.
this endpoint appends the message to a per-session pending buffer.
The executor currently running the turn (baseline path) drains the
buffer between tool-call rounds and appends the message to the
conversation before the next LLM call. On the SDK path the buffer
is drained at the *start* of the next turn (the long-lived
``ClaudeSDKClient.receive_response`` iterator returns after a
``ResultMessage`` so there is no safe point to inject mid-stream
into an existing connection).
Returns 202 with the new buffer length on success. If the buffer
is full (``MAX_PENDING_MESSAGES``), the oldest pending message is
evicted to make room for the new one — the newest message always
wins.
Intended for the frontend "send while streaming" flow. If no turn
is currently in flight the message is still queued — the next turn
the user starts will pick it up before its first LLM call.
Returns 202. Enforces the same per-user daily/weekly token rate
limit as the regular ``/stream`` endpoint so a client can't bypass
it by batching messages through here.
"""
await _validate_and_get_session(session_id, user_id)
# Persist the message to the session immediately so it shows up in
# the transcript even before the executor drains the buffer.
chat_msg = ChatMessage(role="user", content=request.message)
# Pre-turn rate-limit check — mirrors stream_chat_post. Without
# this, a client could bypass per-turn token limits by batching
# their extra context through this endpoint while a cheap stream
# is in flight.
if user_id:
try:
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
if user_id:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
await append_and_save_message(session_id, chat_msg)
# Sanitise file IDs to the user's own workspace (same logic as
# stream_chat_post) so injection doesn't surface other users' files.
@@ -1093,7 +1119,18 @@ async def queue_pending_message(
}
)
sanitized_file_ids = [wf.id for wf in files]
if len(sanitized_file_ids) != len(valid_ids):
logger.warning(
"queue_pending_message: dropped %d file id(s) not in "
"caller's workspace (session=%s)",
len(valid_ids) - len(sanitized_file_ids),
session_id,
)
# Push to Redis BEFORE writing to the session DB. If the push
# fails we raise 5xx and the client retries; ``append_and_save_message``
# would otherwise leave an orphan user message persisted with no
# corresponding queued pending entry, and a retry would duplicate it.
pending = PendingMessage(
content=request.message,
file_ids=sanitized_file_ids,
@@ -1101,12 +1138,32 @@ async def queue_pending_message(
)
buffer_length = await push_pending_message(session_id, pending)
# Persist the message into the session transcript only after the
# push succeeds. The message content embeds file/context metadata
# via format_pending_as_user_message so the DB copy matches what
# the model will actually see on drain.
chat_msg = ChatMessage(
role="user",
content=format_pending_as_user_message(pending)["content"],
)
await append_and_save_message(session_id, chat_msg)
# Check whether a turn is currently running for UX feedback.
active_session = await stream_registry.get_session(session_id)
turn_in_flight = bool(active_session and active_session.status == "running")
return QueuePendingMessageResponse(
queued=True,
buffer_length=buffer_length,
message=(
f"Queued — will be injected into the current turn "
f"(buffer: {buffer_length}/{MAX_PENDING_MESSAGES})"
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=turn_in_flight,
detail=(
(
"Queued — will be injected into the current turn."
if turn_in_flight
else "Queued — will be injected at the start of the next turn."
)
+ f" buffer={buffer_length}/{MAX_PENDING_MESSAGES}"
),
)

View File

@@ -36,7 +36,6 @@ from backend.copilot.model import (
upsert_chat_session,
)
from backend.copilot.pending_messages import (
clear_pending_messages,
drain_pending_messages,
format_pending_as_user_message,
)
@@ -933,6 +932,23 @@ async def stream_chat_completion_baseline(
message_length=len(message or ""),
)
# Drain any messages the user queued via POST /messages/pending
# while this session was idle (or during a previous turn whose
# mid-loop drains missed them). Atomic LPOP guarantees that a
# concurrent push lands *after* the drain and stays queued for the
# next turn instead of being lost. Prepended to the session so
# the initial LLM call sees them.
drained_at_start = await drain_pending_messages(session_id)
if drained_at_start:
logger.info(
"[Baseline] Draining %d pending message(s) at turn start " "for session %s",
len(drained_at_start),
session_id,
)
for _pm in drained_at_start:
_content = format_pending_as_user_message(_pm)["content"]
maybe_append_user_message(session, _content, is_user_message=True)
session = await upsert_chat_session(session)
# Select model based on the per-request mode. 'fast' downgrades to
@@ -1168,16 +1184,32 @@ async def stream_chat_completion_baseline(
# Inject any messages the user queued while the turn was
# running. ``tool_call_loop`` mutates ``openai_messages``
# in-place, so appending here means the model sees the new
# messages before its next LLM call. Also persist them to
# the ChatSession so they're part of the durable transcript.
# messages on its next LLM call.
#
# IMPORTANT: skip when the loop has already finished (no
# more LLM calls are coming). Draining here would silently
# lose the message because ``tool_call_loop`` is about to
# return on the next ``async for`` step — the user would
# see a 202 from the pending endpoint but the model would
# never actually read the text. Those messages stay in
# the buffer and will be picked up at the start of the
# next turn.
if loop_result is None or loop_result.finished_naturally:
continue
pending = await drain_pending_messages(session_id)
if pending:
for pm in pending:
# ``format_pending_as_user_message`` embeds file
# attachments and context URL/page content into the
# content string so the in-session transcript is
# a faithful copy of what the model actually saw.
formatted = format_pending_as_user_message(pm)
content_for_db = formatted["content"]
maybe_append_user_message(
session, pm.content, is_user_message=True
session, content_for_db, is_user_message=True
)
openai_messages.append(format_pending_as_user_message(pm))
transcript_builder.append_user(content=pm.content)
openai_messages.append(formatted)
transcript_builder.append_user(content=content_for_db)
try:
await upsert_chat_session(session)
except Exception as persist_err:
@@ -1234,19 +1266,10 @@ async def stream_chat_completion_baseline(
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Safety net — if the stream exited early (error, timeout, etc.)
# we may still have queued pending messages in the buffer. Drop
# them so they don't leak into the next turn. During normal
# completion the tool-call loop drain will already have cleared
# the buffer, so this is a no-op in the happy path.
try:
await clear_pending_messages(session_id)
except Exception as clear_err:
logger.warning(
"[Baseline] Failed to clear pending messages for %s: %s",
session_id,
clear_err,
)
# Pending messages are drained atomically at turn start and
# between tool rounds, so there's nothing to clear in finally.
# Any message pushed after the final drain window stays in the
# buffer and gets picked up at the start of the next turn.
# Set cost attributes on OTEL span before closing
if _trace_ctx is not None:

View File

@@ -23,7 +23,6 @@ buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
import json
import logging
import time
from typing import Any, cast
from pydantic import BaseModel, Field
@@ -49,11 +48,9 @@ _PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
class PendingMessage(BaseModel):
"""A user message queued for injection into an in-flight turn."""
content: str = Field(min_length=1)
content: str = Field(min_length=1, max_length=16_000)
file_ids: list[str] = Field(default_factory=list)
context: dict[str, str] | None = None
# Unix epoch seconds at enqueue time, for ordering and debugging.
enqueued_at: float = Field(default_factory=time.time)
def _buffer_key(session_id: str) -> str:
@@ -64,31 +61,50 @@ def _notify_channel(session_id: str) -> str:
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
# Lua script: push-then-trim-then-expire-then-length, atomically.
# Running these four commands via a single EVAL guarantees a concurrent
# LPOP drain lands either entirely before the push (returns 0 from
# our earlier LLEN) or entirely after it (sees the new message) —
# never in the middle of a partial state.
_PUSH_LUA = """
redis.call('RPUSH', KEYS[1], ARGV[1])
redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1)
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3]))
return redis.call('LLEN', KEYS[1])
"""
async def push_pending_message(
session_id: str,
message: PendingMessage,
) -> int:
"""Append a pending message to the session's buffer.
"""Append a pending message to the session's buffer atomically.
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
trimming from the left (oldest) — the newest message always wins if
the user has been typing faster than the copilot can drain.
The push + trim + expire + llen are wrapped in a single Lua EVAL so
concurrent LPOP drains from the executor never observe a partial
state.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
payload = message.model_dump_json()
# Push + trim + expire in a pipeline so the three writes land atomically
# enough for this use case (pipelining doesn't guarantee atomicity
# across commands but ordering is preserved).
async with redis.pipeline(transaction=False) as pipe:
pipe.rpush(key, payload)
pipe.ltrim(key, -MAX_PENDING_MESSAGES, -1)
pipe.expire(key, _PENDING_TTL_SECONDS)
pipe.llen(key)
results = await pipe.execute()
new_length = int(results[-1])
new_length = int(
await cast(
"Any",
redis.eval(
_PUSH_LUA,
1,
key,
payload,
str(MAX_PENDING_MESSAGES),
str(_PENDING_TTL_SECONDS),
),
)
)
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
# the buffer itself is authoritative so a lost notify is harmless.

View File

@@ -79,6 +79,25 @@ class _FakeRedis:
def pipeline(self, transaction: bool = False) -> _FakePipeline:
return _FakePipeline(self)
async def eval(self, script: str, num_keys: int, *args: Any) -> Any:
"""Emulate the push Lua script.
The real Lua script runs atomically in Redis; the fake
implementation just runs the equivalent list operations in
order and returns the final LLEN. That's enough to exercise
the cap + ordering invariants the tests care about.
"""
key = args[0]
payload = args[1]
max_len = int(args[2])
# ARGV[3] is TTL — fake doesn't enforce expiry
lst = self.lists.setdefault(key, [])
lst.append(payload)
if len(lst) > max_len:
# RPUSH + LTRIM(-N, -1) = keep only last N
self.lists[key] = lst[-max_len:]
return len(self.lists[key])
async def publish(self, channel: str, payload: str) -> int:
self.published.append((channel, payload))
return 1

View File

@@ -35,9 +35,8 @@ from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.pending_messages import (
PendingMessage,
clear_pending_messages,
drain_pending_messages,
format_pending_as_user_message,
)
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
@@ -218,25 +217,6 @@ def _is_prompt_too_long(err: BaseException) -> bool:
return False
def _combine_pending_messages(pending: list[PendingMessage]) -> str:
"""Merge drained pending messages into a single user-message body.
The Claude Agent SDK's ``client.query()`` takes a plain string (or
an async iterable); the simplest way to preserve ordering across
multiple drained messages is to concatenate them with a separator
and send a single ``query()`` call. If there's only one message,
its ``content`` is returned verbatim so the transcript stays clean.
"""
if len(pending) == 1:
return pending[0].content
parts: list[str] = []
for idx, msg in enumerate(pending, start=1):
header = f"[Additional message {idx}]" if idx > 1 else ""
body = msg.content
parts.append(f"{header}\n{body}".lstrip("\n") if header else body)
return "\n\n".join(parts)
def _is_sdk_disconnect_error(exc: BaseException) -> bool:
"""Return True if *exc* is an expected SDK cleanup error from client disconnect.
@@ -1808,39 +1788,6 @@ async def _run_stream_attempt(
_msgs_since_flush = 0
if acc.stream_completed:
# Before exiting the loop, check if the user queued any
# follow-up messages while this turn was running. If so,
# send them to the same live SDK client as a new query
# and reset the stream completion state so we keep
# consuming CLI messages. This avoids releasing the
# cluster lock and requeueing — the pending messages
# flow directly into the existing conversation.
pending = await drain_pending_messages(ctx.session_id)
if pending:
logger.info(
"%s Injecting %d pending message(s) mid-turn",
ctx.log_prefix,
len(pending),
)
injected_text = _combine_pending_messages(pending)
injected_chat_msg = ChatMessage(role="user", content=injected_text)
ctx.session.messages.append(injected_chat_msg)
state.transcript_builder.append_user(content=injected_text)
try:
await asyncio.shield(upsert_chat_session(ctx.session))
except Exception as persist_err:
logger.warning(
"%s Failed to persist injected pending message: %s",
ctx.log_prefix,
persist_err,
)
await client.query(injected_text, session_id=ctx.session_id)
# Reset turn-level state so the next ResultMessage
# ends the injected turn cleanly instead of
# re-completing the previous one.
acc.stream_completed = False
_last_real_msg_time = time.monotonic()
continue
break
finally:
await _safe_close_sdk_client(sdk_client, ctx.log_prefix)
@@ -2328,6 +2275,28 @@ async def stream_chat_completion_sdk(
if last_user:
current_message = last_user[-1].content or ""
# Drain any messages the user queued via POST /messages/pending
# while the previous turn was running (or since the session was
# idle). Messages are drained ATOMICALLY — one LPOP with count
# removes them all at once, so a concurrent push lands *after*
# the drain and stays queued for the next turn instead of being
# lost between LPOP and clear. File IDs and context are
# preserved via format_pending_as_user_message.
pending_at_start = await drain_pending_messages(session_id)
if pending_at_start:
logger.info(
"%s Draining %d pending message(s) at turn start",
log_prefix,
len(pending_at_start),
)
pending_texts: list[str] = [
format_pending_as_user_message(pm)["content"] for pm in pending_at_start
]
if current_message.strip():
current_message = current_message + "\n\n" + "\n\n".join(pending_texts)
else:
current_message = "\n\n".join(pending_texts)
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
@@ -2783,17 +2752,10 @@ async def stream_chat_completion_sdk(
raise
finally:
# Safety net — drop any pending messages still in the buffer.
# During normal completion the mid-turn drain already cleared
# them; this handles early exits (errors, cancellation, retry).
try:
await clear_pending_messages(session_id)
except Exception as _clear_err:
logger.warning(
"Failed to clear pending messages for %s: %s",
session_id,
_clear_err,
)
# Pending messages are drained atomically at the start of each
# turn (see drain_pending_messages call above), so there's
# nothing to clean up here — any message pushed after that
# point belongs to the next turn.
# --- Close OTEL context (with cost attributes) ---
if _otel_ctx is not None: