mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'feat/copilot-pending-messages' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs
This commit is contained in:
@@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
@@ -29,6 +29,12 @@ from backend.copilot.model import (
|
||||
get_user_sessions,
|
||||
update_session_title,
|
||||
)
|
||||
from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
push_pending_message,
|
||||
)
|
||||
from backend.copilot.rate_limit import (
|
||||
CoPilotUsageStatus,
|
||||
RateLimitExceeded,
|
||||
@@ -84,6 +90,32 @@ _UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
|
||||
)
|
||||
|
||||
# Call-frequency cap for the pending-message endpoint. The token-budget
|
||||
# check in queue_pending_message guards against overspend, but does not
|
||||
# prevent rapid-fire pushes from a client with a large budget. This cap
|
||||
# (per user, per 60-second window) limits the rate a caller can hammer the
|
||||
# endpoint independently of token consumption.
|
||||
_PENDING_CALL_LIMIT = 30 # pushes per minute per user
|
||||
_PENDING_CALL_WINDOW_SECONDS = 60
|
||||
_PENDING_CALL_KEY_PREFIX = "copilot:pending:calls:"
|
||||
|
||||
# Maximum lengths for pending-message context fields (url: 2 KB, content: 32 KB).
|
||||
# Enforced by QueuePendingMessageRequest._validate_context_length.
|
||||
_CONTEXT_URL_MAX_LENGTH = 2_000
|
||||
_CONTEXT_CONTENT_MAX_LENGTH = 32_000
|
||||
|
||||
# Lua script for atomic INCR + conditional EXPIRE.
|
||||
# Using a single EVAL ensures the counter never persists without a TTL —
|
||||
# a bare INCR followed by a separate EXPIRE can leave the key without
|
||||
# an expiry if the process crashes between the two commands.
|
||||
_CALL_INCR_LUA = """
|
||||
local count = redis.call('INCR', KEYS[1])
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[1]))
|
||||
end
|
||||
return count
|
||||
"""
|
||||
|
||||
|
||||
async def _validate_and_get_session(
|
||||
session_id: str,
|
||||
@@ -96,6 +128,29 @@ async def _validate_and_get_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _resolve_workspace_files(
|
||||
user_id: str,
|
||||
file_ids: list[str],
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""Filter *file_ids* to UUID-valid entries that exist in the caller's workspace.
|
||||
|
||||
Returns the matching ``UserWorkspaceFile`` records (empty list if none pass).
|
||||
Used by both the stream and pending-message endpoints to prevent callers from
|
||||
referencing other users' files.
|
||||
"""
|
||||
valid_ids = [fid for fid in file_ids if _UUID_RE.fullmatch(fid)]
|
||||
if not valid_ids:
|
||||
return []
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
return await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
)
|
||||
@@ -119,6 +174,61 @@ class StreamChatRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QueuePendingMessageRequest(BaseModel):
|
||||
"""Request model for queueing a message into an in-flight turn.
|
||||
|
||||
Unlike ``StreamChatRequest`` this endpoint does **not** start a new
|
||||
turn — the message is appended to a per-session pending buffer that
|
||||
the executor currently processing the turn will drain between tool
|
||||
rounds.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
message: str = Field(min_length=1, max_length=16_000)
|
||||
context: PendingMessageContext | None = Field(
|
||||
default=None,
|
||||
description="Optional page context with 'url' and 'content' fields.",
|
||||
)
|
||||
file_ids: list[str] | None = Field(default=None, max_length=20)
|
||||
|
||||
@field_validator("context")
|
||||
@classmethod
|
||||
def _validate_context_length(
|
||||
cls, v: PendingMessageContext | None
|
||||
) -> PendingMessageContext | None:
|
||||
if v is None:
|
||||
return v
|
||||
# Cap context values to prevent LLM context-window stuffing via
|
||||
# large page payloads. Limits are module-level constants so
|
||||
# they are visible to callers and documentation.
|
||||
if v.url and len(v.url) > _CONTEXT_URL_MAX_LENGTH:
|
||||
raise ValueError(
|
||||
f"context.url exceeds maximum length of {_CONTEXT_URL_MAX_LENGTH} characters"
|
||||
)
|
||||
if v.content and len(v.content) > _CONTEXT_CONTENT_MAX_LENGTH:
|
||||
raise ValueError(
|
||||
f"context.content exceeds maximum length of {_CONTEXT_CONTENT_MAX_LENGTH} characters"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class QueuePendingMessageResponse(BaseModel):
|
||||
"""Response for the pending-message endpoint.
|
||||
|
||||
- ``buffer_length``: how many messages are now in the session's
|
||||
pending buffer (after this push)
|
||||
- ``max_buffer_length``: the per-session cap (server-side constant)
|
||||
- ``turn_in_flight``: ``True`` if a copilot turn was running when
|
||||
we checked — purely informational for UX feedback. Even when
|
||||
``False`` the message is still queued: the next turn drains it.
|
||||
"""
|
||||
|
||||
buffer_length: int
|
||||
max_buffer_length: int
|
||||
turn_in_flight: bool
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session.
|
||||
|
||||
@@ -786,33 +896,21 @@ async def stream_chat_post(
|
||||
# Also sanitise file_ids so only validated, workspace-scoped IDs are
|
||||
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
|
||||
sanitized_file_ids: list[str] | None = None
|
||||
if request.file_ids and user_id:
|
||||
# Filter to valid UUIDs only to prevent DB abuse
|
||||
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
|
||||
|
||||
if valid_ids:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Batch query instead of N+1
|
||||
files = await UserWorkspaceFile.prisma().find_many(
|
||||
where={
|
||||
"id": {"in": valid_ids},
|
||||
"workspaceId": workspace.id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if request.file_ids:
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
# Only keep IDs that actually exist in the user's workspace
|
||||
sanitized_file_ids = [wf.id for wf in files] or None
|
||||
file_lines: list[str] = [
|
||||
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
|
||||
for wf in files
|
||||
]
|
||||
if file_lines:
|
||||
files_block = (
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(file_lines)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
request.message += files_block
|
||||
request.message += files_block
|
||||
|
||||
# Atomically append user message to session BEFORE creating task to avoid
|
||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||
@@ -1012,6 +1110,135 @@ async def stream_chat_post(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages/pending",
|
||||
response_model=QueuePendingMessageResponse,
|
||||
status_code=202,
|
||||
responses={
|
||||
404: {"description": "Session not found or access denied"},
|
||||
429: {"description": "Token rate-limit or call-frequency cap exceeded"},
|
||||
},
|
||||
)
|
||||
async def queue_pending_message(
|
||||
session_id: str,
|
||||
request: QueuePendingMessageRequest,
|
||||
user_id: str = Security(auth.get_user_id),
|
||||
):
|
||||
"""Queue a new user message into an in-flight copilot turn.
|
||||
|
||||
When a user sends a follow-up message while a turn is still
|
||||
streaming, we don't want to block them or start a separate turn —
|
||||
this endpoint appends the message to a per-session pending buffer.
|
||||
The executor currently running the turn (baseline path) drains the
|
||||
buffer between tool-call rounds and appends the message to the
|
||||
conversation before the next LLM call. On the SDK path the buffer
|
||||
is drained at the *start* of the next turn (the long-lived
|
||||
``ClaudeSDKClient.receive_response`` iterator returns after a
|
||||
``ResultMessage`` so there is no safe point to inject mid-stream
|
||||
into an existing connection).
|
||||
|
||||
Returns 202. Enforces the same per-user daily/weekly token rate
|
||||
limit as the regular ``/stream`` endpoint so a client can't bypass
|
||||
it by batching messages through here.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Pre-turn rate-limit check — mirrors stream_chat_post. Without
|
||||
# this, a client could bypass per-turn token limits by batching
|
||||
# their extra context through this endpoint while a cheap stream
|
||||
# is in flight.
|
||||
# user_id is guaranteed non-empty by Security(auth.get_user_id) — no guard needed.
|
||||
try:
|
||||
daily_limit, weekly_limit, _tier = await get_global_rate_limits(
|
||||
user_id, config.daily_token_limit, config.weekly_token_limit
|
||||
)
|
||||
await check_rate_limit(
|
||||
user_id=user_id,
|
||||
daily_token_limit=daily_limit,
|
||||
weekly_token_limit=weekly_limit,
|
||||
)
|
||||
except RateLimitExceeded as e:
|
||||
raise HTTPException(status_code=429, detail=str(e)) from e
|
||||
|
||||
# Call-frequency cap: prevent rapid-fire pushes that would bypass the
|
||||
# token-budget check (which only fires per-turn, not per-push).
|
||||
# Uses an atomic Lua EVAL (INCR + EXPIRE) so the key can never be
|
||||
# orphaned without a TTL; fails open if Redis is down.
|
||||
try:
|
||||
_redis = await get_redis_async()
|
||||
_call_key = f"{_PENDING_CALL_KEY_PREFIX}{user_id}"
|
||||
_call_count = int(
|
||||
await cast(
|
||||
"Any",
|
||||
_redis.eval(
|
||||
_CALL_INCR_LUA,
|
||||
1,
|
||||
_call_key,
|
||||
str(_PENDING_CALL_WINDOW_SECONDS),
|
||||
),
|
||||
)
|
||||
)
|
||||
if _call_count > _PENDING_CALL_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many pending messages: limit is {_PENDING_CALL_LIMIT} per {_PENDING_CALL_WINDOW_SECONDS}s",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"queue_pending_message: rate-limit check failed, failing open"
|
||||
) # non-fatal
|
||||
|
||||
# Sanitise file IDs to the user's own workspace so injection doesn't
|
||||
# surface other users' files. _resolve_workspace_files handles UUID
|
||||
# filtering and the workspace-scoped DB lookup.
|
||||
sanitized_file_ids: list[str] = []
|
||||
if request.file_ids:
|
||||
valid_id_count = sum(1 for fid in request.file_ids if _UUID_RE.fullmatch(fid))
|
||||
files = await _resolve_workspace_files(user_id, request.file_ids)
|
||||
sanitized_file_ids = [wf.id for wf in files]
|
||||
if len(sanitized_file_ids) != valid_id_count:
|
||||
logger.warning(
|
||||
"queue_pending_message: dropped %d file id(s) not in "
|
||||
"caller's workspace (session=%s)",
|
||||
valid_id_count - len(sanitized_file_ids),
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Redis is the single source of truth for pending messages. We do
|
||||
# NOT persist to ``session.messages`` here — the drain-at-start
|
||||
# path in the baseline/SDK executor is the sole writer for pending
|
||||
# content. Persisting both here AND in the drain would cause
|
||||
# double injection (executor sees the message in ``session.messages``
|
||||
# *and* drains it from Redis) unless we also dedupe. The dedup in
|
||||
# ``maybe_append_user_message`` only checks trailing same-role
|
||||
# repeats, so relying on it is fragile. Keeping the endpoint
|
||||
# Redis-only avoids the whole consistency-bug class.
|
||||
pending = PendingMessage(
|
||||
content=request.message,
|
||||
file_ids=sanitized_file_ids,
|
||||
context=request.context,
|
||||
)
|
||||
buffer_length = await push_pending_message(session_id, pending)
|
||||
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
|
||||
# Check whether a turn is currently running for UX feedback.
|
||||
active_session = await stream_registry.get_session(session_id)
|
||||
turn_in_flight = bool(active_session and active_session.status == "running")
|
||||
|
||||
return QueuePendingMessageResponse(
|
||||
buffer_length=buffer_length,
|
||||
max_buffer_length=MAX_PENDING_MESSAGES,
|
||||
turn_in_flight=turn_in_flight,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
|
||||
@@ -579,3 +579,300 @@ class TestStreamChatRequestModeValidation:
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
|
||||
# ─── QueuePendingMessageRequest validation ────────────────────────────
|
||||
|
||||
|
||||
class TestQueuePendingMessageRequest:
|
||||
"""Unit tests for QueuePendingMessageRequest field validation."""
|
||||
|
||||
def test_accepts_valid_message(self) -> None:
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
req = QueuePendingMessageRequest(message="hello")
|
||||
assert req.message == "hello"
|
||||
|
||||
def test_rejects_empty_message(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(message="")
|
||||
|
||||
def test_rejects_message_over_limit(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(message="x" * 16_001)
|
||||
|
||||
def test_accepts_valid_context(self) -> None:
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
req = QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
context={"url": "https://example.com", "content": "page text"},
|
||||
)
|
||||
assert req.context is not None
|
||||
assert req.context.url == "https://example.com"
|
||||
|
||||
def test_rejects_context_url_over_limit(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError, match="url"):
|
||||
QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
context={"url": "https://example.com/" + "x" * 2_000},
|
||||
)
|
||||
|
||||
def test_rejects_context_content_over_limit(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError, match="content"):
|
||||
QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
context={"content": "x" * 32_001},
|
||||
)
|
||||
|
||||
def test_rejects_extra_fields(self) -> None:
|
||||
"""extra='forbid' should reject unknown fields."""
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(message="hi", unknown_field="bad") # type: ignore[call-arg]
|
||||
|
||||
def test_accepts_up_to_20_file_ids(self) -> None:
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
req = QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
|
||||
)
|
||||
assert req.file_ids is not None
|
||||
assert len(req.file_ids) == 20
|
||||
|
||||
def test_rejects_more_than_20_file_ids(self) -> None:
|
||||
import pydantic
|
||||
|
||||
from backend.api.features.chat.routes import QueuePendingMessageRequest
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
QueuePendingMessageRequest(
|
||||
message="hi",
|
||||
file_ids=[f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
)
|
||||
|
||||
|
||||
# ─── queue_pending_message endpoint ──────────────────────────────────
|
||||
|
||||
|
||||
def _mock_pending_internals(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
*,
|
||||
session_exists: bool = True,
|
||||
call_count: int = 1,
|
||||
):
|
||||
"""Mock all async dependencies for the pending-message endpoint."""
|
||||
if session_exists:
|
||||
mock_session = mocker.MagicMock()
|
||||
mock_session.id = "sess-1"
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
else:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes._validate_and_get_session",
|
||||
side_effect=fastapi.HTTPException(
|
||||
status_code=404, detail="Session not found."
|
||||
),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_global_rate_limits",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(0, 0, None),
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
# Mock Redis for per-user call-frequency rate limit (atomic Lua EVAL)
|
||||
mock_redis = mocker.MagicMock()
|
||||
mock_redis.eval = mocker.AsyncMock(return_value=call_count)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_redis_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_redis,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.track_user_message",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.push_pending_message",
|
||||
new_callable=AsyncMock,
|
||||
return_value=1,
|
||||
)
|
||||
mock_registry = mocker.MagicMock()
|
||||
mock_registry.get_session = mocker.AsyncMock(return_value=None)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry",
|
||||
mock_registry,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_pending_message_returns_202(mocker: pytest_mock.MockerFixture) -> None:
|
||||
"""Happy path: valid message returns 202 with buffer_length."""
|
||||
_mock_pending_internals(mocker)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "follow-up"},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["buffer_length"] == 1
|
||||
assert data["turn_in_flight"] is False
|
||||
|
||||
|
||||
def test_queue_pending_message_empty_body_returns_422() -> None:
|
||||
"""Empty message must be rejected by Pydantic before hitting any route logic."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_missing_message_returns_422() -> None:
|
||||
"""Missing 'message' field returns 422."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_session_not_found_returns_404(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""If the session doesn't exist or belong to the user, returns 404."""
|
||||
_mock_pending_internals(mocker, session_exists=False)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/bad-sess/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_queue_pending_message_rate_limited_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When rate limit is exceeded, endpoint returns 429."""
|
||||
from backend.copilot.rate_limit import RateLimitExceeded
|
||||
|
||||
_mock_pending_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.check_rate_limit",
|
||||
side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
def test_queue_pending_message_call_frequency_limit_returns_429(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When per-user call frequency limit is exceeded, endpoint returns 429."""
|
||||
from backend.api.features.chat.routes import _PENDING_CALL_LIMIT
|
||||
|
||||
_mock_pending_internals(mocker, call_count=_PENDING_CALL_LIMIT + 1)
|
||||
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "Too many pending messages" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_queue_pending_message_context_url_too_long_returns_422() -> None:
|
||||
"""context.url over 2 KB is rejected."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"context": {"url": "https://example.com/" + "x" * 2_000},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_context_content_too_long_returns_422() -> None:
|
||||
"""context.content over 32 KB is rejected."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"context": {"content": "x" * 32_001},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_too_many_file_ids_returns_422() -> None:
|
||||
"""More than 20 file_ids should be rejected."""
|
||||
response = client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={
|
||||
"message": "hi",
|
||||
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_queue_pending_message_file_ids_scoped_to_workspace(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""File IDs must be sanitized to the user's workspace before push."""
|
||||
_mock_pending_internals(mocker)
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_or_create_workspace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=type("W", (), {"id": "ws-1"})(),
|
||||
)
|
||||
mock_prisma = mocker.MagicMock()
|
||||
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
|
||||
mocker.patch(
|
||||
"prisma.models.UserWorkspaceFile.prisma",
|
||||
return_value=mock_prisma,
|
||||
)
|
||||
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
client.post(
|
||||
"/sessions/sess-1/messages/pending",
|
||||
json={"message": "hi", "file_ids": [fid, "not-a-uuid"]},
|
||||
)
|
||||
|
||||
call_kwargs = mock_prisma.find_many.call_args[1]
|
||||
assert call_kwargs["where"]["id"]["in"] == [fid]
|
||||
assert call_kwargs["where"]["workspaceId"] == "ws-1"
|
||||
assert call_kwargs["where"]["isDeleted"] is False
|
||||
|
||||
@@ -36,6 +36,10 @@ from backend.copilot.model import (
|
||||
maybe_append_user_message,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.pending_messages import (
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
)
|
||||
from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
@@ -341,6 +345,11 @@ class _BaselineStreamState:
|
||||
cost_usd: float | None = None
|
||||
thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper)
|
||||
session_messages: list[ChatMessage] = field(default_factory=list)
|
||||
# Tracks how much of ``assistant_text`` has already been flushed to
|
||||
# ``session.messages`` via mid-loop pending drains, so the ``finally``
|
||||
# block only appends the *new* assistant text (avoiding duplication of
|
||||
# round-1 text when round-1 entries were cleared from session_messages).
|
||||
_flushed_assistant_text_len: int = 0
|
||||
|
||||
|
||||
async def _baseline_llm_caller(
|
||||
@@ -930,7 +939,54 @@ async def stream_chat_completion_baseline(
|
||||
message_length=len(message or ""),
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
# Capture count *before* the pending drain so is_first_turn and the
|
||||
# transcript staleness check are not skewed by queued messages.
|
||||
_pre_drain_msg_count = len(session.messages)
|
||||
|
||||
# Drain any messages the user queued via POST /messages/pending
|
||||
# while this session was idle (or during a previous turn whose
|
||||
# mid-loop drains missed them). Atomic LPOP guarantees that a
|
||||
# concurrent push lands *after* the drain and stays queued for the
|
||||
# next turn instead of being lost.
|
||||
try:
|
||||
drained_at_start = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[Baseline] drain_pending_messages failed at turn start, skipping",
|
||||
exc_info=True,
|
||||
)
|
||||
drained_at_start = []
|
||||
# Pre-compute formatted content once per message so we don't call
|
||||
# format_pending_as_user_message twice (once for session.messages and
|
||||
# once for transcript_builder below).
|
||||
drained_at_start_content: list[str] = []
|
||||
if drained_at_start:
|
||||
logger.info(
|
||||
"[Baseline] Draining %d pending message(s) at turn start for session %s",
|
||||
len(drained_at_start),
|
||||
session_id,
|
||||
)
|
||||
for pm in drained_at_start:
|
||||
content = format_pending_as_user_message(pm)["content"]
|
||||
drained_at_start_content.append(content)
|
||||
# Append directly — pending messages are atomically-popped from
|
||||
# Redis and are never stale-cache duplicates, so the
|
||||
# maybe_append_user_message dedup is wrong here.
|
||||
session.messages.append(ChatMessage(role="user", content=content))
|
||||
|
||||
# Persist the drained pending messages (if any) plus the current user
|
||||
# message. Wrap in try/except so a transient DB failure here does not
|
||||
# silently discard messages that were already popped from Redis — the
|
||||
# turn can still proceed using the in-memory session.messages, and a
|
||||
# later resume/replay will backfill from the DB on the next turn.
|
||||
try:
|
||||
session = await upsert_chat_session(session)
|
||||
except Exception as _persist_err:
|
||||
logger.warning(
|
||||
"[Baseline] Failed to persist session at turn start "
|
||||
"(pending drain may not be durable): %s",
|
||||
_persist_err,
|
||||
)
|
||||
|
||||
# Select model based on the per-request mode. 'fast' downgrades to
|
||||
# the cheaper/faster model; everything else keeps the default.
|
||||
@@ -959,7 +1015,9 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
# Use the pre-drain count so queued pending messages don't incorrectly
|
||||
# flip is_first_turn to False on an actual first turn.
|
||||
is_first_turn = _pre_drain_msg_count <= 1
|
||||
# Gate context fetch on both first turn AND user message so that assistant-
|
||||
# role calls (e.g. tool-result submissions) on the first turn don't trigger
|
||||
# a needless DB lookup for user understanding.
|
||||
@@ -970,14 +1028,18 @@ async def stream_chat_completion_baseline(
|
||||
prompt_task = _build_cacheable_system_prompt(None)
|
||||
|
||||
# Run download + prompt build concurrently — both are independent I/O
|
||||
# on the request critical path.
|
||||
if user_id and len(session.messages) > 1:
|
||||
# on the request critical path. Use the pre-drain count so pending
|
||||
# messages drained at turn start don't spuriously trigger a transcript
|
||||
# load on an actual first turn.
|
||||
if user_id and _pre_drain_msg_count > 1:
|
||||
transcript_covers_prefix, (base_system_prompt, understanding) = (
|
||||
await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
# Use pre-drain count so pending messages don't falsely
|
||||
# mark the stored transcript as stale and prevent upload.
|
||||
session_msg_count=_pre_drain_msg_count,
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
@@ -989,6 +1051,15 @@ async def stream_chat_completion_baseline(
|
||||
# Append user message to transcript after context injection below so the
|
||||
# transcript receives the prefixed message when user context is available.
|
||||
|
||||
# Mirror any messages drained at turn start (see above) into the
|
||||
# transcript — otherwise the loaded prior transcript would be
|
||||
# missing them and a mid-turn upload could leave a malformed
|
||||
# assistant-after-assistant structure on the next turn.
|
||||
# Reuse the pre-computed content strings to avoid calling
|
||||
# format_pending_as_user_message a second time.
|
||||
for _drained_content in drained_at_start_content:
|
||||
transcript_builder.append_user(content=_drained_content)
|
||||
|
||||
# Generate title for new sessions
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
@@ -1009,8 +1080,10 @@ async def stream_chat_completion_baseline(
|
||||
graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else ""
|
||||
system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement
|
||||
|
||||
# Warm context: pre-load relevant facts from Graphiti on first turn
|
||||
if graphiti_enabled and user_id and len(session.messages) <= 1:
|
||||
# Warm context: pre-load relevant facts from Graphiti on first turn.
|
||||
# Use the pre-drain count so pending messages drained at turn start
|
||||
# don't prevent warm context injection on an actual first turn.
|
||||
if graphiti_enabled and user_id and _pre_drain_msg_count <= 1:
|
||||
from backend.copilot.graphiti.context import fetch_warm_context
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "")
|
||||
@@ -1203,6 +1276,91 @@ async def stream_chat_completion_baseline(
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
|
||||
# Inject any messages the user queued while the turn was
|
||||
# running. ``tool_call_loop`` mutates ``openai_messages``
|
||||
# in-place, so appending here means the model sees the new
|
||||
# messages on its next LLM call.
|
||||
#
|
||||
# IMPORTANT: skip when the loop has already finished (no
|
||||
# more LLM calls are coming). ``tool_call_loop`` yields
|
||||
# a final ``ToolCallLoopResult`` on both paths:
|
||||
# - natural finish: ``finished_naturally=True``
|
||||
# - hit max_iterations: ``finished_naturally=False``
|
||||
# and ``iterations >= max_iterations``
|
||||
# In either case the loop is about to return on the next
|
||||
# ``async for`` step, so draining here would silently
|
||||
# lose the message (the user sees 202 but the model never
|
||||
# reads the text). Those messages stay in the buffer and
|
||||
# get picked up at the start of the next turn.
|
||||
if loop_result is None:
|
||||
continue
|
||||
is_final_yield = (
|
||||
loop_result.finished_naturally
|
||||
or loop_result.iterations >= _MAX_TOOL_ROUNDS
|
||||
)
|
||||
if is_final_yield:
|
||||
continue
|
||||
try:
|
||||
pending = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Mid-loop drain_pending_messages failed for session %s",
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
pending = []
|
||||
if pending:
|
||||
# Flush any buffered assistant/tool messages from completed
|
||||
# rounds into session.messages BEFORE appending the pending
|
||||
# user message. ``_baseline_conversation_updater`` only
|
||||
# records assistant+tool rounds into ``state.session_messages``
|
||||
# — they are normally batch-flushed in the finally block.
|
||||
# Without this in-order flush, the mid-loop pending user
|
||||
# message lands before the preceding round's assistant/tool
|
||||
# entries, producing chronologically-wrong session.messages
|
||||
# on persist (user interposed between an assistant tool_call
|
||||
# and its tool-result), which breaks OpenAI tool-call ordering
|
||||
# invariants on the next turn's replay.
|
||||
for _buffered in state.session_messages:
|
||||
session.messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
# Record how much assistant_text has been covered by the
|
||||
# structured entries just flushed, so the finally block's
|
||||
# final-text dedup doesn't re-append rounds already persisted.
|
||||
state._flushed_assistant_text_len = len(state.assistant_text)
|
||||
|
||||
for pm in pending:
|
||||
# ``format_pending_as_user_message`` embeds file
|
||||
# attachments and context URL/page content into the
|
||||
# content string so the in-session transcript is
|
||||
# a faithful copy of what the model actually saw.
|
||||
formatted = format_pending_as_user_message(pm)
|
||||
content_for_db = formatted["content"]
|
||||
# Append directly — pending messages are atomically-popped
|
||||
# from Redis and are never stale-cache duplicates, so the
|
||||
# maybe_append_user_message dedup is wrong here and would
|
||||
# cause openai_messages/transcript to diverge from session.
|
||||
session.messages.append(
|
||||
ChatMessage(role="user", content=content_for_db)
|
||||
)
|
||||
openai_messages.append(formatted)
|
||||
transcript_builder.append_user(content=content_for_db)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as persist_err:
|
||||
logger.warning(
|
||||
"[Baseline] Failed to persist pending messages for "
|
||||
"session %s: %s",
|
||||
session_id,
|
||||
persist_err,
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Injected %d pending message(s) into "
|
||||
"session %s mid-turn",
|
||||
len(pending),
|
||||
session_id,
|
||||
)
|
||||
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
@@ -1243,6 +1401,11 @@ async def stream_chat_completion_baseline(
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Pending messages are drained atomically at turn start and
|
||||
# between tool rounds, so there's nothing to clear in finally.
|
||||
# Any message pushed after the final drain window stays in the
|
||||
# buffer and gets picked up at the start of the next turn.
|
||||
|
||||
# Set cost attributes on OTEL span before closing
|
||||
if _trace_ctx is not None:
|
||||
try:
|
||||
@@ -1312,7 +1475,11 @@ async def stream_chat_completion_baseline(
|
||||
# no tool calls, i.e. the natural finish). Only add it if the
|
||||
# conversation updater didn't already record it as part of a
|
||||
# tool-call round (which would have empty response_text).
|
||||
final_text = state.assistant_text
|
||||
# Only consider assistant text produced AFTER the last mid-loop
|
||||
# flush. ``_flushed_assistant_text_len`` tracks the prefix already
|
||||
# persisted via structured session_messages during mid-loop pending
|
||||
# drains; including it here would duplicate those rounds.
|
||||
final_text = state.assistant_text[state._flushed_assistant_text_len :]
|
||||
if state.session_messages:
|
||||
# Strip text already captured in tool-call round messages
|
||||
recorded = "".join(
|
||||
|
||||
@@ -828,3 +828,204 @@ class TestBaselineCostExtraction:
|
||||
|
||||
# response was never assigned so cost extraction must not raise
|
||||
assert state.cost_usd is None
|
||||
|
||||
|
||||
class TestMidLoopPendingFlushOrdering:
|
||||
"""Regression test for the mid-loop pending drain ordering invariant.
|
||||
|
||||
``_baseline_conversation_updater`` records assistant+tool entries from
|
||||
each tool-call round into ``state.session_messages``; the finally block
|
||||
of ``stream_chat_completion_baseline`` batch-flushes them into
|
||||
``session.messages`` at the end of the turn.
|
||||
|
||||
The mid-loop pending drain appends pending user messages directly to
|
||||
``session.messages``. Without flushing ``state.session_messages`` first,
|
||||
the pending user message lands BEFORE the preceding round's assistant+
|
||||
tool entries in the final persisted ``session.messages`` — which
|
||||
produces a malformed tool-call/tool-result ordering on the next turn's
|
||||
replay.
|
||||
|
||||
This test documents the invariant by replaying the production flush
|
||||
sequence against an in-memory state.
|
||||
"""
|
||||
|
||||
def test_flush_then_append_preserves_chronological_order(self):
|
||||
"""Mid-loop drain must flush state.session_messages before appending
|
||||
the pending user message, so the final order matches the
|
||||
chronological execution order.
|
||||
"""
|
||||
# Initial state: user turn already appended by maybe_append_user_message
|
||||
session_messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="original user turn"),
|
||||
]
|
||||
state = _BaselineStreamState()
|
||||
|
||||
# Round 1 completes: conversation_updater buffers assistant+tool
|
||||
# entries into state.session_messages (but does NOT write to
|
||||
# session.messages yet).
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("original user turn")
|
||||
response = LLMLoopResponse(
|
||||
response_text="calling search",
|
||||
tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_1", tool_name="search", content="search output"
|
||||
),
|
||||
]
|
||||
openai_messages: list = []
|
||||
_baseline_conversation_updater(
|
||||
openai_messages,
|
||||
response,
|
||||
tool_results=tool_results,
|
||||
transcript_builder=builder,
|
||||
state=state,
|
||||
model="test-model",
|
||||
)
|
||||
# state.session_messages should now hold the round-1 assistant + tool
|
||||
assert len(state.session_messages) == 2
|
||||
assert state.session_messages[0].role == "assistant"
|
||||
assert state.session_messages[1].role == "tool"
|
||||
|
||||
# --- Mid-loop pending drain (production code pattern) ---
|
||||
# Flush first, THEN append pending. This is the ordering fix.
|
||||
for _buffered in state.session_messages:
|
||||
session_messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
session_messages.append(
|
||||
ChatMessage(role="user", content="pending mid-loop message")
|
||||
)
|
||||
|
||||
# Round 2 completes: new assistant+tool entries buffer again.
|
||||
response2 = LLMLoopResponse(
|
||||
response_text="another call",
|
||||
tool_calls=[LLMToolCall(id="tc_2", name="calc", arguments="{}")],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
tool_results2 = [
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_2", tool_name="calc", content="calc output"
|
||||
),
|
||||
]
|
||||
_baseline_conversation_updater(
|
||||
openai_messages,
|
||||
response2,
|
||||
tool_results=tool_results2,
|
||||
transcript_builder=builder,
|
||||
state=state,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# --- Finally-block flush (end of turn) ---
|
||||
for msg in state.session_messages:
|
||||
session_messages.append(msg)
|
||||
|
||||
# Assert chronological order: original user, round-1 assistant,
|
||||
# round-1 tool, pending user, round-2 assistant, round-2 tool.
|
||||
assert [m.role for m in session_messages] == [
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
]
|
||||
assert session_messages[0].content == "original user turn"
|
||||
assert session_messages[3].content == "pending mid-loop message"
|
||||
# The assistant message carrying tool_call tc_1 must be immediately
|
||||
# followed by its tool result — no user message interposed.
|
||||
assert session_messages[1].role == "assistant"
|
||||
assert session_messages[1].tool_calls is not None
|
||||
assert session_messages[1].tool_calls[0]["id"] == "tc_1"
|
||||
assert session_messages[2].role == "tool"
|
||||
assert session_messages[2].tool_call_id == "tc_1"
|
||||
# Same invariant for the round after the pending user.
|
||||
assert session_messages[4].tool_calls is not None
|
||||
assert session_messages[4].tool_calls[0]["id"] == "tc_2"
|
||||
assert session_messages[5].tool_call_id == "tc_2"
|
||||
|
||||
def test_flushed_assistant_text_len_prevents_duplicate_final_text(self):
|
||||
"""After mid-loop drain clears state.session_messages, the finally
|
||||
block must not re-append assistant text from rounds already flushed.
|
||||
|
||||
``state.assistant_text`` accumulates ALL rounds' text, but
|
||||
``state.session_messages`` only holds entries from rounds AFTER the
|
||||
last mid-loop flush. Without ``_flushed_assistant_text_len``, the
|
||||
``finally`` block's ``startswith(recorded)`` check fails because
|
||||
``recorded`` only covers post-flush rounds, and the full
|
||||
``assistant_text`` is appended — duplicating pre-flush rounds.
|
||||
"""
|
||||
state = _BaselineStreamState()
|
||||
session_messages: list[ChatMessage] = [
|
||||
ChatMessage(role="user", content="user turn"),
|
||||
]
|
||||
|
||||
# Simulate round 1 text accumulation (as _bound_llm_caller does)
|
||||
state.assistant_text += "calling search"
|
||||
|
||||
# Round 1 conversation_updater buffers structured entries
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("user turn")
|
||||
response1 = LLMLoopResponse(
|
||||
response_text="calling search",
|
||||
tool_calls=[LLMToolCall(id="tc_1", name="search", arguments="{}")],
|
||||
raw_response=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
_baseline_conversation_updater(
|
||||
[],
|
||||
response1,
|
||||
tool_results=[
|
||||
ToolCallResult(
|
||||
tool_call_id="tc_1", tool_name="search", content="result"
|
||||
)
|
||||
],
|
||||
transcript_builder=builder,
|
||||
state=state,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# Mid-loop drain: flush + clear + record flushed text length
|
||||
for _buffered in state.session_messages:
|
||||
session_messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
state._flushed_assistant_text_len = len(state.assistant_text)
|
||||
session_messages.append(ChatMessage(role="user", content="pending message"))
|
||||
|
||||
# Simulate round 2 text accumulation
|
||||
state.assistant_text += "final answer"
|
||||
|
||||
# Round 2: natural finish (no tool calls → no session_messages entry)
|
||||
|
||||
# --- Finally block logic (production code) ---
|
||||
for msg in state.session_messages:
|
||||
session_messages.append(msg)
|
||||
|
||||
final_text = state.assistant_text[state._flushed_assistant_text_len :]
|
||||
if state.session_messages:
|
||||
recorded = "".join(
|
||||
m.content or "" for m in state.session_messages if m.role == "assistant"
|
||||
)
|
||||
if final_text.startswith(recorded):
|
||||
final_text = final_text[len(recorded) :]
|
||||
if final_text.strip():
|
||||
session_messages.append(ChatMessage(role="assistant", content=final_text))
|
||||
|
||||
# The final assistant message should only contain round-2 text,
|
||||
# not the round-1 text that was already flushed mid-loop.
|
||||
assistant_msgs = [m for m in session_messages if m.role == "assistant"]
|
||||
# Round-1 structured assistant (from mid-loop flush)
|
||||
assert assistant_msgs[0].content == "calling search"
|
||||
assert assistant_msgs[0].tool_calls is not None
|
||||
# Round-2 final text (from finally block)
|
||||
assert assistant_msgs[1].content == "final answer"
|
||||
assert assistant_msgs[1].tool_calls is None
|
||||
# Crucially: only 2 assistant messages, not 3 (no duplicate)
|
||||
assert len(assistant_msgs) == 2
|
||||
|
||||
222
autogpt_platform/backend/backend/copilot/pending_messages.py
Normal file
222
autogpt_platform/backend/backend/copilot/pending_messages.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Pending-message buffer for in-flight copilot turns.
|
||||
|
||||
When a user sends a new message while a copilot turn is already executing,
|
||||
instead of blocking the frontend (or queueing a brand-new turn after the
|
||||
current one finishes), we want the new message to be *injected into the
|
||||
running turn* — appended between tool-call rounds so the model sees it
|
||||
before its next LLM call.
|
||||
|
||||
This module provides the cross-process buffer that makes that possible:
|
||||
|
||||
- **Producer** (chat API route): pushes a pending message to Redis and
|
||||
publishes a notification on a pub/sub channel.
|
||||
- **Consumer** (executor running the turn): on each tool-call round,
|
||||
drains the buffer and appends the pending messages to the conversation.
|
||||
|
||||
The Redis list is the durable store; the pub/sub channel is a fast
|
||||
wake-up hint for long-idle consumers (not used by default, but available
|
||||
for future blocking-wait semantics).
|
||||
|
||||
A hard cap of ``MAX_PENDING_MESSAGES`` per session prevents abuse. The
|
||||
buffer is trimmed to the latest ``MAX_PENDING_MESSAGES`` on every push.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per-session cap. Higher values risk a runaway consumer; lower values
|
||||
# risk dropping user input under heavy typing. 10 was chosen as a
|
||||
# reasonable ceiling — a user typing faster than the copilot can drain
|
||||
# between tool rounds is already an unusual usage pattern.
|
||||
MAX_PENDING_MESSAGES = 10
|
||||
|
||||
# Redis key + TTL. The buffer is ephemeral: if a turn completes or the
|
||||
# executor dies, the pending messages should either have been drained
|
||||
# already or are safe to drop (the user can resend).
|
||||
_PENDING_KEY_PREFIX = "copilot:pending:"
|
||||
_PENDING_CHANNEL_PREFIX = "copilot:pending:notify:"
|
||||
_PENDING_TTL_SECONDS = 3600 # 1 hour — matches stream_ttl default
|
||||
|
||||
# Payload sent on the pub/sub notify channel. Subscribers treat any
|
||||
# message as a wake-up hint; the value itself is not meaningful.
|
||||
_NOTIFY_PAYLOAD = "1"
|
||||
|
||||
|
||||
class PendingMessageContext(BaseModel, extra="forbid"):
|
||||
"""Structured page context attached to a pending message."""
|
||||
|
||||
url: str | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class PendingMessage(BaseModel):
|
||||
"""A user message queued for injection into an in-flight turn."""
|
||||
|
||||
content: str = Field(min_length=1, max_length=16_000)
|
||||
file_ids: list[str] = Field(default_factory=list)
|
||||
context: PendingMessageContext | None = None
|
||||
|
||||
|
||||
def _buffer_key(session_id: str) -> str:
|
||||
return f"{_PENDING_KEY_PREFIX}{session_id}"
|
||||
|
||||
|
||||
def _notify_channel(session_id: str) -> str:
|
||||
return f"{_PENDING_CHANNEL_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Lua script: push-then-trim-then-expire-then-length, atomically.
|
||||
# Redis serializes EVAL commands, so a concurrent ``LPOP`` drain
|
||||
# observes either the pre-push or post-push state of the list — never
|
||||
# a partial state where the RPUSH has landed but LTRIM hasn't run.
|
||||
_PUSH_LUA = """
|
||||
redis.call('RPUSH', KEYS[1], ARGV[1])
|
||||
redis.call('LTRIM', KEYS[1], -tonumber(ARGV[2]), -1)
|
||||
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[3]))
|
||||
return redis.call('LLEN', KEYS[1])
|
||||
"""
|
||||
|
||||
|
||||
async def push_pending_message(
|
||||
session_id: str,
|
||||
message: PendingMessage,
|
||||
) -> int:
|
||||
"""Append a pending message to the session's buffer atomically.
|
||||
|
||||
Returns the new buffer length. Enforces ``MAX_PENDING_MESSAGES`` by
|
||||
trimming from the left (oldest) — the newest message always wins if
|
||||
the user has been typing faster than the copilot can drain.
|
||||
|
||||
The push + trim + expire + llen are wrapped in a single Lua EVAL so
|
||||
concurrent LPOP drains from the executor never observe a partial
|
||||
state.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
payload = message.model_dump_json()
|
||||
|
||||
new_length = int(
|
||||
await cast(
|
||||
"Any",
|
||||
redis.eval(
|
||||
_PUSH_LUA,
|
||||
1,
|
||||
key,
|
||||
payload,
|
||||
str(MAX_PENDING_MESSAGES),
|
||||
str(_PENDING_TTL_SECONDS),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Fire-and-forget notify. Subscribers use this as a wake-up hint;
|
||||
# the buffer itself is authoritative so a lost notify is harmless.
|
||||
try:
|
||||
await redis.publish(_notify_channel(session_id), _NOTIFY_PAYLOAD)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.warning("pending_messages: publish failed for %s: %s", session_id, e)
|
||||
|
||||
logger.info(
|
||||
"pending_messages: pushed message to session=%s (buffer_len=%d)",
|
||||
session_id,
|
||||
new_length,
|
||||
)
|
||||
return new_length
|
||||
|
||||
|
||||
async def drain_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
"""Atomically pop all pending messages for *session_id*.
|
||||
|
||||
Returns them in enqueue order (oldest first). Uses ``LPOP`` with a
|
||||
count so the read+delete is a single Redis round trip. If the list
|
||||
is empty or missing, returns ``[]``.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
key = _buffer_key(session_id)
|
||||
|
||||
# Redis LPOP with count (Redis 6.2+) returns None for missing key,
|
||||
# empty list if we somehow race an empty key, or the popped items.
|
||||
# redis-py's async lpop overload with a count collapses the return
|
||||
# type in pyright; cast the awaitable so strict type-check stays
|
||||
# clean without changing runtime behaviour.
|
||||
lpop_result = await cast(
|
||||
"Any",
|
||||
redis.lpop(key, MAX_PENDING_MESSAGES),
|
||||
)
|
||||
if not lpop_result:
|
||||
return []
|
||||
raw_popped: list[Any] = list(lpop_result)
|
||||
|
||||
# redis-py may return bytes or str depending on decode_responses.
|
||||
decoded: list[str] = [
|
||||
item.decode("utf-8") if isinstance(item, bytes) else str(item)
|
||||
for item in raw_popped
|
||||
]
|
||||
|
||||
messages: list[PendingMessage] = []
|
||||
for payload in decoded:
|
||||
try:
|
||||
messages.append(PendingMessage(**json.loads(payload)))
|
||||
except (json.JSONDecodeError, ValidationError, TypeError, ValueError) as e:
|
||||
logger.warning(
|
||||
"pending_messages: dropping malformed entry for %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
|
||||
if messages:
|
||||
logger.info(
|
||||
"pending_messages: drained %d messages for session=%s",
|
||||
len(messages),
|
||||
session_id,
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
async def peek_pending_count(session_id: str) -> int:
|
||||
"""Return the current buffer length without consuming it."""
|
||||
redis = await get_redis_async()
|
||||
length = await cast("Any", redis.llen(_buffer_key(session_id)))
|
||||
return int(length)
|
||||
|
||||
|
||||
async def clear_pending_messages(session_id: str) -> None:
|
||||
"""Drop the session's pending buffer.
|
||||
|
||||
Not called by the normal turn flow — the atomic ``LPOP`` drain at
|
||||
turn start is the primary consumer, and any push that arrives
|
||||
after the drain window belongs to the next turn by definition.
|
||||
Retained as an operator/debug escape hatch for manually clearing a
|
||||
stuck session and as a fixture in the unit tests.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_buffer_key(session_id))
|
||||
|
||||
|
||||
def format_pending_as_user_message(message: PendingMessage) -> dict[str, Any]:
|
||||
"""Shape a ``PendingMessage`` into the OpenAI-format user message dict.
|
||||
|
||||
Used by the baseline tool-call loop when injecting the buffered
|
||||
message into the conversation. Context/file metadata (if any) is
|
||||
embedded into the content so the model sees everything in one block.
|
||||
"""
|
||||
parts: list[str] = [message.content]
|
||||
if message.context:
|
||||
if message.context.url:
|
||||
parts.append(f"\n\n[Page URL: {message.context.url}]")
|
||||
if message.context.content:
|
||||
parts.append(f"\n\n[Page content]\n{message.context.content}")
|
||||
if message.file_ids:
|
||||
parts.append(
|
||||
"\n\n[Attached files]\n"
|
||||
+ "\n".join(f"- file_id={fid}" for fid in message.file_ids)
|
||||
+ "\nUse read_workspace_file with the file_id to access file contents."
|
||||
)
|
||||
return {"role": "user", "content": "".join(parts)}
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Tests for the copilot pending-messages buffer.
|
||||
|
||||
Uses a fake async Redis client so the tests don't require a real Redis
|
||||
instance (the backend test suite's DB/Redis fixtures are heavyweight
|
||||
and pull in the full app startup).
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import pending_messages as pm_module
|
||||
from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
clear_pending_messages,
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
peek_pending_count,
|
||||
push_pending_message,
|
||||
)
|
||||
|
||||
# ── Fake Redis ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
# Values are ``str | bytes`` because real redis-py returns
|
||||
# bytes when ``decode_responses=False``; the drain path must
|
||||
# handle both and our tests exercise both.
|
||||
self.lists: dict[str, list[str | bytes]] = {}
|
||||
self.published: list[tuple[str, str]] = []
|
||||
|
||||
async def eval(self, script: str, num_keys: int, *args: Any) -> Any:
|
||||
"""Emulate the push Lua script.
|
||||
|
||||
The real Lua script runs atomically in Redis; the fake
|
||||
implementation just runs the equivalent list operations in
|
||||
order and returns the final LLEN. That's enough to exercise
|
||||
the cap + ordering invariants the tests care about.
|
||||
"""
|
||||
key = args[0]
|
||||
payload = args[1]
|
||||
max_len = int(args[2])
|
||||
# ARGV[3] is TTL — fake doesn't enforce expiry
|
||||
lst = self.lists.setdefault(key, [])
|
||||
lst.append(payload)
|
||||
if len(lst) > max_len:
|
||||
# RPUSH + LTRIM(-N, -1) = keep only last N
|
||||
self.lists[key] = lst[-max_len:]
|
||||
return len(self.lists[key])
|
||||
|
||||
async def publish(self, channel: str, payload: str) -> int:
|
||||
self.published.append((channel, payload))
|
||||
return 1
|
||||
|
||||
async def lpop(self, key: str, count: int) -> list[str | bytes] | None:
|
||||
lst = self.lists.get(key)
|
||||
if not lst:
|
||||
return None
|
||||
popped = lst[:count]
|
||||
self.lists[key] = lst[count:]
|
||||
return popped
|
||||
|
||||
async def llen(self, key: str) -> int:
|
||||
return len(self.lists.get(key, []))
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self.lists:
|
||||
del self.lists[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis:
|
||||
redis = _FakeRedis()
|
||||
|
||||
async def _get_redis_async() -> _FakeRedis:
|
||||
return redis
|
||||
|
||||
monkeypatch.setattr(pm_module, "get_redis_async", _get_redis_async)
|
||||
return redis
|
||||
|
||||
|
||||
# ── Basic push / drain ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_drain_single_message(fake_redis: _FakeRedis) -> None:
|
||||
length = await push_pending_message("sess1", PendingMessage(content="hello"))
|
||||
assert length == 1
|
||||
assert await peek_pending_count("sess1") == 1
|
||||
|
||||
drained = await drain_pending_messages("sess1")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "hello"
|
||||
assert await peek_pending_count("sess1") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_and_drain_preserves_order(fake_redis: _FakeRedis) -> None:
|
||||
for i in range(3):
|
||||
await push_pending_message("sess2", PendingMessage(content=f"msg {i}"))
|
||||
|
||||
drained = await drain_pending_messages("sess2")
|
||||
assert [m.content for m in drained] == ["msg 0", "msg 1", "msg 2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_empty_returns_empty_list(fake_redis: _FakeRedis) -> None:
|
||||
assert await drain_pending_messages("nope") == []
|
||||
|
||||
|
||||
# ── Buffer cap ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
|
||||
# Push MAX_PENDING_MESSAGES + 3 messages
|
||||
for i in range(MAX_PENDING_MESSAGES + 3):
|
||||
await push_pending_message("sess3", PendingMessage(content=f"m{i}"))
|
||||
|
||||
# Buffer should be clamped to MAX
|
||||
assert await peek_pending_count("sess3") == MAX_PENDING_MESSAGES
|
||||
|
||||
drained = await drain_pending_messages("sess3")
|
||||
assert len(drained) == MAX_PENDING_MESSAGES
|
||||
# Oldest 3 dropped — we should only see m3..m(MAX+2)
|
||||
assert drained[0].content == "m3"
|
||||
assert drained[-1].content == f"m{MAX_PENDING_MESSAGES + 2}"
|
||||
|
||||
|
||||
# ── Clear ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess4", PendingMessage(content="x"))
|
||||
await push_pending_message("sess4", PendingMessage(content="y"))
|
||||
await clear_pending_messages("sess4")
|
||||
assert await peek_pending_count("sess4") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
|
||||
# Clearing an already-empty buffer should not raise
|
||||
await clear_pending_messages("sess_empty")
|
||||
await clear_pending_messages("sess_empty")
|
||||
|
||||
|
||||
# ── Publish hook ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_publishes_notification(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess5", PendingMessage(content="hi"))
|
||||
assert ("copilot:pending:notify:sess5", "1") in fake_redis.published
|
||||
|
||||
|
||||
# ── Format helper ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_format_pending_plain_text() -> None:
|
||||
msg = PendingMessage(content="just text")
|
||||
out = format_pending_as_user_message(msg)
|
||||
assert out == {"role": "user", "content": "just text"}
|
||||
|
||||
|
||||
def test_format_pending_with_context_url() -> None:
|
||||
msg = PendingMessage(
|
||||
content="see this page",
|
||||
context=PendingMessageContext(url="https://example.com"),
|
||||
)
|
||||
out = format_pending_as_user_message(msg)
|
||||
content = out["content"]
|
||||
assert out["role"] == "user"
|
||||
assert "see this page" in content
|
||||
# The URL should appear verbatim in the [Page URL: ...] block.
|
||||
assert "[Page URL: https://example.com]" in content
|
||||
|
||||
|
||||
def test_format_pending_with_file_ids() -> None:
|
||||
msg = PendingMessage(content="look here", file_ids=["a", "b"])
|
||||
out = format_pending_as_user_message(msg)
|
||||
assert "file_id=a" in out["content"]
|
||||
assert "file_id=b" in out["content"]
|
||||
|
||||
|
||||
def test_format_pending_with_all_fields() -> None:
|
||||
"""All fields (content + context url/content + file_ids) should all appear."""
|
||||
msg = PendingMessage(
|
||||
content="summarise this",
|
||||
context=PendingMessageContext(
|
||||
url="https://example.com/page",
|
||||
content="headline text",
|
||||
),
|
||||
file_ids=["f1", "f2"],
|
||||
)
|
||||
out = format_pending_as_user_message(msg)
|
||||
body = out["content"]
|
||||
assert out["role"] == "user"
|
||||
assert "summarise this" in body
|
||||
assert "[Page URL: https://example.com/page]" in body
|
||||
assert "[Page content]\nheadline text" in body
|
||||
assert "file_id=f1" in body
|
||||
assert "file_id=f2" in body
|
||||
|
||||
|
||||
# ── Malformed payload handling ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_skips_malformed_entries(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
# Seed the fake with a mix of valid and malformed payloads
|
||||
fake_redis.lists["copilot:pending:bad"] = [
|
||||
json.dumps({"content": "valid"}),
|
||||
"{not valid json",
|
||||
json.dumps({"content": "also valid", "file_ids": ["a"]}),
|
||||
]
|
||||
drained = await drain_pending_messages("bad")
|
||||
assert len(drained) == 2
|
||||
assert drained[0].content == "valid"
|
||||
assert drained[1].content == "also valid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_decodes_bytes_payloads(
|
||||
fake_redis: _FakeRedis,
|
||||
) -> None:
|
||||
"""Real redis-py returns ``bytes`` when ``decode_responses=False``.
|
||||
|
||||
Seed the fake with bytes values to exercise the ``decode("utf-8")``
|
||||
branch in ``drain_pending_messages`` so a regression there doesn't
|
||||
slip past CI.
|
||||
"""
|
||||
fake_redis.lists["copilot:pending:bytes_sess"] = [
|
||||
json.dumps({"content": "from bytes"}).encode("utf-8"),
|
||||
]
|
||||
drained = await drain_pending_messages("bytes_sess")
|
||||
assert len(drained) == 1
|
||||
assert drained[0].content == "from bytes"
|
||||
@@ -226,6 +226,111 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
assert was_compacted is False # mock returns False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_prevents_pending_duplication():
|
||||
"""session_msg_ceiling stops pending messages from leaking into the gap.
|
||||
|
||||
Scenario: transcript covers 2 messages, session has 2 historical + 1 current
|
||||
+ 2 pending drained at turn start. Without the ceiling the gap would include
|
||||
the pending messages AND current_message already has them → duplication.
|
||||
With session_msg_ceiling=3 (pre-drain count) the gap slice is empty and
|
||||
only current_message carries the pending content.
|
||||
"""
|
||||
# session.messages after drain: [hist1, hist2, current_msg, pending1, pending2]
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hist1"),
|
||||
ChatMessage(role="assistant", content="hist2"),
|
||||
ChatMessage(role="user", content="current msg with pending1 pending2"),
|
||||
ChatMessage(role="user", content="pending1"),
|
||||
ChatMessage(role="user", content="pending2"),
|
||||
]
|
||||
)
|
||||
# transcript covers hist1+hist2 (2 messages); pre-drain count was 3 (includes current_msg)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"current msg with pending1 pending2",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=3, # len(session.messages) before drain
|
||||
)
|
||||
# Gap should be empty (transcript_msg_count == ceiling - 1), so no history prepended
|
||||
assert result == "current msg with pending1 pending2"
|
||||
assert was_compacted is False
|
||||
# Pending messages must NOT appear in gap context
|
||||
assert "pending1" not in result.split("current msg")[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_preserves_real_gap():
|
||||
"""session_msg_ceiling still surfaces a genuine stale-transcript gap.
|
||||
|
||||
Scenario: transcript covers 2 messages, session has 4 historical + 1 current
|
||||
+ 2 pending. Ceiling = 5 (pre-drain). Real gap = messages 2-3 (hist3, hist4).
|
||||
"""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hist1"),
|
||||
ChatMessage(role="assistant", content="hist2"),
|
||||
ChatMessage(role="user", content="hist3"),
|
||||
ChatMessage(role="assistant", content="hist4"),
|
||||
ChatMessage(role="user", content="current"),
|
||||
ChatMessage(role="user", content="pending1"),
|
||||
ChatMessage(role="user", content="pending2"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"current",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=5, # pre-drain: [hist1..hist4, current]
|
||||
)
|
||||
# Gap = session.messages[2:4] = [hist3, hist4]
|
||||
assert "<conversation_history>" in result
|
||||
assert "hist3" in result
|
||||
assert "hist4" in result
|
||||
assert "Now, the user says:\ncurrent" in result
|
||||
# Pending messages must NOT appear in gap
|
||||
assert "pending1" not in result
|
||||
assert "pending2" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_session_msg_ceiling_suppresses_spurious_no_resume_fallback():
|
||||
"""session_msg_ceiling prevents the no-resume compression fallback from
|
||||
firing on the first turn of a session when pending messages inflate msg_count.
|
||||
|
||||
Scenario: fresh session (1 message) + 1 pending message drained at turn start.
|
||||
Without the ceiling: msg_count=2 > 1 → fallback triggers → pending message
|
||||
leaked into history → wrong context sent to model.
|
||||
With session_msg_ceiling=1 (pre-drain count): effective_count=1, 1 > 1 is False
|
||||
→ fallback does not trigger → current_message returned as-is.
|
||||
"""
|
||||
# session.messages after drain: [current_msg, pending_msg]
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="What is 2 plus 2?"),
|
||||
ChatMessage(role="user", content="What is 7 plus 7?"), # pending
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"What is 2 plus 2?\n\nWhat is 7 plus 7?",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
session_msg_ceiling=1, # pre-drain: only 1 message existed
|
||||
)
|
||||
# Should return current_message directly without wrapping in history context
|
||||
assert result == "What is 2 plus 2?\n\nWhat is 7 plus 7?"
|
||||
assert was_compacted is False
|
||||
# Pending question must NOT appear in a spurious history section
|
||||
assert "<conversation_history>" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
"""When compression actually compacts, was_compacted should be True."""
|
||||
|
||||
@@ -1031,6 +1031,12 @@ def _make_sdk_patches(
|
||||
),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
|
||||
# Stub pending-message drain so retry tests don't hit Redis.
|
||||
# Returns an empty list → no mid-turn injection happens.
|
||||
(
|
||||
f"{_SVC}.drain_pending_messages",
|
||||
dict(new_callable=AsyncMock, return_value=[]),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,10 @@ from opentelemetry import trace as otel_trace
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.pending_messages import (
|
||||
drain_pending_messages,
|
||||
format_pending_as_user_message,
|
||||
)
|
||||
from backend.copilot.permissions import apply_tool_permissions
|
||||
from backend.copilot.rate_limit import get_user_tier
|
||||
from backend.copilot.transcript import (
|
||||
@@ -955,17 +959,33 @@ async def _build_query_message(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
*,
|
||||
session_msg_ceiling: int | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
Args:
|
||||
session_msg_ceiling: If provided, treat ``session.messages`` as if it
|
||||
only has this many entries when computing the gap slice. Pass
|
||||
``len(session.messages)`` captured *before* appending any pending
|
||||
messages so that mid-turn drains do not skew the gap calculation
|
||||
and cause pending messages to be duplicated in both the gap context
|
||||
and ``current_message``.
|
||||
|
||||
Returns:
|
||||
Tuple of (query_message, was_compacted).
|
||||
"""
|
||||
msg_count = len(session.messages)
|
||||
# Use the ceiling if supplied (prevents pending-message duplication when
|
||||
# messages were appended to session.messages after the drain but before
|
||||
# this function is called).
|
||||
effective_count = (
|
||||
session_msg_ceiling if session_msg_ceiling is not None else msg_count
|
||||
)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
if transcript_msg_count < msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
if transcript_msg_count < effective_count - 1:
|
||||
gap = session.messages[transcript_msg_count : effective_count - 1]
|
||||
compressed, was_compressed = await _compress_messages(gap)
|
||||
gap_context = _format_conversation_context(compressed)
|
||||
if gap_context:
|
||||
@@ -981,12 +1001,14 @@ async def _build_query_message(
|
||||
f"{gap_context}\n\nNow, the user says:\n{current_message}",
|
||||
was_compressed,
|
||||
)
|
||||
elif not use_resume and msg_count > 1:
|
||||
elif not use_resume and effective_count > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({msg_count} messages) — no transcript for --resume"
|
||||
f"{session_id} ({effective_count} messages) — no transcript for --resume"
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(
|
||||
session.messages[: effective_count - 1]
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(session.messages[:-1])
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
return (
|
||||
@@ -2042,6 +2064,7 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
assert session is not None # narrowed at line 1898
|
||||
if not (
|
||||
config.claude_agent_use_resume and user_id and len(session.messages) > 1
|
||||
):
|
||||
@@ -2288,6 +2311,69 @@ async def stream_chat_completion_sdk(
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
# Capture the message count *before* draining so _build_query_message
|
||||
# can compute the gap slice without including the newly-drained pending
|
||||
# messages. Pending messages are both appended to session.messages AND
|
||||
# concatenated into current_message; without the ceiling the gap slice
|
||||
# would extend into the pending messages and duplicate them in the
|
||||
# model's input context (gap_context + current_message both containing
|
||||
# them).
|
||||
_pre_drain_msg_count = len(session.messages)
|
||||
|
||||
# Drain any messages the user queued via POST /messages/pending
|
||||
# while the previous turn was running (or since the session was
|
||||
# idle). Messages are drained ATOMICALLY — one LPOP with count
|
||||
# removes them all at once, so a concurrent push lands *after*
|
||||
# the drain and stays queued for the next turn instead of being
|
||||
# lost between LPOP and clear. File IDs and context are
|
||||
# preserved via format_pending_as_user_message.
|
||||
#
|
||||
# The drained content is concatenated into ``current_message``
|
||||
# so the SDK CLI sees it in the new user message, AND appended
|
||||
# directly to ``session.messages`` (no dedup — pending messages are
|
||||
# atomically-popped from Redis and are never stale-cache duplicates)
|
||||
# so the durable transcript records it too. Session is persisted
|
||||
# immediately after the drain so a crash doesn't lose the messages.
|
||||
# The endpoint deliberately does NOT persist to session.messages —
|
||||
# Redis is the single source of truth until this drain runs.
|
||||
try:
|
||||
pending_at_start = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"%s drain_pending_messages failed at turn start, skipping",
|
||||
log_prefix,
|
||||
exc_info=True,
|
||||
)
|
||||
pending_at_start = []
|
||||
if pending_at_start:
|
||||
logger.info(
|
||||
"%s Draining %d pending message(s) at turn start",
|
||||
log_prefix,
|
||||
len(pending_at_start),
|
||||
)
|
||||
pending_texts: list[str] = [
|
||||
format_pending_as_user_message(pm)["content"] for pm in pending_at_start
|
||||
]
|
||||
for pt in pending_texts:
|
||||
# Append directly — pending messages are atomically-popped from
|
||||
# Redis and are never stale-cache duplicates, so the
|
||||
# maybe_append_user_message dedup is wrong here.
|
||||
session.messages.append(ChatMessage(role="user", content=pt))
|
||||
if current_message.strip():
|
||||
current_message = current_message + "\n\n" + "\n\n".join(pending_texts)
|
||||
else:
|
||||
current_message = "\n\n".join(pending_texts)
|
||||
# Persist immediately so a crash between here and the finally block
|
||||
# doesn't lose messages that were already drained from Redis.
|
||||
try:
|
||||
session = await upsert_chat_session(session)
|
||||
except Exception as _persist_err:
|
||||
logger.warning(
|
||||
"%s Failed to persist drained pending messages: %s",
|
||||
log_prefix,
|
||||
_persist_err,
|
||||
)
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
@@ -2301,6 +2387,7 @@ async def stream_chat_completion_sdk(
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
session_msg_ceiling=_pre_drain_msg_count,
|
||||
)
|
||||
# On the first turn inject user context into the message instead of the
|
||||
# system prompt — the system prompt is now static (same for all users)
|
||||
@@ -2438,6 +2525,7 @@ async def stream_chat_completion_sdk(
|
||||
state.use_resume,
|
||||
state.transcript_msg_count,
|
||||
session_id,
|
||||
session_msg_ceiling=_pre_drain_msg_count,
|
||||
)
|
||||
if attachments.hint:
|
||||
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
|
||||
@@ -2767,6 +2855,11 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
raise
|
||||
finally:
|
||||
# Pending messages are drained atomically at the start of each
|
||||
# turn (see drain_pending_messages call above), so there's
|
||||
# nothing to clean up here — any message pushed after that
|
||||
# point belongs to the next turn.
|
||||
|
||||
# --- Close OTEL context (with cost attributes) ---
|
||||
if _otel_ctx is not None:
|
||||
try:
|
||||
|
||||
@@ -1605,6 +1605,60 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/messages/pending": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Queue Pending Message",
|
||||
"description": "Queue a new user message into an in-flight copilot turn.\n\nWhen a user sends a follow-up message while a turn is still\nstreaming, we don't want to block them or start a separate turn —\nthis endpoint appends the message to a per-session pending buffer.\nThe executor currently running the turn (baseline path) drains the\nbuffer between tool-call rounds and appends the message to the\nconversation before the next LLM call. On the SDK path the buffer\nis drained at the *start* of the next turn (the long-lived\n``ClaudeSDKClient.receive_response`` iterator returns after a\n``ResultMessage`` so there is no safe point to inject mid-stream\ninto an existing connection).\n\nReturns 202. Enforces the same per-user daily/weekly token rate\nlimit as the regular ``/stream`` endpoint so a client can't bypass\nit by batching messages through here.",
|
||||
"operationId": "postV2QueuePendingMessage",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/QueuePendingMessageRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"202": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/QueuePendingMessageResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"404": { "description": "Session not found or access denied" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"429": {
|
||||
"description": "Token rate-limit or call-frequency cap exceeded"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -12124,6 +12178,22 @@
|
||||
"title": "PendingHumanReviewModel",
|
||||
"description": "Response model for pending human review data.\n\nRepresents a human review request that is awaiting user action.\nContains all necessary information for a user to review and approve\nor reject data from a Human-in-the-Loop block execution.\n\nAttributes:\n id: Unique identifier for the review record\n user_id: ID of the user who must perform the review\n node_exec_id: ID of the node execution that created this review\n node_id: ID of the node definition (for grouping reviews from same node)\n graph_exec_id: ID of the graph execution containing the node\n graph_id: ID of the graph template being executed\n graph_version: Version number of the graph template\n payload: The actual data payload awaiting review\n instructions: Instructions or message for the reviewer\n editable: Whether the reviewer can edit the data\n status: Current review status (WAITING, APPROVED, or REJECTED)\n review_message: Optional message from the reviewer\n created_at: Timestamp when review was created\n updated_at: Timestamp when review was last modified\n reviewed_at: Timestamp when review was completed (if applicable)"
|
||||
},
|
||||
"PendingMessageContext": {
|
||||
"properties": {
|
||||
"url": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Url"
|
||||
},
|
||||
"content": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Content"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"title": "PendingMessageContext",
|
||||
"description": "Structured page context attached to a pending message."
|
||||
},
|
||||
"PlatformCostDashboard": {
|
||||
"properties": {
|
||||
"by_provider": {
|
||||
@@ -12668,6 +12738,53 @@
|
||||
"required": ["providers", "pagination"],
|
||||
"title": "ProviderResponse"
|
||||
},
|
||||
"QueuePendingMessageRequest": {
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"maxLength": 16000,
|
||||
"minLength": 1,
|
||||
"title": "Message"
|
||||
},
|
||||
"context": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/PendingMessageContext" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"description": "Optional page context with 'url' and 'content' fields."
|
||||
},
|
||||
"file_ids": {
|
||||
"anyOf": [
|
||||
{
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"maxItems": 20
|
||||
},
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "File Ids"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["message"],
|
||||
"title": "QueuePendingMessageRequest",
|
||||
"description": "Request model for queueing a message into an in-flight turn.\n\nUnlike ``StreamChatRequest`` this endpoint does **not** start a new\nturn — the message is appended to a per-session pending buffer that\nthe executor currently processing the turn will drain between tool\nrounds."
|
||||
},
|
||||
"QueuePendingMessageResponse": {
|
||||
"properties": {
|
||||
"buffer_length": { "type": "integer", "title": "Buffer Length" },
|
||||
"max_buffer_length": {
|
||||
"type": "integer",
|
||||
"title": "Max Buffer Length"
|
||||
},
|
||||
"turn_in_flight": { "type": "boolean", "title": "Turn In Flight" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["buffer_length", "max_buffer_length", "turn_in_flight"],
|
||||
"title": "QueuePendingMessageResponse",
|
||||
"description": "Response for the pending-message endpoint.\n\n- ``buffer_length``: how many messages are now in the session's\n pending buffer (after this push)\n- ``max_buffer_length``: the per-session cap (server-side constant)\n- ``turn_in_flight``: ``True`` if a copilot turn was running when\n we checked — purely informational for UX feedback. Even when\n ``False`` the message is still queued: the next turn drains it."
|
||||
},
|
||||
"RateLimitResetResponse": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
|
||||
Reference in New Issue
Block a user