Merge branch 'feat/copilot-pending-messages' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs

This commit is contained in:
majdyz
2026-04-13 08:01:45 +00:00
10 changed files with 1721 additions and 40 deletions

View File

@@ -4,7 +4,7 @@ import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated
from typing import Annotated, Any, cast
from uuid import uuid4
from autogpt_libs import auth
@@ -29,6 +29,12 @@ from backend.copilot.model import (
get_user_sessions,
update_session_title,
)
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
push_pending_message,
)
from backend.copilot.rate_limit import (
CoPilotUsageStatus,
RateLimitExceeded,
@@ -84,6 +90,32 @@ _UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
# Call-frequency cap for the pending-message endpoint. The token-budget
# check in queue_pending_message guards against overspend, but does not
# prevent rapid-fire pushes from a client with a large budget. This cap
# (per user, per 60-second window) limits the rate a caller can hammer the
# endpoint independently of token consumption.
_PENDING_CALL_LIMIT = 30 # pushes per minute per user
_PENDING_CALL_WINDOW_SECONDS = 60
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
# Maximum lengths for pending-message context fields (url: 2 KB, content: 32 KB).
# Enforced by QueuePendingMessageRequest._validate_context_length.
_CONTEXT_URL_MAX_LENGTH = 2_000
_CONTEXT_CONTENT_MAX_LENGTH = 32_000
# Lua script for atomic INCR + conditional EXPIRE.
# Using a single EVAL ensures the counter never persists without a TTL —
# a bare INCR followed by a separate EXPIRE can leave the key without
# an expiry if the process crashes between the two commands.
_CALL_INCR_LUA = """
local count = redis.call('INCR', KEYS[1])
if count == 1 then
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[1]))
end
return count
"""
async def _validate_and_get_session(
session_id: str,
@@ -96,6 +128,29 @@ async def _validate_and_get_session(
return session
async def _resolve_workspace_files(
user_id: str,
file_ids: list[str],
) -> list[UserWorkspaceFile]:
"""Filter *file_ids* to UUID-valid entries that exist in the caller's workspace.
Returns the matching ``UserWorkspaceFile`` records (empty list if none pass).
Used by both the stream and pending-message endpoints to prevent callers from
referencing other users' files.
"""
valid_ids = [fid for fid in file_ids if _UUID_RE.fullmatch(fid)]
if not valid_ids:
return []
workspace = await get_or_create_workspace(user_id)
return await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
router = APIRouter(
tags=["chat"],
)
@@ -119,6 +174,61 @@ class StreamChatRequest(BaseModel):
)
class QueuePendingMessageRequest(BaseModel):
"""Request model for queueing a message into an in-flight turn.
Unlike ``StreamChatRequest`` this endpoint does **not** start a new
turn — the message is appended to a per-session pending buffer that
the executor currently processing the turn will drain between tool
rounds.
"""
model_config = ConfigDict(extra="forbid")
message: str = Field(min_length=1, max_length=16_000)
context: PendingMessageContext | None = Field(
default=None,
description="Optional page context with 'url' and 'content' fields.",
)
file_ids: list[str] | None = Field(default=None, max_length=20)
@field_validator("context")
@classmethod
def _validate_context_length(
cls, v: PendingMessageContext | None
) -> PendingMessageContext | None:
if v is None:
return v
# Cap context values to prevent LLM context-window stuffing via
# large page payloads. Limits are module-level constants so
# they are visible to callers and documentation.
if v.url and len(v.url) > _CONTEXT_URL_MAX_LENGTH:
raise ValueError(
f"context.url exceeds maximum length of {_CONTEXT_URL_MAX_LENGTH} characters"
)
if v.content and len(v.content) > _CONTEXT_CONTENT_MAX_LENGTH:
raise ValueError(
f"context.content exceeds maximum length of {_CONTEXT_CONTENT_MAX_LENGTH} characters"
)
return v
class QueuePendingMessageResponse(BaseModel):
"""Response for the pending-message endpoint.
- ``buffer_length``: how many messages are now in the session's
pending buffer (after this push)
- ``max_buffer_length``: the per-session cap (server-side constant)
- ``turn_in_flight``: ``True`` if a copilot turn was running when
we checked — purely informational for UX feedback. Even when
``False`` the message is still queued: the next turn drains it.
"""
buffer_length: int
max_buffer_length: int
turn_in_flight: bool
class CreateSessionRequest(BaseModel):
"""Request model for creating a new chat session.
@@ -786,33 +896,21 @@ async def stream_chat_post(
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
if request.file_ids:
files = await _resolve_workspace_files(user_id, request.file_ids)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
@@ -1012,6 +1110,135 @@ async def stream_chat_post(
)
@router.post(
"/sessions/{session_id}/messages/pending",
response_model=QueuePendingMessageResponse,
status_code=202,
responses={
404: {"description": "Session not found or access denied"},
429: {"description": "Token rate-limit or call-frequency cap exceeded"},
},
)
async def queue_pending_message(
session_id: str,
request: QueuePendingMessageRequest,
user_id: str = Security(auth.get_user_id),
):
"""Queue a new user message into an in-flight copilot turn.
When a user sends a follow-up message while a turn is still
streaming, we don't want to block them or start a separate turn —
this endpoint appends the message to a per-session pending buffer.
The executor currently running the turn (baseline path) drains the
buffer between tool-call rounds and appends the message to the
conversation before the next LLM call. On the SDK path the buffer
is drained at the *start* of the next turn (the long-lived
``ClaudeSDKClient.receive_response`` iterator returns after a
``ResultMessage`` so there is no safe point to inject mid-stream
into an existing connection).
Returns 202. Enforces the same per-user daily/weekly token rate
limit as the regular ``/stream`` endpoint so a client can't bypass
it by batching messages through here.
"""
await _validate_and_get_session(session_id, user_id)
# Pre-turn rate-limit check — mirrors stream_chat_post. Without
# this, a client could bypass per-turn token limits by batching
# their extra context through this endpoint while a cheap stream
# is in flight.
# user_id is guaranteed non-empty by Security(auth.get_user_id) — no guard needed.
try:
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
user_id, config.daily_token_limit, config.weekly_token_limit
)
await check_rate_limit(
user_id=user_id,
daily_token_limit=daily_limit,
weekly_token_limit=weekly_limit,
)
except RateLimitExceeded as e:
raise HTTPException(status_code=429, detail=str(e)) from e
# Call-frequency cap: prevent rapid-fire pushes that would bypass the
# token-budget check (which only fires per-turn, not per-push).
# Uses an atomic Lua EVAL (INCR + EXPIRE) so the key can never be
# orphaned without a TTL; fails open if Redis is down.
try:
_redis = await get_redis_async()
_call_key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
_call_count = int(
await cast(
"Any",
_redis.eval(
_CALL_INCR_LUA,
1,
_call_key,
str(_PENDING_CALL_WINDOW_SECONDS),
),
)
)
if _call_count > _PENDING_CALL_LIMIT:
raise HTTPException(
status_code=429,
detail=f"Too many pending messages: limit is {_PENDING_CALL_LIMIT} per {_PENDING_CALL_WINDOW_SECONDS}s",
)
except HTTPException:
raise
except Exception:
logger.warning(
"queue_pending_message: rate-limit check failed, failing open"
) # non-fatal
# Sanitise file IDs to the user's own workspace so injection doesn't
# surface other users' files. _resolve_workspace_files handles UUID
# filtering and the workspace-scoped DB lookup.
sanitized_file_ids: list[str] = []
if request.file_ids:
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.fullmatch(fid))
files = await _resolve_workspace_files(user_id, request.file_ids)
sanitized_file_ids = [wf.id for wf in files]
if len(sanitized_file_ids) != valid_id_count:
logger.warning(
"queue_pending_message: dropped %d file id(s) not in "
"caller's workspace (session=%s)",
valid_id_count - len(sanitized_file_ids),
session_id,
)
# Redis is the single source of truth for pending messages. We do
# NOT persist to ``session.messages`` here — the drain-at-start
# path in the baseline/SDK executor is the sole writer for pending
# content. Persisting both here AND in the drain would cause
# double injection (executor sees the message in ``session.messages``
# *and* drains it from Redis) unless we also dedupe. The dedup in
# ``maybe_append_user_message`` only checks trailing same-role
# repeats, so relying on it is fragile. Keeping the endpoint
# Redis-only avoids the whole consistency-bug class.
pending = PendingMessage(
content=request.message,
file_ids=sanitized_file_ids,
context=request.context,
)
buffer_length = await push_pending_message(session_id, pending)
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
# Check whether a turn is currently running for UX feedback.
active_session = await stream_registry.get_session(session_id)
turn_in_flight = bool(active_session and active_session.status == "running")
return QueuePendingMessageResponse(
buffer_length=buffer_length,
max_buffer_length=MAX_PENDING_MESSAGES,
turn_in_flight=turn_in_flight,
)
@router.get(
"/sessions/{session_id}/stream",
)

View File

@@ -579,3 +579,300 @@ class TestStreamChatRequestModeValidation:
req = StreamChatRequest(message="hi")
assert req.mode is None
# ─── QueuePendingMessageRequest validation ────────────────────────────
class TestQueuePendingMessageRequest:
"""Unit tests for QueuePendingMessageRequest field validation."""
def test_accepts_valid_message(self) -> None:
from backend.api.features.chat.routes import QueuePendingMessageRequest
req = QueuePendingMessageRequest(message="hello")
assert req.message == "hello"
def test_rejects_empty_message(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(message="")
def test_rejects_message_over_limit(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(message="x" * 16_001)
def test_accepts_valid_context(self) -> None:
from backend.api.features.chat.routes import QueuePendingMessageRequest
req = QueuePendingMessageRequest(
message="hi",
context={"url": "https://example.com", "content": "page text"},
)
assert req.context is not None
assert req.context.url == "https://example.com"
def test_rejects_context_url_over_limit(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError, match="url"):
QueuePendingMessageRequest(
message="hi",
context={"url": "https://example.com/" + "x" * 2_000},
)
def test_rejects_context_content_over_limit(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError, match="content"):
QueuePendingMessageRequest(
message="hi",
context={"content": "x" * 32_001},
)
def test_rejects_extra_fields(self) -> None:
"""extra='forbid' should reject unknown fields."""
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(message="hi", unknown_field="bad") # type: ignore[call-arg]
def test_accepts_up_to_20_file_ids(self) -> None:
from backend.api.features.chat.routes import QueuePendingMessageRequest
req = QueuePendingMessageRequest(
message="hi",
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
)
assert req.file_ids is not None
assert len(req.file_ids) == 20
def test_rejects_more_than_20_file_ids(self) -> None:
import pydantic
from backend.api.features.chat.routes import QueuePendingMessageRequest
with pytest.raises(pydantic.ValidationError):
QueuePendingMessageRequest(
message="hi",
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
)
# ─── queue_pending_message endpoint ──────────────────────────────────
def _mock_pending_internals(
mocker: pytest_mock.MockerFixture,
*,
session_exists: bool = True,
call_count: int = 1,
):
"""Mock all async dependencies for the pending-message endpoint."""
if session_exists:
mock_session = mocker.MagicMock()
mock_session.id = "sess-1"
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
new_callable=AsyncMock,
return_value=mock_session,
)
else:
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
side_effect=fastapi.HTTPException(
status_code=404, detail="Session not found."
),
)
mocker.patch(
"backend.api.features.chat.routes.get_global_rate_limits",
new_callable=AsyncMock,
return_value=(0, 0, None),
)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
new_callable=AsyncMock,
return_value=None,
)
# Mock Redis for per-user call-frequency rate limit (atomic Lua EVAL)
mock_redis = mocker.MagicMock()
mock_redis.eval = mocker.AsyncMock(return_value=call_count)
mocker.patch(
"backend.api.features.chat.routes.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.push_pending_message",
new_callable=AsyncMock,
return_value=1,
)
mock_registry = mocker.MagicMock()
mock_registry.get_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
def test_queue_pending_message_returns_202(mocker: pytest_mock.MockerFixture) -> None:
"""Happy path: valid message returns 202 with buffer_length."""
_mock_pending_internals(mocker)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "follow-up"},
)
assert response.status_code == 202
data = response.json()
assert data["buffer_length"] == 1
assert data["turn_in_flight"] is False
def test_queue_pending_message_empty_body_returns_422() -> None:
"""Empty message must be rejected by Pydantic before hitting any route logic."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": ""},
)
assert response.status_code == 422
def test_queue_pending_message_missing_message_returns_422() -> None:
"""Missing 'message' field returns 422."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={},
)
assert response.status_code == 422
def test_queue_pending_message_session_not_found_returns_404(
mocker: pytest_mock.MockerFixture,
) -> None:
"""If the session doesn't exist or belong to the user, returns 404."""
_mock_pending_internals(mocker, session_exists=False)
response = client.post(
"/sessions/bad-sess/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 404
def test_queue_pending_message_rate_limited_returns_429(
mocker: pytest_mock.MockerFixture,
) -> None:
"""When rate limit is exceeded, endpoint returns 429."""
from backend.copilot.rate_limit import RateLimitExceeded
_mock_pending_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.check_rate_limit",
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 429
def test_queue_pending_message_call_frequency_limit_returns_429(
mocker: pytest_mock.MockerFixture,
) -> None:
"""When per-user call frequency limit is exceeded, endpoint returns 429."""
from backend.api.features.chat.routes import _PENDING_CALL_LIMIT
_mock_pending_internals(mocker, call_count=_PENDING_CALL_LIMIT + 1)
response = client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi"},
)
assert response.status_code == 429
assert "Too many pending messages" in response.json()["detail"]
def test_queue_pending_message_context_url_too_long_returns_422() -> None:
"""context.url over 2 KB is rejected."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"context": {"url": "https://example.com/" + "x" * 2_000},
},
)
assert response.status_code == 422
def test_queue_pending_message_context_content_too_long_returns_422() -> None:
"""context.content over 32 KB is rejected."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"context": {"content": "x" * 32_001},
},
)
assert response.status_code == 422
def test_queue_pending_message_too_many_file_ids_returns_422() -> None:
"""More than 20 file_ids should be rejected."""
response = client.post(
"/sessions/sess-1/messages/pending",
json={
"message": "hi",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def test_queue_pending_message_file_ids_scoped_to_workspace(
mocker: pytest_mock.MockerFixture,
) -> None:
"""File IDs must be sanitized to the user's workspace before push."""
_mock_pending_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
new_callable=AsyncMock,
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/messages/pending",
json={"message": "hi", "file_ids": [fid, "not-a-uuid"]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [fid]
assert call_kwargs["where"]["workspaceId"] == "ws-1"
assert call_kwargs["where"]["isDeleted"] is False

View File

@@ -36,6 +36,10 @@ from backend.copilot.model import (
maybe_append_user_message,
upsert_chat_session,
)
from backend.copilot.pending_messages import (
drain_pending_messages,
format_pending_as_user_message,
)
from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement
from backend.copilot.response_model import (
StreamBaseResponse,
@@ -341,6 +345,11 @@ class _BaselineStreamState:
cost_usd: float | None = None
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
session_messages: list[ChatMessage] = field(default_factory=list)
# Tracks how much of ``assistant_text`` has already been flushed to
# ``session.messages`` via mid-loop pending drains, so the ``finally``
# block only appends the *new* assistant text (avoiding duplication of
# round-1 text when round-1 entries were cleared from session_messages).
_flushed_assistant_text_len: int = 0
async def _baseline_llm_caller(
@@ -930,7 +939,54 @@ async def stream_chat_completion_baseline(
message_length=len(message or ""),
)
session = await upsert_chat_session(session)
# Capture count *before* the pending drain so is_first_turn and the
# transcript staleness check are not skewed by queued messages.
_pre_drain_msg_count = len(session.messages)
# Drain any messages the user queued via POST /messages/pending
# while this session was idle (or during a previous turn whose
# mid-loop drains missed them). Atomic LPOP guarantees that a
# concurrent push lands *after* the drain and stays queued for the
# next turn instead of being lost.
try:
drained_at_start = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"[Baseline] drain_pending_messages failed at turn start, skipping",
exc_info=True,
)
drained_at_start = []
# Pre-compute formatted content once per message so we don't call
# format_pending_as_user_message twice (once for session.messages and
# once for transcript_builder below).
drained_at_start_content: list[str] = []
if drained_at_start:
logger.info(
"[Baseline] Draining %d pending message(s) at turn start for session %s",
len(drained_at_start),
session_id,
)
for pm in drained_at_start:
content = format_pending_as_user_message(pm)["content"]
drained_at_start_content.append(content)
# Append directly — pending messages are atomically-popped from
# Redis and are never stale-cache duplicates, so the
# maybe_append_user_message dedup is wrong here.
session.messages.append(ChatMessage(role="user", content=content))
# Persist the drained pending messages (if any) plus the current user
# message. Wrap in try/except so a transient DB failure here does not
# silently discard messages that were already popped from Redis — the
# turn can still proceed using the in-memory session.messages, and a
# later resume/replay will backfill from the DB on the next turn.
try:
session = await upsert_chat_session(session)
except Exception as _persist_err:
logger.warning(
"[Baseline] Failed to persist session at turn start "
"(pending drain may not be durable): %s",
_persist_err,
)
# Select model based on the per-request mode. 'fast' downgrades to
# the cheaper/faster model; everything else keeps the default.
@@ -959,7 +1015,9 @@ async def stream_chat_completion_baseline(
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
# Use the pre-drain count so queued pending messages don't incorrectly
# flip is_first_turn to False on an actual first turn.
is_first_turn = _pre_drain_msg_count <= 1
# Gate context fetch on both first turn AND user message so that assistant-
# role calls (e.g. tool-result submissions) on the first turn don't trigger
# a needless DB lookup for user understanding.
@@ -970,14 +1028,18 @@ async def stream_chat_completion_baseline(
prompt_task = _build_cacheable_system_prompt(None)
# Run download + prompt build concurrently — both are independent I/O
# on the request critical path.
if user_id and len(session.messages) > 1:
# on the request critical path. Use the pre-drain count so pending
# messages drained at turn start don't spuriously trigger a transcript
# load on an actual first turn.
if user_id and _pre_drain_msg_count > 1:
transcript_covers_prefix, (base_system_prompt, understanding) = (
await asyncio.gather(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
# Use pre-drain count so pending messages don't falsely
# mark the stored transcript as stale and prevent upload.
session_msg_count=_pre_drain_msg_count,
transcript_builder=transcript_builder,
),
prompt_task,
@@ -989,6 +1051,15 @@ async def stream_chat_completion_baseline(
# Append user message to transcript after context injection below so the
# transcript receives the prefixed message when user context is available.
# Mirror any messages drained at turn start (see above) into the
# transcript — otherwise the loaded prior transcript would be
# missing them and a mid-turn upload could leave a malformed
# assistant-after-assistant structure on the next turn.
# Reuse the pre-computed content strings to avoid calling
# format_pending_as_user_message a second time.
for _drained_content in drained_at_start_content:
transcript_builder.append_user(content=_drained_content)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
@@ -1009,8 +1080,10 @@ async def stream_chat_completion_baseline(
graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else ""
system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement
# Warm context: pre-load relevant facts from Graphiti on first turn
if graphiti_enabled and user_id and len(session.messages) <= 1:
# Warm context: pre-load relevant facts from Graphiti on first turn.
# Use the pre-drain count so pending messages drained at turn start
# don't prevent warm context injection on an actual first turn.
if graphiti_enabled and user_id and _pre_drain_msg_count <= 1:
from backend.copilot.graphiti.context import fetch_warm_context
warm_ctx = await fetch_warm_context(user_id, message or "")
@@ -1203,6 +1276,91 @@ async def stream_chat_completion_baseline(
yield evt
state.pending_events.clear()
# Inject any messages the user queued while the turn was
# running. ``tool_call_loop`` mutates ``openai_messages``
# in-place, so appending here means the model sees the new
# messages on its next LLM call.
#
# IMPORTANT: skip when the loop has already finished (no
# more LLM calls are coming). ``tool_call_loop`` yields
# a final ``ToolCallLoopResult`` on both paths:
# - natural finish: ``finished_naturally=True``
# - hit max_iterations: ``finished_naturally=False``
# and ``iterations >= max_iterations``
# In either case the loop is about to return on the next
# ``async for`` step, so draining here would silently
# lose the message (the user sees 202 but the model never
# reads the text). Those messages stay in the buffer and
# get picked up at the start of the next turn.
if loop_result is None:
continue
is_final_yield = (
loop_result.finished_naturally
or loop_result.iterations >= _MAX_TOOL_ROUNDS
)
if is_final_yield:
continue
try:
pending = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"Mid-loop drain_pending_messages failed for session %s",
session_id,
exc_info=True,
)
pending = []
if pending:
# Flush any buffered assistant/tool messages from completed
# rounds into session.messages BEFORE appending the pending
# user message. ``_baseline_conversation_updater`` only
# records assistant+tool rounds into ``state.session_messages``
# — they are normally batch-flushed in the finally block.
# Without this in-order flush, the mid-loop pending user
# message lands before the preceding round's assistant/tool
# entries, producing chronologically-wrong session.messages
# on persist (user interposed between an assistant tool_call
# and its tool-result), which breaks OpenAI tool-call ordering
# invariants on the next turn's replay.
for _buffered in state.session_messages:
session.messages.append(_buffered)
state.session_messages.clear()
# Record how much assistant_text has been covered by the
# structured entries just flushed, so the finally block's
# final-text dedup doesn't re-append rounds already persisted.
state._flushed_assistant_text_len = len(state.assistant_text)
for pm in pending:
# ``format_pending_as_user_message`` embeds file
# attachments and context URL/page content into the
# content string so the in-session transcript is
# a faithful copy of what the model actually saw.
formatted = format_pending_as_user_message(pm)
content_for_db = formatted["content"]
# Append directly — pending messages are atomically-popped
# from Redis and are never stale-cache duplicates, so the
# maybe_append_user_message dedup is wrong here and would
# cause openai_messages/transcript to diverge from session.
session.messages.append(
ChatMessage(role="user", content=content_for_db)
)
openai_messages.append(formatted)
transcript_builder.append_user(content=content_for_db)
try:
await upsert_chat_session(session)
except Exception as persist_err:
logger.warning(
"[Baseline] Failed to persist pending messages for "
"session %s: %s",
session_id,
persist_err,
)
logger.info(
"[Baseline] Injected %d pending message(s) into "
"session %s mid-turn",
len(pending),
session_id,
)
if loop_result and not loop_result.finished_naturally:
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
@@ -1243,6 +1401,11 @@ async def stream_chat_completion_baseline(
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Pending messages are drained atomically at turn start and
# between tool rounds, so there's nothing to clear in finally.
# Any message pushed after the final drain window stays in the
# buffer and gets picked up at the start of the next turn.
# Set cost attributes on OTEL span before closing
if _trace_ctx is not None:
try:
@@ -1312,7 +1475,11 @@ async def stream_chat_completion_baseline(
# no tool calls, i.e. the natural finish). Only add it if the
# conversation updater didn't already record it as part of a
# tool-call round (which would have empty response_text).
final_text = state.assistant_text
# Only consider assistant text produced AFTER the last mid-loop
# flush. ``_flushed_assistant_text_len`` tracks the prefix already
# persisted via structured session_messages during mid-loop pending
# drains; including it here would duplicate those rounds.
final_text = state.assistant_text[state._flushed_assistant_text_len :]
if state.session_messages:
# Strip text already captured in tool-call round messages
recorded = "".join(

View File

@@ -828,3 +828,204 @@ class TestBaselineCostExtraction:
# response was never assigned so cost extraction must not raise
assert state.cost_usd is None
class TestMidLoopPendingFlushOrdering:
"""Regression test for the mid-loop pending drain ordering invariant.
``_baseline_conversation_updater`` records assistant+tool entries from
each tool-call round into ``state.session_messages``; the finally block
of ``stream_chat_completion_baseline`` batch-flushes them into
``session.messages`` at the end of the turn.
The mid-loop pending drain appends pending user messages directly to
``session.messages``. Without flushing ``state.session_messages`` first,
the pending user message lands BEFORE the preceding round's assistant+
tool entries in the final persisted ``session.messages`` — which
produces a malformed tool-call/tool-result ordering on the next turn's
replay.
This test documents the invariant by replaying the production flush
sequence against an in-memory state.
"""
def test_flush_then_append_preserves_chronological_order(self):
"""Mid-loop drain must flush state.session_messages before appending
the pending user message, so the final order matches the
chronological execution order.
"""
# Initial state: user turn already appended by maybe_append_user_message
session_messages: list[ChatMessage] = [
ChatMessage(role="user", content="original user turn"),
]
state = _BaselineStreamState()
# Round 1 completes: conversation_updater buffers assistant+tool
# entries into state.session_messages (but does NOT write to
# session.messages yet).
builder = TranscriptBuilder()
builder.append_user("original user turn")
response = LLMLoopResponse(
response_text="calling search",
tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(
tool_call_id="tc_1", tool_name="search", content="search output"
),
]
openai_messages: list = []
_baseline_conversation_updater(
openai_messages,
response,
tool_results=tool_results,
transcript_builder=builder,
state=state,
model="test-model",
)
# state.session_messages should now hold the round-1 assistant + tool
assert len(state.session_messages) == 2
assert state.session_messages[0].role == "assistant"
assert state.session_messages[1].role == "tool"
# --- Mid-loop pending drain (production code pattern) ---
# Flush first, THEN append pending. This is the ordering fix.
for _buffered in state.session_messages:
session_messages.append(_buffered)
state.session_messages.clear()
session_messages.append(
ChatMessage(role="user", content="pending mid-loop message")
)
# Round 2 completes: new assistant+tool entries buffer again.
response2 = LLMLoopResponse(
response_text="another call",
tool_calls=[LLMToolCall(id="tc_2", name="calc", arguments="{}")],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results2 = [
ToolCallResult(
tool_call_id="tc_2", tool_name="calc", content="calc output"
),
]
_baseline_conversation_updater(
openai_messages,
response2,
tool_results=tool_results2,
transcript_builder=builder,
state=state,
model="test-model",
)
# --- Finally-block flush (end of turn) ---
for msg in state.session_messages:
session_messages.append(msg)
# Assert chronological order: original user, round-1 assistant,
# round-1 tool, pending user, round-2 assistant, round-2 tool.
assert [m.role for m in session_messages] == [
"user",
"assistant",
"tool",
"user",
"assistant",
"tool",
]
assert session_messages[0].content == "original user turn"
assert session_messages[3].content == "pending mid-loop message"
# The assistant message carrying tool_call tc_1 must be immediately
# followed by its tool result — no user message interposed.
assert session_messages[1].role == "assistant"
assert session_messages[1].tool_calls is not None
assert session_messages[1].tool_calls[0]["id"] == "tc_1"
assert session_messages[2].role == "tool"
assert session_messages[2].tool_call_id == "tc_1"
# Same invariant for the round after the pending user.
assert session_messages[4].tool_calls is not None
assert session_messages[4].tool_calls[0]["id"] == "tc_2"
assert session_messages[5].tool_call_id == "tc_2"
def test_flushed_assistant_text_len_prevents_duplicate_final_text(self):
"""After mid-loop drain clears state.session_messages, the finally
block must not re-append assistant text from rounds already flushed.
``state.assistant_text`` accumulates ALL rounds' text, but
``state.session_messages`` only holds entries from rounds AFTER the
last mid-loop flush. Without ``_flushed_assistant_text_len``, the
``finally`` block's ``startswith(recorded)`` check fails because
``recorded`` only covers post-flush rounds, and the full
``assistant_text`` is appended — duplicating pre-flush rounds.
"""
state = _BaselineStreamState()
session_messages: list[ChatMessage] = [
ChatMessage(role="user", content="user turn"),
]
# Simulate round 1 text accumulation (as _bound_llm_caller does)
state.assistant_text += "calling search"
# Round 1 conversation_updater buffers structured entries
builder = TranscriptBuilder()
builder.append_user("user turn")
response1 = LLMLoopResponse(
response_text="calling search",
tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
[],
response1,
tool_results=[
ToolCallResult(
tool_call_id="tc_1", tool_name="search", content="result"
)
],
transcript_builder=builder,
state=state,
model="test-model",
)
# Mid-loop drain: flush + clear + record flushed text length
for _buffered in state.session_messages:
session_messages.append(_buffered)
state.session_messages.clear()
state._flushed_assistant_text_len = len(state.assistant_text)
session_messages.append(ChatMessage(role="user", content="pending message"))
# Simulate round 2 text accumulation
state.assistant_text += "final answer"
# Round 2: natural finish (no tool calls → no session_messages entry)
# --- Finally block logic (production code) ---
for msg in state.session_messages:
session_messages.append(msg)
final_text = state.assistant_text[state._flushed_assistant_text_len :]
if state.session_messages:
recorded = "".join(
m.content or "" for m in state.session_messages if m.role == "assistant"
)
if final_text.startswith(recorded):
final_text = final_text[len(recorded) :]
if final_text.strip():
session_messages.append(ChatMessage(role="assistant", content=final_text))
# The final assistant message should only contain round-2 text,
# not the round-1 text that was already flushed mid-loop.
assistant_msgs = [m for m in session_messages if m.role == "assistant"]
# Round-1 structured assistant (from mid-loop flush)
assert assistant_msgs[0].content == "calling search"
assert assistant_msgs[0].tool_calls is not None
# Round-2 final text (from finally block)
assert assistant_msgs[1].content == "final answer"
assert assistant_msgs[1].tool_calls is None
# Crucially: only 2 assistant messages, not 3 (no duplicate)
assert len(assistant_msgs) == 2

View File

@@ -0,0 +1,222 @@
"""Pending-message buffer for in-flight copilot turns.
When a user sends a new message while a copilot turn is already executing,
instead of blocking the frontend (or queueing a brand-new turn after the
current one finishes), we want the new message to be *injected into the
running turn* — appended between tool-call rounds so the model sees it
before its next LLM call.
This module provides the cross-process buffer that makes that possible:
- **Producer** (chat API route): pushes a pending message to Redis and
publishes a notification on a pub/sub channel.
- **Consumer** (executor running the turn): on each tool-call round,
drains the buffer and appends the pending messages to the conversation.
The Redis list is the durable store; the pub/sub channel is a fast
wake-up hint for long-idle consumers (not used by default, but available
for future blocking-wait semantics).
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
"""
import json
import logging
from typing import Any, cast
from pydantic import BaseModel, Field, ValidationError
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# Per-session cap. Higher values risk a runaway consumer; lower values
# risk dropping user input under heavy typing. 10 was chosen as a
# reasonable ceiling — a user typing faster than the copilot can drain
# between tool rounds is already an unusual usage pattern.
MAX_PENDING_MESSAGES = 10
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
# executor dies, the pending messages should either have been drained
# already or are safe to drop (the user can resend).
_PENDING_KEY_PREFIX = "copilot:pending:"
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
# Payload sent on the pub/sub notify channel. Subscribers treat any
# message as a wake-up hint; the value itself is not meaningful.
_NOTIFY_PAYLOAD = "1"
class PendingMessageContext(BaseModel, extra="forbid"):
"""Structured page context attached to a pending message."""
url: str | None = None
content: str | None = None
class PendingMessage(BaseModel):
"""A user message queued for injection into an in-flight turn."""
content: str = Field(min_length=1, max_length=16_000)
file_ids: list[str] = Field(default_factory=list)
context: PendingMessageContext | None = None
def _buffer_key(session_id: str) -> str:
return f"{_PENDING_KEY_PREFIX}{session_id}"
def _notify_channel(session_id: str) -> str:
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
# Lua script: push-then-trim-then-expire-then-length, atomically.
# Redis serializes EVAL commands, so a concurrent ``LPOP`` drain
# observes either the pre-push or post-push state of the list — never
# a partial state where the RPUSH has landed but LTRIM hasn't run.
_PUSH_LUA = """
redis.call('RPUSH', KEYS[1], ARGV[1])
redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1)
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3]))
return redis.call('LLEN', KEYS[1])
"""
async def push_pending_message(
session_id: str,
message: PendingMessage,
) -> int:
"""Append a pending message to the session's buffer atomically.
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
trimming from the left (oldest) — the newest message always wins if
the user has been typing faster than the copilot can drain.
The push + trim + expire + llen are wrapped in a single Lua EVAL so
concurrent LPOP drains from the executor never observe a partial
state.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
payload = message.model_dump_json()
new_length = int(
await cast(
"Any",
redis.eval(
_PUSH_LUA,
1,
key,
payload,
str(MAX_PENDING_MESSAGES),
str(_PENDING_TTL_SECONDS),
),
)
)
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
# the buffer itself is authoritative so a lost notify is harmless.
try:
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
except Exception as e: # pragma: no cover
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
logger.info(
"pending_messages: pushed message to session=%s (buffer_len=%d)",
session_id,
new_length,
)
return new_length
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
"""Atomically pop all pending messages for *session_id*.
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
count so the read+delete is a single Redis round trip. If the list
is empty or missing, returns ``[]``.
"""
redis = await get_redis_async()
key = _buffer_key(session_id)
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
# empty list if we somehow race an empty key, or the popped items.
# redis-py's async lpop overload with a count collapses the return
# type in pyright; cast the awaitable so strict type-check stays
# clean without changing runtime behaviour.
lpop_result = await cast(
"Any",
redis.lpop(key, MAX_PENDING_MESSAGES),
)
if not lpop_result:
return []
raw_popped: list[Any] = list(lpop_result)
# redis-py may return bytes or str depending on decode_responses.
decoded: list[str] = [
item.decode("utf-8") if isinstance(item, bytes) else str(item)
for item in raw_popped
]
messages: list[PendingMessage] = []
for payload in decoded:
try:
messages.append(PendingMessage(**json.loads(payload)))
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
logger.warning(
"pending_messages: dropping malformed entry for %s: %s",
session_id,
e,
)
if messages:
logger.info(
"pending_messages: drained %d messages for session=%s",
len(messages),
session_id,
)
return messages
async def peek_pending_count(session_id: str) -> int:
"""Return the current buffer length without consuming it."""
redis = await get_redis_async()
length = await cast("Any", redis.llen(_buffer_key(session_id)))
return int(length)
async def clear_pending_messages(session_id: str) -> None:
"""Drop the session's pending buffer.
Not called by the normal turn flow — the atomic ``LPOP`` drain at
turn start is the primary consumer, and any push that arrives
after the drain window belongs to the next turn by definition.
Retained as an operator/debug escape hatch for manually clearing a
stuck session and as a fixture in the unit tests.
"""
redis = await get_redis_async()
await redis.delete(_buffer_key(session_id))
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
Used by the baseline tool-call loop when injecting the buffered
message into the conversation. Context/file metadata (if any) is
embedded into the content so the model sees everything in one block.
"""
parts: list[str] = [message.content]
if message.context:
if message.context.url:
parts.append(f"\n\n[Page URL: {message.context.url}]")
if message.context.content:
parts.append(f"\n\n[Page content]\n{message.context.content}")
if message.file_ids:
parts.append(
"\n\n[Attached files]\n"
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
return {"role": "user", "content": "".join(parts)}

View File

@@ -0,0 +1,246 @@
"""Tests for the copilot pending-messages buffer.
Uses a fake async Redis client so the tests don't require a real Redis
instance (the backend test suite's DB/Redis fixtures are heavyweight
and pull in the full app startup).
"""
import json
from typing import Any
import pytest
from backend.copilot import pending_messages as pm_module
from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
clear_pending_messages,
drain_pending_messages,
format_pending_as_user_message,
peek_pending_count,
push_pending_message,
)
# ── Fake Redis ──────────────────────────────────────────────────────
class _FakeRedis:
def __init__(self) -> None:
# Values are ``str | bytes`` because real redis-py returns
# bytes when ``decode_responses=False``; the drain path must
# handle both and our tests exercise both.
self.lists: dict[str, list[str | bytes]] = {}
self.published: list[tuple[str, str]] = []
async def eval(self, script: str, num_keys: int, *args: Any) -> Any:
"""Emulate the push Lua script.
The real Lua script runs atomically in Redis; the fake
implementation just runs the equivalent list operations in
order and returns the final LLEN. That's enough to exercise
the cap + ordering invariants the tests care about.
"""
key = args[0]
payload = args[1]
max_len = int(args[2])
# ARGV[3] is TTL — fake doesn't enforce expiry
lst = self.lists.setdefault(key, [])
lst.append(payload)
if len(lst) > max_len:
# RPUSH + LTRIM(-N, -1) = keep only last N
self.lists[key] = lst[-max_len:]
return len(self.lists[key])
async def publish(self, channel: str, payload: str) -> int:
self.published.append((channel, payload))
return 1
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
lst = self.lists.get(key)
if not lst:
return None
popped = lst[:count]
self.lists[key] = lst[count:]
return popped
async def llen(self, key: str) -> int:
return len(self.lists.get(key, []))
async def delete(self, key: str) -> int:
if key in self.lists:
del self.lists[key]
return 1
return 0
@pytest.fixture()
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
redis = _FakeRedis()
async def _get_redis_async() -> _FakeRedis:
return redis
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
return redis
# ── Basic push / drain ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
length = await push_pending_message("sess1", PendingMessage(content="hello"))
assert length == 1
assert await peek_pending_count("sess1") == 1
drained = await drain_pending_messages("sess1")
assert len(drained) == 1
assert drained[0].content == "hello"
assert await peek_pending_count("sess1") == 0
@pytest.mark.asyncio
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
for i in range(3):
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
drained = await drain_pending_messages("sess2")
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
@pytest.mark.asyncio
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
assert await drain_pending_messages("nope") == []
# ── Buffer cap ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
# Push MAX_PENDING_MESSAGES + 3 messages
for i in range(MAX_PENDING_MESSAGES + 3):
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
# Buffer should be clamped to MAX
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
drained = await drain_pending_messages("sess3")
assert len(drained) == MAX_PENDING_MESSAGES
# Oldest 3 dropped — we should only see m3..m(MAX+2)
assert drained[0].content == "m3"
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
# ── Clear ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess4", PendingMessage(content="x"))
await push_pending_message("sess4", PendingMessage(content="y"))
await clear_pending_messages("sess4")
assert await peek_pending_count("sess4") == 0
@pytest.mark.asyncio
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
# Clearing an already-empty buffer should not raise
await clear_pending_messages("sess_empty")
await clear_pending_messages("sess_empty")
# ── Publish hook ────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess5", PendingMessage(content="hi"))
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
# ── Format helper ───────────────────────────────────────────────────
def test_format_pending_plain_text() -> None:
msg = PendingMessage(content="just text")
out = format_pending_as_user_message(msg)
assert out == {"role": "user", "content": "just text"}
def test_format_pending_with_context_url() -> None:
msg = PendingMessage(
content="see this page",
context=PendingMessageContext(url="https://example.com"),
)
out = format_pending_as_user_message(msg)
content = out["content"]
assert out["role"] == "user"
assert "see this page" in content
# The URL should appear verbatim in the [Page URL: ...] block.
assert "[Page URL: https://example.com]" in content
def test_format_pending_with_file_ids() -> None:
msg = PendingMessage(content="look here", file_ids=["a", "b"])
out = format_pending_as_user_message(msg)
assert "file_id=a" in out["content"]
assert "file_id=b" in out["content"]
def test_format_pending_with_all_fields() -> None:
"""All fields (content + context url/content + file_ids) should all appear."""
msg = PendingMessage(
content="summarise this",
context=PendingMessageContext(
url="https://example.com/page",
content="headline text",
),
file_ids=["f1", "f2"],
)
out = format_pending_as_user_message(msg)
body = out["content"]
assert out["role"] == "user"
assert "summarise this" in body
assert "[Page URL: https://example.com/page]" in body
assert "[Page content]\nheadline text" in body
assert "file_id=f1" in body
assert "file_id=f2" in body
# ── Malformed payload handling ──────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_skips_malformed_entries(
fake_redis: _FakeRedis,
) -> None:
# Seed the fake with a mix of valid and malformed payloads
fake_redis.lists["copilot:pending:bad"] = [
json.dumps({"content": "valid"}),
"{not valid json",
json.dumps({"content": "also valid", "file_ids": ["a"]}),
]
drained = await drain_pending_messages("bad")
assert len(drained) == 2
assert drained[0].content == "valid"
assert drained[1].content == "also valid"
@pytest.mark.asyncio
async def test_drain_decodes_bytes_payloads(
fake_redis: _FakeRedis,
) -> None:
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
Seed the fake with bytes values to exercise the ``decode("utf-8")``
branch in ``drain_pending_messages`` so a regression there doesn't
slip past CI.
"""
fake_redis.lists["copilot:pending:bytes_sess"] = [
json.dumps({"content": "from bytes"}).encode("utf-8"),
]
drained = await drain_pending_messages("bytes_sess")
assert len(drained) == 1
assert drained[0].content == "from bytes"

View File

@@ -226,6 +226,111 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
assert was_compacted is False # mock returns False
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_prevents_pending_duplication():
"""session_msg_ceiling stops pending messages from leaking into the gap.
Scenario: transcript covers 2 messages, session has 2 historical + 1 current
+ 2 pending drained at turn start. Without the ceiling the gap would include
the pending messages AND current_message already has them → duplication.
With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and
only current_message carries the pending content.
"""
# session.messages after drain: [hist1, hist2, current_msg, pending1, pending2]
session = _make_session(
[
ChatMessage(role="user", content="hist1"),
ChatMessage(role="assistant", content="hist2"),
ChatMessage(role="user", content="current msg with pending1 pending2"),
ChatMessage(role="user", content="pending1"),
ChatMessage(role="user", content="pending2"),
]
)
# transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg)
result, was_compacted = await _build_query_message(
"current msg with pending1 pending2",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
session_msg_ceiling=3, # len(session.messages) before drain
)
# Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended
assert result == "current msg with pending1 pending2"
assert was_compacted is False
# Pending messages must NOT appear in gap context
assert "pending1" not in result.split("current msg")[0]
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_preserves_real_gap():
"""session_msg_ceiling still surfaces a genuine stale-transcript gap.
Scenario: transcript covers 2 messages, session has 4 historical + 1 current
+ 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4).
"""
session = _make_session(
[
ChatMessage(role="user", content="hist1"),
ChatMessage(role="assistant", content="hist2"),
ChatMessage(role="user", content="hist3"),
ChatMessage(role="assistant", content="hist4"),
ChatMessage(role="user", content="current"),
ChatMessage(role="user", content="pending1"),
ChatMessage(role="user", content="pending2"),
]
)
result, was_compacted = await _build_query_message(
"current",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
session_msg_ceiling=5, # pre-drain: [hist1..hist4, current]
)
# Gap = session.messages[2:4] = [hist3, hist4]
assert "<conversation_history>" in result
assert "hist3" in result
assert "hist4" in result
assert "Now, the user says:\ncurrent" in result
# Pending messages must NOT appear in gap
assert "pending1" not in result
assert "pending2" not in result
@pytest.mark.asyncio
async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback():
"""session_msg_ceiling prevents the no-resume compression fallback from
firing on the first turn of a session when pending messages inflate msg_count.
Scenario: fresh session (1 message) + 1 pending message drained at turn start.
Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message
leaked into history → wrong context sent to model.
With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False
→ fallback does not trigger → current_message returned as-is.
"""
# session.messages after drain: [current_msg, pending_msg]
session = _make_session(
[
ChatMessage(role="user", content="What is 2 plus 2?"),
ChatMessage(role="user", content="What is 7 plus 7?"), # pending
]
)
result, was_compacted = await _build_query_message(
"What is 2 plus 2?\n\nWhat is 7 plus 7?",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
session_msg_ceiling=1, # pre-drain: only 1 message existed
)
# Should return current_message directly without wrapping in history context
assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?"
assert was_compacted is False
# Pending question must NOT appear in a spurious history section
assert "<conversation_history>" not in result
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
"""When compression actually compacts, was_compacted should be True."""

View File

@@ -1031,6 +1031,12 @@ def _make_sdk_patches(
),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
# Stub pending-message drain so retry tests don't hit Redis.
# Returns an empty list → no mid-turn injection happens.
(
f"{_SVC}.drain_pending_messages",
dict(new_callable=AsyncMock, return_value=[]),
),
]

View File

@@ -34,6 +34,10 @@ from opentelemetry import trace as otel_trace
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.pending_messages import (
drain_pending_messages,
format_pending_as_user_message,
)
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
from backend.copilot.transcript import (
@@ -955,17 +959,33 @@ async def _build_query_message(
use_resume: bool,
transcript_msg_count: int,
session_id: str,
*,
session_msg_ceiling: int | None = None,
) -> tuple[str, bool]:
"""Build the query message with appropriate context.
Args:
session_msg_ceiling: If provided, treat ``session.messages`` as if it
only has this many entries when computing the gap slice. Pass
``len(session.messages)`` captured *before* appending any pending
messages so that mid-turn drains do not skew the gap calculation
and cause pending messages to be duplicated in both the gap context
and ``current_message``.
Returns:
Tuple of (query_message, was_compacted).
"""
msg_count = len(session.messages)
# Use the ceiling if supplied (prevents pending-message duplication when
# messages were appended to session.messages after the drain but before
# this function is called).
effective_count = (
session_msg_ceiling if session_msg_ceiling is not None else msg_count
)
if use_resume and transcript_msg_count > 0:
if transcript_msg_count < msg_count - 1:
gap = session.messages[transcript_msg_count:-1]
if transcript_msg_count < effective_count - 1:
gap = session.messages[transcript_msg_count : effective_count - 1]
compressed, was_compressed = await _compress_messages(gap)
gap_context = _format_conversation_context(compressed)
if gap_context:
@@ -981,12 +1001,14 @@ async def _build_query_message(
f"{gap_context}\n\nNow, the user says:\n{current_message}",
was_compressed,
)
elif not use_resume and msg_count > 1:
elif not use_resume and effective_count > 1:
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({msg_count} messages) — no transcript for --resume"
f"{session_id} ({effective_count} messages) — no transcript for --resume"
)
compressed, was_compressed = await _compress_messages(
session.messages[: effective_count - 1]
)
compressed, was_compressed = await _compress_messages(session.messages[:-1])
history_context = _format_conversation_context(compressed)
if history_context:
return (
@@ -2042,6 +2064,7 @@ async def stream_chat_completion_sdk(
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""
assert session is not None # narrowed at line 1898
if not (
config.claude_agent_use_resume and user_id and len(session.messages) > 1
):
@@ -2288,6 +2311,69 @@ async def stream_chat_completion_sdk(
if last_user:
current_message = last_user[-1].content or ""
# Capture the message count *before* draining so _build_query_message
# can compute the gap slice without including the newly-drained pending
# messages. Pending messages are both appended to session.messages AND
# concatenated into current_message; without the ceiling the gap slice
# would extend into the pending messages and duplicate them in the
# model's input context (gap_context + current_message both containing
# them).
_pre_drain_msg_count = len(session.messages)
# Drain any messages the user queued via POST /messages/pending
# while the previous turn was running (or since the session was
# idle). Messages are drained ATOMICALLY — one LPOP with count
# removes them all at once, so a concurrent push lands *after*
# the drain and stays queued for the next turn instead of being
# lost between LPOP and clear. File IDs and context are
# preserved via format_pending_as_user_message.
#
# The drained content is concatenated into ``current_message``
# so the SDK CLI sees it in the new user message, AND appended
# directly to ``session.messages`` (no dedup — pending messages are
# atomically-popped from Redis and are never stale-cache duplicates)
# so the durable transcript records it too. Session is persisted
# immediately after the drain so a crash doesn't lose the messages.
# The endpoint deliberately does NOT persist to session.messages —
# Redis is the single source of truth until this drain runs.
try:
pending_at_start = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"%s drain_pending_messages failed at turn start, skipping",
log_prefix,
exc_info=True,
)
pending_at_start = []
if pending_at_start:
logger.info(
"%s Draining %d pending message(s) at turn start",
log_prefix,
len(pending_at_start),
)
pending_texts: list[str] = [
format_pending_as_user_message(pm)["content"] for pm in pending_at_start
]
for pt in pending_texts:
# Append directly — pending messages are atomically-popped from
# Redis and are never stale-cache duplicates, so the
# maybe_append_user_message dedup is wrong here.
session.messages.append(ChatMessage(role="user", content=pt))
if current_message.strip():
current_message = current_message + "\n\n" + "\n\n".join(pending_texts)
else:
current_message = "\n\n".join(pending_texts)
# Persist immediately so a crash between here and the finally block
# doesn't lose messages that were already drained from Redis.
try:
session = await upsert_chat_session(session)
except Exception as _persist_err:
logger.warning(
"%s Failed to persist drained pending messages: %s",
log_prefix,
_persist_err,
)
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
@@ -2301,6 +2387,7 @@ async def stream_chat_completion_sdk(
use_resume,
transcript_msg_count,
session_id,
session_msg_ceiling=_pre_drain_msg_count,
)
# On the first turn inject user context into the message instead of the
# system prompt — the system prompt is now static (same for all users)
@@ -2438,6 +2525,7 @@ async def stream_chat_completion_sdk(
state.use_resume,
state.transcript_msg_count,
session_id,
session_msg_ceiling=_pre_drain_msg_count,
)
if attachments.hint:
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
@@ -2767,6 +2855,11 @@ async def stream_chat_completion_sdk(
raise
finally:
# Pending messages are drained atomically at the start of each
# turn (see drain_pending_messages call above), so there's
# nothing to clean up here — any message pushed after that
# point belongs to the next turn.
# --- Close OTEL context (with cost attributes) ---
if _otel_ctx is not None:
try:

View File

@@ -1605,6 +1605,60 @@
}
}
},
"/api/chat/sessions/{session_id}/messages/pending": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Queue Pending Message",
"description": "Queue a new user message into an in-flight copilot turn.\n\nWhen a user sends a follow-up message while a turn is still\nstreaming, we don't want to block them or start a separate turn —\nthis endpoint appends the message to a per-session pending buffer.\nThe executor currently running the turn (baseline path) drains the\nbuffer between tool-call rounds and appends the message to the\nconversation before the next LLM call. On the SDK path the buffer\nis drained at the *start* of the next turn (the long-lived\n``ClaudeSDKClient.receive_response`` iterator returns after a\n``ResultMessage`` so there is no safe point to inject mid-stream\ninto an existing connection).\n\nReturns 202. Enforces the same per-user daily/weekly token rate\nlimit as the regular ``/stream`` endpoint so a client can't bypass\nit by batching messages through here.",
"operationId": "postV2QueuePendingMessage",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Session Id" }
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueuePendingMessageRequest"
}
}
}
},
"responses": {
"202": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/QueuePendingMessageResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"404": { "description": "Session not found or access denied" },
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"429": {
"description": "Token rate-limit or call-frequency cap exceeded"
}
}
}
},
"/api/chat/sessions/{session_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -12124,6 +12178,22 @@
"title": "PendingHumanReviewModel",
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
},
"PendingMessageContext": {
"properties": {
"url": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Url"
},
"content": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Content"
}
},
"additionalProperties": false,
"type": "object",
"title": "PendingMessageContext",
"description": "Structured page context attached to a pending message."
},
"PlatformCostDashboard": {
"properties": {
"by_provider": {
@@ -12668,6 +12738,53 @@
"required": ["providers", "pagination"],
"title": "ProviderResponse"
},
"QueuePendingMessageRequest": {
"properties": {
"message": {
"type": "string",
"maxLength": 16000,
"minLength": 1,
"title": "Message"
},
"context": {
"anyOf": [
{ "$ref": "#/components/schemas/PendingMessageContext" },
{ "type": "null" }
],
"description": "Optional page context with 'url' and 'content' fields."
},
"file_ids": {
"anyOf": [
{
"items": { "type": "string" },
"type": "array",
"maxItems": 20
},
{ "type": "null" }
],
"title": "File Ids"
}
},
"additionalProperties": false,
"type": "object",
"required": ["message"],
"title": "QueuePendingMessageRequest",
"description": "Request model for queueing a message into an in-flight turn.\n\nUnlike ``StreamChatRequest`` this endpoint does **not** start a new\nturn — the message is appended to a per-session pending buffer that\nthe executor currently processing the turn will drain between tool\nrounds."
},
"QueuePendingMessageResponse": {
"properties": {
"buffer_length": { "type": "integer", "title": "Buffer Length" },
"max_buffer_length": {
"type": "integer",
"title": "Max Buffer Length"
},
"turn_in_flight": { "type": "boolean", "title": "Turn In Flight" }
},
"type": "object",
"required": ["buffer_length", "max_buffer_length", "turn_in_flight"],
"title": "QueuePendingMessageResponse",
"description": "Response for the pending-message endpoint.\n\n- ``buffer_length``: how many messages are now in the session's\n pending buffer (after this push)\n- ``max_buffer_length``: the per-session cap (server-side constant)\n- ``turn_in_flight``: ``True`` if a copilot turn was running when\n we checked — purely informational for UX feedback. Even when\n ``False`` the message is still queued: the next turn drains it."
},
"RateLimitResetResponse": {
"properties": {
"success": { "type": "boolean", "title": "Success" },