fix(copilot): address round 2 review on pending-messages feature

Critical: SDK path was double-injecting.  The endpoint persisted the
message to ``session.messages`` AND the executor drained it from Redis
and concatenated into ``current_message`` — the LLM saw each queued
message twice (once via the compacted history / gap context that
``_build_query_message`` pulls from ``session.messages``, once via
the new query).  Baseline avoided this via ``maybe_append_user_message``
dedup but SDK had no equivalent guard.

### Fix: Redis is the single source of truth

- Endpoint no longer persists to ``session.messages``.  It only
  pushes to Redis and returns.
- Baseline drain-at-start calls ``maybe_append_user_message`` (dedup
  is a safety net, not the primary guard).
- SDK drain-at-start calls ``maybe_append_user_message`` too, so the
  durable transcript records the queued messages.  The concatenation
  into ``current_message`` stays so the SDK CLI sees the content in
  the first user message of the new turn.

### Baseline max-iterations silent-loss — Fixed

``tool_call_loop`` yields ``finished_naturally=False`` when
``iteration == max_iterations`` then returns.  Previously the drain
only skipped ``finished_naturally=True``, so messages drained on the
max-iterations final yield were appended to ``openai_messages`` and
silently lost (the loop was already exiting).  Now the drain also
skips when ``loop_result.iterations >= _MAX_TOOL_ROUNDS``.

### API response cleanup

- ``QueuePendingMessageResponse``: dropped ``queued`` (always True) and
  ``detail`` (human-readable, clients shouldn't parse).  Kept
  ``buffer_length``, ``max_buffer_length``, and ``turn_in_flight``.

### Tests

- Removed dead ``_FakePipeline`` class (the code switched to Lua EVAL
  in round 1 so the pipeline fake was unused).
- Added ``test_drain_decodes_bytes_payloads`` so the ``bytes → str``
  decode branch in ``drain_pending_messages`` is actually exercised
  (real redis-py returns bytes when ``decode_responses=False``).
- Updated ``_FakeRedis.lists`` type hint to ``list[str | bytes]``.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
majdyz
2026-04-10 15:57:57 +00:00
parent cafe49f295
commit f140e73150
4 changed files with 65 additions and 88 deletions

View File

@@ -32,7 +32,6 @@ from backend.copilot.model import (
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
format_pending_as_user_message,
push_pending_message,
)
from backend.copilot.rate_limit import (
@@ -147,15 +146,17 @@ class QueuePendingMessageRequest(BaseModel):
class QueuePendingMessageResponse(BaseModel):
"""Response for the pending-message endpoint.
Clients should rely on ``queued`` / ``buffer_length`` / ``turn_in_flight``
— the ``detail`` field is human-readable and may change without notice.
- ``buffer_length``: how many messages are now in the session's
pending buffer (after this push)
- ``max_buffer_length``: the per-session cap (server-side constant)
- ``turn_in_flight``: ``True`` if a copilot turn was running when
we checked — purely informational for UX feedback. Even when
``False`` the message is still queued: the next turn drains it.
"""
queued: bool
buffer_length: int
max_buffer_length: int
turn_in_flight: bool
detail: str
class CreateSessionRequest(BaseModel):
@@ -1127,10 +1128,15 @@ async def queue_pending_message(
session_id,
)
# Push to Redis BEFORE writing to the session DB. If the push
# fails we raise 5xx and the client retries; ``append_and_save_message``
# would otherwise leave an orphan user message persisted with no
# corresponding queued pending entry, and a retry would duplicate it.
# Redis is the single source of truth for pending messages. We do
# NOT persist to ``session.messages`` here — the drain-at-start
# path in the baseline/SDK executor is the sole writer for pending
# content. Persisting both here AND in the drain would cause
# double injection (executor sees the message in ``session.messages``
# *and* drains it from Redis) unless we also dedupe. The dedup in
# ``maybe_append_user_message`` only checks trailing same-role
# repeats, so relying on it is fragile. Keeping the endpoint
# Redis-only avoids the whole consistency-bug class.
pending = PendingMessage(
content=request.message,
file_ids=sanitized_file_ids,
@@ -1138,33 +1144,14 @@ async def queue_pending_message(
)
buffer_length = await push_pending_message(session_id, pending)
# Persist the message into the session transcript only after the
# push succeeds. The message content embeds file/context metadata
# via format_pending_as_user_message so the DB copy matches what
# the model will actually see on drain.
chat_msg = ChatMessage(
role="user",
content=format_pending_as_user_message(pending)["content"],
)
await append_and_save_message(session_id, chat_msg)
# Check whether a turn is currently running for UX feedback.
active_session = await stream_registry.get_session(session_id)
turn_in_flight = bool(active_session and active_session.status == "running")
return QueuePendingMessageResponse(
queued=True,
buffer_length=buffer_length,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=turn_in_flight,
detail=(
(
"Queued — will be injected into the current turn."
if turn_in_flight
else "Queued — will be injected at the start of the next turn."
)
+ f" buffer={buffer_length}/{MAX_PENDING_MESSAGES}"
),
)

View File

@@ -1187,14 +1187,23 @@ async def stream_chat_completion_baseline(
# messages on its next LLM call.
#
# IMPORTANT: skip when the loop has already finished (no
# more LLM calls are coming). Draining here would silently
# lose the message because ``tool_call_loop`` is about to
# return on the next ``async for`` step — the user would
# see a 202 from the pending endpoint but the model would
# never actually read the text. Those messages stay in
# the buffer and will be picked up at the start of the
# next turn.
if loop_result is None or loop_result.finished_naturally:
# more LLM calls are coming). ``tool_call_loop`` yields
# a final ``ToolCallLoopResult`` on both paths:
# - natural finish: ``finished_naturally=True``
# - hit max_iterations: ``finished_naturally=False``
# and ``iterations >= max_iterations``
# In either case the loop is about to return on the next
# ``async for`` step, so draining here would silently
# lose the message (the user sees 202 but the model never
# reads the text). Those messages stay in the buffer and
# get picked up at the start of the next turn.
if loop_result is None:
continue
is_final_yield = (
loop_result.finished_naturally
or loop_result.iterations >= _MAX_TOOL_ROUNDS
)
if is_final_yield:
continue
pending = await drain_pending_messages(session_id)
if pending:

View File

@@ -24,61 +24,14 @@ from backend.copilot.pending_messages import (
# ── Fake Redis ──────────────────────────────────────────────────────
class _FakePipeline:
def __init__(self, parent: "_FakeRedis") -> None:
self._parent = parent
self._ops: list[tuple[str, tuple[Any, ...]]] = []
async def __aenter__(self) -> "_FakePipeline":
return self
async def __aexit__(self, *args: object) -> None:
return None
def rpush(self, key: str, value: Any) -> None:
self._ops.append(("rpush", (key, value)))
def ltrim(self, key: str, start: int, stop: int) -> None:
self._ops.append(("ltrim", (key, start, stop)))
def expire(self, key: str, ttl: int) -> None:
self._ops.append(("expire", (key, ttl)))
def llen(self, key: str) -> None:
self._ops.append(("llen", (key,)))
async def execute(self) -> list[Any]:
results: list[Any] = []
for op, args in self._ops:
if op == "rpush":
key, value = args
self._parent.lists.setdefault(key, []).append(value)
results.append(len(self._parent.lists[key]))
elif op == "ltrim":
key, start, stop = args
lst = self._parent.lists.get(key, [])
# Emulate Redis LTRIM (-N, -1) = last N
if start < 0 and stop == -1:
self._parent.lists[key] = lst[start:]
else:
self._parent.lists[key] = lst[start : stop + 1]
results.append(True)
elif op == "expire":
results.append(True)
elif op == "llen":
key = args[0]
results.append(len(self._parent.lists.get(key, [])))
return results
class _FakeRedis:
def __init__(self) -> None:
self.lists: dict[str, list[str]] = {}
# Values are ``str | bytes`` because real redis-py returns
# bytes when ``decode_responses=False``; the drain path must
# handle both and our tests exercise both.
self.lists: dict[str, list[str | bytes]] = {}
self.published: list[tuple[str, str]] = []
def pipeline(self, transaction: bool = False) -> _FakePipeline:
return _FakePipeline(self)
async def eval(self, script: str, num_keys: int, *args: Any) -> Any:
"""Emulate the push Lua script.
@@ -102,7 +55,7 @@ class _FakeRedis:
self.published.append((channel, payload))
return 1
async def lpop(self, key: str, count: int) -> list[str] | None:
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
lst = self.lists.get(key)
if not lst:
return None
@@ -250,3 +203,21 @@ async def test_drain_skips_malformed_entries(
assert len(drained) == 2
assert drained[0].content == "valid"
assert drained[1].content == "also valid"
@pytest.mark.asyncio
async def test_drain_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
Seed the fake with bytes values to exercise the ``decode("utf-8")``
branch in ``drain_pending_messages`` so a regression there doesn't
slip past CI.
"""
fake_redis.lists["copilot:pending:bytes_sess"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
drained = await drain_pending_messages("bytes_sess")
assert len(drained) == 1
assert drained[0].content == "from bytes"

View File

@@ -2282,6 +2282,14 @@ async def stream_chat_completion_sdk(
# the drain and stays queued for the next turn instead of being
# lost between LPOP and clear. File IDs and context are
# preserved via format_pending_as_user_message.
#
# The drained content is concatenated into ``current_message``
# so the SDK CLI sees it in the new user message, AND appended
# to ``session.messages`` (via ``maybe_append_user_message``,
# which dedupes trailing same-role repeats) so the durable
# transcript records it too. The endpoint deliberately does
# NOT persist to session.messages — Redis is the single source
# of truth until this drain runs.
pending_at_start = await drain_pending_messages(session_id)
if pending_at_start:
logger.info(
@@ -2292,6 +2300,8 @@ async def stream_chat_completion_sdk(
pending_texts: list[str] = [
format_pending_as_user_message(pm)["content"] for pm in pending_at_start
]
for _pt in pending_texts:
maybe_append_user_message(session, _pt, is_user_message=True)
if current_message.strip():
current_message = current_message + "\n\n" + "\n\n".join(pending_texts)
else: